# Reproduction of 'Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network'
> This blog post describes the reproduction of the paper: 'Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network'. It explains the main point of the paper, tries to reproduce the results of table 1 of the paper and argues about it's reproducibility

- toc: true- branch: master- badges: true
- comments: true
- author: Luuk Balkenende & Sieger Falkena & Luc Kloosterlaan
- categories: [fastpages, jupyter]

---

### Running the experiment on Google Colab
This notebook is running remotely on the Google Colab platform. Therefore, to save and access the trained model, we needed to mount the Google drive. We used the following code snippet to set up a local drive on our computer.

In [0]:
from torchvision import transforms
from google.colab import drive
from torch.utils import data
from PIL import Image

import matplotlib.pyplot as plt
import PIL.Image as pil_image
import torch.nn as nn
import random as rnd
import torchvision
import numpy as np
import torch
import time
import cv2
import os

In [3]:
drive.mount('/content/gdrive')
path ='/content/gdrive/My Drive/deep_learning_group_7'
os.chdir(path)

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


In [4]:
from IPython.display import HTML
HTML(filename="interactive_image.html")

# ESPCN Architecture

Below, you can find the ESPCN architecture as defined in the paper. Notice that we have used `tanh` as activation function, as the writers indicated that this leads to better results. 

Most important to notice here is that the sub-pixel convolution layer is divided into two layers: one normal convolutional layer (`conv3`) and a layer which preforms the SR operation (`upsample`).

In [0]:
r = 3 #scaling factor

class SuperResConvNet(nn.Module):
    def __init__(self):
        super(SuperResConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1,64, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv2d(64,32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(32,1*r**2, kernel_size=3, stride=1, padding=1)
        self.upsample = nn.PixelShuffle(r)

    def forward(self, y):
        y = torch.tanh(self.conv1(y))
        y = torch.tanh(self.conv2(y))
        y = self.conv3(y)
        y = self.upsample(y)
        return y

if torch.cuda.is_available():
    print("Using GPU")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

ConvNet = SuperResConvNet()
ConvNet.to(device)

#Functions

In [0]:
slide_subfolders = ['/CVPR2016_ESPCN_OurBenchMarkResult/Ours/yang91/T91/', 
                    '/CVPR2016_ESPCN_OurBenchMarkResult/Ours/x3.0/Set5/', 
                    '/CVPR2016_ESPCN_OurBenchMarkResult/Ours/x3.0/Set14/',
                    '/CVPR2016_ESPCN_OurBenchMarkResult/Ours/x3.0/BSD300/',
                    '/CVPR2016_ESPCN_OurBenchMarkResult/Ours/x3.0/BSD500/',
                    '/CVPR2016_ESPCN_OurBenchMarkResult/Ours/x3.0/SuperText136/']

#width and height of patches
x = 17      

class Slide:
    def __init__(self, path, slide_subfolders, count):
        self.count = count
        self.dir = path + slide_subfolders
        
        self.namelist = self.deleteLRImage()
        if count == 0:
            self.patchlist = self.patchlist_get()
        
    def removePng(f):
        'returns filename without png'
        filename_parts = f[:-4]
        return filename_parts

    def getList(self):
        'returns list of filenames' 
        return [Slide.removePng(f) for f in os.listdir(self.dir) if f.endswith(".png")]

    def deleteLRImage(self):
        'deletes LR images from list if they are already made (which they are)'
        name_list_2 = Slide.getList(self)
        name_list_1 = [x for x in name_list_2 if "lr" not in x ]
        name_list =  [x for x in name_list_1 if "lowRes" not in x ]
        return name_list
    
    def getPatchList(self, img_name):
        'returns patch list of one image'
        #get filenames and load images
        filename = self.dir + img_name
        lr_img = cv2.cvtColor(cv2.imread(filename+'_lr.png'), cv2.COLOR_BGR2RGB)

        #parameters 
        stride_lr = x-np.sum((5%2,3%2,3%2))
        tot_img_d = int(lr_img.shape[0]/stride_lr), int(lr_img.shape[1]/stride_lr)    #amount of image in height and width respectively
        tot_img = tot_img_d[0]*tot_img_d[1]     #total amount of images

        #create list for current image
        patch_list = []
        for i in range(tot_img_d[0]-1):
            for j in range(tot_img_d[1]-1):
                patch_list.append([img_name, i,j])
        return patch_list
    
    def patchlist_get(self):
        #create patch_list
        patch_list = []
        for i in range(len(self.namelist)):
            patch_list.extend(self.getPatchList(self.namelist[i]))
        print('Found', len(patch_list), 'trainable patches out of', len(self.namelist), 'images.')
        return patch_list
    
    def createPatch(self, name):
        'returns a patch'
        img_name = name[0]
        patch_name = name[1], name[2]

        #get corresponding images
        filename = self.dir + img_name
        hr_img = cv2.cvtColor(cv2.imread(filename+'.png'), cv2.COLOR_BGR2YCrCb)[:,:,0] # only get Y channel from YCrCb
        lr_img = cv2.cvtColor(cv2.imread(filename+'_lr.png'), cv2.COLOR_BGR2YCR_CB)[:,:,0]

        #create hr patch 
        stride_hr = (x-np.sum((5%2,3%2,3%2)))*r
        hr_patch = hr_img[stride_hr*patch_name[0]:(stride_hr*patch_name[0]+17*r),stride_hr*patch_name[1]:(stride_hr*patch_name[1]+17*r)]

        #create lr patch 
        stride_lr = x-np.sum((5%2,3%2,3%2))
        lr_patch = lr_img[stride_lr*patch_name[0]:(stride_lr*patch_name[0]+17),stride_lr*patch_name[1]:(stride_lr*patch_name[1]+17)]

        return lr_patch, hr_patch
    
    def image_operations_testing(self, img_n):

        #get corresponding images
        hr_img = cv2.cvtColor(cv2.imread(self.dir + img_n + '.png'), cv2.COLOR_BGR2YCrCb)[:,:,0]
        lr_img = cv2.cvtColor(cv2.imread(self.dir + img_n[:-9] + '-lowRes.png'), cv2.COLOR_BGR2YCrCb)[:,:,0]

        #difference in shape: lr*3 and hr, due to sub-sampling (few pixels difference)
        #--> reshape lr
        lr_img = lr_img[:-1,:-1]
        hr_img = hr_img[:lr_img.shape[0]*r,:lr_img.shape[1]*r]

        return lr_img, hr_img

    def colored_prediction(self, img_n):
      'creates colored prediction'
      #get corresponding images
      filename = self.dir + img_n
      hr_img = Image.open(filename + '.png').convert('YCbCr')
      lr_img = Image.open(filename[:-9] + '-lowRes.png').convert('YCbCr')

      #upsample
      bicubic = lr_img.resize(hr_img.size, Image.BICUBIC)

      return lr_img, bicubic, hr_img # return as pil images

def getSlideList(slide_subfolders, path):
    slides = []
    for i, slide in enumerate(slide_subfolders):
        slides.append(Slide(path, slide, i))
    return slides

slides = getSlideList(slide_subfolders, path)

### PSNR calculation

In [0]:
def psnr_from_mselist(mse_list):
  mse = np.mean(mse_list)
  if mse == 0:
    return float('inf')
  else:
    return 20*np.log10(255/np.sqrt(mse))

def calc_mse(img_pred, img_hr):
    return np.mean((img_pred*255 - img_hr*255)**2)

# Obtaining low resolution images

In [0]:
# ## DONT RUN THIS CELL, LR IMAGES ARE ALREADY MADE AND PRESENT IN DIRECTORY
# def createLowRes(img_name, dir_91):
#   'saves low resolution image'
#   # Call HR image
#   filename = dir_91 + img_name + '.png'
#   hr_img = cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2RGB)

#   # Blur HR image
#   blur_img = cv2.GaussianBlur(hr_img,(5,5),0)

#   # Apply subsampling
#   lr_img = blur_img[::r,::r]
#   #print('Shape HR image: ', hr_img.shape)
#   #print('Shape LR image: ', lr_img.shape)

#   # Save lr_img
#   im = Image.fromarray(lr_img)
#   im.save(dir_91 + img_name + '_lr.png')

# for i in range(len(slides[0].)):
#   dir_91 = path + slide_subfolders[0]
#   createLowRes(slides[0].namelist[i], dir_91)

---
# Training
-loss \
-optimizer

-stopping criterion (anders gedaan dan in paper) \
-learning rate verhaal \

Eigenlijk zouden we hier nog validation bij moeten doen. Dus ongeveer 80% van patches gebruiken voor training, 20% gebruiken voor validation. Maar dat heb ik nu niet toegevoegd. Zouden we nog kunnen doen na morgenochtend

##Hyperparameter settings

In [0]:
## Loss and optimizer
criterion = nn.MSELoss()     
optimizer = torch.optim.Adam(ConvNet.parameters(), lr=0.01)  

## Stopping criterion
num_epoch = 5000     #amount of epochs

## Scheduler for dynamic reduction of the learning rate
threshold_mu = 1e-6     # Treshold for decreasing learning rate. Default threshold: 0.0001, paper does not say anything about value of threshold.
factor_value = 0.8     # Amount of decay per step, new lr = factor_value*lr. Default value: 0.1, paper does not say anything about this value.
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=factor_value, patience=2, threshold=threshold_mu, min_lr=0.0001, eps=1e-08, verbose=True)

## Batch size
batch_size = 16

## training validation ratio
train_val_ratio = 0.95

--- 
### Data loader
-waarom en hoe we data loader/generator hebben gebruikt

In [0]:
class DataGenerator(data.Dataset):
    'Generates the dataset that is used for training the ESPCN'
    def __init__(self, slides):
        self.slides = slides
        self.transform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor()])

    def __len__(self):
        return len(self.slides.patchlist)

    def __getitem__(self, idx):
        lr_patch, hr_patch = self.slides.createPatch(self.slides.patchlist[idx])
        
        # transform images to pytorch tensors
        lr_patch = self.transform(lr_patch)
        hr_patch = self.transform(hr_patch)

        return lr_patch, hr_patch   

# Put DataGenerator in DataLoader
full_dataset = DataGenerator(slides[0])

# Split between training and validation set
train_size = int(train_val_ratio * len(full_dataset))
validation_size = len(full_dataset) - train_size
training_set, validation_set = torch.utils.data.random_split(full_dataset, [train_size, validation_size])

# num_workers should parallelize loading of the batches by the CPU
training_generator      = data.DataLoader(training_set, batch_size=batch_size, num_workers=batch_size, shuffle='True')
validation_generator    = data.DataLoader(validation_set, batch_size=1, num_workers=1, shuffle='False')

### Training the model

In [0]:
# Lists used for saving losses and PSNR
loss_list = []
epoch_loss_list = []
val_loss_list = []
validation_psnr = []

# Start stopwatch
t0 = time.time()

for epoch in range(num_epoch):
    # Switch to training mode
    ConvNet.train()
    for i, (lr_patch, hr_patch) in enumerate(training_generator):
        
        # Transfer training data to active device
        lr_patch, hr_patch = lr_patch.to(device), hr_patch.to(device)
        
        # Run the forward pass
        outputs = ConvNet(lr_patch)
        loss = criterion(outputs, hr_patch)
        loss_list.append(loss.item())

        # Backprop and perform Adam optimisation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Save loss every epoch
    epoch_loss = np.sum(loss_list)/len(training_generator)
    epoch_loss_list.append(epoch_loss)
    if epoch % 50 == 0:
        print("Epoch", epoch, "loss: {}".format(epoch_loss))
    
    # Save model every 1000 epochs
    if epoch % 1000 == 999:
        'Save model every 1000 epochs (in case Google Colab stops runtime)'
        model_name = 'test3_mu_' + str(threshold_mu) + '_' + str(epoch+1) + '_epochs'
        path_model = path+'/introductory_notebooks/saved_models/' + model_name
        torch.save(ConvNet.state_dict(), path_model)
        print('Model saved as: ', model_name)

    # Step to next step of lr-scheduler
    scheduler.step(epoch_loss)
    loss_list = []

    # Enter validation mode
    ConvNet.eval()
    # Keep track of mse for every patch, to collectively calculate PSNR per epoch
    epoch_mse = []
    with torch.no_grad():
        for i, (lr_patch, hr_patch) in enumerate(validation_generator):
            # Transfer training data to active device (GPU)
            lr_patch, hr_patch = lr_patch.to(device), hr_patch.to(device)

            # Predict output
            img_pred = ConvNet(lr_patch)
            
            # Calculate validation loss
            loss = criterion(img_pred, hr_patch)
            loss_list.append(loss.item())
            
            # Calculate mse for every sample 
            img_pred = img_pred[0].cpu().numpy()
            hr_patch = hr_patch[0].cpu().numpy()
            epoch_mse.append(calc_mse(img_pred, hr_patch))

        # Calculate validation psnr on the complete epoch from all individual MSE's
        val_psnr = psnr_from_mselist(np.array(epoch_mse))
        validation_psnr.append(val_psnr)
    
    # Save validation loss every epoch
    val_loss = np.sum(loss_list)/len(validation_generator)
    val_loss_list.append(val_loss)
    
    loss_list = []

print('Training took {} seconds'.format(time.time() - t0))
print('Seconds per epoch:',(time.time()-t0)/num_epoch)

### Plot training results

In [0]:
# Show training and validation loss of current model in memory
plt.figure(figsize=(8,8))
plt.subplot(2,1,2)
plt.title('PSNR for validation data')
plt.plot(np.arange(num_epoch), validation_psnr)

plt.subplot(2,1,1)
plt.title('Loss per epoch')
plt.plot(np.arange(num_epoch), epoch_loss_list, label='Training loss')
plt.plot(np.arange(num_epoch), val_loss_list, label='Validation loss')
plt.yscale("log")
plt.legend()

plt.show()

###Saving the model

In [0]:
# Save the current model in memory
model_name = path + '/introductory_notebooks/mu/' + 'mu_1e-2'
path_name = model_name + '_validation_psnr'
np.save(path_name, validation_psnr)
path_name = model_name + '_epoch_loss_list'
np.save(path_name, epoch_loss_list)
path_name = model_name + '_val_loss_list'
np.save(path_name, val_loss_list)
path_model = model_name + '_model'
torch.save(ConvNet.state_dict(), path_model)
print('Model saved as: ', model_name)

# Testing


#### Loading the model

In [0]:
model_name = 'final_test_5000_epochsweights'
path_model = path + '/introductory_notebooks/combined/' + model_name
ConvNet.load_state_dict(torch.load(path_model))

###Test image functions

In [0]:
def generate_figure(lr_img, sr_img, hr_img):
  """Show the low resolution image, together with the original high resolution image and upscaled version.
  Is used for a visual representation of how well the model behaves. """
  f = plt.figure(figsize=(8*3,8))
  f.add_subplot(1, 3, 1)
  plt.imshow(transforms.ToPILImage('YCbCr')(lr_img[0]).convert('RGB'))
  f.add_subplot(1, 3, 2)
  plt.imshow(transforms.ToPILImage('YCbCr')(sr_img[0]).convert('RGB'))
  f.add_subplot(1, 3, 3)
  plt.imshow(transforms.ToPILImage('YCbCr')(hr_img[0]).convert('RGB'))
  

def save_tensor_as_image(lr_img, sr_img, hr_img, i):
  """Save the tensors as PIL images in the saved_image folder."""
  transforms.ToPILImage('YCbCr')(lr_img[0]).convert('RGB').save(path + '/introductory_notebooks/saved_images/lr_img_' + str(i) + '.png')
  transforms.ToPILImage('YCbCr')(sr_img[0]).convert('RGB').save(path + '/introductory_notebooks/saved_images/sr_img_' + str(i) + '.png')
  transforms.ToPILImage('YCbCr')(hr_img[0]).convert('RGB').save(path + '/introductory_notebooks/saved_images/hr_img_' + str(i) + '.png')

####Data loader

In [0]:
class DataGenerator_test(data.Dataset):
  'Generates the dataset used for testing'

  def __init__(self, slides):
    self.slides = slides
    self.transform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor()])

  def __len__(self):
    return len(self.slides.namelist)

  def __getitem__(self, idx):
    lr_img, bicubic, hr_img = self.slides.colored_prediction(self.slides.namelist[idx])
    
    lr_img = self.transform(lr_img)
    hr_img = self.transform(hr_img)
    bicubic = self.transform(bicubic)


    return lr_img, bicubic, hr_img

## Put DataGenerator in DataLoader
def DataGenerator(slides):
  test_set = DataGenerator_test(slides)
  test_generator = data.DataLoader(test_set, batch_size=1, shuffle=False)
  return test_generator

###Testing the model

In [0]:
# Create predictions
#0: Yang91.
#1: Set5
#2: Set14
#3: BSD300
#4: BSD500
#5: SuperText136
#To train on all datasets: for j in range(len(slides[1:]))

rnd.seed(4)
testsets = ['Set5','Set14','BSD300','BSD500','SuperText136']

for j in range(len(slides[1:])):
  test_generator = DataGenerator(slides[j+1])
  show = rnd.randint(0,len(test_generator))

  start_time = time.time()
  outputs = []
  mse_list = []
  #for epoch in range(1): 
  for i, (lr_img, bicubic, hr_img) in enumerate(test_generator): #currently in form: [1,3,x,y]
    with torch.no_grad():
        
      # Only test on the Y (intensity channel)
      to_network = (lr_img[:,0,:,:]).unsqueeze(0)

      #Run the forward pass. Testing is done on CPU now
      outputs = ConvNet.cpu()(to_network)
      img_pred = outputs[0]

      # substitute the Y channel from bicubic with the one outputted by the model
      bicubic[:,0,:,:] = img_pred 
    
      mse = calc_mse(img_pred.numpy(), hr_img.numpy()[:,0,:,:])
      mse_list.append(mse.item())

      if i == show:
        generate_figure(lr_img, bicubic, hr_img)
        save_tensor_as_image(lr_img, bicubic, hr_img, i)

  elapsed_time = time.time() - start_time
  print(testsets[j], ': ', psnr_from_mselist(mse_list))