# Dependencies

In [20]:
import time
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torch.optim as optim

# VGG16

In [2]:
vgg16_trained = models.vgg16(pretrained=True)
vgg16_untrained = models.vgg16()

In [3]:
def modify_model(model, input_channels, output_units):
    '''
    Parameters
    
    model: instance of a pytorch model to be modified
    input_channels: channels of input tensor
    output_units: number of units in the last layer
    '''
    model.features[0] = nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True)
    model.classifier[6] = nn.Linear(4096, output_units)

In [4]:
modify_model(vgg16_trained, 2, 512)

Test with a random input tensor:

In [5]:
x = torch.randn(1, 2, 256, 256) # (256, 256, 3)
output = vgg16_trained(x)
print(output.shape)

torch.Size([1, 512])


Features:

In [6]:
vgg16_trained.features

Sequential(
  (0): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (17): Conv2d(256, 512, kernel_si

In [7]:
vgg16_trained.classifier

Sequential(
  (0): Linear(in_features=25088, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Dropout(p=0.5, inplace=False)
  (3): Linear(in_features=4096, out_features=4096, bias=True)
  (4): ReLU(inplace=True)
  (5): Dropout(p=0.5, inplace=False)
  (6): Linear(in_features=4096, out_features=512, bias=True)
)

Do the same with the untrained version:

In [8]:
modify_model(vgg16_untrained, 2, 512)

# Settings

Create a dictionary with settings:

In [12]:
settings = {
                'criterion': nn.MSELoss,
                'optimizer': optim.Adam,
                'lr': 0.001
           }

In [17]:
use_cuda = torch.cuda.is_available() # True if cuda is available

Move model to cuda:

In [18]:
if use_cuda:
    model = model.cuda()

Define a train and test function:

In [21]:
def train(n_epochs, loaders, model, optimizer, criterion, use_cuda, save_path):
    '''
    Returns
    
    The trained model
    '''
    valid_loss_min = np.Inf
    
    for epoch in range(1, num_epochs + 1):
        start_time = time.time()
        train_loss = 0.
        valid_loss = 0.
        
        # train 
        model.train() # keeps the gradients
        for batch_idx, (data, target) in enumerate(loaders['train']):
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            optimizer = zero_grad() # reset gradients after each batch
            output = model(data) # make prediction
            loss = criterion(output, target) # calculate MSE
            loss.backward() # backpropagation
            optimizer.step() # update weights
            mean_train_error = loss.data()
            train_loss += ((1 / (batch_idx + 1))) * (loss.data() - train_loss)
            
        # validation
        model.eval() # no need to keep grads
        for batch_idx, (data, target) in enumerate(loaders['valid']):
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            output = model(data) # make prediction
            loss = criterion(output, target) # calculate MSE
            mean_valid_error = loss.data()
            valid_loss += ((1 / (batch_idx + 1))) * (loss.data() - valid_loss)
            
        print('Epoch {}, Training Loss: {:.6f}, Mean Phase Error: {:.6f}, Validation Loss: {:.6f}, Mean Phase Error: {:.6f}'.format(train_loss, 
                                                                                                                                    mean_train_error,
                                                                                                                                    valid_loss,
                                                                                                                                    mean_valid_error))
        print('Computation time: {:.4f} sec'.format(time.time() - start_time))
        
        if valid_loss < valid_loss_min:
            print('Validation loss has decreased from {:.6f} -> {:.6f}'.format(valid_loss_min, valid_loss))
            valid_loss_min = valid_loss
            torch.save(model_state_dict(), model)
            
    return model