# SRGAN
Notebook to reproduce the results of the SRGAN paper following https://medium.com/analytics-vidhya/super-resolution-gan-srgan-5e10438aec0c 

In [1]:
!git clone https://github.com/vishal1905/Super-Resolution.git

Cloning into 'Super-Resolution'...
remote: Enumerating objects: 10096, done.[K
remote: Total 10096 (delta 0), reused 0 (delta 0), pack-reused 10096[K
Receiving objects: 100% (10096/10096), 65.77 MiB | 26.36 MiB/s, done.
Resolving deltas: 100% (3/3), done.


In [1]:
import os
import torch
import torch.nn as n
import torch.nn.functional as f
import numpy as np
import os
from torchsummary import summary
import torch.optim as optim
from tqdm import tqdm
from torchvision import models
import torchvision
import cv2
from matplotlib import pyplot as plt
from PIL import Image

cuda = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
base_dir = "Super-Resolution/celeba-dataset/img_align_celeba/img_align_celeba/"
images = os.listdir(base_dir)
imageList = images[:1500]

## Define generator

In [3]:
class Generator(n.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = n.Conv2d(3,64,9,padding=4,bias=False)
        self.conv2 = n.Conv2d(64,64,3,padding=1,bias=False)
        self.conv3_1 = n.Conv2d(64,256,3,padding=1,bias=False)
        self.conv3_2 = n.Conv2d(64,256,3,padding=1,bias=False)
        self.conv4 = n.Conv2d(64,3,9,padding=4,bias=False)
        self.bn = n.BatchNorm2d(64)
        self.ps = n.PixelShuffle(2)
        self.prelu = n.PReLU()
        
    def forward(self,x):
        block1 = self.prelu(self.conv1(x))
        block2 = torch.add(self.bn(self.conv2(self.prelu(self.bn(self.conv2(block1))))),block1)
        block3 = torch.add(self.bn(self.conv2(self.prelu(self.bn(self.conv2(block2))))),block2)
        block4 = torch.add(self.bn(self.conv2(self.prelu(self.bn(self.conv2(block3))))),block3)
        block5 = torch.add(self.bn(self.conv2(self.prelu(self.bn(self.conv2(block4))))),block4)
        block6 = torch.add(self.bn(self.conv2(self.prelu(self.bn(self.conv2(block5))))),block5)
        block7 = torch.add(self.bn(self.conv2(block6)),block1)
        block8 = self.prelu(self.ps(self.conv3_1(block7)))
        block9 = self.prelu(self.ps(self.conv3_2(block8)))
        block10 = self.conv4(block9)
        return block10

In [4]:
gen = Generator().to(cuda).float()

## Discriminator

In [5]:
class Discriminator(n.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = n.Conv2d(3,64,3,padding=1,bias=False)
        self.conv2 = n.Conv2d(64,64,3,stride=2,padding=1,bias=False)
        self.bn2 = n.BatchNorm2d(64)
        self.conv3 = n.Conv2d(64,128,3,padding=1,bias=False)
        self.bn3 = n.BatchNorm2d(128)
        self.conv4 = n.Conv2d(128,128,3,stride=2,padding=1,bias=False)
        self.bn4 = n.BatchNorm2d(128)
        self.conv5 = n.Conv2d(128,256,3,padding=1,bias=False)
        self.bn5 = n.BatchNorm2d(256)
        self.conv6 = n.Conv2d(256,256,3,stride=2,padding=1,bias=False)
        self.bn6 = n.BatchNorm2d(256)
        self.conv7 = n.Conv2d(256,512,3,padding=1,bias=False)
        self.bn7 = n.BatchNorm2d(512)
        self.conv8 = n.Conv2d(512,512,3,stride=2,padding=1,bias=False)
        self.bn8 = n.BatchNorm2d(512)
        self.fc1 = n.Linear(512*16*16,1024)
        self.fc2 = n.Linear(1024,1)
        self.drop = n.Dropout2d(0.3)
        
    def forward(self,x):
        block1 = f.leaky_relu(self.conv1(x))
        block2 = f.leaky_relu(self.bn2(self.conv2(block1)))
        block3 = f.leaky_relu(self.bn3(self.conv3(block2)))
        block4 = f.leaky_relu(self.bn4(self.conv4(block3)))
        block5 = f.leaky_relu(self.bn5(self.conv5(block4)))
        block6 = f.leaky_relu(self.bn6(self.conv6(block5)))
        block7 = f.leaky_relu(self.bn7(self.conv7(block6)))
        block8 = f.leaky_relu(self.bn8(self.conv8(block7)))
        block8 = block8.view(-1,block8.size(1)*block8.size(2)*block8.size(3))
        block9 = f.leaky_relu(self.fc1(block8),)
        block10 = torch.sigmoid(self.drop(self.fc2(block9)))
        return block9,block10

In [6]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [7]:
disc = Discriminator().to(cuda).float()

In [8]:
count_parameters(gen), count_parameters(disc)

(363009, 138906945)

## Setup training

In [9]:
vgg = models.vgg19(pretrained=True).to(cuda)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /home/stephan/.cache/torch/checkpoints/vgg19-dcbb9e9d.pth


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=574673361.0), HTML(value='')))




### Losses

In [10]:
gen_loss = n.BCELoss()
vgg_loss = n.MSELoss()
mse_loss = n.MSELoss()
disc_loss = n.BCELoss()

In [11]:
gen_optimizer = optim.Adam(gen.parameters(),lr=0.0001)
disc_optimizer = optim.Adam(disc.parameters(),lr=0.0001)

### Load images

### New image loaders

In [12]:
def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = checkpoint['model']
    model.load_state_dict(checkpoint['state_dict'])
    for parameter in model.parameters():
        parameter.requires_grad = False 
    model.eval()
    return model

In [13]:
def imagePostProcess(imagedir, modelPath):
    """
    Show model output on unseen images 
    Parameters:
    ----------
    imagedir: str
        List of paths to unseen images
    
    """
    imagelist=[]
    original_images = []
    for img in imagedir:
        img_original = cv2.imread(os.path.join(hr_path,img))
        img_original = cv2.resize(img_original, (256,256))
        img = degrade_resolution(img_original)
        imagelist.append(img)
        original_images.append(img_original)
    original_images = np.array(original_images)
    imagearray = np.array(imagelist)/255
    imagearrayPT = np.moveaxis(imagearray,3,1)

    model = load_checkpoint(modelPath)
    im_tensor = torch.from_numpy(imagearrayPT).float()
    out_tensor = model(im_tensor)
    out = out_tensor.numpy()
    out = np.moveaxis(out,1,3)
    out = np.clip(out,0,1)
    
    return original_images, imagearray, out

In [14]:
def show_samples(image_dir, model_path):

    # Load images and run through
    original_images, low_res, out = imagePostProcess(image_dir, model_path)

    # Get the number of samples to plot
    n_samples = len(image_dir)
    figure, axes = plt.subplots(n_samples, 3)
    for i in range(n_samples):
        axes[i,0].imshow(original_images[i,...][...,::-1])

    for i in range(n_samples):
        axes[i,1].imshow(low_res[i,...][...,::-1])

    for i in range(n_samples):
        axes[i,2].imshow(out[i,...][...,::-1])
    plt.axis("off")
    plt.show()
    plt.close()

In [15]:
def degrade_resolution(image):
    """
    Degrade image resolution
    """
    resized = cv2.resize(cv2.GaussianBlur(image,(5,5),cv2.BORDER_DEFAULT),(64,64)) 
    return resized

In [16]:
class Dataset(torch.utils.data.Dataset):
  'Characterizes a dataset for PyTorch'

  def __init__(self, image_list):
        'Initialization'
        self.image_list = image_list
        self.base_dir = "Super-Resolution/celeba-dataset/img_align_celeba/img_align_celeba/"

  def __len__(self):
        return len(self.image_list)

  def __getitem__(self, index):
        # Select sample
        image_file = self.image_list[index]

        # Load the original (high res) image
        high_res = cv2.imread(self.base_dir+image_file)
        high_res = cv2.resize(high_res, (256,256))

        # Degrade to low res
        low_res = degrade_resolution(high_res)

        # Normalise
        high_res = torch.from_numpy(high_res/255)
        low_res = torch.from_numpy(low_res/255)

        # Channels to second dim
        high_res = high_res.permute(2,0,1)
        low_res = low_res.permute(2,0,1)

        return low_res.cuda().float(), high_res.cuda().float()

In [17]:
image_dataset = Dataset(imageList)

In [18]:
training_generator = torch.utils.data.DataLoader(image_dataset,
                                                 shuffle = True,
                                                 batch_size = 32
                                                )

In [19]:
import os 
base_path = os.getcwd()

#lr_path = os.path.join(base_path,"trainImages")
hr_path = base_dir
#valid_path = os.path.join(base_path,"SR_valid")
weight_file = os.path.join(base_path,"SRPT_weights")
out_path = os.path.join(base_path,"out")

if not os.path.exists(weight_file):
    os.makedirs(weight_file)

if not os.path.exists(out_path):
    os.makedirs(out_path)


In [20]:
#batch_count=60
epochs = 100

for epoch in range(epochs):
    d1loss_list=[]
    d2loss_list=[]
    gloss_list=[]
    vloss_list=[]
    mloss_list=[]
    
    for lr_images, hr_images in tqdm(training_generator):
                
        disc.zero_grad()

        gen_out = gen(lr_images)
        _,f_label = disc(gen_out)
        _,r_label = disc(hr_images)

        d1_loss = (disc_loss(f_label,torch.zeros_like(f_label,dtype=torch.float)))
        d2_loss = (disc_loss(r_label,torch.ones_like(r_label,dtype=torch.float)))
        d2_loss.backward()
        d1_loss.backward(retain_graph=True)
        disc_optimizer.step()

        gen.zero_grad()      
        g_loss = gen_loss(f_label.data,torch.ones_like(f_label,dtype=torch.float))
        v_loss = vgg_loss(vgg.features[:7](gen_out),vgg.features[:7](hr_images))
        m_loss = mse_loss(gen_out,hr_images)
        
        generator_loss = g_loss + v_loss + m_loss
        generator_loss.backward()
        gen_optimizer.step()
        
        d1loss_list.append(d1_loss.item())
        d2loss_list.append(d2_loss.item())
        
        gloss_list.append(g_loss.item())
        vloss_list.append(v_loss.item())
        mloss_list.append(m_loss.item())

    print("d1_loss: "+str(np.mean(d1loss_list))+"  d2_loss:"+str(np.mean(d2loss_list)))
    print("genLoss: "+str(np.mean(gloss_list))+"  vggLoss: "+str(np.mean(vloss_list))+"  MeanLoss: "+str(np.mean(mloss_list)))
    
    if(epoch%3==0):
        
        checkpoint = {'model': Generator(),
              'input_size': 64,
              'output_size': 256,
              'state_dict': gen.state_dict()}
        torch.save(checkpoint,os.path.join(weight_file,"SR"+str(epoch+1)+".pth"))
        torch.cuda.empty_cache()
        
        show_samples(images[-2:],os.path.join(weight_file,"SR"+str(epoch+1)+".pth"))

  0%|          | 0/47 [00:01<?, ?it/s]


RuntimeError: CUDA out of memory. Tried to allocate 512.00 MiB (GPU 0; 11.17 GiB total capacity; 3.76 GiB already allocated; 234.69 MiB free; 4.15 GiB reserved in total by PyTorch)