In [1]:
import os
import logging
import argparse
import utils
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms, models
from dataset import PSFDataset, ToTensor, MinMaxNorm
import numpy as np
from model_inception import Net

In [2]:
# Variables

n_zernike = 20
split = 0.1
batch_size = 64 # Increase stability of convergence?
dataset_size = 10000
num_epochs = 300
lr = 0.001

model_dir = 'models/baseline_norm/'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
data_dir = 'psfs/'

In [3]:
# GPU support
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Logs
log_path = os.path.join(model_dir, 'logs.log')
utils.set_logger(log_path)

In [4]:
# Load dataset:
dataset = PSFDataset(root_dir=data_dir, size=dataset_size,
                         transform=transforms.Compose([MinMaxNorm(), ToTensor()]))
# Ensure reproducibility:
random_seed = 42
shuffle_dataset = True
    
# Split train-test:
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
    
# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)
    
train_dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, sampler=train_sampler)
val_dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, sampler=val_sampler)

logging.info('Train set size: %i | Validation set size: %i' % (batch_size*len(train_dataloader), 
                                                              batch_size*len(val_dataloader)))

Train set size: 9024 | Validation set size: 1024


In [5]:
from collections import OrderedDict

# Load convolutional network
model = Net()
#state_dict = torch.load(os.path.join(model_dir, 'checkpoint.pth'))
#new_state_dict = OrderedDict()
#for k, v in state_dict.items():
#    name = k[7:] # remove module.
#    new_state_dict[name] = v
#model.load_state_dict(new_state_dict)
if torch.cuda.device_count() > 1:
    logging.info("Model deployed on %d GPUs" % (torch.cuda.device_count()))
    model = nn.DataParallel(model)
model.to(device)

# Loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr = lr)

def adjust_learning_rate(optimizer, epoch):
    lr = 0.01 * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

Model deployed on 2 GPUs


In [None]:
start_time = time.time()
for epoch in range(num_epochs):
    
    adjust_learning_rate(optimizer, epoch)
    
    logging.info('Epoch {}/{}'.format(epoch, num_epochs - 1))
    logging.info('-' * 10)
    
    running_loss = 0.0
    log_every = len(train_dataloader)-1
    epoch_time = time.time()

    # Training
    model.train()
    for i_batch, sample_batched in enumerate(train_dataloader):

        zernike = sample_batched['zernike'].type(torch.FloatTensor)
        image = sample_batched['image'].type(torch.FloatTensor)
        image = image.to(device)
        zernike = zernike.to(device)

        # Forward pass, backward pass, optimize
        outputs = model(image)
        loss = criterion(outputs, zernike)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += float(loss)
        # Print statistics
        if (i_batch + 1) % (log_every) == 0:
            logging.info('train loss: %.3f time: %.3f s' %
                      (running_loss / log_every, time.time() - epoch_time))
            running_loss = 0.0
            epoch_time = time.time()

    model.eval()
    val_loss = 0.0
    for i_batch, sample_batched in enumerate(val_dataloader):

        zernike = sample_batched['zernike'].type(torch.FloatTensor)
        image = sample_batched['image'].type(torch.FloatTensor)
        image = image.to(device)
        zernike = zernike.to(device)

        outputs = model(image)
        loss = criterion(outputs, zernike)
        val_loss += float(loss)

    # Save best val metrics in a json file in the model directory
    accuracy = val_loss / len(val_dataloader)
    metrics_json_path = os.path.join(model_dir, "metrics.json")
    metrics = utils.Params(metrics_json_path)
    if not metrics.hasKey(metrics_json_path, 'accuracy') or metrics.accuracy > accuracy:
        metrics.accuracy = accuracy
        metrics.save(metrics_json_path)
        checkpoint_path = os.path.join(model_dir, 'checkpoint.pth')
        torch.save(model.state_dict(), checkpoint_path)
        
    logging.info('val loss: %.3f ' % (val_loss / len(val_dataloader)))
    
logging.info('Training finished in %.3f s' % (time.time() - start_time))

Epoch 0/299
----------
train loss: 10679.921 time: 17.919 s
val loss: 6660.946 
Epoch 1/299
----------
train loss: 8253.802 time: 15.383 s
val loss: 5494.860 
Epoch 2/299
----------
train loss: 7689.164 time: 15.314 s
val loss: 5162.320 
Epoch 3/299
----------
train loss: 7380.335 time: 15.668 s
val loss: 5045.515 
Epoch 4/299
----------
train loss: 7202.824 time: 15.257 s
val loss: 4578.617 
Epoch 5/299
----------
train loss: 7109.747 time: 15.720 s
val loss: 4118.489 
Epoch 6/299
----------
train loss: 6998.736 time: 15.528 s
val loss: 4594.954 
Epoch 7/299
----------
train loss: 6982.121 time: 15.549 s
val loss: 4318.254 
Epoch 8/299
----------
train loss: 6920.840 time: 15.590 s
val loss: 5311.163 
Epoch 9/299
----------
train loss: 6848.010 time: 15.322 s
val loss: 3970.080 
Epoch 10/299
----------
train loss: 6774.841 time: 15.483 s
val loss: 3851.949 
Epoch 11/299
----------
train loss: 6755.358 time: 15.370 s
val loss: 4488.349 
Epoch 12/299
----------
train loss: 6694.703 time