# Chapter 9: CycleGAN

# Imports

In [171]:
import torch
import torch.nn as nn

# Generator

In [172]:
class Generator(nn.Module):
    class TransSkip(nn.Module):
        def __init__(self, in_channel, out_channel, kernel_size):
            super().__init__()
            self.trans1 = nn.Sequential(
                nn.UpsamplingBilinear2d(scale_factor=2),
                nn.Conv2d(in_channel, out_channel, kernel_size, stride=1, padding=int(kernel_size/2)),
                nn.LeakyReLU(negative_slope=0.2),
                nn.InstanceNorm2d(out_channel/2))
            
        def forward(self, x, skip):
            trans = self.trans1(x)
            x = torch.cat([trans,  skip], dim=1)
            
            return x

    def __init__(self):
        super().__init__()
        self.conv1 = self.conv_layer(3, 32, 3)
        self.conv2 = self.conv_layer(32, 64, 3)
        self.conv3 = self.conv_layer(64, 128, 3)
        self.conv4 = self.conv_layer(128, 256, 3)
        
        self.trans1 = self.TransSkip(256, 128, 3)
        self.trans2 = self.TransSkip(256, 64, 3)
        self.trans3 = self.TransSkip(128, 32, 3)
        # Check if using skip from first image improves result
        # self.trans4 = self.TransSkip(64, 64, 3)
        # self.conv5 = nn.Sequential(
        #     nn.Conv2d(67, 3, 3, stride=1, padding=1),
        #     nn.Tanh())
        self.trans4 = nn.UpsamplingBilinear2d(scale_factor=2)
        self.conv5 = nn.Sequential(
            nn.Conv2d(64, 3, 3, stride=1, padding=1),
            nn.Tanh())
    
    def forward(self, x):
        skip_0 = x
        skip_1 = x = self.conv1(x)
        skip_2 = x = self.conv2(x)
        skip_3 = x = self.conv3(x)
        x = self.conv4(x)
        x = self.trans1(x, skip_3)
        x = self.trans2(x, skip_2)
        x = self.trans3(x, skip_1)
        # x = self.trans4(x, skip_0)
        x = self.trans4(x)
        x = self.conv5(x)
        
        return x
    
    def conv_layer(self, in_channels, out_channels, kernel_size):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=2, padding=1),
            nn.LeakyReLU(),
            nn.InstanceNorm2d(out_channels))

In [173]:
generator = Generator()

In [174]:
out = generator(torch.randn(8, 3, 128, 128))

In [175]:
out.shape

torch.Size([8, 3, 128, 128])

In [176]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = self.conv_layer(3, 64, 3)
        self.conv2 = self.conv_layer(64, 128, 3)
        self.conv3 = self.conv_layer(128, 256, 3)
        self.conv4 = self.conv_layer(256, 512, 3)
        self.conv5 = nn.Conv2d(512, 1, 4, stride=1, padding=2)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        
        return x
    
    def conv_layer(self, in_channels, out_channels, kernel_size):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=2, padding=int(kernel_size/2)),
            nn.LeakyReLU(negative_slope=0.2))

In [177]:
discriminator = Discriminator()

In [178]:
out = discriminator(torch.randn(8, 3, 128, 128))

In [179]:
out.shape

torch.Size([8, 1, 9, 9])