In [3]:
import os
import logging
import argparse
import utils
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms, models
from dataset import PSFDataset, ToTensor, MinMaxNorm
import numpy as np
from model_batchNorm import Net

In [4]:
# Variables

n_zernike = 20
split = 0.1
batch_size = 256 # Increase stability of convergence?
dataset_size = 10000
num_epochs = 300
lr = 0.001

model_dir = 'models/baseline_norm/'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
data_dir = 'psfs/'

In [5]:
# GPU support
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Logs
log_path = os.path.join(model_dir, 'logs.log')
utils.set_logger(log_path)

In [6]:
# Load dataset:
dataset = PSFDataset(root_dir=data_dir, size=dataset_size,
                         transform=transforms.Compose([MinMaxNorm(), ToTensor()]))
# Ensure reproducibility:
random_seed = 42
shuffle_dataset = True
    
# Split train-test:
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
    
# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)
    
train_dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, sampler=train_sampler)
val_dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, sampler=val_sampler)

logging.info('Train set size: %i | Validation set size: %i' % (batch_size*len(train_dataloader), 
                                                              batch_size*len(val_dataloader)))

Train set size: 9216 | Validation set size: 1024


In [7]:
from collections import OrderedDict

# Load convolutional network
model = Net()
#state_dict = torch.load(os.path.join(model_dir, 'checkpoint.pth'))
#new_state_dict = OrderedDict()
#for k, v in state_dict.items():
#    name = k[7:] # remove module.
#    new_state_dict[name] = v
#model.load_state_dict(new_state_dict)
print(model)
if torch.cuda.device_count() > 1:
    logging.info("Model deployed on %d GPUs" % (torch.cuda.device_count()))
    model = nn.DataParallel(model)
model.to(device)

# Loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr = lr, weight_decay=1e-3)

def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 20 epochs"""
    lr = 0.001 * (0.1 ** (epoch // 50))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

Model deployed on 2 GPUs


Net(
  (conv1): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv1_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv11): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv11_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv22): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv22_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv33): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv33_bn): BatchNorm2d(128, eps=1e-05, moment

In [9]:
start_time = time.time()
for epoch in range(num_epochs):
    
    adjust_learning_rate(optimizer, epoch)
    
    logging.info('Epoch {}/{}'.format(epoch, num_epochs - 1))
    logging.info('-' * 10)
    
    running_loss = 0.0
    log_every = len(train_dataloader) // 3
    epoch_time = time.time()

    # Training
    model.train()
    for i_batch, sample_batched in enumerate(train_dataloader):

        zernike = sample_batched['zernike'].type(torch.FloatTensor)
        image = sample_batched['image'].type(torch.FloatTensor)
        image = image.to(device)
        zernike = zernike.to(device)

        # Forward pass, backward pass, optimize
        outputs = model(image)
        loss = criterion(outputs, zernike)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += float(loss)
        # Print statistics
        if (i_batch + 1) % (log_every) == 0:
            logging.info('estimate train loss: %.3f time: %.3f s' %
                      (running_loss / log_every, time.time() - epoch_time))
            running_loss = 0.0
            epoch_time = time.time()

    model.eval()
    val_loss = 0.0
    for i_batch, sample_batched in enumerate(val_dataloader):

        zernike = sample_batched['zernike'].type(torch.FloatTensor)
        image = sample_batched['image'].type(torch.FloatTensor)
        image = image.to(device)
        zernike = zernike.to(device)

        outputs = model(image)
        loss = criterion(outputs, zernike)
        val_loss += float(loss)

    # Save best val metrics in a json file in the model directory
    accuracy = val_loss / len(val_dataloader)
    metrics_json_path = os.path.join(model_dir, "metrics.json")
    metrics = utils.Params(metrics_json_path)
    if not metrics.hasKey(metrics_json_path, 'accuracy') or metrics.accuracy > accuracy:
        metrics.accuracy = accuracy
        metrics.save(metrics_json_path)
        checkpoint_path = os.path.join(model_dir, 'checkpoint.pth')
        torch.save(model.state_dict(), checkpoint_path)
        
    logging.info('val loss: %.3f ' % (val_loss / len(val_dataloader)))
    
logging.info('Training finished in %.3f s' % (time.time() - start_time))

Epoch 0/299
----------
estimate train loss: 12494.660 time: 4.598 s
estimate train loss: 9675.897 time: 3.301 s
estimate train loss: 7085.382 time: 3.060 s
val loss: 18868.649 
Epoch 1/299
----------
estimate train loss: 5290.200 time: 4.564 s
estimate train loss: 4339.108 time: 3.324 s
estimate train loss: 3767.070 time: 3.145 s
val loss: 8297.976 
Epoch 2/299
----------
estimate train loss: 3324.518 time: 4.728 s
estimate train loss: 2977.151 time: 3.456 s
estimate train loss: 2808.671 time: 3.167 s
val loss: 2758.469 
Epoch 3/299
----------
estimate train loss: 2492.915 time: 4.610 s
estimate train loss: 2364.687 time: 3.443 s
estimate train loss: 2225.074 time: 3.076 s
val loss: 4672.418 
Epoch 4/299
----------
estimate train loss: 2074.808 time: 4.595 s
estimate train loss: 1894.315 time: 3.378 s
estimate train loss: 1798.366 time: 3.207 s
val loss: 3041.575 
Epoch 5/299
----------
estimate train loss: 1680.863 time: 4.571 s
estimate train loss: 1520.133 time: 3.339 s
estimate tra

estimate train loss: 268.924 time: 3.127 s
val loss: 417.080 
Epoch 48/299
----------
estimate train loss: 262.004 time: 4.705 s
estimate train loss: 265.442 time: 3.327 s
estimate train loss: 270.016 time: 3.106 s
val loss: 418.611 
Epoch 49/299
----------
estimate train loss: 264.932 time: 4.576 s
estimate train loss: 261.213 time: 3.369 s
estimate train loss: 269.139 time: 3.169 s
val loss: 420.758 
Epoch 50/299
----------
estimate train loss: 262.904 time: 4.714 s
estimate train loss: 265.233 time: 3.311 s
estimate train loss: 260.120 time: 3.113 s
val loss: 419.114 
Epoch 51/299
----------
estimate train loss: 267.430 time: 4.713 s
estimate train loss: 260.523 time: 3.378 s
estimate train loss: 271.432 time: 3.158 s
val loss: 417.781 
Epoch 52/299
----------
estimate train loss: 264.136 time: 4.478 s
estimate train loss: 266.977 time: 3.390 s
estimate train loss: 259.135 time: 3.138 s
val loss: 418.724 
Epoch 53/299
----------
estimate train loss: 258.483 time: 4.526 s
estimate tr

estimate train loss: 260.605 time: 3.350 s
estimate train loss: 260.145 time: 3.112 s
val loss: 413.850 
Epoch 96/299
----------
estimate train loss: 259.120 time: 4.594 s
estimate train loss: 259.225 time: 3.407 s
estimate train loss: 258.577 time: 3.156 s
val loss: 413.418 
Epoch 97/299
----------
estimate train loss: 256.891 time: 4.690 s
estimate train loss: 260.926 time: 3.514 s
estimate train loss: 257.341 time: 3.216 s
val loss: 412.825 
Epoch 98/299
----------
estimate train loss: 260.258 time: 4.609 s
estimate train loss: 257.646 time: 3.288 s
estimate train loss: 257.505 time: 3.127 s
val loss: 413.112 
Epoch 99/299
----------
estimate train loss: 253.604 time: 4.659 s
estimate train loss: 262.094 time: 3.343 s
estimate train loss: 266.250 time: 3.064 s
val loss: 415.647 
Epoch 100/299
----------
estimate train loss: 255.050 time: 4.665 s
estimate train loss: 261.768 time: 3.408 s
estimate train loss: 263.462 time: 3.105 s
val loss: 413.946 
Epoch 101/299
----------
estimate 

  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
  File "/mnt/diskss/povanberg/phase-retrieval/dataset.py", line 29, in __getitem__
    image = np.stack((sample_hdu[1].data, sample_hdu[2].data))
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/astropy/io/fits/hdu/hdulist.py", line 300, in __getitem__
    self._positive_index_of(key))
KeyboardInterrupt
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/astropy/io/fits/hdu/hdulist.py", line 300, in __getitem__
    self._positive_index_of(key))
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/astropy/io/fits/hdu/hdulist.py", line 1033, in _try_while_unread_hdus
    if self._read_next_hdu():
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/astropy/io/fits/hdu/hdulist.py", line 1033, in _try_while_unread_hdus
  

KeyboardInterrupt: 