# Mnist CGAN

In [1]:
import os
import sys

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST

import torchvision 
import torchvision.transforms as tv_transforms
import torchvision.datasets as tv_datasets
import torchvision.utils as tv_utils

from torch.utils.tensorboard import SummaryWriter

from fid import calculate_activation_statistics
from models_classifier import MnistCNN
from models_cgan import apply_sn, Discriminator, ConditionalDiscriminator, ConditionalGenerator, ConditionalResidualBlock, ConditionalBatchNorm2d, conv3x3
from inception import InceptionV3
from datasets import ColorMNIST
from plot_tools import plot_im
from utils import makedirs_exists_ok, seed_rng, set_cuda_visible_devices, load_weights_from_file, bin_index

In [20]:
model_name = 'cgan_jpt'
data_root = './data'
model_root = f'./models/{model_name}'
figure_root = f'./figures/{model_name}'
log_root = f'./logs/{model_name}'

image_size = 32
batch_size = 32
seed = 1
gpu_id = '0'
n_workers = 8
load_weights = ''
lr = 0.0002
beta1 = 0
beta2 = 0.9
n_epochs = 20
log_interval = 100

target_type = 'color'
dim_z = 50
num_classes = 10
im_channels = 3
conditional = True

In [7]:
class OrdinalConditionalDiscriminator(Discriminator):
    """ conditional discriminator where the conditioned variable is ordinal
    
        Taken from:
            https://github.com/batmanlab/Explanation_by_Progressive_Exaggeration/blob/master/src/explainer.py
    """
    
    def __init__(self, conv_channels, conv_dnsample, num_classes, use_sn=True):
        """
            Projection cGAN (ImageNet)
                conv_channels = [3, 64, 128, 256, 512, 1024, 1024]
                conv_dnsample = [True, True, True, True, True, False]
        """
        super(OrdinalConditionalDiscriminator, self).__init__(conv_channels, conv_dnsample, use_sn=use_sn)
        
        self.c_embed = apply_sn(nn.Embedding(num_classes, conv_channels[-1]), use_sn)
        
    def forward(self, x, c):
        """ x    batch_size x im_channels x h x w
            c    batch_size
        """
        c = c.view(-1)
        # conv_channels = [3, 64, 128, 256, 512, 1024, 1024]
        # conv_dnsample = [True, True, True, True, True]
        #
        # 3x128x128
        x = self.residual_blocks(x)
        # 1024x4x4
        x = self.nonlinearity(x)
        x = torch.sum(x, dim=(2,3))   # (global sum pooling)
        # 1024
        
        # sigmoid^-1(p(real/fake|x,c)) =
        #     log(p_data(x)/p_model(x)) + 
        #     log(p_data(c|x)/p_model(c|x))
        all_classes = torch.arange(0, num_classes, dtype=torch.long, device=x.device)
        W = x @ self.c_embed(all_classes).T
        W = torch.cumsum(W, dim=1)
        # 10
        f_1 = W.gather(dim=1, index=c.view(-1,1))
        f_2 = self.linear(x)
        
        x = f_1 + f_2
        # 1
        return x
    

class ConditionalAutoencoder(nn.Module):
    
    def __init__(self, enc_channels, dec_channels, num_classes,
                 dim_z = 128,
                 im_channels = 3):
        """
            enc_channels
                [64, 128, 256, 256]
                   c1   c2   c3
            dec_channels
                [256, 128, 64, 64]
                   c1   c2   c3
            num_classes
                if not None, use conditional batchnorm
        """
        super(ConditionalAutoencoder, self).__init__()
        
        n_enc_blks = len(enc_channels) - 1
        n_dec_blks = len(dec_channels) - 1
        assert(n_enc_blks > 0)
        assert(n_dec_blks > 0)
        
        self.n_enc_blks = n_enc_blks
        self.n_dec_blks = n_dec_blks
        self.bottom_width = 4
        self.nonlinearity = nn.ReLU()
        
        resblk_cls = ConditionalResidualBlock
        norm_layer = lambda num_features: ConditionalBatchNorm2d(num_features, num_classes)
        
        self.normalization_initial = norm_layer(im_channels)
        self.conv_initial = conv3x3(im_channels, enc_channels[0])
        
        for i in range(n_enc_blks):
            self.add_module(
                f'residual_block_enc_{i}',
                resblk_cls(enc_channels[i], enc_channels[i+1],
                           resample = "dn",
                           norm_layer = norm_layer,
                           nonlinearity = self.nonlinearity,
                           resblk_1st = True if i == 0 else False))
        
        for i in range(n_dec_blks):
            self.add_module(
                f'residual_block_dec_{i}',
                resblk_cls(dec_channels[i], dec_channels[i+1],
                           resample = "up",
                           norm_layer = norm_layer,
                           nonlinearity = self.nonlinearity))
            
        self.normalization_final = norm_layer(dec_channels[-1])
        self.conv_final = conv3x3(dec_channels[-1], im_channels)
        self.nonlinearity_final = nn.Tanh()

    def forward(self, x, c):
        """ x    batch_size x im_channels x h x w
            c    batch_size
            Returns  
                 batch_size x im_channels x h x w
        """
        c = c.view(-1)
        # bottom_width = 4
        # enc_channels = [64, 128, 256, 256]
        # dec_channels = [256, 128, 64, 64]
        # im_channnels = 3
        #
        # 3x32x32
        x = self.normalization_initial(x, c)
        x = self.nonlinearity(x)
        x = self.conv_initial(x)
        # 64x32x32
        for i in range(self.n_enc_blks):
            x = getattr(self, f'residual_block_enc_{i}')(x, c)
        # 256x4x4
        z = x
        # 256x4x4
        for i in range(self.n_dec_blks):
            x = getattr(self, f'residual_block_dec_{i}')(x, c)
        # 64x32x32
        x = self.normalization_final(x, c)
        x = self.nonlinearity(x)
        x = self.conv_final(x)
        x = self.nonlinearity_final(x)
        # 3x32x32
        return x, z

In [8]:
##############################
## conditional D
##############################

num_classes = 10

conv_channels = [3, 64, 128, 256]
conv_dnsample = [True, True, True]
D = OrdinalConditionalDiscriminator(conv_channels, conv_dnsample, num_classes)


x = torch.rand((50, 3, 32, 32))
c = torch.empty((50, 1), dtype=torch.long).random_(0, num_classes)
out = D(x, c)

print(x.shape, out.shape)

##############################
## conditional autoencoder
##############################

num_classes = 10
enc_channels = [64, 128, 256, 256]
dec_channels = [256, 128, 64, 64]
im_channnels = 3

G = ConditionalAutoencoder(enc_channels, dec_channels, num_classes, im_channels=im_channels)

x = torch.rand((50, 3, 32, 32))
c = torch.empty((50, 1), dtype=torch.long).random_(0, num_classes)
xhat, z = G(x, c)

print(x.shape, xhat.shape, z.shape)

torch.Size([50, 3, 32, 32]) torch.Size([50, 1])
torch.Size([50, 3, 32, 32]) torch.Size([50, 3, 32, 32]) torch.Size([50, 256, 4, 4])


In [9]:
os.makedirs(figure_root, exist_ok=True)
os.makedirs(model_root,  exist_ok=True)
os.makedirs(log_root,    exist_ok=True)

writer = SummaryWriter(log_root)
writer.flush()

torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transforms = tv_transforms.Compose([
    tv_transforms.Resize(image_size),
    tv_transforms.ToTensor(),
    tv_transforms.Normalize((0.5,), (0.5,)),
])

train_loader = torch.utils.data.DataLoader(
    ColorMNIST(root=data_root, download=True, train=True, transform=transforms),
    batch_size=batch_size, shuffle=True, num_workers=n_workers, pin_memory=True)

In [22]:
conv_channels = [256, 256, 128, 64]
conv_upsample = [True, True, True]

enc_channels = [64, 128, 256, 256]
dec_channels = [256, 128, 64, 64]

conv_channels = [im_channels, 64, 128, 256]
conv_dnsample = [True, True, True]

# G = ConditionalAutoencoder(enc_channels, dec_channels, num_classes=num_classes, dim_z=dim_z, im_channels=im_channels).to(device)
G = ConditionalGenerator(conv_channels, conv_upsample, num_classes=num_classes, dim_z=dim_z, im_channels=im_channels).to(device)
D = OrdinalConditionalDiscriminator(conv_channels, conv_dnsample, num_classes, use_sn=True).to(device)


In [23]:
criterion_D = nn.BCEWithLogitsLoss()

optimizer_G = torch.optim.Adam(G.parameters(), lr, (beta1, beta2))
optimizer_D = torch.optim.Adam(D.parameters(), lr, (beta1, beta2))

fixed_z = torch.randn(100, dim_z).to(device)
fixed_c = torch.arange(10).repeat(10).to(device)

real_label, fake_label = 0, 1

In [None]:
for epoch in range(n_epochs):
    for it, (x_real, c_digit, c_color) in enumerate(train_loader):

        # batch_size for last batch might be different ...
        batch_size = x_real.size(0)
        real_labels = torch.full((batch_size, 1), real_label, device=device)
        fake_labels = torch.full((batch_size, 1), fake_label, device=device)
        
        
        if target_type == 'digit':
            c_real = c_digit
        elif target_type == 'color':
            c_real = bin_index(c_color, num_classes)
        else:
            raise Exception()
        
        
        ##############################################################
        # Update Discriminator
        ##############################################################

        # a minibatch of samples from data distribution
        x_real, c_real = x_real.to(device), c_real.to(device)
        
        y = D(x_real, c_real)
        loss_D_real = criterion_D(y, real_labels)
        
        # a minibatch of samples from the model distribution
        z = torch.randn(batch_size, dim_z).to(device)
        c_fake = torch.empty(batch_size, dtype=torch.long).random_(0, num_classes).to(device)

        x_fake = G(z, c_fake)
        y = D(x_fake, c_fake)
        loss_D_fake = criterion_D(y, fake_labels)
        
        # backprop
        optimizer_D.zero_grad()
        loss_D = loss_D_real + loss_D_fake
        loss_D.backward()
        optimizer_D.step()
        
        ##############################################################
        # Update Generator/Encoder
        ##############################################################
        
        
        # a minibatch of samples from the model distribution
        z = torch.randn(batch_size, dim_z).to(device)
        c_fake = torch.empty(batch_size, dtype=torch.long).random_(0, num_classes).to(device)
        x_fake = G(z, c_fake)
        y = D(x_fake, c_fake)
        loss_G = criterion_D(y, real_labels)
        
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        ##############################################################
        # print
        ##############################################################

        loss_D = loss_D.item()
        loss_G = loss_G.item()
        loss_total = loss_D + loss_G

        global_step = epoch*len(train_loader)+it
        writer.add_scalar('loss/total', loss_total, global_step)
        writer.add_scalar('loss/D', loss_D, global_step)
        writer.add_scalar('loss/G', loss_G, global_step)

        if it % log_interval == log_interval-1:
            print(f'[{epoch+1}/{n_epochs}]\t'
                  f'[{(it+1)*batch_size}/{len(train_loader.dataset)} ({100.*(it+1)/len(train_loader):.0f}%)]\t'
                  f'loss: {loss_total:.4}\t'
                  f'loss_D: {loss_D:.4}\t'
                  f'loss_G: {loss_G:.4}\t')
            
            x_fake = G(fixed_z, fixed_c).detach()
            tv_utils.save_image(x_fake,
                os.path.join(figure_root,
                    f'{model_name}_fake_samples_epoch={epoch}_it={it}.png'), nrow=10, normalize=True)

            writer.add_image('mnist', tv_utils.make_grid(x_fake, nrow=10, normalize=True), global_step)
        

#     torch.save(G.state_dict(), os.path.join(model_root, f'G_epoch_{epoch}.pt'))
#     torch.save(D.state_dict(), os.path.join(model_root, f'D_epoch_{epoch}.pt'))

