In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.autograd import Variable
from torch.utils.serialization import load_lua
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms, models, datasets

import numpy as np

from PIL import Image
import matplotlib
import matplotlib.pyplot as plt

import time
import copy
import os

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
matplotlib.rc('figure', figsize=(12, 5))

use_cuda = torch.cuda.is_available()
use_cuda

## Image loading and processing

Common pretrained vision models require us to normalise images with a mean and std vector.

In [None]:
mean_vec = torch.FloatTensor([ 0.485, 0.456, 0.406 ]).view(3,1,1)
std_vec = torch.FloatTensor([ 0.229, 0.224, 0.225 ]).view(3,1,1)

# convert tensor to image for viewing
tensor_to_image = transforms.Compose([
    transforms.Lambda(lambda x: x * std_vec + mean_vec),
    transforms.ToPILImage()
])

def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)

# create low resolution square images
def input_transform(crop_size, upscale_factor):
    return transforms.Compose([
        transforms.CenterCrop(crop_size),
        transforms.Scale(crop_size // upscale_factor), # downsample
        transforms.ToTensor(),
        transforms.Normalize(mean_vec, std_vec)
    ])

# maintain high resolution square images
def target_transform(crop_size):
    return transforms.Compose([
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(),
        transforms.Normalize(mean_vec, std_vec)
    ])

In [None]:
# load each image as a duplicate pair
class DoubleImageDataset(Dataset):
    
    def __init__(self, img_dir, input_transform=None, target_transform=None):
        super(DoubleImageDataset, self).__init__()
        
        self.image_fnames = [os.path.join(img_dir, f) for f in os.listdir(img_dir)]
        self.input_transform = input_transform
        self.target_transform = target_transform
    
    def __getitem__(self, index):
        
        input_img = Image.open(self.image_fnames[index]).convert('RGB')
        target_img = input_img.copy()
        
        if self.input_transform:
            input_img = self.input_transform(input_img)
        if self.target_transform:
            target_img = self.target_transform(target_img)
            
        return input_img, target_img
    
    def __len__(self):
        return len(self.image_fnames)

In [None]:
dirname = '/home/samir/Downloads/ILSVRC2012_img_train/train'
upscale_factor = 2
crop_size = calculate_valid_crop_size(256, upscale_factor)

train_dataset = DoubleImageDataset(
    dirname,
    input_transform=input_transform(crop_size, upscale_factor),
    target_transform=target_transform(crop_size))

data_loader = DataLoader(
    train_dataset,
    batch_size=8, # INCREASE BATCH SIZE
    shuffle=True,
    num_workers=7,
    pin_memory=True)

# preview some images
low_res, high_res = next(iter(data_loader))

In [None]:
tensor_to_image(low_res[0])

In [None]:
tensor_to_image(high_res[0])

## Use pretrained model to visualise activations

- Load pretrained model
- Pass through some test images to inspect activations at certain layers

(VGG specific) We want to see what each of the convolutional layers will output (after ReLU).

In [None]:
pretrained_model = models.vgg16_bn(pretrained=True).features
pretrained_model

In [None]:
children = list(pretrained_model.children())
layer_idxs = [0, 2, 5, 9, 12, 16] #  just before next Conv2d
layer_weights = [0.05, 0.05, 0.1, 0.1, 0.35, 0.35]
conv_models = [(copy.deepcopy(nn.Sequential(*children[:i+1])), w)
               for i, w in zip(layer_idxs, layer_weights)]

assert conv_models[0][0] is not conv_models[1][0]
conv_models

In [None]:
test_img = high_res[0].unsqueeze(0)

out_imgs = []
for conv_model, _ in conv_models:
    conv_model = conv_model.cuda()
    out_img = conv_model(Variable(test_img).cuda())
    out_img = out_img.cpu().squeeze().data
    out_imgs.append(out_img)

In [None]:
filter_idx = 30
gs = matplotlib.gridspec.GridSpec(2, 3)

for i, g in enumerate(gs):
    ax = plt.subplot(g)
    ax.imshow(out_imgs[i][filter_idx])
    ax.set_xticks([])

## Upsampling model 

- Takes in a low resolution image and upscales it to the same size as its target
- This network attemps to match its target

In [None]:
class ResidualBlock(nn.Module):
    
    def __init__(self, n_in, n_out):
        super(ResidualBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(n_in, n_out, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(n_out)
        self.conv2 = nn.Conv2d(n_out, n_out, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(n_out)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):

        # apply residual block
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = self.relu(out)
        return out


class UpsampleBlock(nn.Module):
    
    def __init__(self):
        super(UpsampleBlock, self).__init__()
        
        # WHY USE A CONV AFTER UPSAMPLING?
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
        self.conv = nn.Conv2d(64, 64, 3, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.upsample(x)
        x = self.conv(x)
        x = self.bn(x)
        return self.relu(x)


class UpsamplingNetwork(nn.Module):
    
    def __init__(self):
        super(UpsamplingNetwork, self).__init__()
        
        self.receptor = nn.Sequential(
            nn.Conv2d(3, 64, 9, padding=4, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True))
        
        self.resblock1 = ResidualBlock(64, 64)
        self.resblock2 = ResidualBlock(64, 64)
        self.resblock3 = ResidualBlock(64, 64)
        self.resblock4 = ResidualBlock(64, 64)
        
        self.upsampler1 = UpsampleBlock() # TRY ONLY ONE UPSAMPLE BLOCK
#         self.upsampler2 = UpsampleBlock()
        
        self.reducer = nn.Sequential(
            nn.Conv2d(64, 3, 9, padding=4, bias=False),
            nn.Tanh())
        
    def forward(self, x):
        
        x = self.receptor(x)
        
        x = self.resblock1(x)
        x = self.resblock2(x)
        x = self.resblock3(x)
        x = self.resblock4(x)
        
        x = self.upsampler1(x)
#         x = self.upsampler2(x)
        
        x = self.reducer(x)
        return x

In [None]:
upsampling_net = UpsamplingNetwork().cuda()

## Perceptual loss model

- Construct perceptual loss model from required pretrained model layers
- When forward passing through the perceptual loss model, forward pass through every subset conv model.
- Test outputs

*Need to clamp input?*
*Need to make input to percept loss volatile.*

In [None]:
class PerceptualLossNetwork(nn.Module):
    
    def __init__(self, activation_models):
        super(PerceptualLossNetwork, self).__init__()
        
        for model, _ in activation_models:
            for param in model.parameters():
                param.requires_grad = False
                
        self.activation_models = activation_models
        self.loss_fn = nn.MSELoss()
    
    def forward(self, x, target):
        
        total_loss = 0.0
        for model, weight in self.activation_models:
            x_activations = model(x)
            target_activations = model(target)
            total_loss += self.loss_fn(x_activations, target_activations) * weight
        return total_loss

In [None]:
percept_net = PerceptualLossNetwork(conv_models).cuda()

## Define optimiser

In [None]:
optimiser = optim.Adam(upsampling_net.parameters(), lr=1e-4) # CHANGE OPTIMISER
# optimiser = optim.LBFGS(upsampling_net.parameters())

## Training

In [None]:
num_epochs = 1

print('Training ...')
start_time = time.time()
weights = upsampling_net.state_dict()

for epoch in range(num_epochs):
    print('Epoch {:3d}/{:3d}'.format(epoch, num_epochs))
    epoch_loss = 0

    for i, (input_img, target_img) in enumerate(data_loader, 1):

        # load batch
        input_img = Variable(input_img, requires_grad=False).cuda(async=True)
        target_img = Variable(target_img, requires_grad=False).cuda(async=True)

        # clear gradients
        optimiser.zero_grad()

        # forward pass
        predicted_img = upsampling_net(input_img)

        # get loss
        percept_loss = percept_net(predicted_img, target_img)
        
        # compute gradients
        percept_loss.backward()

        # optimise
        optimiser.step()
        
        print('Batch loss: {:4f}'.format(percept_loss.data[0]))

In [None]:
# torch.save(upsampling_net.state_dict(), './model-35min.state')

In [None]:
end_time = (time.time() - start_time) / 60
end_time

## Test upsampling network

In [None]:
low_res, high_res = next(iter(data_loader))

In [None]:
idx = 4
y = upsampling_net(Variable(low_res).cuda())
y = y.cpu().data[idx] * 1.5
print(y.min(), y.max())
tensor_to_image(y)