# Progressive GAN

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Generator
* TODO: Change UpsamplingBilinear to interpolate()
* TODO: Move to inner class

In [11]:
class GeneratorBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        
    def forward(self, x, alpha=0):
        x = self.upsample(x)
        skip = x * (1 - alpha)
        x = F.leaky_relu(self.conv1(x))
        x = F.leaky_relu(self.conv2(x))
        x = skip + (x * alpha)
        
        return x

In [19]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(512, 512, 4, padding=3)
        self.conv2 = nn.Conv2d(512, 512, 3, padding=1)
        
        self.block1 = GeneratorBlock(512, 512)
        self.block2 = GeneratorBlock(512, 512)
        self.block3 = GeneratorBlock(512, 512)
        
        self.image = nn.Conv2d(512, 3, 3, padding=1)
        
    def to_rgb(self, in_channels):
        return nn.Conv2d(in_channels, 3, 1, padding=0)
        
    def forward(self, x):
        x = F.leaky_relu(self.conv1(x))
        x = F.leaky_relu(self.conv2(x))
        
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        
        x = self.image(x)
        
        return x

# Disciminator

In [32]:
class DiscriminatorBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.downsample = nn.UpsamplingBilinear2d(scale_factor=0.5)
        
    def forward(self, x, alpha=0):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.downsample(x)
        
        return x

In [33]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 1, padding=0)
        
        self.block1 = DiscriminatorBlock(16, 32)
        self.block2 = DiscriminatorBlock(32, 64)
        self.block3 = DiscriminatorBlock(64, 128)
        
        self.out = nn.Conv2d(128, 128, 4, padding=0)
        self.linear = nn.Linear(128, 1)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.out(x)
        x = self.linear(x.view(-1, 128))
        
        return x

In [34]:
g = Generator()
d = Discriminator()

In [35]:
z = torch.randn(2, 512, 1, 1)
gen_img = g(z)

In [36]:
print(gen_img.shape)

torch.Size([2, 3, 32, 32])


In [37]:
out = d(gen_img)

In [38]:
print(out.shape)

torch.Size([2, 1])
