In [None]:
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.optim as optim

from torch.autograd import Variable
from torch.utils.serialization import load_lua
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms, models, datasets

import numpy as np

from PIL import Image
import matplotlib
import matplotlib.pyplot as plt

import time
import copy
import os

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
matplotlib.rc('figure', figsize=(12, 5))

use_cuda = torch.cuda.is_available()
use_cuda

In [None]:
mean_vec = torch.FloatTensor([ 0.485, 0.456, 0.406 ]).view(3,1,1)
std_vec = torch.FloatTensor([ 0.229, 0.224, 0.225 ]).view(3,1,1)

# convert tensor to image for viewing
tensor_to_image = transforms.Compose([
    transforms.ToPILImage()
])

def ycbcr_to_rgb(y, cb, cr):
    # each arg is a 2D tensor
    y = y.squeeze()
    y = (y * 255.0).clamp(0, 255)
    y = Image.fromarray(np.uint8(y), mode='L')
    cb = transforms.ToPILImage()(cb).resize(y.size, Image.BICUBIC)
    cr = transforms.ToPILImage()(cr).resize(y.size, Image.BICUBIC)
    rgb = Image.merge('YCbCr', [y, cb, cr]).convert('RGB')
    return transforms.ToTensor()(rgb)

def normalise(x):
    # x is an RGB tensor
    return (x - mean_vec) / std_vec

def denormalise(x):
    # x is an RGB tensor
    return x * std_vec + mean_vec

def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)

# create low resolution square images
def input_transform(crop_size, upscale_factor):
    return transforms.Compose([
        transforms.CenterCrop(crop_size),
        transforms.Scale(crop_size // upscale_factor),
        transforms.ToTensor()
    ])

# maintain high resolution square images
def target_transform(crop_size):
    return transforms.Compose([
        transforms.CenterCrop(crop_size),
        transforms.ToTensor()
    ])

In [None]:
# load each image as a duplicate pair
class DoubleImageDataset(Dataset):
    
    def __init__(self, img_dir, input_transform=None, target_transform=None):
        super(DoubleImageDataset, self).__init__()
        
        self.image_fnames = [os.path.join(img_dir, f) for f in os.listdir(img_dir)]
        self.input_transform = input_transform
        self.target_transform = target_transform
    
    def __getitem__(self, index):
        
        # CAN MOVE THIS TO TRANSFORM COMPOSITION AS LAMBDA FN?
        img = Image.open(self.image_fnames[index])
        y_img, cb_img, cr_img = img.convert('YCbCr').split()
        
        y_low_res = self.input_transform(y_img)
        cb_low_res = self.input_transform(cb_img)
        cr_low_res = self.input_transform(cr_img)
        
        rgb_high_res = self.target_transform(img.convert('RGB'))
        rgb_high_res = normalise(rgb_high_res) # note normalisation
        y_high_res = self.target_transform(y_img)
        
        return y_low_res, cb_low_res, cr_low_res, rgb_high_res, y_high_res
    
    def __len__(self):
        return len(self.image_fnames)

In [None]:
dirname = '/home/samir/Downloads/ILSVRC2012_img_train/train'
upscale_factor = 2
crop_size = calculate_valid_crop_size(256, upscale_factor)

train_dataset = DoubleImageDataset(
    dirname,
    input_transform=input_transform(crop_size, upscale_factor),
    target_transform=target_transform(crop_size))

data_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=7,
    pin_memory=True)

# preview some images
y_low_res, cb_low_res, cr_low_res, rgb_high_res, y_high_res = next(iter(data_loader))

In [None]:
idx = 0
tensor_to_image(denormalise(rgb_high_res[idx]))

In [None]:
tensor_to_image(y_high_res[idx])

In [None]:
tensor_to_image(y_low_res[idx])

In [None]:
tensor_to_image(cb_low_res[idx])

In [None]:
tensor_to_image(cr_low_res[idx])

In [None]:
img = ycbcr_to_rgb(y_low_res[idx], cb_low_res[idx], cr_low_res[idx])
transforms.ToPILImage()(img)

In [None]:
class UpsamplingNetwork(nn.Module):
    
    def __init__(self, upscale_factor):
        super(UpsamplingNetwork, self).__init__()
        
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(1, 64, 5, padding=2)
        self.conv2 = nn.Conv2d(64, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, upscale_factor**2, 3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
        
        self._initialise_weights()
    
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.pixel_shuffle(self.conv3(x))
        return x
    
    def _initialise_weights(self):
        gain = init.calculate_gain('relu')
        init.orthogonal(self.conv1.weight, gain)
        init.orthogonal(self.conv2.weight, gain)
        init.orthogonal(self.conv3.weight)

In [None]:
upsampling_net = UpsamplingNetwork(upscale_factor).cuda()

In [None]:
pretrained_model = models.vgg16(pretrained=True).features.cuda()
pretrained_model

In [None]:
class PerceptualLossNetwork(nn.Module):
    
    def __init__(self, pretrained_model, extraction_layers, weights):
        super(PerceptualLossNetwork, self).__init__()
        
        for param in pretrained_model.parameters():
            param.requires_grad = False
        
        children = list(pretrained_model.children())
        
        modules = []
        i = 0
        for j, w in zip(extraction_layers, weights):
            modules.append((nn.Sequential(*children[i:j+1]), w))
            i = j+1
        self.modules = modules
        
        self.loss_fn = nn.MSELoss()
    
    
    def forward(self, x, target):
        
        total_loss = 0.0
        for module, weight in self.modules:
            x = module(x)
            target = module(target)
            total_loss += self.loss_fn(x, target) * weight
        return total_loss

In [None]:
percept_net = PerceptualLossNetwork(
    pretrained_model,
    [29],
    [1]).cuda()

In [None]:
optimiser = optim.Adam(upsampling_net.parameters(), lr=1e-3, weight_decay=1e-3)
pixel_loss_fn = nn.MSELoss().cuda()

In [None]:
num_epochs = 1

print('Training ...')
start_time = time.time()

for epoch in range(num_epochs):
    print('Epoch {:3d}/{:3d}'.format(epoch, num_epochs))

    for i, batch in enumerate(data_loader, 1):
        
        y_low_res, cb_low_res, cr_low_res, rgb_high_res, y_high_res = batch
        
        y_low_res = Variable(y_low_res).cuda()
        y_high_res = Variable(y_high_res).cuda()
        rgb_high_res = Variable(rgb_high_res).cuda()
        
        optimiser.zero_grad()
        
        # upsample
        y_high_res_pred = upsampling_net(y_low_res)
        pixel_loss = pixel_loss_fn(y_high_res_pred, y_high_res)
        
        # convert prediction YCbCr to RGB
        rgb_high_res_pred = []
        for j, y_pred in enumerate(y_high_res_pred.cpu().data):
            rgb_pred = ycbcr_to_rgb(y_pred, cb_low_res[j], cr_low_res[j])
            rgb_high_res_pred.append(normalise(rgb_pred))
        
        rgb_high_res_pred = torch.stack(rgb_high_res_pred)
        rgb_high_res_pred = Variable(rgb_high_res_pred.cuda())
        
        # compute perceptual loss
        percept_loss = percept_net(rgb_high_res_pred, rgb_high_res)
        
        # calculate loss
        loss = pixel_loss + percept_loss
        loss.backward()
        
        optimiser.step()
        
        # if (i % 100 == 0):
        print('#: {:3d} Losses: pixel: {:3f}, percept: {:3f}'.format(
            i, pixel_loss.data[0], percept_loss.data[0]))

## Test model

In [None]:
idx = 0
y_low_res, cb_low_res, cr_low_res, rgb_high_res, y_high_res = next(iter(data_loader))

In [None]:
tensor_to_image(y_low_res[idx])

In [None]:
tensor_to_image(denormalise(rgb_high_res[idx]))

In [None]:
prediction = upsampling_net(Variable(y_low_res[idx].unsqueeze(0)).cuda())
prediction = prediction.cpu().data[0]

In [None]:
transforms.ToPILImage()(prediction)

In [None]:
img = ycbcr_to_rgb(prediction, cb_low_res[idx], cr_low_res[idx])
transforms.ToPILImage()(img)