In [None]:
import sys
import aotools
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
from torchvision import transforms
from collections import OrderedDict

sys.path.insert(0, '../..')
from model_custom import Net
from load import load
from train_v2 import train
from utils import *
from dataset2 import *
import monitoring

In [None]:
model = Net()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
    model.to(device)
    
# Load weights
model_dir = 'custom0_lr1e-05/model.pth'
state_dict = torch.load(model_dir)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove module.
    new_state_dict[name] = v
model.load_state_dict(state_dict)
    
# Test
x = torch.randn(32, 2, 128, 128)
output, _ = model(x)
print(output.shape)
print(_.shape)

In [None]:
# Data set
data_dir = '../../dataset/'
dataset_size = 100000
dataset = psf_dataset(
                      root_dir = data_dir,
                      size = dataset_size,
                      transform = transforms.Compose([Normalize(data_dir),ToTensor()])
                     )

#monitor = monitoring.monitoringGPU(120)
    
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)

train(
              model, 
              dataset, 
              optimizer, 
              criterion,
              split = [0.9, 0.1],
              batch_size = 256,
              n_epochs = 250,
              random_seed = 42,
              model_dir = './',
              visdom = True,
)


#monitor.stop()

In [None]:
metrics = get_metrics(model_dir='./')
plot_learningcurve(metrics, name='lrcurve.pdf', ylim=[0,10000])
plot_learningcurve(metrics, zernike=True, name='lrcurve_zernike.pdf', ylim=[0,1000])