# Progressive GAN

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Generator
* TODO: Change UpsamplingBilinear to interpolate()
* TODO: Move to inner class

In [127]:
class GeneratorBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        
    def forward(self, x):
        _, x = self.forward_with_img(x)
        
        return x
    
    def forward_with_img(self, x):
        img = self.upsample(x)
        x = F.leaky_relu(self.conv1(img))
        x = F.leaky_relu(self.conv2(x))
        
        return img, x

In [171]:
class Generator(nn.Module):
    def __init__(self, channels=512):
        super().__init__()
        self.channels = channels
        self.trained_blocks = nn.ModuleList()
        self.new_block = self.create_initial_layer(channels)
        self.rgb = self.to_rgb(channels)
        
    def to_rgb(self, in_channels):
        return nn.Conv2d(in_channels, 3, 1, padding=0)
    
    def create_initial_layer(self, channels):
        module = nn.Sequential(
            nn.Conv2d(channels, channels, 4, padding=3),
            nn.LeakyReLU(),
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.LeakyReLU())
        
        return module
    
    def append_layer(self, channels):
        self.trained_blocks.append(self.new_block)
        self.new_block = GeneratorBlock(self.channels, channels)
        self.rgb = self.to_rgb(channels)
        self.channels = channels
        
    def forward(self, x, alpha=1.0):
        for block in self.trained_blocks:
            x = block(x)
            
        if alpha < 1.0:
            img , x = self.new_block.forward_with_img(x)
            img = self.rgb(img)
            x = self.rgb(x)
            
            x = img * (1 - alpha) + img * alpha
        else:
            x = self.new_block(x)
            x = self.rgb(x)
        
        return x

# Disciminator

In [219]:
class DiscriminatorBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.downsample = nn.UpsamplingBilinear2d(scale_factor=0.5)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.downsample(x)
        
        return x

In [340]:
class Discriminator(nn.Module):
    def __init__(self, channels=512):
        super().__init__()
        self.channels = channels
        self.trained_layers = nn.ModuleList()
        self.new_layer = self.create_initial_layer(channels)
        self.downsample = nn.UpsamplingBilinear2d(scale_factor=0.5)
        self.rgb = self.from_rgb(channels)
        
    def create_initial_layer(self, channels):
        module = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.Conv2d(channels, channels, 4, padding=0),
            nn.Flatten(),
            nn.Linear(channels, 1))
        
        return module
    
    def prepend_layer(self, channels):
        self.trained_layers.insert(0, self.new_layer)
        self.new_layer = DiscriminatorBlock(channels, self.channels)
        self.rgb = self.from_rgb(channels)
        
    def from_rgb(self, channels):
        return nn.Conv2d(3, channels, 1, padding=0)
        
    def forward(self, x, alpha=1.0):
        if alpha < 1.0:
            skip = self.downsample(x)
            skip = self.rgb(skip)

            x = self.rgb(x)
            x = self.new_layer(x)
            x = skip * (1-alpha) + x * alpha
        else:
            x = self.rgb(x)
            x = self.new_layer(x)
        
        for layer in self.trained_layers:
            x = layer(x)
        
        return x

In [341]:
g = Generator()
d = Discriminator()

In [342]:
z = torch.randn(2, 512, 1, 1)
gen_img = g(z)

In [343]:
print(gen_img.shape)

torch.Size([2, 3, 4, 4])


In [344]:
g.append_layer(128)

In [345]:
out = d(gen_img)

torch.Size([2, 1])


In [346]:
print(out.shape)

torch.Size([2, 1])


In [347]:
d.prepend_layer(512)

In [350]:
out = d(torch.randn(2, 3, 8, 8), 0.5)

torch.Size([2, 512, 4, 4])
Loop: torch.Size([2, 1])


In [351]:
print(out.shape)

torch.Size([2, 1])
