In [258]:
import torch
import torchsummary
from torch import nn
from torchvision.models import vgg19
import numpy as np

In [259]:
class ResBlock(nn.Module):
    def __init__(self,kernel_size,channels,padding,stride):
        super(ResBlock,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=channels,out_channels=channels,kernel_size=kernel_size,padding=padding,stride=stride)
        self.batch_norm1 = nn.BatchNorm2d(num_features=channels)
        self.prelu1 = nn.PReLU()
        self.conv2 = nn.Conv2d(in_channels=channels,out_channels=channels,kernel_size=kernel_size,padding=padding,stride=stride)
        self.batch_norm2 = nn.BatchNorm2d(num_features=channels)
    def forward(self,inp):
        x = self.conv1(inp)
        x = self.batch_norm1(x)
        x = self.prelu1(x)
        x = self.conv2(x)
        x = self.batch_norm2(x)
        output = x + inp
        return output

In [260]:
class SubPixelConv(nn.Module):
    def __init__(self,kernel_size,in_channels,out_channels,padding,stride):
        super(SubPixelConv,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,padding=padding,stride=stride)
        self.pixel_shuffler1 = nn.PixelShuffle(2)
        self.prelu1 = nn.PReLU()
    def forward(self,inp):
        x = self.conv1(inp)
        x = self.pixel_shuffler1(x)
        output = self.prelu1(x)
        return output

In [261]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3,kernel_size=9,out_channels=64,padding=4,stride=1)
        self.prelu1 = nn.PReLU()
        self.resblock1 = ResBlock(channels=64,kernel_size=3,padding=1,stride=1)
        self.resblock2 = ResBlock(channels=64,kernel_size=3,padding=1,stride=1)
        self.resblock3 = ResBlock(channels=64,kernel_size=3,padding=1,stride=1)
        self.resblock4 = ResBlock(channels=64,kernel_size=3,padding=1,stride=1)
        self.resblock5 = ResBlock(channels=64,kernel_size=3,padding=1,stride=1)
        self.resblock6 = ResBlock(channels=64,kernel_size=3,padding=1,stride=1)
        self.resblock7 = ResBlock(channels=64,kernel_size=3,padding=1,stride=1)
        self.resblock8 = ResBlock(channels=64,kernel_size=3,padding=1,stride=1)
        self.resblock9 = ResBlock(channels=64,kernel_size=3,padding=1,stride=1)
        self.resblock10 = ResBlock(channels=64,kernel_size=3,padding=1,stride=1)
        self.resblock11 = ResBlock(channels=64,kernel_size=3,padding=1,stride=1)
        self.resblock12 = ResBlock(channels=64,kernel_size=3,padding=1,stride=1)
        self.resblock13 = ResBlock(channels=64,kernel_size=3,padding=1,stride=1)
        self.resblock14 = ResBlock(channels=64,kernel_size=3,padding=1,stride=1)
        self.resblock15 = ResBlock(channels=64,kernel_size=3,padding=1,stride=1)
        self.resblock16 = ResBlock(channels=64,kernel_size=3,padding=1,stride=1)
        self.conv2 = nn.Conv2d(in_channels=64,kernel_size=3,out_channels=64,padding=1,stride=1)
        self.batch_norm1 = nn.BatchNorm2d(num_features=64)
        self.subpixelconv1 = SubPixelConv(in_channels=64,kernel_size=3,out_channels=256,padding=1,stride=1)
        self.subpixelconv2 = SubPixelConv(in_channels=64,kernel_size=3,out_channels=256,padding=1,stride=1)
        self.conv3 = nn.Conv2d(in_channels=64,kernel_size=9,out_channels=3,padding=4,stride=1)
    
    def forward(self,inp):
        x = self.conv1(inp)
        x = self.prelu1(x)
        y = self.resblock1(x)
        y = self.resblock2(y)
        y = self.resblock3(y)
        y = self.resblock4(y)
        y = self.resblock5(y)
        y = self.resblock6(y)
        y = self.resblock7(y)
        y = self.resblock8(y)
        y = self.resblock9(y)
        y = self.resblock10(y)
        y = self.resblock11(y)
        y = self.resblock12(y)
        y = self.resblock13(y)
        y = self.resblock14(y)
        y = self.resblock15(y)
        y = self.resblock16(y)
        y = self.conv2(y)
        y = self.batch_norm1(y)
        y = self.subpixelconv1(x+y)
        y = self.subpixelconv2(y)
        output = self.conv3(y)
        return output

In [262]:
class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss,self).__init__()
        self.mse_loss = nn.MSELoss()
        self.vgg = list(vgg19(pretrained=True).children())[0]
        for param in self.vgg.parameters():
            param.requires_grad = False
    def forward(self,generated_image,ground_truth):
        y = ground_truth
        y_hat = generated_image
        loss = 0
        for name,module in self.vgg._modules.items():
            if isinstance(module,nn.MaxPool2d):
                loss = loss + self.mse_loss(y_hat/12.75,y.detach()/12.75)
            y = module(y)
            y_hat = module(y_hat)
        return loss

In [263]:
class ConvBlock(nn.Module):
    def __init__(self,kernel_size,in_channels,out_channels,padding,stride):
        super(ConvBlock,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,padding=padding,stride=stride)
        self.batch_norm1 = nn.BatchNorm2d(num_features=out_channels)
        self.leaky_relu1 = nn.LeakyReLU(negative_slope=0.2)
    def forward(self,inp):
        x = self.conv1(inp)
        x = self.batch_norm1(x)
        output = self.leaky_relu1(x)
        return output

In [284]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3,kernel_size=3,out_channels=64,padding=1,stride=1)
        self.leaky_relu1 = nn.LeakyReLU(negative_slope=0.2)
        self.convblock1 = ConvBlock(kernel_size=3,in_channels=64,out_channels=64,padding=1,stride=2)
        self.convblock2 = ConvBlock(kernel_size=3,in_channels=64,out_channels=128,padding=1,stride=1)
        self.convblock3 = ConvBlock(kernel_size=3,in_channels=128,out_channels=128,padding=1,stride=2)
        self.convblock4 = ConvBlock(kernel_size=3,in_channels=128,out_channels=256,padding=1,stride=1)
        self.convblock5 = ConvBlock(kernel_size=3,in_channels=256,out_channels=256,padding=1,stride=2)
        self.convblock6 = ConvBlock(kernel_size=3,in_channels=256,out_channels=512,padding=1,stride=1)
        self.convblock7 = ConvBlock(kernel_size=3,in_channels=512,out_channels=512,padding=1,stride=2)
        self.dense1 = nn.Linear(in_features=18432,out_features=1024)
        self.leaky_relu2 = nn.LeakyReLU(negative_slope=0.2)
        self.dense2 = nn.Linear(in_features=1024,out_features=1)
        self.sigmoid1 = nn.Sigmoid()
    def forward(self,inp):
        x = self.conv1(inp)
        x = self.leaky_relu1(x)
        x = self.convblock1(x)
        x = self.convblock2(x)
        x = self.convblock3(x)
        x = self.convblock4(x)
        x = self.convblock5(x)
        x = self.convblock6(x)
        x = self.convblock7(x)
        x = x.view(x.size(0),-1)
        x = self.dense1(x)
        x = self.leaky_relu2(x)
        x = self.dense2(x)
        output = self.sigmoid1(x)
        return output

In [None]:
class AdversarialLoss(nn.Module):
    def __init__(self):
        super(AdversarialLoss,self).__init__()
    def forward(self,generated_image):
        y = ground_truth
        y_hat = generated_image
        loss = 0
        for name,module in self.vgg._modules.items():
            if isinstance(module,nn.MaxPool2d):
                loss = loss + self.mse_loss(y_hat/12.75,y.detach()/12.75)
            y = module(y)
            y_hat = module(y_hat)
        return loss