## Notes

- Need to read convolution arithmetic guide
- Perceptual loss network does not have to be a whole module.
- Add learning rate decay
- Models and parameters must be moved to the GPU before the optimiser is created.
- Use pinned memory

**How the optimiser works with a closure**

Let's say we've created an optimiser as such:

`optimiser = optim.SGD(model.parameters())`

When `loss.backward()` is called, the model's Variables are given gradient values so that when we call `optimiser.step()`, the optimiser updates the parameters of the model with their respective gradients.

If a closure argument is given to `optimiser.step(closure)`, we assume that the optimiser wishes to keep a history of the losses or some internal state which may involve calling the closure multiple times before the actual optimisation step is done.

In [None]:
# ! ls $dirname'/sample' | tail -100
# Image.open(dirname+'/sample/'+'n09332890_8608.JPEG')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.autograd import Variable
from torch.utils.serialization import load_lua
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms, models, datasets

import numpy as np

from PIL import Image
import matplotlib
import matplotlib.pyplot as plt

import time
import copy
import os

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
matplotlib.rc('figure', figsize=(12, 5))

In [None]:
use_cuda = torch.cuda.is_available()
use_cuda

## Loading and preprocessing images

In [None]:
# mean and std are specific to pretrained models
mean_vec = torch.FloatTensor([ 0.485, 0.456, 0.406 ]).view(3,1,1)
std_vec = torch.FloatTensor([ 0.229, 0.224, 0.225 ]).view(3,1,1)

# preproces PIL image
# return tensor of shape (batch, channels, height, width)
image_to_tensor = transforms.Compose([
    transforms.Scale(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean_vec, std_vec)
])

tensor_to_image = transforms.Compose([
    transforms.Lambda(lambda x: x * std_vec + mean_vec),
    transforms.ToPILImage()
])

In [None]:
def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)

def input_transform(crop_size, upscale_factor):
    return transforms.Compose([
        transforms.CenterCrop(crop_size),
        transforms.Scale(crop_size // upscale_factor), # downsample
        transforms.ToTensor(),
        transforms.Normalize(mean_vec, std_vec)
    ])

def target_transform(crop_size):
    return transforms.Compose([
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(),
        transforms.Normalize(mean_vec, std_vec)
    ])

In [None]:
class ImagesOnlyDataset(Dataset):
    
    def __init__(self, root, transform=None):
        super(ImagesOnlyDataset, self).__init__()
        
        self.images = [os.path.join(root, f) for f in os.listdir(root)]
        self.transform = transform
    
    def __getitem__(self, index):
        img = Image.open(self.images[index]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img
    
    def __len__(self):
        return len(self.images)

In [None]:
class DoubleImageDataset(Dataset):
    
    def __init__(self, img_dir, input_transform=None, target_transform=None):
        super(DoubleImageDataset, self).__init__()
        
        self.image_fnames = [os.path.join(img_dir, f) for f in os.listdir(img_dir)]
        self.input_transform = input_transform
        self.target_transform = target_transform
    
    def __getitem__(self, index):
        
        input_img = Image.open(self.image_fnames[index]).convert('RGB')
        target_img = input_img.copy()
        
        if self.input_transform:
            input_img = self.input_transform(input_img)
        if self.target_transform:
            target_img = self.target_transform(target_img)
            
        return input_img, target_img
    
    def __len__(self):
        return len(self.image_fnames)

In [None]:
dirname = '/home/samir/Downloads/ILSVRC2012_img_train/train'
upscale_factor = 2
crop_size = calculate_valid_crop_size(256, upscale_factor)

train_dataset = DoubleImageDataset(
    dirname,
    input_transform=input_transform(crop_size, upscale_factor),
    target_transform=target_transform(crop_size))

In [None]:
data_loader = DataLoader(
    train_dataset,
    batch_size=4, # INCREASE BATCH SIZE
    shuffle=True,
    num_workers=6)

In [None]:
low_res, high_res = next(iter(data_loader))

In [None]:
tensor_to_image(low_res[0])

In [None]:
tensor_to_image(high_res[0])

## Construct network

**Blocks required:**

- Semantic network - trainable network used for objective inference
- Perceptual loss network - ouputs the loss between the activations of two inputs at some layer
- Training loss function
- Optimiser

In [None]:
# pretrained_model = models.vgg16_bn(pretrained=True).features
pretrained_model = models.resnet50(pretrained=True)

In [None]:
pretrained_model_subset = nn.Sequential(*list(pretrained_model.children())[:5])

In [None]:
pretrained_model_subset = pretrained_model_subset.cuda()

In [None]:
tensor_to_image(pretrained_model_subset(Variable(low_res).cuda()).cpu().data[3][0])

In [None]:
tensor_to_image(low_res[3])

In [None]:
class ResidualBlock(nn.Module):
    
    def __init__(self, n_in, n_out):
        super(ResidualBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(n_in, n_out, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(n_out)
        self.conv2 = nn.Conv2d(n_out, n_out, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(n_out)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):

        # apply residual block
        residual = x # COULD BE ERROR IN VARIABLE NAMES
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = self.relu(out) # PROBABLY DONT NEED THIS
        return out


class UpsampleBlock(nn.Module):
    
    def __init__(self):
        super(UpsampleBlock, self).__init__()
        
        self.upsample = nn.Upsample(size=(256, 256), mode='bilinear')
        self.conv = nn.Conv2d(64, 64, 3, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.upsample(x)
        x = self.conv(x)
        x = self.bn(x)
        return self.relu(x)


class SemanticNetwork(nn.Module):
    
    def __init__(self):
        super(SemanticNetwork, self).__init__()
        
        self.receptor = nn.Sequential(
            nn.Conv2d(3, 64, 9, padding=4, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True))
        
        self.resblock1 = ResidualBlock(64, 64)
        self.resblock2 = ResidualBlock(64, 64)
        self.resblock3 = ResidualBlock(64, 64)
        self.resblock4 = ResidualBlock(64, 64)
        
        self.upsampler1 = UpsampleBlock() # TRY ONLY ONE UPSAMPLE BLOCK
        self.upsampler2 = UpsampleBlock()
        
        self.reducer = nn.Sequential(
            nn.Conv2d(64, 3, 9, padding=4, bias=False),
            nn.Tanh())
        
    def forward(self, x):
        
        x = self.receptor(x)
        
        x = self.resblock1(x)
        x = self.resblock2(x)
        x = self.resblock3(x)
        x = self.resblock4(x)
        
        x = self.upsampler1(x)
        x = self.upsampler2(x)
        
        x = self.reducer(x)
        
        x = (x + 1) * 127.5
        return x

In [None]:
class PerceptualLossNetwork(nn.Module):
    
    def __init__(self, pretrained_model, activation_layer):
        super(PerceptualLossNetwork, self).__init__()
        
        layers = list(pretrained_model.children())[:activation_layer]
        self.net = nn.Sequential(*layers)
        for param in self.net.parameters():
            param.requires_grad = False
        
        self.loss_criterion = nn.MSELoss()
    
    def forward(self, x, target):
        
        # PERCEPTUAL LOSS IS MULTIPLE LAYERS (with weights)
        
        x_activations = self.net(x)
        target_activations = self.net(target)
        
        loss = self.loss_criterion(x_activations, target_activations)
        return loss

In [None]:
activation_layer = 4+ 3 # MIGHT NEED A LATER ACTIVATION LAYER

semantic_net = SemanticNetwork()

percept_net = PerceptualLossNetwork(pretrained_model_subset, activation_layer)
for param in percept_net.parameters():
    param.requires_grad = False

In [None]:
# move to GPU
semantic_net = semantic_net.cuda()
percept_net = percept_net.cuda()

## Define loss and optimisation functions

In [None]:
optimiser = optim.Adam(semantic_net.parameters(), lr=1e-3) # CHANGE OPTIMISER

## Training

In [None]:
x = None
y = None
z = None

def train(num_epochs=5):
    
    global x
    global y
    global z
    
    print('Training ...')
    start_time = time.time()
    best_weights = semantic_net.state_dict()
    
    for epoch in range(num_epochs):
        print('Epoch {:3d}/{:3d}'.format(epoch, num_epochs))
        epoch_loss = 0
        
        for i, (input_img, target_img) in enumerate(data_loader):
            
            # load batch
            input_img = Variable(input_img).cuda()
            target_img = Variable(target_img).cuda()
            
            print(input_img.shape)
            x = input_img
            
            # clear gradients
            optimiser.zero_grad()
            
            # forward pass
            predicted_img = semantic_net(input_img)
            
            print(predicted_img.shape)
            z = predicted_img
            
            percept_loss = percept_net(predicted_img, target_img)
            
            print(target_img.shape)
            y = target_img
            
            # optimise
            optimiser.step()
            
            break
            
            print('Batch loss {:4f}'.format(percept_loss.data[0]))

In [None]:
train(1)

In [None]:
tensor_to_image(y.cpu().data[0])

In [None]:
tensor_to_image(percept_net.net(z).cpu().data[0][0])