In [3]:
import os
import time
import utils
import torch
import logging
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from dataset2 import PSFDataset, ToTensor, MinMaxNorm
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import torchvision.models as models
import torch.nn as nn
import numpy as np
import torch.optim as optim

In [8]:
# Others
random_seed = 42

# Data variables
data_dir = 'data/'
dataset_size = 10000
shuffle_dataset = True
split = [0.9, 0.1]      # [Train, Val]

# Train variables
model_dir = 'models/baseline'
num_epochs = 300
batch_size = 128
lr = 0.01

In [9]:
# GPU support
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Logs
log_path = os.path.join(model_dir, 'logs.log')
utils.set_logger(log_path)

In [7]:
# Load and split dataset in training and validation sets.

dataset = PSFDataset(root_dir=data_dir, size=dataset_size,
                     transform=transforms.Compose([MinMaxNorm(), ToTensor()]))
    
indices = list(range(dataset_size))
s = int(np.floor(split[1] * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[s:], indices[:s]
    
train_sampler, val_sampler = SubsetRandomSampler(train_indices), SubsetRandomSampler(val_indices)

train_dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, sampler=train_sampler)
val_dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, sampler=val_sampler)

logging.info('Train set size: %i | Validation set size: %i' % (len(train_indices), 
                                                               len(val_indices)))

INFO:root:Train set size: 9000 | Validation set size: 1000


In [10]:
# Transfer learning 
# 299x299x3 input (3 and 1 images are the same)
# Extra pixels added to fit input size with 0 value

model = models.inception_v3(pretrained=True)
#model = models.resnet50(pretrained=True)

# Freeze weights of conv layers (use as features extractor)
for param in model.parameters():
    param.requires_grad = False

# Replace fc layers 
model.fc = nn.Sequential(
                        nn.Linear(2048, 1024),
                        nn.ReLU(inplace=True),
                        nn.Dropout(p=0.5),
                        nn.Linear(1024, 256),
                        nn.ReLU(inplace=True),
                        nn.Dropout(p=0.5),
                        nn.Linear(256, 20)
                    )

print(model)

# Deploy on multiple GPUs
if torch.cuda.device_count() > 1:
    logging.info("Model deployed on %d GPUs" % (torch.cuda.device_count()))
    model = nn.DataParallel(model)
model.to(device)

# Loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr = lr)

def adjust_learning_rate(optimizer, epoch):
    #LR decayed by 10 every 30 epochs, see Resnet paper.
    lr = 0.01 * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

INFO:root:Model deployed on 2 GPUs


Inception3(
  (Conv2d_1a_3x3): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2a_3x3): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2b_3x3): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_3b_1x1): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_4a_3x3): BasicConv2d(
    (conv): Conv2d(80, 192, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, t

In [11]:
start_time = time.time()
for epoch in range(num_epochs):
    
    adjust_learning_rate(optimizer, epoch)
    
    logging.info('Epoch {}/{}'.format(epoch, num_epochs - 1))
    logging.info('-' * 10)
    
    running_loss = 0.0
    log_every = len(train_dataloader) // 2
    epoch_time = time.time()

    # Training
    model.train()
    for i_batch, sample_batched in enumerate(train_dataloader):

        zernike = sample_batched['zernike'].type(torch.FloatTensor)
        image = sample_batched['image'].type(torch.FloatTensor)
        image = image.to(device)
        zernike = zernike.to(device)

        # Forward pass, backward pass, optimize
        outputs = model(image)
        loss = criterion(outputs, zernike)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += float(loss)
        # Print statistics
        if (i_batch + 1) % (log_every) == 0:
            logging.info('train loss: %.3f time: %.3f s' %
                      (running_loss / log_every, time.time() - epoch_time))
            running_loss = 0.0
            epoch_time = time.time()

    model.eval()
    val_loss = 0.0
    for i_batch, sample_batched in enumerate(val_dataloader):

        zernike = sample_batched['zernike'].type(torch.FloatTensor)
        image = sample_batched['image'].type(torch.FloatTensor)
        image = image.to(device)
        zernike = zernike.to(device)

        outputs = model(image)
        loss = criterion(outputs, zernike)
        val_loss += float(loss)

    # Save best val metrics in a json file in the model directory
    accuracy = val_loss / len(val_dataloader)
    metrics_json_path = os.path.join(model_dir, "metrics.json")
    metrics = utils.Params(metrics_json_path)
    if not metrics.hasKey(metrics_json_path, 'accuracy') or metrics.accuracy > accuracy:
        metrics.accuracy = accuracy
        metrics.save(metrics_json_path)
        checkpoint_path = os.path.join(model_dir, 'checkpoint.pth')
        torch.save(model.state_dict(), checkpoint_path)
        
    logging.info('val loss: %.3f ' % (val_loss / len(val_dataloader)))
    
logging.info('Training finished in %.3f s' % (time.time() - start_time))

INFO:root:Epoch 0/299
INFO:root:----------


FileNotFoundError: Traceback (most recent call last):
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 106, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 106, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/mnt/diskss/povanberg/phase-retrieval/dataset2.py", line 27, in __getitem__
    sample_hdu = fits.open(sample_name)
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/astropy/io/fits/hdu/hdulist.py", line 151, in fitsopen
    lazy_load_hdus, **kwargs)
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/astropy/io/fits/hdu/hdulist.py", line 387, in fromfile
    lazy_load_hdus=lazy_load_hdus, **kwargs)
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/astropy/io/fits/hdu/hdulist.py", line 974, in _readfrom
    fileobj = _File(fileobj, mode=mode, memmap=memmap, cache=cache)
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/astropy/utils/decorators.py", line 488, in wrapper
    return function(*args, **kwargs)
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/astropy/io/fits/file.py", line 175, in __init__
    self._open_filename(fileobj, mode, overwrite)
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/astropy/io/fits/file.py", line 531, in _open_filename
    self._file = fileobj_open(self.name, IO_FITS_MODES[mode])
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/astropy/io/fits/util.py", line 388, in fileobj_open
    return open(filename, mode, buffering=0)
FileNotFoundError: [Errno 2] No such file or directory: 'data/psf_4561.fits'


Process Process-4:
Process Process-1:
Process Process-2:
Process Process-3:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/multiprocessi