<a href="https://colab.research.google.com/github/SagarJiyani3010/GAN-Tutorials/blob/SRGAN/SRGAN/Notebooks/SRGAN_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install -q kaggle

In [None]:
from google.colab import files
files.upload()

In [None]:
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json

In [None]:
! kaggle datasets download -d akhileshdkapse/super-image-resolution

In [None]:
! unzip super-image-resolution.zip -d /content/drive/MyDrive/ColabNotebooks/GANs/SRGAN/

# Libraries

In [None]:
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import transforms
import numpy as np
import math
from torchvision.models.vgg import vgg16
import torch.optim as optim
from tqdm import tqdm
from torchvision.utils import save_image
from torch.autograd import Variable

torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f9a68e5bc10>

# Generator

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels),
            nn.PReLU(),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels)
        )

    def forward(self, x):
        x = self.conv(x)
        return x

class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super(UpsampleBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels*up_scale**2, kernel_size=3, padding=1),
            nn.PixelShuffle(up_scale),
            nn.PReLU()
        )
    
    def forward(self, x):
        x = self.conv(x)
        return x

class Generator(nn.Module):
    def __init__(self, scale_factor):
        upsample_block_num = int(math.log(scale_factor, 2))
        super(Generator, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=9, padding=4),
            nn.PReLU()
        )
        self.block2 = ResidualBlock(64)
        self.block3 = ResidualBlock(64)
        self.block4 = ResidualBlock(64)
        self.block5 = ResidualBlock(64)
        self.block6 = ResidualBlock(64)
        self.block7 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64)
        )
        block8 = [UpsampleBlock(64, 2) for _ in range(upsample_block_num)]
        block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
        self.block8 = nn.Sequential(*block8)

    def forward(self, x):
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)
        block6 = self.block6(block5)
        block7 = self.block7(block6)
        block8 = self.block8(block1 + block7)

        return (torch.tanh(block8) + 1) / 2   

# Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, kernel_size=1)
        )

    def forward(self, x):
        batch_size = x.size(0)
        return torch.sigmoid(self.net(x).view(batch_size))

# Loss FUnctions

In [None]:
class GeneratorLoss(nn.Module):
    def __init__(self):
        super(GeneratorLoss, self).__init__()
        vgg = vgg16(pretrained=True)
        loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
        for param in loss_network.parameters():
            param.requires_grad = False
        self.loss_network = loss_network
        self.mse_loss = nn.MSELoss()
        self.tv_loss = TVLoss()

    def forward(self, out_labels, out_images, target_images):
        # Adversarial Loss
        adversarial_loss = torch.mean(1 - out_labels)
        # Perception Loss
        perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
        # Image Loss
        image_loss = self.mse_loss(out_images, target_images)
        # TV Loss
        tv_loss = self.tv_loss(out_images)
        return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss

class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super(TVLoss, self).__init__()
        self.tv_loss_weight = tv_loss_weight

    def forward(self, x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self.tensor_size(x[:, :, 1:, :])
        count_w = self.tensor_size(x[:, :, :, 1:])
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    @staticmethod
    def tensor_size(t):
        return t.size()[1] * t.size()[2] * t.size()[3]

# Config

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
INPUT_DIR = '/content/drive/MyDrive/ColabNotebooks/GANs/SRGAN/Data/LR'
TARGET_DIR = '/content/drive/MyDrive/ColabNotebooks/GANs/SRGAN/Data/HR'
# INPUT_DIR_TEST = '../data/LR_test/'
# TARGET_DIR_TEST = '../data/HR_test/'
UPSCALE_FACTOR = 4
CROP_SIZE = 88
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
LEARNING_RATE = 0.0002
BATCH_SIZE = 16
NUM_WORKERS = 2
NUM_EPOCHS = 150
LOAD_MODEL = False
SAVE_MODEL = False
CHECKPOINT_DISC = "disc.pth.tar"
CHECKPOINT_GEN = "gen.pth.tar"

transform = transforms.Compose([
    transforms.RandomVerticalFlip(),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

# Dataset

In [None]:
class MapDataset(Dataset):
    def __init__(self, input_dir, target_dir):
        self.input_dir = input_dir
        self.target_dir = target_dir
        self.input_files = os.listdir(self.input_dir)
        self.target_files = os.listdir(self.target_dir)

    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, index=0):
        input_img_file = self.input_files[index]
        input_img_path = os.path.join(self.input_dir, input_img_file)
        input_image = Image.open(input_img_path)
        input_image = input_image.resize((64 , 64))
        input_image = transform(input_image)

        target_img_file = self.target_files[index]
        target_img_path = os.path.join(self.target_dir, target_img_file)
        target_image = Image.open(target_img_path)
        target_image = target_image.resize((256 , 256))
        target_image = transform(target_image)

        return input_image, target_image


# Training

In [None]:
def train_fn(disc, gen, loader, generator_criterion, opt_disc, opt_gen, epoch):
    train_bar = tqdm(loader)
    running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}        
    
    gen.train()
    disc.train()

    for data, target in train_bar:
        g_update_first = True
        batch_size = data.size(0)
        running_results['batch_sizes'] += batch_size

        real_img = Variable(target)
        if torch.cuda.is_available():
            real_img = real_img.cuda()
        z = Variable(data)
        if torch.cuda.is_available():
            z = z.cuda()

        # Update Discriminator Network
        fake_img = gen(z)

        disc.zero_grad()
        real_out = disc(real_img).mean()
        fake_out = disc(fake_img).mean()
        d_loss = 1 - real_out + fake_out
        d_loss.backward(retain_graph=True)
        opt_disc.step()

         # Update Generator network
        fake_img = gen(z)
        fake_out = disc(fake_img).mean()

        gen.zero_grad()
        g_loss = generator_criterion(fake_out, fake_img, real_img)
        g_loss.backward()

        fake_img = gen(z)
        fake_out = disc(fake_img).mean()

        opt_gen.step()

        running_results['g_loss'] += g_loss.item() * batch_size
        running_results['d_loss'] += d_loss.item() * batch_size
        running_results['d_score'] += real_out.item() * batch_size
        running_results['g_score'] += fake_out.item() * batch_size

        train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
            epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
            running_results['g_loss'] / running_results['batch_sizes'],
            running_results['d_score'] / running_results['batch_sizes'],
            running_results['g_score'] / running_results['batch_sizes']))

    gen.eval()
    out_path = 'training_results/SRF_' + str(UPSCALE_FACTOR) + '/'
    if not os.path.exists(out_path):
        os.makedirs(out_path)
    torch.save(gen.state_dict(), "super_res_gen.pth")



def main():
    disc = Discriminator().to(DEVICE)
    gen = Generator(UPSCALE_FACTOR).to(DEVICE)
    generator_criterion = GeneratorLoss().to(DEVICE)

    opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE)
    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE)

    results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}

    train_dataset = MapDataset(INPUT_DIR, TARGET_DIR)
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
    )
    

    for epoch in range(NUM_EPOCHS):
        train_fn(disc, gen, train_loader, generator_criterion, opt_disc, opt_gen, epoch)

if __name__ == "__main__":
    main()

[0/150] Loss_D: 0.9101 Loss_G: 0.0813 D(x): 0.5637 D(G(z)): 0.4423: 100%|██████████| 7/7 [00:17<00:00,  2.48s/it]
[1/150] Loss_D: 0.6657 Loss_G: 0.0814 D(x): 0.6413 D(G(z)): 0.2791: 100%|██████████| 7/7 [00:17<00:00,  2.50s/it]
[2/150] Loss_D: 0.4163 Loss_G: 0.0810 D(x): 0.7866 D(G(z)): 0.1856: 100%|██████████| 7/7 [00:17<00:00,  2.52s/it]
[3/150] Loss_D: 0.2203 Loss_G: 0.0790 D(x): 0.8864 D(G(z)): 0.0933: 100%|██████████| 7/7 [00:17<00:00,  2.55s/it]
[4/150] Loss_D: 0.1073 Loss_G: 0.0792 D(x): 0.9414 D(G(z)): 0.0432: 100%|██████████| 7/7 [00:18<00:00,  2.58s/it]
[5/150] Loss_D: 0.0531 Loss_G: 0.0785 D(x): 0.9717 D(G(z)): 0.0224: 100%|██████████| 7/7 [00:18<00:00,  2.59s/it]
[6/150] Loss_D: 0.0246 Loss_G: 0.0779 D(x): 0.9881 D(G(z)): 0.0115: 100%|██████████| 7/7 [00:18<00:00,  2.62s/it]
[7/150] Loss_D: 0.0129 Loss_G: 0.0774 D(x): 0.9940 D(G(z)): 0.0063: 100%|██████████| 7/7 [00:18<00:00,  2.63s/it]
[8/150] Loss_D: 0.0080 Loss_G: 0.0778 D(x): 0.9962 D(G(z)): 0.0038: 100%|██████████| 7/7