# Progressive GAN

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Resize, Lambda
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Data

In [None]:
def transform_in_size(size):
    return Compose([Resize((size, size)), ToTensor(), Lambda(lambda x: x * 2 - 1)])

def get_dataloader(size):
    dataset = CIFAR10('~/pytorch', train=True, download=True, transform=transform_in_size(size))
    return DataLoader(dataset, batch_size=32, drop_last=True, shuffle=True)

# CUDA

In [None]:
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
    
print(f"Learning on: {device}")

# Generator
* TODO: Change UpsamplingBilinear to interpolate()
* TODO: Move to inner class

In [None]:
class GeneratorBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
    def forward(self, x):
        _, x = self.forward_with_img(x)
        
        return x
    
    def forward_with_img(self, x):
        img = self.upsample(x)
        x = F.leaky_relu(self.conv1(img))
        x = self.bn1(x)
        x = F.leaky_relu(self.conv2(x))
        x = self.bn2(x)
        
        return img, x

In [None]:
class Generator(nn.Module):
    def __init__(self, channels=512):
        super().__init__()
        self.channels = channels
        self.trained_blocks = nn.ModuleList()
        self.new_block = self.create_initial_layer(channels)
        self.rgb = self.to_rgb(channels)
        
    def to_rgb(self, in_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, 3, 1, padding=0),
            nn.Tanh())
    
    def create_initial_layer(self, channels):
        module = nn.Sequential(
            nn.Conv2d(channels, channels, 4, padding=3),
            nn.BatchNorm2d(channels),
            nn.LeakyReLU(),
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
            nn.LeakyReLU())
        
        return module
    
    def append_layer(self, channels):
        self.trained_blocks.append(self.new_block)
        self.new_block = GeneratorBlock(self.channels, channels)
        self.rgb = self.to_rgb(channels)
        self.channels = channels
        
    def forward(self, x, alpha=1.0):
        for block in self.trained_blocks:
            x = block(x)
            
        if alpha < 1.0:
            img , x = self.new_block.forward_with_img(x)
            img = self.rgb(img)
            x = self.rgb(x)
            
            x = img * (1 - alpha) + img * alpha
        else:
            x = self.new_block(x)
            x = self.rgb(x)
        
        return x

# Disciminator

In [None]:
class DiscriminatorBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = nn.AvgPool2d(2, 2)
        
    def forward(self, x):
        x = F.leaky_relu(self.conv1(x))
        x = self.bn1(x)
        x = F.leaky_relu(self.conv2(x))
        x = self.bn2(x)
        x = self.downsample(x)
        
        return x

In [None]:
class Discriminator(nn.Module):
    def __init__(self, channels=512):
        super().__init__()
        self.in_channels = channels
        self.out_channels = channels
        self.trained_layers = nn.ModuleList()
        self.new_layer = self.create_initial_layer(channels)
        self.downsample = nn.AvgPool2d(2, 2)
        self.rgb = self.from_rgb(channels)
        self.rgb_skip = self.from_rgb(channels)
        
    def create_initial_layer(self, channels):
        module = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
            nn.LeakyReLU(),
            nn.Conv2d(channels, channels, 4, padding=0),
            nn.LeakyReLU(),
            nn.Flatten(),
            nn.Linear(channels, 1))
        
        return module
    
    def prepend_layer(self, channels):
        self.out_channels = self.in_channels
        self.in_channels = channels
        
        self.trained_layers.insert(0, self.new_layer)
        self.new_layer = DiscriminatorBlock(self.in_channels, self.out_channels)
        self.rgb = self.from_rgb(self.in_channels)
        self.rgb_skip = self.from_rgb(self.out_channels)
        
    def from_rgb(self, channels):
        return nn.Conv2d(3, channels, 1, padding=0)
        
    def forward(self, x, alpha=1.0):
        if alpha < 1.0:
            skip = self.downsample(x)
            skip = self.rgb_skip(skip)

            x = self.rgb(x)
            x = self.new_layer(x)
            x = skip * (1-alpha) + x * alpha
        else:
            x = self.rgb(x)
            x = self.new_layer(x)
        
        for layer in self.trained_layers:
            x = layer(x)
        
        return x

In [None]:
def train_batch():
    z = torch.randn(32, 512, 1, 1)
    gen_img = g(z)
    print(f"Image size: {gen_img.shape}")
    out = d(gen_img)
    
    print(out.shape)

In [None]:
def train_epoch(g, d, g_optim, d_optim, size, alpha=1.0):
    d_criterion = nn.BCEWithLogitsLoss()

    
    dl = get_dataloader(size)
    real = torch.ones((32, 1)).to(device)
    fake = torch.zeros((32, 1)).to(device)
    g_loss_total = 0
    d_loss_total = 0
    it = 0
    for real_img, y in dl:
        # print(real_img.shape)
        real_img = real_img.to(device)
        out = d(real_img, alpha)
        real_loss = d_criterion(out, real)
        fake_img = g(torch.randn((32, 512, 1, 1), device=device))
        out = d(fake_img, alpha)
        fake_loss = d_criterion(out, fake)
        d_loss = 0.5 * (real_loss + fake_loss)
        
        d_optim.zero_grad()
        d_loss.backward()
        d_optim.step()
        
        fake_img = g(torch.randn((32, 512, 1, 1), device=device))
        out = d(fake_img, alpha)
        g_loss = d_criterion(out, real)
        
        g_optim.zero_grad()
        g_loss.backward()
        g_optim.step()
        
        it += 1
        g_loss_total += g_loss.item()
        d_loss_total += d_loss.item()
        if it % 200 == 0:
            print(f"G_LOSS: {g_loss_total / it:.4f} - D_LOSS: {d_loss_total / it:.4f}")
            g_loss_total = 0
            d_loss_total = 0
            it = 0
            
        

In [None]:
def train_new_layer(g, d, size):
    g.append_layer(512)
    d.prepend_layer(256)
    g = g.to(device)
    d = d.to(device)
    
    d_optim = Adam(d.parameters(), lr=0.001, betas=(0.0, 0.99))
    g_optim = Adam(g.parameters(), lr=0.001, betas=(0.0, 0.99))
    
    for alpha in np.linspace(0.0, 1.0, 11):
        for i in range(1):
            train_epoch(g, d, g_optim, d_optim, size, alpha)
            
def train_initial_layer(g, d):
    d_optim = Adam(d.parameters(), lr=0.001, betas=(0.0, 0.99))
    g_optim = Adam(g.parameters(), lr=0.001, betas=(0.0, 0.99))
    
    for i in range(1):
        train_epoch(g, d, g_optim, d_optim, 4)

In [None]:
g = Generator().to(device)
d = Discriminator().to(device)

In [None]:
train_initial_layer(g, d)

In [None]:
train_new_layer(g, d, 8)

In [None]:
train_new_layer(g, d, 16)

In [None]:
train_new_layer(g, d, 32)

In [None]:
image = g(torch.randn((1, 512, 1, 1), device=device))

In [None]:
def show_image(img):
    img = (img + 1) / 2
    img = img.squeeze()
    img = img.permute(1, 2, 0)
    image = img.cpu().detach().numpy()
    plt.imshow(image)
    

In [None]:
show_image(image)

In [None]:
image = g(torch.randn((100, 512, 1, 1), device=device))
out = d(image)

In [None]:
f = out.sigmoid() > 0.5

In [None]:
print(f)

In [None]:
dl = get_dataloader(32)

In [None]:
for i, _ in dl:
    image = i.to(device)
    show_image(i[1])
    break