## Conditional GAN on MNIST

In [None]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import os
import numpy as np

define parameters

In [None]:
image_size = [1, 28, 28]
latent_dim = 96
label_emb_dim = 32
use_gpu = torch.cuda.is_available()
num_epoch, batch_size = 100, 32
save_dir = 'cgan_images'
os.makedirs(save_dir, exist_ok=True)

define generator model

In [None]:
class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()

        self.embedding = nn.Embedding(10, label_emb_dim)
        self.model = nn.Sequential(
            nn.Linear(latent_dim + label_emb_dim, 128),
            nn.BatchNorm1d(128),
            nn.GELU(),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.GELU(),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.Linear(512, 1024),
            nn.GELU(),
            nn.Linear(1024, np.prod(image_size, dtype=np.int32)),
            nn.Sigmoid(),
        )

    def forward(self, z, labels):
        # shape of z: [batch_size, latent_dim]
        label_embedding = self.embedding(labels)
        z = torch.cat([z, label_embedding], dim=-1)
        output = self.model(z)
        image = output.reshape(z.shape[0], *image_size)

        return image

generator = Generator()

define discriminator model

In [None]:
class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()

        self.embedding = nn.Embedding(10, label_emb_dim)
        self.model = nn.Sequential(
            nn.Linear(np.prod(image_size, dtype=np.int32) + label_emb_dim, 512), nn.GELU(),
            nn.utils.spectral_norm(nn.Linear(512, 256)), nn.GELU(),
            nn.utils.spectral_norm(nn.Linear(256, 128)), nn.GELU(),
            nn.utils.spectral_norm(nn.Linear(128, 64)), nn.GELU(),
            nn.utils.spectral_norm(nn.Linear(64, 32)), nn.GELU(),
            nn.utils.spectral_norm(nn.Linear(32, 1)), nn.Sigmoid(),
        )

    def forward(self, image, labels):
        # shape of image: [batch_size, 1, 28, 28]
        label_embedding = self.embedding(labels)
        prob = self.model(torch.cat([image.reshape(image.shape[0], -1), label_embedding], dim=-1))

        return prob

discriminator = Discriminator()

load data

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(28),
    torchvision.transforms.ToTensor(),
    # torchvision.transforms.Normalize(mean=[0.5], std=[0.5]),
])

dataset = torchvision.datasets.MNIST("mnist_data", train=True, download=True, transform=transform)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

define optimizer and loss function

In [None]:
g_optimizer = torch.optim.Adam(generator.parameters(), lr=3e-4, betas=(0.4, 0.8), weight_decay=1e-4)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=3e-4, betas=(0.4, 0.8), weight_decay=1e-4)

loss_fn = nn.BCELoss()
labels_one = torch.ones(batch_size, 1)
labels_zero = torch.zeros(batch_size, 1)

if use_gpu:
    print(f'training on gpu')
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    loss_fn = loss_fn.cuda()
    labels_one = labels_zero.to('cuda')
    labels_zero = labels_zero.to('cuda')

train:

In [None]:
for epoch in range(num_epoch):
    print(f'epcoh {epoch}: ')
    for i, mini_batch in enumerate(dataloader):
        gt_images, labels = mini_batch

        z = torch.randn(batch_size, latent_dim)

        if use_gpu:
            gt_images = gt_images.to('cuda')
            labels = labels.to('cuda')
            z = z.to('cuda')

        pred_images = generator(z, labels)

        g_optimizer.zero_grad()

        recons_loss = torch.abs(pred_images - gt_images).mean()

        g_loss = recons_loss * 0.05 + loss_fn(discriminator(pred_images, labels), labels_one)

        g_loss.backward()
        g_optimizer.step()

        d_optimizer.zero_grad()

        real_loss = loss_fn(discriminator(gt_images, labels), labels_one)
        fake_loss = loss_fn(discriminator(pred_images.detach(), labels), labels_zero)
        d_loss = real_loss + fake_loss

        # model is stable when you observing real_loss and fake_loss both minimize

        d_loss.backward()
        d_optimizer.step()

        if i % 50 == 0:
            print(f'step:{len(dataloader) * epoch + i}, recons_loss:{recons_loss.item()}, g_loss:{g_loss.item()},'
                  f'd_loss:{d_loss.item()}, real_loss:{real_loss.item()}, fake_loss:{fake_loss.item()}')

        if i % 800 == 0:
            image = pred_images[:16].data
            torchvision.utils.save_image(image, f"{save_dir}/image_{len(dataloader)*epoch + i}.png", nrow=4)