In [1]:
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
from torchvision import transforms

sys.path.insert(0, '../..')
from load import load
from train import train, eval
from utils import plot_learningcurve, get_metrics
from dataset import psf_dataset, splitDataLoader, ToTensor, Normalize

In [6]:
model = models.resnet50(pretrained=True)

#for param in model.parameters():
#    param.requires_grad = False

# Fit classifier
model.fc = nn.Sequential(
                        nn.Linear(8192, 512),
                        nn.ReLU(inplace=True),
                        nn.BatchNorm1d(512),
                        nn.Linear(512, 512),
                        nn.ReLU(inplace=True),
                        nn.BatchNorm1d(512),
                        nn.Linear(512, 20)
                    )

# Add deconv layer
# Input size 2x128x128 -> 3x255x255
first_conv_layer = [nn.ConvTranspose2d(2, 3, kernel_size=3, stride=2, padding=1, dilation=1, groups=1, bias=True),
                    model.conv1]
model.conv1= nn.Sequential(*first_conv_layer)

# GPU support
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
    model.to(device)

In [3]:
# Loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr = 1e-4)

In [4]:
# Data set
data_dir = '../../dataset/'
dataset_size = 10000
dataset = psf_dataset(
                      root_dir = data_dir,
                      size = dataset_size,
                      transform = transforms.Compose([Normalize(data_dir),ToTensor()])
                     )

In [5]:
train(
      model, 
      dataset, 
      optimizer, 
      criterion,
      split = [0.9, 0.1],
      batch_size = 64,
      n_epoch = 200,
      random_seed = 42,
      model_dir = './',
      visdom = True,
      decay =  True
     )

Training started on cuda:0
Visdom successfully connected to server
[1/200] Train loss: 22031.668177 
[1/200] Validation loss: 21943.794312 
[1/200] Time: 47.021369 s
------------------------------
[2/200] Train loss: 21761.193332 
[2/200] Validation loss: 21836.933350 
[2/200] Time: 44.964117 s
------------------------------
[3/200] Train loss: 21548.282289 
[3/200] Validation loss: 21470.177246 
[3/200] Time: 45.035586 s
------------------------------
Process Process-27:
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
Process Process-28:
Process Process-25:
Process Process-26:
Traceback (most recent call last):
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/astropy/io/fits/hdu/hdulist.py", line 1031, in _try_while_unread_hdus
    return func(*args, **kwargs)
Traceback (most recent call last):
IndexError: list index out of range
  File "/mnt/diskss/povanberg/minico

KeyboardInterrupt: 

  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 106, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "../../dataset.py", line 31, in __getitem__
    image = np.stack((sample_hdu[1].data, sample_hdu[2].data)).astype(np.float32)
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/astropy/io/fits/hdu/hdulist.py", line 300, in __getitem__
    self._positive_index_of(key))
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/astropy/io/fits/hdu/hdulist.py", line 1033, in _try_while_unread_hdus
    if self._read_next_hdu():
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/astropy/io/fits/hdu/hdulist.py", line 1074, in _read_next_hdu
    hdu = _BaseHDU.readfrom(fileobj, **kwargs)


<Figure size 432x288 with 0 Axes>

  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/astropy/io/fits/hdu/base.py", line 328, in readfrom
    **kwargs)
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/astropy/io/fits/hdu/base.py", line 442, in _readfrom_internal
    hdu = cls(data=DELAYED, header=header, **new_kwargs)
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/astropy/io/fits/hdu/image.py", line 1103, in __init__
    scale_back=scale_back, ver=ver)
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/astropy/io/fits/hdu/image.py", line 43, in __init__
    super().__init__(data=data, header=header)
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/astropy/io/fits/hdu/base.py", line 869, in __init__
    super().__init__(data=data, header=header)
  File "/mnt/diskss/povanberg/miniconda3/envs/pytorch/lib/python3.7/site-packages/astropy/io/fits/hdu/base.py", line 

In [None]:
metrics = get_metrics()
plot_learningcurve(metrics, xlim=[1,8], ylim=[-1e5,1e6])