In [None]:
import sys; sys.path.append('../')
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
import matplotlib.pyplot as plt

import torchvision.transforms as tf
from torchvision.datasets import CIFAR10

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
from python.deepgen.models import progressive
from python.deepgen.datasets import common
from python.deepgen.utils import image, init
from python.deepgen.training import adversarial, loss

# CIFAR 10

In [None]:
batch_size = 32
shuffle = True
num_workers = 8
num_lods = 4
dataloaders = {}
output_transforms = {}
samples = []

for lod in range(num_lods):
    transform = tf.Compose([
        tf.Resize((4*2**lod, 4*2**lod)),
        tf.ToTensor(),
    ])
    cifar10n, cifar_mean, cifar_std = common.cifar10n(transform)
    dataloader = torch.utils.data.DataLoader(
        cifar10n, batch_size, shuffle, num_workers=num_workers)
    output_transform = tf.Compose([
        image.Unnormalize(cifar_mean, cifar_std),
        tf.Lambda(lambda x: x.squeeze(0)),
        tf.ToPILImage('RGB')
    ])
    lod_samples = torch.split(next(iter(dataloader)), 1)[:4]
    samples.extend([output_transform(s) for s in lod_samples])

    dataloaders[lod+1] = dataloader
    output_transforms[lod+1] = output_transform


num_cols = 4
num_rows = len(samples)//num_cols
fig, axes = plt.subplots(num_rows, num_cols, figsize=(2*num_cols, 2*num_rows))
for sample, ax in zip(samples, axes.ravel()):
    ax.imshow(sample)

for ax in axes.ravel():
    ax.axis('off')
fig.tight_layout()


# Progressive GAN 

In [None]:
epochs = 4
epoch_iters = len(dataloaders[list(dataloaders.keys())[-1]])*epochs
fadeout_iters = epoch_iters // 2
stable_iters = epoch_iters - fadeout_iters
transition_iters = (fadeout_iters, stable_iters)

init_map = {
    nn.Conv2d: {'weight': (nn.init.kaiming_normal_, {'nonlinearity': 'leaky_relu'}),
                'bias': (nn.init.constant_, {'val': 0.0})},
    nn.ConvTranspose2d: {'weight': (nn.init.kaiming_normal_, {'nonlinearity': 'leaky_relu'}),
                         'bias': (nn.init.constant_, {'val': 0.0})},
}

init_fn = init.Initializer(init_map)

latent_dims = 128

model, scheduler = progressive._progressive_gan_builder(
    transition_iters, latent_dims, 3, 32, 2, 4, init_fn)

model['gen'].to(device)
model['disc'].to(device)

gen_opt = optim.Adam(model['gen'].parameters(), 1e-3, betas=(0.0, 0.99), eps=1e-8)
disc_opt = optim.Adam(model['disc'].parameters(), 1e-3, betas=(0.0, 0.99), eps=1e-8)
optims = {'gen': gen_opt, 'disc': disc_opt}


In [None]:
criterions = [loss.Wasserstein1()]
reg = [loss.GradPenalty(10)]
callbacks = {'iteration': [scheduler]}
samplers = {
    'latent': lambda bs: torch.randn(bs, latent_dims, 1, 1),
    'labels': lambda bs: (torch.ones(bs, 1, 1, 1), (-1)*torch.ones(bs, 1, 1, 1))
}
trainer = adversarial.Trainer(model, optims, samplers, criterions, device, reg=reg, callbacks=callbacks)


In [None]:
lod_samples = {}
output_dir = '../models/progressive_gan_cifar10/'
for i, (dataloader, output_transform) in enumerate(
        zip(dataloaders.values(),output_transforms.values())):
    scale_output_dir = os.path.join(output_dir, f'{i+1}')
    trainer.train(dataloader, epochs, scale_output_dir, output_transform, True)
    sample_latent = samplers['latent'](16)
    sample = model['gen'](sample_latent.to(device)).cpu()
    lod_samples[i+1] = [output_transform(s)for s in torch.split(sample, 1)]
