<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Abstract" data-toc-modified-id="Abstract-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Abstract</a></span></li><li><span><a href="#Imports" data-toc-modified-id="Imports-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Imports</a></span></li><li><span><a href="#Helpers" data-toc-modified-id="Helpers-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Helpers</a></span></li><li><span><a href="#Dataset" data-toc-modified-id="Dataset-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Dataset</a></span></li><li><span><a href="#Model-Definition" data-toc-modified-id="Model-Definition-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>Model Definition</a></span></li><li><span><a href="#Model-Setup" data-toc-modified-id="Model-Setup-6"><span class="toc-item-num">6&nbsp;&nbsp;</span>Model Setup</a></span></li><li><span><a href="#Training" data-toc-modified-id="Training-7"><span class="toc-item-num">7&nbsp;&nbsp;</span>Training</a></span></li><li><span><a href="#Testing" data-toc-modified-id="Testing-8"><span class="toc-item-num">8&nbsp;&nbsp;</span>Testing</a></span></li></ul></div>

# Abstract

We are building a realtime GAN for super resolution of streaming video.

Inspired by https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/esrgan/models.py
https://github.com/yjn870/ESPCN-pytorch
https://github.com/ewrfcas/Face-Super-Resolution


# Imports 

In [2]:
import glob
import random
import os
import sys
import itertools
import numpy as np
import math
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg19
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from torch.autograd import Variable

# Helpers

In [5]:
class DotDict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
    
def num_trainable_params(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    return params

# Dataset

In [19]:
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])


def denormalize(tensors):
    """ Denormalizes image tensors using mean and std """
    for c in range(3):
        tensors[:, c].mul_(std[c]).add_(mean[c])
    return torch.clamp(tensors, 0, 255)


class ImageDataset(Dataset):
    def __init__(self, root, hr_shape):
        hr_height, hr_width = hr_shape
        # Transforms for low resolution images and high resolution images
        self.lr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height // 4, hr_height // 4), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )
        self.hr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height, hr_height), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )

        self.files = sorted(glob.glob(root + "/*.*"))

    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        img_lr = self.lr_transform(img)
        img_hr = self.hr_transform(img)

        return {"lr": img_lr, "hr": img_hr}

    def __len__(self):
        return len(self.files)

# Model Definition

In [60]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        vgg19_model = vgg19(pretrained=True)
        self.vgg19_54 = nn.Sequential(*list(vgg19_model.features.children())[:35])

    def forward(self, img):
        return self.vgg19_54(img)

class Generator(nn.Module):
    def __init__(self, scale_factor, num_channels=3):
        super().__init__()
#         layers = [
#             nn.Conv2d(num_channels, 64, kernel_size=5, padding=5//2),
#             #nn.BatchNorm2d(64),
#             #nn.ReLU(),
#             nn.Tanh(),
#             nn.Conv2d(64, 32, kernel_size=3, padding=3//2),
#             #nn.BatchNorm2d(32),
#             #nn.ReLU(),
#             nn.Tanh(),
#             nn.Conv2d(32, 16, kernel_size=3, padding=3//2),
#             #nn.BatchNorm2d(16),
#             #nn.ReLU(),
#             nn.Tanh(),
#             nn.Conv2d(16, num_channels * (scale_factor ** 2), kernel_size=3, padding=3 // 2),
#             nn.PixelShuffle(scale_factor),
#             nn.Sigmoid()
#         ]
        layers = [
            nn.Conv2d(num_channels, 64, kernel_size=5, padding=5//2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Tanh(),
            nn.Conv2d(64, 32, kernel_size=3, padding=3//2),
            nn.Tanh(),
            nn.Conv2d(32, 16, kernel_size=3, padding=3//2),
            nn.Tanh(),
            nn.Conv2d(16, num_channels * (scale_factor ** 2), kernel_size=3, padding=3 // 2),
            nn.PixelShuffle(scale_factor)
        ]
        self.model = nn.Sequential(*layers)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m.in_channels == 32:
                    nn.init.normal_(m.weight.data, mean=0.0, std=0.001)
                    nn.init.zeros_(m.bias.data)
                else:
                    nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                    nn.init.zeros_(m.bias.data)

    def forward(self, x):
        x = self.model(x)
        return x


class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        self.input_shape = input_shape
        in_channels, in_height, in_width = self.input_shape
        patch_h, patch_w = int(in_height / 2 ** 3), int(in_width / 2 ** 3) # ? used to be ** 4
        self.output_shape = (1, patch_h, patch_w)
        layers = []
        
        layers.append(nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        
        layers.append(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1))
        layers.append(nn.BatchNorm2d(128))
        layers.append(nn.LeakyReLU(0.2, inplace=True))

        layers.append(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1))
        layers.append(nn.BatchNorm2d(256))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        
        layers.append(nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1))
        layers.append(nn.BatchNorm2d(256))
        layers.append(nn.LeakyReLU(0.2, inplace=True))  

        layers.append(nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1))
        
        # Dense Layers
        #layers.append(nn.Linear(256 * 4 * 4, 128))
        #layers.append(nn.LeakyReLU(0,2, inplace=True))
        
        #layers.append(nn.Linear(128, 1))
        #layers.append(nn.Sigmoid())

        self.model = nn.Sequential(*layers)

    def forward(self, img):
        return self.model(img)

# Model Setup

In [62]:
os.makedirs("images/training", exist_ok=True)
os.makedirs("saved_models", exist_ok=True)

# parser = argparse.ArgumentParser()
# parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
# parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
# parser.add_argument("--dataset_name", type=str, default="img_align_celeba", help="name of the dataset")
# parser.add_argument("--batch_size", type=int, default=4, help="size of the batches")
# parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
# parser.add_argument("--b1", type=float, default=0.9, help="adam: decay of first order momentum of gradient")
# parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
# parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")
# parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
# parser.add_argument("--hr_height", type=int, default=256, help="high res. image height")
# parser.add_argument("--hr_width", type=int, default=256, help="high res. image width")
# parser.add_argument("--channels", type=int, default=3, help="number of image channels")
# parser.add_argument("--sample_interval", type=int, default=100, help="interval between saving image samples")
# parser.add_argument("--checkpoint_interval", type=int, default=5000, help="batch interval between model checkpoints")
# parser.add_argument("--residual_blocks", type=int, default=23, help="number of residual blocks in the generator")
# parser.add_argument("--warmup_batches", type=int, default=500, help="number of batches with pixel-wise loss only")
# parser.add_argument("--lambda_adv", type=float, default=5e-3, help="adversarial loss weight")
# parser.add_argument("--lambda_pixel", type=float, default=1e-2, help="pixel-wise loss weight")

opt = {
    'epoch': 0,
    'n_epochs': 200,
    'dataset_name': 'img_align_celeba',
    'batch_size': 4,
    'lr': 0.0001,
    'b1': 0.9,
    'b2': 0.999,
    'decay_epoch': 100,
    'n_cpu': 8,
    'hr_height': 256,
    'hr_width': 256,
    'channels': 3,
    'sample_interval': 100,
    'checkpoint_interval': 5000,
    'residual_blocks': 23,
    'warmup_batches': 500,
    'lambda_adv': 5e-3,
    'lambda_pixel': 1e-2,
    'scale_factor': 4
}

opt = DotDict(opt)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

hr_shape = (opt.hr_height, opt.hr_width)

# Initialize generator and discriminator
generator = Generator(num_channels=opt.channels, scale_factor=opt.scale_factor).to(device)
discriminator = Discriminator(input_shape=(opt.channels, *hr_shape)).to(device)
feature_extractor = FeatureExtractor().to(device)

# Set feature extractor to inference mode
feature_extractor.eval()

# Losses
criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device)
criterion_content = torch.nn.L1Loss().to(device)
criterion_pixel = torch.nn.L1Loss().to(device)

if opt.epoch != 0:
    # Load pretrained models
    generator.load_state_dict(torch.load("saved_models/generator_%d.pth" % opt.epoch))
    discriminator.load_state_dict(torch.load("saved_models/discriminator_%d.pth" % opt.epoch))

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor

dataloader = DataLoader(
    ImageDataset("./data/%s" % opt.dataset_name, hr_shape=hr_shape),
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=opt.n_cpu,
)

In [51]:
num_trainable_params(generator)

35168

# Training

In [63]:
# ----------
#  Training
# ----------

for epoch in range(opt.epoch, opt.n_epochs):
    for i, imgs in enumerate(dataloader):

        batches_done = epoch * len(dataloader) + i

        # Configure model input
        imgs_lr = Variable(imgs["lr"].type(Tensor))
        imgs_hr = Variable(imgs["hr"].type(Tensor))

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)

        # ------------------
        #  Train Generators
        # ------------------

        optimizer_G.zero_grad()

        # Generate a high resolution image from low resolution input
        gen_hr = generator(imgs_lr)

        # Measure pixel-wise loss against ground truth
        loss_pixel = criterion_pixel(gen_hr, imgs_hr)

        if batches_done < opt.warmup_batches:
            # Warm-up (pixel-wise loss only)
            loss_pixel.backward()
            optimizer_G.step()
            print(
                "[Epoch %d/%d] [Batch %d/%d] [G pixel: %f]"
                % (epoch, opt.n_epochs, i, len(dataloader), loss_pixel.item())
            )
            continue

        # Extract validity predictions from discriminator
        pred_real = discriminator(imgs_hr).detach()
        pred_fake = discriminator(gen_hr)

        # Adversarial loss (relativistic average GAN)
        loss_GAN = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), valid)

        # Content loss
        gen_features = feature_extractor(gen_hr)
        real_features = feature_extractor(imgs_hr).detach()
        loss_content = criterion_content(gen_features, real_features)

        # Total generator loss
        loss_G = loss_content + opt.lambda_adv * loss_GAN + opt.lambda_pixel * loss_pixel

        loss_G.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        pred_real = discriminator(imgs_hr)
        pred_fake = discriminator(gen_hr.detach())

        # Adversarial loss for real and fake images (relativistic average GAN)
        loss_real = criterion_GAN(pred_real - pred_fake.mean(0, keepdim=True), valid)
        loss_fake = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), fake)

        # Total loss
        loss_D = (loss_real + loss_fake) / 2

        loss_D.backward()
        optimizer_D.step()

        # --------------
        #  Log Progress
        # --------------

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, content: %f, adv: %f, pixel: %f]"
            % (
                epoch,
                opt.n_epochs,
                i,
                len(dataloader),
                loss_D.item(),
                loss_G.item(),
                loss_content.item(),
                loss_GAN.item(),
                loss_pixel.item(),
            )
        )

        if batches_done % opt.sample_interval == 0:
            # Save image grid with upsampled inputs and ESRGAN outputs
            imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
            img_grid = denormalize(torch.cat((imgs_lr, gen_hr), -1))
            save_image(img_grid, "images/training/%d.png" % batches_done, nrow=1, normalize=False)

        if batches_done % opt.checkpoint_interval == 0:
            # Save model checkpoints
            torch.save(generator.state_dict(), "saved_models/generator_%d.pth" % epoch)
            torch.save(discriminator.state_dict(), "saved_models/discriminator_%d.pth" %epoch)

[Epoch 0/200] [Batch 0/50650] [G pixel: 1.219584]
[Epoch 0/200] [Batch 1/50650] [G pixel: 1.089705]
[Epoch 0/200] [Batch 2/50650] [G pixel: 1.120722]
[Epoch 0/200] [Batch 3/50650] [G pixel: 1.370141]
[Epoch 0/200] [Batch 4/50650] [G pixel: 1.253313]
[Epoch 0/200] [Batch 5/50650] [G pixel: 1.280681]
[Epoch 0/200] [Batch 6/50650] [G pixel: 1.175487]
[Epoch 0/200] [Batch 7/50650] [G pixel: 1.268566]
[Epoch 0/200] [Batch 8/50650] [G pixel: 1.187997]
[Epoch 0/200] [Batch 9/50650] [G pixel: 1.160830]
[Epoch 0/200] [Batch 10/50650] [G pixel: 0.998945]
[Epoch 0/200] [Batch 11/50650] [G pixel: 1.075892]
[Epoch 0/200] [Batch 12/50650] [G pixel: 1.079527]
[Epoch 0/200] [Batch 13/50650] [G pixel: 1.222022]
[Epoch 0/200] [Batch 14/50650] [G pixel: 1.061572]
[Epoch 0/200] [Batch 15/50650] [G pixel: 1.294962]
[Epoch 0/200] [Batch 16/50650] [G pixel: 0.869440]
[Epoch 0/200] [Batch 17/50650] [G pixel: 1.150122]
[Epoch 0/200] [Batch 18/50650] [G pixel: 0.898998]
[Epoch 0/200] [Batch 19/50650] [G pixel: 

[Epoch 0/200] [Batch 161/50650] [G pixel: 0.420105]
[Epoch 0/200] [Batch 162/50650] [G pixel: 0.503436]
[Epoch 0/200] [Batch 163/50650] [G pixel: 0.388756]
[Epoch 0/200] [Batch 164/50650] [G pixel: 0.418133]
[Epoch 0/200] [Batch 165/50650] [G pixel: 0.438609]
[Epoch 0/200] [Batch 166/50650] [G pixel: 0.425950]
[Epoch 0/200] [Batch 167/50650] [G pixel: 0.503539]
[Epoch 0/200] [Batch 168/50650] [G pixel: 0.365685]
[Epoch 0/200] [Batch 169/50650] [G pixel: 0.418709]
[Epoch 0/200] [Batch 170/50650] [G pixel: 0.430653]
[Epoch 0/200] [Batch 171/50650] [G pixel: 0.360724]
[Epoch 0/200] [Batch 172/50650] [G pixel: 0.444838]
[Epoch 0/200] [Batch 173/50650] [G pixel: 0.450646]
[Epoch 0/200] [Batch 174/50650] [G pixel: 0.469654]
[Epoch 0/200] [Batch 175/50650] [G pixel: 0.449924]
[Epoch 0/200] [Batch 176/50650] [G pixel: 0.509331]
[Epoch 0/200] [Batch 177/50650] [G pixel: 0.392790]
[Epoch 0/200] [Batch 178/50650] [G pixel: 0.413990]
[Epoch 0/200] [Batch 179/50650] [G pixel: 0.378979]
[Epoch 0/200

[Epoch 0/200] [Batch 319/50650] [G pixel: 0.436049]
[Epoch 0/200] [Batch 320/50650] [G pixel: 0.455440]
[Epoch 0/200] [Batch 321/50650] [G pixel: 0.350699]
[Epoch 0/200] [Batch 322/50650] [G pixel: 0.369248]
[Epoch 0/200] [Batch 323/50650] [G pixel: 0.396323]
[Epoch 0/200] [Batch 324/50650] [G pixel: 0.367824]
[Epoch 0/200] [Batch 325/50650] [G pixel: 0.613468]
[Epoch 0/200] [Batch 326/50650] [G pixel: 0.360417]
[Epoch 0/200] [Batch 327/50650] [G pixel: 0.359117]
[Epoch 0/200] [Batch 328/50650] [G pixel: 0.540572]
[Epoch 0/200] [Batch 329/50650] [G pixel: 0.472352]
[Epoch 0/200] [Batch 330/50650] [G pixel: 0.329671]
[Epoch 0/200] [Batch 331/50650] [G pixel: 0.438570]
[Epoch 0/200] [Batch 332/50650] [G pixel: 0.420260]
[Epoch 0/200] [Batch 333/50650] [G pixel: 0.456399]
[Epoch 0/200] [Batch 334/50650] [G pixel: 0.477672]
[Epoch 0/200] [Batch 335/50650] [G pixel: 0.283108]
[Epoch 0/200] [Batch 336/50650] [G pixel: 0.378916]
[Epoch 0/200] [Batch 337/50650] [G pixel: 0.349797]
[Epoch 0/200

[Epoch 0/200] [Batch 478/50650] [G pixel: 0.393337]
[Epoch 0/200] [Batch 479/50650] [G pixel: 0.368473]
[Epoch 0/200] [Batch 480/50650] [G pixel: 0.472915]
[Epoch 0/200] [Batch 481/50650] [G pixel: 0.319962]
[Epoch 0/200] [Batch 482/50650] [G pixel: 0.395060]
[Epoch 0/200] [Batch 483/50650] [G pixel: 0.331254]
[Epoch 0/200] [Batch 484/50650] [G pixel: 0.310075]
[Epoch 0/200] [Batch 485/50650] [G pixel: 0.373957]
[Epoch 0/200] [Batch 486/50650] [G pixel: 0.369219]
[Epoch 0/200] [Batch 487/50650] [G pixel: 0.403492]
[Epoch 0/200] [Batch 488/50650] [G pixel: 0.398776]
[Epoch 0/200] [Batch 489/50650] [G pixel: 0.385531]
[Epoch 0/200] [Batch 490/50650] [G pixel: 0.398330]
[Epoch 0/200] [Batch 491/50650] [G pixel: 0.343642]
[Epoch 0/200] [Batch 492/50650] [G pixel: 0.320224]
[Epoch 0/200] [Batch 493/50650] [G pixel: 0.375526]
[Epoch 0/200] [Batch 494/50650] [G pixel: 0.321037]
[Epoch 0/200] [Batch 495/50650] [G pixel: 0.276862]
[Epoch 0/200] [Batch 496/50650] [G pixel: 0.415510]
[Epoch 0/200

KeyboardInterrupt: 

# Testing