In [None]:
import sys; sys.path.append('../')

import os
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms as tf
from torchvision.datasets import CIFAR10
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
from python.deepgen.models import dcgan
from python.deepgen.training import adversarial, loss
from python.deepgen.utils.image import Unnormalize
from python.deepgen.utils.init import Initializer
from python.deepgen.datasets import  common

In [None]:
latent_channels = 128
image_channels = 3
hidden_layers = 3
base_channels = 64
lr = 2e-4

init_map = {
    nn.Conv2d: {'weight': (nn.init.normal_, {'mean':0.0, 'std':0.02}),
                'bias': (nn.init.constant_, {'val': 0.0})},
    nn.ConvTranspose2d: {'weight': (nn.init.normal_, {'mean':0.0, 'std':0.02}),
                         'bias': (nn.init.constant_, {'val': 0.0})},
}

init_fn = Initializer(init_map)

def build_model():
    model = dcgan.vanilla_dcgan(latent_channels, image_channels,
                                hidden_layers, base_channels=base_channels,
                                init_fn=init_fn)
    model['gen'].to(device)
    model['disc'].to(device)
    gen_opt = optim.Adam(model['gen'].parameters(), lr, betas=(0.5, 0.999))
    disc_opt = optim.Adam(model['disc'].parameters(), lr, betas=(0.5, 0.999))
    optimizers = {'gen': gen_opt, 'disc': disc_opt}
    return model, optimizers


# CIFAR 10

In [None]:
cifar10n, cifar10_mean, cifar10_std  = common.cifar10n()

batch_size = 32
shuffle = True
num_workers = 0
dataloader = torch.utils.data.DataLoader(
    cifar10n, batch_size, shuffle, num_workers=num_workers)
samples = torch.split(next(iter(dataloader)), 1)

num_rows =int(np.ceil(batch_size/4))
num_cols = 4
fig, axes = plt.subplots(num_rows, num_cols, figsize=(2*num_cols, 2*num_rows))

unnorm_cifar10 = Unnormalize(cifar10_mean, cifar10_std)

output_transform = tf.Compose([
    Unnormalize(cifar10_mean, cifar10_std),
    tf.Lambda(lambda x: x.squeeze(0)),
    tf.ToPILImage('RGB'),
])

for sample, ax in zip(samples, axes.ravel()):
    ax.imshow(output_transform(sample))
for ax in axes.ravel():
    ax.axis('off')

fig.tight_layout()


## BCELoss

In [None]:
model, optims = build_model()
criterions = [nn.BCEWithLogitsLoss(reduction='mean')]
samplers = {
    'latent': lambda bs: torch.rand(bs, latent_channels, 1, 1)*2-1,
    'labels': lambda bs: (torch.ones(bs, 1, 1, 1), torch.zeros(bs, 1, 1, 1))   
}
trainer = adversarial.Trainer(model, optims, samplers, criterions, device)

num_epochs = 32
output_dir = '../models/dcgan_cifar10_bce'

In [None]:
history = trainer.train(dataloader, num_epochs, output_dir, output_transform, True)

## WGAN-GP

In [None]:
model, optims = build_model()
criterions = [loss.Wasserstein1()]
reg = [loss.GradPenalty(10)]
samplers = {
    'latent': lambda bs: torch.rand(bs, latent_channels, 1, 1)*2-1,
    'labels': lambda bs: (torch.ones(bs, 1, 1, 1), (-1)*torch.ones(bs, 1, 1, 1)),
}
trainer = adversarial.Trainer(model, optims, samplers, criterions, device, reg=reg)

num_epochs = 16
output_dir = '../models/dcgan_cifar10_wgp'

In [None]:
history = trainer.train(dataloader, num_epochs, output_dir, output_transform, True)