<a href="https://colab.research.google.com/github/zachmurphy1/facemask-faster-rcnn/blob/main/Super_Resolution.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Super resolution
This notebook implements and trains a super resolution network according to a SRResNet architecture for 4x upscaling. Train data is 5k 128x128 px images of faces from the Flickr Faces HQ (FFHQ) dataset. Val data is 1k images from the FFHQ not in the train set.

## Input
Train and val images
```
sr_training/128_train/
sr_training/128_val/
```

## Output
Trained SR network
```
sr_training/sr_model.pkl
```

In [None]:
# Imports
import numpy as np
import pickle
import sys, os
import matplotlib.pyplot as plt
import matplotlib.patches as patches

from PIL import Image
from bs4 import BeautifulSoup
import torch, torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn

In [None]:
# Mount data directory
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
%cd /content/gdrive/My\ Drive/facemask-faster-rcnn/

SRDATADIR = 'sr_training/128_train/'
SRDATADIR_VAL = 'sr_training/128_val/'

Mounted at /content/gdrive


## Dataset class

In [None]:
class SRDataset(Dataset):
  def __init__(self, mode='train'):
    if mode=='train':
      self.data_dir = SRDATADIR
    elif mode=='val':
      self.data_dir = SRDATADIR_VAL

  def __len__(self):
    return len(next(os.walk(self.data_dir))[2])

  def __getitem__(self,idx):
    img = Image.open(SRDATADIR + f'{idx:05d}' + '.png').convert('RGB')
    to_tensor = transforms.ToTensor()
    img = to_tensor(img)

    # Color jitter and random horizontal flips
    augmentations = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ColorJitter()])
    img = augmentations(img)

    # Downscale by factor of 4
    p = img
    scale = 0.25
    downscale = torchvision.transforms.Resize((int(p.shape[1]*scale),int(p.shape[2]*scale)),interpolation=Image.BICUBIC)
    p = downscale(p)
    return p, img

## SRResNet class

In [None]:
class Bblock(nn.Module):
  def __init__(self):
    super(Bblock,self).__init__()

    self.conv1 = nn.Conv2d(64,64,(3,3),stride=1,padding=1)
    self.bn1 = nn.BatchNorm2d(64)
    self.prelu = nn.PReLU(64)
    self.conv2 = nn.Conv2d(64,64,(3,3),stride=1,padding=1)
    self.bn2 = nn.BatchNorm2d(64)
    
  def forward(self, x):
    skip = x
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.prelu(out)
    out = self.conv2(out)
    out = self.bn2(out)
    out = out + skip
    return out


class Upscale(nn.Module):
  def __init__(self):
    super(Upscale,self).__init__()
    self.conv1 = nn.Conv2d(64,256,(3,3),stride=1,padding=1)
    self.pixelShuffle = nn.PixelShuffle(2)
    self.prelu = nn.PReLU()

  def forward(self, x):
    out = self.conv1(x)
    out = self.pixelShuffle(out)
    out = self.prelu(out)
    return out


class SRNetwork(nn.Module):
  def __init__(self):
    super(SRNetwork,self).__init__()
    self.conv1 = nn.Conv2d(3,64,(9,9),stride=1,padding=4)
    self.prelu = nn.PReLU()

    bres_modules = []
    for i in range(16):
      bres_modules.append(Bblock())
    self.Bres = nn.Sequential(*bres_modules)

    self.conv2 = nn.Conv2d(64,64,(3,3),stride=1,padding=1)
    self.bn2 = nn.BatchNorm2d(64)

    self.upscale1 = Upscale()
    self.upscale2 = Upscale()

    self.conv3 = nn.Conv2d(64,3,(9,9),stride=1,padding=4)

  def forward(self, x):
    out = self.conv1(x)
    out = self.prelu(out)
    skip = out
    out = self.Bres(out)
    out = self.conv2(out)
    out = self.bn2(out)
    out = out + skip

    out = self.upscale1(out)
    out = self.upscale2(out)

    out = self.conv3(out)
    return out


## Image show function

In [None]:
def showImg(images):
  with torch.no_grad():
    fig, ax = plt.subplots(1,len(images), figsize=(20,60))
    for i in range(len(images)):
      img_t = torch.transpose(torch.transpose(images[i],0,2),0,1).cpu()
      img_t = torch.clamp(img_t,0,1)
      ax[i].imshow(img_t)
    plt.show()

## Training loop

In [None]:
# Hyperparameters
lr = 1e-4
batch_size = 16
max_epochs = 160

# Model
model = SRNetwork()
model = model.cuda()

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr,weight_decay=1e-4)

# Loss
loss_fxn = nn.MSELoss()

srData = SRDataset(mode='train')
srValData = SRDataset(mode='val')
srLoader = DataLoader(srData, batch_size=batch_size, pin_memory=True, shuffle=True)
srValLoader = DataLoader(srValData, batch_size=batch_size, pin_memory=True, shuffle=True)


# Train loop
minibatch_losses = []
epoch_losses = []
val_losses = []
for epoch in range(max_epochs):
  epoch_loss = 0
  batch_count = 0
  for x, y in srLoader:
    if torch.cuda.is_available():
      x = x.cuda()
      y = y.cuda()

    # Get pred
    yhat = model(x)

    # Get loss and backprop
    optimizer.zero_grad()
    loss = loss_fxn(yhat,y)
    batch_count += 1
    sys.stdout.write('\rEpoch {} (Batch {}/{}) Loss: {:.8f}'.format(epoch,batch_count,len(srLoader), loss))
    sys.stdout.flush()
    minibatch_losses.append(loss)
    epoch_loss += loss
    loss.backward()
    optimizer.step()

  # Val loss
  with torch.no_grad():
    val_loss = 0
    for x, y in srValLoader:
      if torch.cuda.is_available():
        x = x.cuda()
        y = y.cuda()

      # Get pred
      yhat = model(x)

      # Get loss and backprop
      optimizer.zero_grad()
      loss = loss_fxn(yhat,y)
      val_loss += loss

    # Print
    epoch_loss /= len(srLoader)
    val_loss /= len(srValLoader)
    print('Epoch', epoch, 'Train Loss',epoch_loss, 'Val Loss', val_loss)
    showImg([x[0],yhat[0],y[0]])
    epoch_losses.append(epoch_loss)
    val_losses.append(val_loss)
    


## Save model

In [None]:
import pickle
with open('sr_training/sr_model.pkl', 'rb') as f:
  model = pickle.load(f)

## Test model on face mask images

In [None]:
s=0
for i in range(s,s+100):
  print(i)
  img = Image.open('facemask_data/images/maksssksksss{}.png'.format(i)).convert('RGB')
  to_tensor = transforms.ToTensor()
  img = to_tensor(img)
  sr_pred = model(img.unsqueeze(0).cuda())
  showImg([img,sr_pred[0]])

## Save an example

In [None]:
i=64
img = Image.open('facemask_data/images/maksssksksss{}.png'.format(i)).convert('RGB')
to_tensor = transforms.ToTensor()
to_image = transforms.ToPILImage()
img = to_tensor(img)
sr_pred = model(img.unsqueeze(0).cuda())

img = to_image(img)
img.save('sr_example_LR.png')
sr_pred = to_image(torch.clip(sr_pred[0],0,1))
sr_pred.save('sr_example_HR.png')