In [None]:
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.optim as optim

from torch.autograd import Variable
from torch.utils.serialization import load_lua
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms, models, datasets

import numpy as np

from PIL import Image
import matplotlib
import matplotlib.pyplot as plt

import time
import copy
import os

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
matplotlib.rc('figure', figsize=(12, 5))

use_cuda = torch.cuda.is_available()
use_cuda

In [None]:
# mean_vec = torch.FloatTensor([ 0.485, 0.456, 0.406 ]).view(3,1,1)
# std_vec = torch.FloatTensor([ 0.229, 0.224, 0.225 ]).view(3,1,1)

# convert tensor to image for viewing
tensor_to_image = transforms.Compose([
#     transforms.Lambda(lambda x: x * std_vec + mean_vec),
    transforms.ToPILImage()
])

def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)

# create low resolution square images
def input_transform(crop_size, upscale_factor):
    return transforms.Compose([
        transforms.CenterCrop(crop_size),
        transforms.Scale(crop_size // upscale_factor), # downsample
        transforms.ToTensor()
#         transforms.Normalize(mean_vec, std_vec)
    ])

# maintain high resolution square images
def target_transform(crop_size):
    return transforms.Compose([
        transforms.CenterCrop(crop_size),
        transforms.ToTensor()
#         transforms.Normalize(mean_vec, std_vec)
    ])

In [None]:
# load each image as a duplicate pair
class DoubleImageDataset(Dataset):
    
    def __init__(self, img_dir, input_transform=None, target_transform=None):
        super(DoubleImageDataset, self).__init__()
        
        self.image_fnames = [os.path.join(img_dir, f) for f in os.listdir(img_dir)]
        self.input_transform = input_transform
        self.target_transform = target_transform
    
    def __getitem__(self, index):
        
        input_img = Image.open(self.image_fnames[index]).convert('YCbCr')
        input_img, _, _ = input_img.split()
        target_img = input_img.copy()
        
        if self.input_transform:
            input_img = self.input_transform(input_img)
        if self.target_transform:
            target_img = self.target_transform(target_img)
            
        return input_img, target_img
    
    def __len__(self):
        return len(self.image_fnames)

In [None]:
dirname = '/home/samir/Documents/pytorch-examples/super_resolution/dataset/BSDS300/images/train/'
upscale_factor = 3
crop_size = calculate_valid_crop_size(256, upscale_factor)

train_dataset = DoubleImageDataset(
    dirname,
    input_transform=input_transform(crop_size, upscale_factor),
    target_transform=target_transform(crop_size))

data_loader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    num_workers=7,
    pin_memory=True)

# preview some images
low_res, high_res = next(iter(data_loader))

In [None]:
tensor_to_image(low_res[0])

In [None]:
tensor_to_image(high_res[0])

In [None]:
class UpsamplingNetwork(nn.Module):
    
    def __init__(self, upscale_factor):
        super(UpsamplingNetwork, self).__init__()
        
        in_channels = 1
        
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels, 64, (5,5), (1,1), (2,2))
        self.conv2 = nn.Conv2d(64, 64, (3,3), (1,1), (1,1))
        self.conv3 = nn.Conv2d(64, 32, (3,3), (1,1), (1,1))
        self.conv4 = nn.Conv2d(32, in_channels*(upscale_factor**2), (3,3), (1,1), (1,1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
        
        self._initialise_weights()
    
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x
    
    def _initialise_weights(self):
        gain = init.calculate_gain('relu')
        init.orthogonal(self.conv1.weight, gain)
        init.orthogonal(self.conv2.weight, gain)
        init.orthogonal(self.conv3.weight, gain)
        init.orthogonal(self.conv4.weight)

In [None]:
upsampling_net = UpsamplingNetwork(upscale_factor).cuda()

In [None]:
optimiser = optim.Adam(upsampling_net.parameters(), lr=1e-3)

In [None]:
loss_fn = nn.MSELoss().cuda()

In [None]:
num_epochs = 30

print('Training ...')
start_time = time.time()

for epoch in range(num_epochs):
    print('Epoch {:3d}/{:3d}'.format(epoch, num_epochs))

    for i, (input_img, target_img) in enumerate(data_loader, 1):

        # load batch
        input_img = Variable(input_img).cuda(async=True)
        target_img = Variable(target_img).cuda(async=True)

        optimiser.zero_grad()
        predicted_img = upsampling_net(input_img)
        loss = loss_fn(predicted_img, target_img)
        loss.backward()
        optimiser.step()
        
        print('Batch loss: {:4f}'.format(loss.data[0]))