In [1]:
import os
import logging
import argparse
import utils
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms, models
from dataset import PSFDataset, ToTensor, Normalize
import numpy as np

In [2]:
torch.set_default_tensor_type('torch.FloatTensor')
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(2, 64, kernel_size=3, stride=1, padding=1)
        self.conv1_bn = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv2_bn = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv3_bn = nn.BatchNorm2d(256)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

        self.fc1 = nn.Linear(256 * 16 * 16, 4096)
        self.fc1_bn = nn.BatchNorm1d(4096)
        self.fc2 = nn.Linear(4096, 512)
        self.fc2_bn = nn.BatchNorm1d(512)
        self.fc3 = nn.Linear(512, 20)
 
    def forward(self, x):
        # 128x128x2
        x = self.pool(self.conv1_bn(self.conv1(x)))       
        x = F.relu(x) 
        # 64x64x64
        x = self.pool(self.conv2_bn(self.conv2(x)))         
        x = F.relu(x)                  
        # 32x32x128          
        x = self.pool(self.conv3_bn(self.conv3(x)))         
        x = F.relu(x)    
        # 16x16x128 
        x = x.view(-1, 256 * 16 * 16)  
        # 1x(16x16x128)
        x = F.relu(self.fc1_bn(self.fc1(x)))
        x = F.dropout(x)
        x = F.relu(self.fc2_bn(self.fc2(x)))
        x = F.dropout(x)
        x = self.fc3(x)                  
        return x

In [3]:
# Variables

n_zernike = 20
split = 0.1
batch_size = 128
dataset_size = 10000
num_epochs = 500
lr = 0.001

model_dir = 'models/baseline_v3/'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
data_dir = 'psfs/'

In [4]:
# GPU support
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Logs
log_path = os.path.join(model_dir, 'logs.log')
utils.set_logger(log_path)

In [5]:
# Load dataset:
dataset = PSFDataset(root_dir=data_dir, size=dataset_size,
                         transform=transforms.Compose([Normalize(), ToTensor()]))
    
# Ensure reproducibility:
random_seed = 42
shuffle_dataset = True
    
# Split train-test:
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
    
# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)
    
train_dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, sampler=train_sampler)
val_dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, sampler=val_sampler)

logging.info('Train set size: %i | Validation set size: %i' % (batch_size*len(train_dataloader), 
                                                              batch_size*len(val_dataloader)))

Train set size: 9088 | Validation set size: 1024


In [8]:
from collections import OrderedDict

# Load convolutional network
model = Net()
#state_dict = torch.load(os.path.join(model_dir, 'checkpoint.pth'))
#new_state_dict = OrderedDict()
#for k, v in state_dict.items():
#    name = k[7:] # remove module.
#    new_state_dict[name] = v
#model.load_state_dict(new_state_dict)
print(model)
if torch.cuda.device_count() > 1:
    logging.info("Model deployed on %d GPUs" % (torch.cuda.device_count()))
    model = nn.DataParallel(model)
model.to(device)

# Loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr = lr, momentum=0.9)

def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 20 epochs"""
    lr = 0.001 * (0.1 ** (epoch // 20))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

Model deployed on 2 GPUs


Net(
  (conv1): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv1_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=65536, out_features=4096, bias=True)
  (fc1_bn): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=4096, out_features=512, bias=True)
  (fc2_bn): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc3): Linear(in_features=512, out_features=20, bias=True)
)


In [None]:
start_time = time.time()
for epoch in range(num_epochs):
    
    adjust_learning_rate(optimizer, epoch)
    
    logging.info('Epoch {}/{}'.format(epoch, num_epochs - 1))
    logging.info('-' * 10)
    
    running_loss = 0.0
    log_every = len(train_dataloader) // 3
    epoch_time = time.time()

    # Training
    model.train()
    for i_batch, sample_batched in enumerate(train_dataloader):

        zernike = sample_batched['zernike'].type(torch.FloatTensor)
        image = sample_batched['image'].type(torch.FloatTensor)
        image = image.to(device)
        zernike = zernike.to(device)

        # Forward pass, backward pass, optimize
        outputs = model(image)
        loss = criterion(outputs, zernike)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += float(loss)
        # Print statistics
        if (i_batch + 1) % (log_every) == 0:
            logging.info('estimate train loss: %.3f time: %.3f s' %
                      (running_loss / log_every, time.time() - epoch_time))
            running_loss = 0.0
            epoch_time = time.time()

    model.eval()
    train_loss = 0.0
    i=0.0
    for i_batch, sample_batched in enumerate(train_dataloader):
 
        if i_batch < (len(train_dataloader) // 10): 
            i += 1
            zernike = sample_batched['zernike'].type(torch.FloatTensor)
            image = sample_batched['image'].type(torch.FloatTensor)
            image = image.to(device)
            zernike = zernike.to(device)

            outputs = model(image)
            loss = criterion(outputs, zernike)
            train_loss += float(loss)

    logging.info('train loss: %.3f ' % (train_loss / i))
    
    val_loss = 0.0
    for i_batch, sample_batched in enumerate(val_dataloader):

        zernike = sample_batched['zernike'].type(torch.FloatTensor)
        image = sample_batched['image'].type(torch.FloatTensor)
        image = image.to(device)
        zernike = zernike.to(device)

        outputs = model(image)
        loss = criterion(outputs, zernike)
        val_loss += float(loss)

    # Save best val metrics in a json file in the model directory
    accuracy = val_loss / len(val_dataloader)
    metrics_json_path = os.path.join(model_dir, "metrics.json")
    metrics = utils.Params(metrics_json_path)
    if not metrics.hasKey(metrics_json_path, 'accuracy') or metrics.accuracy > accuracy:
        metrics.accuracy = accuracy
        metrics.save(metrics_json_path)
        checkpoint_path = os.path.join(model_dir, 'checkpoint.pth')
        torch.save(model.state_dict(), checkpoint_path)
        
    logging.info('val loss: %.3f ' % (val_loss / len(val_dataloader)))
    
logging.info('Training finished in %.3f s' % (time.time() - start_time))

Epoch 0/499
----------
estimate train loss: 12945.113 time: 11.915 s
estimate train loss: 10647.262 time: 8.238 s
estimate train loss: 5880.720 time: 8.280 s
train loss: 8613.001 
val loss: 8741.327 
Epoch 1/499
----------
estimate train loss: 2121.509 time: 9.110 s
estimate train loss: 1530.002 time: 8.196 s
estimate train loss: 1138.952 time: 8.356 s
train loss: 1409.811 
val loss: 1604.202 
Epoch 2/499
----------
estimate train loss: 956.928 time: 9.279 s
estimate train loss: 879.276 time: 8.212 s
estimate train loss: 817.155 time: 8.314 s
train loss: 735.710 
val loss: 897.998 
Epoch 3/499
----------
estimate train loss: 749.451 time: 9.109 s
estimate train loss: 685.846 time: 8.135 s
estimate train loss: 696.223 time: 8.157 s
train loss: 895.794 
val loss: 1067.667 
Epoch 4/499
----------
estimate train loss: 621.890 time: 9.202 s
estimate train loss: 621.112 time: 8.324 s
estimate train loss: 588.216 time: 8.174 s
train loss: 566.430 
val loss: 746.896 
Epoch 5/499
----------
est

estimate train loss: 238.583 time: 8.407 s
train loss: 21.182 
val loss: 230.091 
Epoch 43/499
----------
estimate train loss: 232.542 time: 9.132 s
estimate train loss: 223.530 time: 8.370 s
estimate train loss: 238.304 time: 8.368 s
train loss: 27.380 
val loss: 236.175 
Epoch 44/499
----------
estimate train loss: 241.554 time: 9.164 s
estimate train loss: 235.638 time: 8.366 s
estimate train loss: 233.544 time: 8.556 s
train loss: 22.203 
val loss: 231.516 
Epoch 45/499
----------
estimate train loss: 238.893 time: 9.192 s
estimate train loss: 237.345 time: 8.301 s
estimate train loss: 239.755 time: 8.452 s
train loss: 20.762 
val loss: 228.920 
Epoch 46/499
----------
estimate train loss: 224.026 time: 9.229 s
estimate train loss: 221.073 time: 8.298 s
estimate train loss: 227.766 time: 8.462 s
train loss: 24.113 
val loss: 232.961 
Epoch 47/499
----------
estimate train loss: 228.992 time: 9.225 s
estimate train loss: 242.082 time: 8.356 s
estimate train loss: 227.285 time: 8.321

estimate train loss: 246.722 time: 8.153 s
estimate train loss: 225.027 time: 8.271 s
train loss: 20.186 
val loss: 229.749 
Epoch 86/499
----------
estimate train loss: 227.796 time: 9.047 s
estimate train loss: 221.702 time: 8.315 s
estimate train loss: 224.411 time: 8.429 s
train loss: 17.269 
val loss: 226.192 
Epoch 87/499
----------
estimate train loss: 233.842 time: 9.061 s
estimate train loss: 224.437 time: 8.310 s
estimate train loss: 235.389 time: 8.467 s
train loss: 22.171 
val loss: 231.904 
Epoch 88/499
----------
estimate train loss: 233.387 time: 9.163 s
estimate train loss: 222.797 time: 8.115 s
estimate train loss: 222.885 time: 8.276 s
train loss: 24.222 
val loss: 233.281 
Epoch 89/499
----------
estimate train loss: 226.760 time: 9.227 s
estimate train loss: 235.879 time: 8.333 s
estimate train loss: 242.039 time: 8.358 s
train loss: 30.761 
val loss: 243.515 
Epoch 90/499
----------
estimate train loss: 228.025 time: 9.164 s
estimate train loss: 231.390 time: 8.353

val loss: 235.778 
Epoch 128/499
----------
estimate train loss: 238.557 time: 9.222 s
estimate train loss: 226.676 time: 8.309 s
estimate train loss: 236.284 time: 8.392 s
train loss: 28.192 
val loss: 237.595 
Epoch 129/499
----------
estimate train loss: 240.423 time: 9.205 s
estimate train loss: 234.147 time: 8.167 s
estimate train loss: 249.170 time: 8.176 s
train loss: 16.427 
val loss: 226.028 
Epoch 130/499
----------
estimate train loss: 239.935 time: 9.312 s
estimate train loss: 240.852 time: 8.546 s
estimate train loss: 227.189 time: 8.650 s
train loss: 23.247 
val loss: 232.244 
Epoch 131/499
----------
estimate train loss: 237.593 time: 9.110 s
estimate train loss: 227.672 time: 8.088 s
estimate train loss: 237.709 time: 8.287 s
train loss: 20.597 
val loss: 228.918 
Epoch 132/499
----------
estimate train loss: 224.133 time: 9.174 s
estimate train loss: 217.313 time: 8.345 s
estimate train loss: 232.823 time: 8.470 s
train loss: 19.810 
val loss: 228.853 
Epoch 133/499
--

estimate train loss: 227.766 time: 8.456 s
estimate train loss: 237.266 time: 8.492 s
train loss: 21.608 
val loss: 231.083 
Epoch 171/499
----------
estimate train loss: 220.418 time: 9.122 s
estimate train loss: 237.679 time: 8.275 s
estimate train loss: 247.454 time: 8.336 s
train loss: 17.691 
val loss: 227.246 
Epoch 172/499
----------
estimate train loss: 237.732 time: 9.063 s
estimate train loss: 231.742 time: 8.126 s
estimate train loss: 233.777 time: 8.175 s
train loss: 20.873 
val loss: 232.698 
Epoch 173/499
----------
estimate train loss: 239.147 time: 9.217 s
estimate train loss: 230.916 time: 8.368 s
estimate train loss: 227.957 time: 8.346 s
train loss: 19.783 
val loss: 229.248 
Epoch 174/499
----------
estimate train loss: 240.941 time: 9.067 s
estimate train loss: 225.974 time: 8.075 s
estimate train loss: 232.606 time: 8.157 s
train loss: 20.690 
val loss: 228.648 
Epoch 175/499
----------
estimate train loss: 229.137 time: 9.191 s
estimate train loss: 237.718 time: 

estimate train loss: 233.926 time: 8.237 s
estimate train loss: 217.163 time: 8.257 s
train loss: 17.037 
val loss: 226.713 
Epoch 214/499
----------
estimate train loss: 220.420 time: 9.134 s
estimate train loss: 246.812 time: 8.273 s
estimate train loss: 225.461 time: 8.284 s
train loss: 29.833 
val loss: 240.429 
Epoch 215/499
----------
estimate train loss: 219.284 time: 9.177 s
estimate train loss: 224.170 time: 8.280 s
estimate train loss: 242.357 time: 8.365 s
train loss: 21.445 
val loss: 230.418 
Epoch 216/499
----------
estimate train loss: 232.042 time: 9.180 s
estimate train loss: 222.188 time: 8.384 s
estimate train loss: 218.908 time: 8.270 s
train loss: 19.354 
val loss: 230.603 
Epoch 217/499
----------
estimate train loss: 221.305 time: 9.164 s
estimate train loss: 230.012 time: 8.303 s
estimate train loss: 227.347 time: 8.222 s
train loss: 25.910 
val loss: 236.802 
Epoch 218/499
----------
estimate train loss: 230.144 time: 9.394 s
estimate train loss: 214.587 time: 

val loss: 227.943 
Epoch 256/499
----------
estimate train loss: 230.765 time: 9.188 s
estimate train loss: 225.589 time: 8.358 s
estimate train loss: 233.231 time: 8.458 s
train loss: 15.512 
val loss: 225.731 
Epoch 257/499
----------
estimate train loss: 233.624 time: 9.260 s
estimate train loss: 227.716 time: 8.301 s
estimate train loss: 242.303 time: 8.414 s
train loss: 22.511 
val loss: 230.723 
Epoch 258/499
----------
estimate train loss: 224.657 time: 9.124 s
estimate train loss: 222.702 time: 8.342 s
estimate train loss: 229.798 time: 8.385 s
train loss: 23.675 
val loss: 232.147 
Epoch 259/499
----------
estimate train loss: 252.779 time: 9.137 s
estimate train loss: 235.111 time: 8.078 s
estimate train loss: 217.943 time: 8.161 s
train loss: 20.737 
val loss: 230.399 
Epoch 260/499
----------
estimate train loss: 239.508 time: 9.066 s
estimate train loss: 225.036 time: 8.285 s
estimate train loss: 225.556 time: 8.397 s
train loss: 18.879 
val loss: 229.584 
Epoch 261/499
--

estimate train loss: 229.087 time: 8.294 s
estimate train loss: 222.914 time: 8.342 s
train loss: 19.274 
val loss: 229.194 
Epoch 299/499
----------
estimate train loss: 222.937 time: 9.022 s
estimate train loss: 225.085 time: 8.310 s
estimate train loss: 223.158 time: 8.301 s
train loss: 24.068 
val loss: 234.441 
Epoch 300/499
----------
estimate train loss: 233.079 time: 9.085 s
estimate train loss: 243.712 time: 8.287 s
estimate train loss: 244.835 time: 8.179 s
train loss: 33.944 
val loss: 245.453 
Epoch 301/499
----------
estimate train loss: 236.563 time: 9.175 s
estimate train loss: 243.898 time: 8.316 s
estimate train loss: 227.084 time: 8.372 s
train loss: 19.819 
val loss: 228.632 
Epoch 302/499
----------
estimate train loss: 228.310 time: 9.110 s
estimate train loss: 226.496 time: 8.315 s
estimate train loss: 232.413 time: 8.255 s
train loss: 21.799 
val loss: 233.015 
Epoch 303/499
----------
estimate train loss: 238.299 time: 9.153 s
estimate train loss: 228.743 time: 

val loss: 231.412 
Epoch 341/499
----------
estimate train loss: 246.758 time: 9.131 s
estimate train loss: 228.624 time: 8.092 s
estimate train loss: 244.706 time: 8.259 s
train loss: 27.088 
val loss: 237.808 
Epoch 342/499
----------
estimate train loss: 226.429 time: 9.164 s
estimate train loss: 237.873 time: 8.305 s
estimate train loss: 233.471 time: 8.399 s
train loss: 28.138 
val loss: 238.034 
Epoch 343/499
----------
estimate train loss: 211.381 time: 9.028 s
estimate train loss: 235.433 time: 8.302 s
estimate train loss: 233.618 time: 8.429 s
train loss: 21.950 
val loss: 231.636 
Epoch 344/499
----------
estimate train loss: 235.457 time: 9.092 s
estimate train loss: 245.577 time: 8.344 s
estimate train loss: 247.705 time: 8.460 s
train loss: 28.360 
val loss: 239.678 
Epoch 345/499
----------
estimate train loss: 234.525 time: 9.125 s
estimate train loss: 218.775 time: 8.310 s
estimate train loss: 227.604 time: 8.443 s
train loss: 22.013 
val loss: 229.365 
Epoch 346/499
--