# Mnist CGAN

In [None]:
import os
import sys

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST

import torchvision 
import torchvision.transforms as tv_transforms
import torchvision.datasets as tv_datasets
import torchvision.utils as tv_utils

from torch.utils.tensorboard import SummaryWriter

from fid import calculate_activation_statistics
from models import MnistCNN, Discriminator, Generator
from inception import InceptionV3
from datasets import ColorMNIST
from plot_tools import plot_im
from utils import makedirs_exists_ok, seed_rng, set_cuda_visible_devices, load_weights_from_file

In [None]:
model_name = 'cgan_jpt'
data_root = './data'
model_root = './models/mnist_cgan'
figure_root = './figures/mnist_cgan'
log_root = './logs/mnist_cgan'


data_root = os.path.join(data_root, model_name)
model_root = os.path.join(model_root, model_name)
figure_root = os.path.join(figure_root, model_name)
log_root = os.path.join(log_root, model_name)


image_size = 32
batch_size = 32
n_workers = 1
seed = 9
gpu_id = '0'
n_workers = 4
load_weights = ''
lr = 0.0002
beta1 = 0
beta2 = 0.9
n_epochs = 20
log_interval = 100
target_type = 'color'


n_features = 32
G_dim_z = 32
G_bottom_width = 4
n_classes = 10
im_channels = 3
model_activation = nn.LeakyReLU(0.2)

In [None]:
makedirs_exists_ok(data_root)
makedirs_exists_ok(model_root)
makedirs_exists_ok(figure_root)
makedirs_exists_ok(log_root)

writer = SummaryWriter(log_root)
writer.flush()

seed_rng(seed)
device = set_cuda_visible_devices(gpu_id)

In [None]:
transforms = tv_transforms.Compose([
    tv_transforms.Resize(image_size),
    tv_transforms.ToTensor(),
    tv_transforms.Normalize((0.5,), (0.5,)),
])

train_loader = torch.utils.data.DataLoader(
    ColorMNIST(
        root=data_root, download=True, train=True, transform=transforms),
    batch_size=batch_size, shuffle=True, num_workers=n_workers, pin_memory=True)
test_loader = torch.utils.data.DataLoader(
    ColorMNIST(
        root=data_root, download=True, train=False, transform=transforms),
    batch_size=batch_size, shuffle=True, num_workers=n_workers, pin_memory=True)

In [None]:
G = Generator(n_features, G_dim_z, G_bottom_width, num_classes=n_classes, im_channels=im_channels, activation=model_activation).to(device)
D = Discriminator(n_features, num_classes=n_classes, im_channels=im_channels, activation=model_activation).to(device)
print(G,D)

In [None]:

optimizer_G = torch.optim.Adam(G.parameters(), lr, (beta1, beta2))
optimizer_D = torch.optim.Adam(D.parameters(), lr, (beta1, beta2))

# hinge loss
criterion_G = lambda D_xf, D_xr: -torch.mean(D_xf)
criterion_D = lambda D_xf, D_xr: \
    torch.mean(torch.relu(1. - D_xf)) + \
    torch.mean(torch.relu(1. + D_xr))

# criterion_G = lambda D_xf, D_xr: torch.mean(F.softplus(-D_xf))
# criterion_D = lambda D_xf, D_xr: torch.mean(F.softplus(-D_xr)) + torch.mean(F.softplus(D_xf))

loss_bce = nn.BCELoss()

def sample_from_G(G, batch_size, dim_z, device, n_classes, z_distribution='normal'):
    # noise
    z = torch.empty(batch_size, dim_z, dtype=torch.float32, device=device).normal_()
    # conditioned variable
    c = torch.from_numpy(np.random.randint(low=0, high=n_classes, size=(batch_size,)))
    c = y.type(torch.long).to(device)

    x_fake = G(z, c)
    return x_fake, c
    

In [None]:
fixed_z = torch.empty(100, G_dim_z, dtype=torch.float32, device=device).normal_()
fixed_y = torch.arange(10).repeat(10).type(torch.long).to(device)

real_label, fake_label = 0, 1

for epoch in range(n_epochs):
    for it, (x, y, _) in enumerate(train_loader):
        

        # batch_size for last batch might be different ...
        batch_size = x.size(0)
        real_labels = torch.full((batch_size,), real_label, device=device)
        fake_labels = torch.full((batch_size,), fake_label, device=device)
        
        
#         criterion_G = lambda D_xf, D_xr: loss_bce(F.sigmoid(D_xf), real_labels)
#         criterion_D = lambda D_xf, D_xr: loss_bce(F.sigmoid(D_xf), fake_labels) + loss_bce(F.sigmoid(D_xr), real_labels)

#         # color
#         x, y = x.to(device), (c < 0.5).long().to(device)
        # style 
        x, y = x.to(device), y.long().to(device)
        
        # Generator
        
        x_fake, c = sample_from_G(G, batch_size, G_dim_z, device, n_classes)
        D_xf = D(x_fake, c)
        loss_G = criterion_G(D_xf, None)
        
        G.zero_grad()
        loss_G.backward()
        optimizer_G.step()
        
        # Discriminator
        
        x_fake, c = sample_from_G(G, batch_size, G_dim_z, device, n_classes)
        D_xf = D(x_fake, c)
        D_xr = D(x, y)
        loss_D = criterion_D(D_xf, D_xr)
        
        D.zero_grad()
        loss_D.backward()
        optimizer_D.step()
        

        ##############################################################
        # print
        ##############################################################


        loss_D = loss_D.item()
        loss_G = loss_G.item()
        loss_total = loss_D + loss_G

        global_step = epoch*len(train_loader)+it
        writer.add_scalar('loss/total', loss_total, global_step)
        writer.add_scalar('loss/D', loss_D, global_step)
        writer.add_scalar('loss/G', loss_G, global_step)

        if it % log_interval == log_interval-1:
            print(f'[{epoch+1}/{n_epochs}]\t'
                  f'[{(it+1)*batch_size}/{len(train_loader.dataset)} ({100.*(it+1)/len(train_loader):.0f}%)]\t'
                  f'loss: {loss_total:.4}\t'
                  f'loss_D: {loss_D:.4}\t'
                  f'loss_G: {loss_G:.4}\t')
            
            x_fake = G(fixed_z, fixed_y).detach()
            tv_utils.save_image(x_fake,
                os.path.join(figure_root,
                    f'{model_name}_fake_samples_epoch={epoch}_it={it}.png'), nrow=10)

            writer.add_image('mnist', tv_utils.make_grid(x_fake), global_step)
        

#     torch.save(G.state_dict(), os.path.join(model_root, f'G_epoch_{epoch}.pt'))
#     torch.save(D.state_dict(), os.path.join(model_root, f'D_epoch_{epoch}.pt'))


In [None]:
torch.min(x_fake), torch.max(x_fake)
plot_im(x_fake)