<a href="https://colab.research.google.com/github/quickgrid/AI-Resources/blob/master/paper-implementations/pytorch/srgan/Pytorch_SRGAN_Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### References
- https://www.youtube.com/watch?v=7FO9qDOhRCc

In [None]:
import torch
from torch import nn


class ConvBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            discriminator=False,
            use_activation=True,
            use_bn=True,
            **kwargs
    ):
        super(ConvBlock, self).__init__()

        self.use_activation = use_activation
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, **kwargs, bias=not use_bn)
        self.bn = nn.BatchNorm2d(num_features=out_channels) if use_bn else nn.Identity()

        # Section 2.1 explains leaky relu slope amount.
        # Based on Figure 4, leaky relu is applied only is discriminator and prelu in generator.
        self.act = (
            nn.LeakyReLU(negative_slope=0.2, inplace=True) if discriminator
            else nn.PReLU(num_parameters=out_channels)
        )

    def forward(self, X):
        return self.act(self.bn(self.conv(X))) if self.use_activation else self.bn(self.act(X))


class UpSampleBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        scale_factor,
    ):
        super(UpSampleBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(scale_factor)
        self.activation = nn.PReLU(num_parameters=in_channels)
    
    def forward(self, X):
        return self.activation(self.pixel_shuffle(self.conv(X)))


class ResidualBlock(nn.Module):
    def __init__(
        self,
        in_channels
    ):
        super(ResidualBlock, self).__init__()

        self.block1 = ConvBlock(
            in_channels=in_channels,
            out_channels=in_channels,
            discriminator=False,
            use_activation=True,
            use_bn=True,
            kernel_size=3,
            stride=1,
            padding=1
        )
        self.block2 = ConvBlock(
            in_channels=in_channels,
            out_channels=in_channels,
            discriminator=False,
            use_activation=False,
            use_bn=True,
            kernel_size=3,
            stride=1,
            padding=1
        )
    
    def forward(self, X):
        out = self.block1(X)
        out = self.block2(out)
        return X + out



class Generator(nn.Module):
    def __init__(
        self,
        in_channels=3,
        num_channels=64,
        num_blocks=16,
        
    ):
        super(Generator, self).__init__()

        self.initial = ConvBlock(in_channels=in_channels, out_channels=num_channels, kernel_size=9, stride=1, padding=4, use_bn=False)
        self.residuals = nn.Sequential(*[ResidualBlock(in_channels=num_channels) for _ in range(num_blocks)])
        self.conv_block = ConvBlock(
            in_channels=num_channels,
            out_channels=num_channels,
            discriminator=False,
            use_activation=False,
            use_bn=True,
            kernel_size=3,
            stride=1,
            padding=1
        )
        self.upsamples = nn.Sequential(
            UpSampleBlock(in_channels=num_channels, scale_factor=2),
            UpSampleBlock(in_channels=num_channels, scale_factor=2),
        )
        self.final = nn.Conv2d(in_channels=num_channels, out_channels=in_channels, kernel_size=9, stride=1, padding=4)
    
    def forward(self, X):
        initial = self.initial(X)
        X = self.residuals(initial)
        X = self.conv_block(X)
        X = initial + X
        X = self.upsamples(X)
        X = self.final(X)
        return torch.tanh(X)


class Discriminator(nn.Module):
    def __init__(
        self,
        in_channels=3,
        features=[64, 64, 128, 128, 256, 256, 512, 512]
    ):
        super(Discriminator, self).__init__()

        blocks = []
        for idx, feature in enumerate(features):
            blocks.append(
                ConvBlock(
                    in_channels=in_channels,
                    out_channels=feature,
                    discriminator=True,
                    use_activation=True,
                    use_bn=False if idx == 0 else True,
                    kernel_size=3,
                    padding=1,
                    stride=1 + idx % 2
                )
            )
            in_channels = feature

        self.blocks = nn.Sequential(*blocks)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=(6, 6)),
            nn.Flatten(),
            nn.Linear(in_features=512 * 6 * 6, out_features=1024),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Linear(in_features=1024, out_features=1),
        )

    def forward(self, X):
        X = self.blocks(X)
        X = self.classifier(X)
        return X

In [None]:
def test():
    low_resolution = 30

    with torch.cuda.amp.autocast():
        X = torch.randn((5, 3, low_resolution, low_resolution))
        gen = Generator()
        gen_out = gen(X)
        disc = Discriminator()
        disc_out = disc(gen_out)

        print(gen_out.shape)
        print(disc_out.shape)

In [None]:
if __name__ == "__main__":
    test()



torch.Size([5, 3, 120, 120])
torch.Size([5, 1])


In [1]:
import torch.nn as nn
from torchvision.models import vgg19


class VGGLoss(nn.Module):
    def __init__(
        self,
        device
    ):
        super(VGGLoss, self).__init__()

        # Before the 4th convolution and 5th maxpooling in the 36th layer.
        # The 37th layer is another maxpooling but it is not in phi_i_j = phi_4_5.
        self.vgg = vgg19(pretrained=True).features[:36].eval(device)

        # Pixelwise MSE Loss
        self.loss = nn.MSELoss()

        # Vgg should not be trained, so no need to update weights.
        for param in self.vgg.parameters():
            param.requires_grad = False

    def forward(self, input, target):
        vgg_input_features = self.vgg(input)
        vgg_target_features = self.vgg(target)
        
        # Equation 5, VGG Loss implementation.
        return self.loss(vgg_input_features, vgg_target_features)


In [None]:
import os
import numpy
from torch.utils.data import DataLoader, Dataset
from PIL import Image


class MyImageDataset(Dataset):
    def __init__(
        self,
        root_dir
    ):
        super(MyImageDataset, self).__init__()
        self.data = []
        self.root_dir = root_dir
        self.class_names = os.listdir(rootdir)

        for idx, names in enumerate(self.class_names):
            file_names = os.listdir(os.path.join(root_dir, names))
            self.data.append(list(zip(file_names, [idx] * len(file_names))))

    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img_file, label = self.data[index]
        image_class_dir_path = os.path.join(self.root_dir, self.class_names[label])

        image = np.array(Image.open(os.path.join(image_class_dir_path, img_file)))
        image = 