## Introduction

We will use Python 3, [NumPy](https://numpy.org/), and [PyTorch](https://pytorch.org/) packages for implementation. To avoid unexpected issues with PyTorch 2.0, we recommend using PyTorch version 1.x.

In this coding project, you will implement 4 generative models, i.e., energy-based model, flow-based model, variational auto-encoder, and generative adverserial network, to generate MNIST images.

**We will implement a generative adversarial network, specifically a [Deep Convolutional Generative Adversarial Network](https://arxiv.org/abs/1511.06434) (DCGAN), in this notebook.**

If you use Colab in this coding project, please uncomment the cell below, change the `GOOGLE_DRIVE_PATH` to your project folder and run the following cell to mount your Google drive. Then, the notebook can find the required files (i.e., utils.py). If you run the notebook locally, you can skip this cell.

In [None]:
# ### uncomment this cell if you're using Google colab
# from google.colab import drive
# drive.mount('/content/drive')

# ### change GOOGLE_DRIVE_PATH to the path of your CP3 folder
# GOOGLE_DRIVE_PATH = '/content/drive/MyDrive/Colab Notebooks/DL23SP/CP3'
# %cd $GOOGLE_DRIVE_PATH

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import rcParams

%matplotlib inline

# figure size in inches optional
rcParams['figure.figsize'] = 6, 4
plt.imshow(mpimg.imread('./gan/sample.png'))

## Set Up Code

If you use Colab in this coding project, please make sure to mount your drive before running the cells below.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from utils import hello
hello()

Finally, please run the following cell to import some base classes for implementation (no matter whether you use colab).

In [None]:
from collections import deque
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm

import numpy as np
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from utils import save_model, load_model, train_set

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

device = torch.device(
    "cuda") if torch.cuda.is_available() else torch.device("cpu")

## Generator

Please implement your own generator module, which should be a fully convolutional network with linear projection heads.

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_size):
        super().__init__()
        self.label_size = label_size = 10
        self.latent_size = latent_size
        self.hidden_size = hidden_size = 512
        ##############################################################################
        #                  TODO: You need to complete the code here                  #
        ##############################################################################
        # YOUR CODE HERE
        ngf = 64
        self.main = nn.Sequential(
            nn.ConvTranspose2d(latent_size + label_size, ngf*4, 3, 2, 0, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf*4, ngf*2, 3, 2, 0, bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf*2, ngf, 3, 2, 0, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, 1, 3, 2, 2, 1, bias=False),
        )
        ##############################################################################
        #                              END OF YOUR CODE                              #
        ##############################################################################

    def forward(self, z, label):
        ##############################################################################
        #                  TODO: You need to complete the code here                  #
        ##############################################################################
        # YOUR CODE HERE
        label = label.view(-1, 1)
        eye = torch.eye(self.label_size).to(device)
        onehot = eye[label].view(-1, self.label_size).to(device)
        x = torch.cat((z, onehot), dim = 1).view(-1, self.latent_size + self.label_size, 1, 1)
        #x = self.prep(torch.cat((z, onehot), dim = 1)).view(-1, self.hidden_size, 1, 1)
        x = self.main(x)
        return x
        ##############################################################################
        #                              END OF YOUR CODE                              #
        ##############################################################################

    @torch.no_grad()
    def sample_images(self, label, save=True, save_dir='./gan'):
        self.eval()
        n_samples = label.shape[0]
        samples = self(torch.randn(
            n_samples, self.latent_size).to(label.device), label)
        imgs = torch.sigmoid(samples).view(n_samples, 1, 28, 28)
        if save:
            os.makedirs(save_dir, exist_ok=True)
            torchvision.utils.save_image(imgs, os.path.join(
                save_dir, 'sample.png'), nrow=int(np.sqrt(n_samples)))
        return imgs
    
    @torch.no_grad()
    def make_dataset(self, n_samples_per_class=10, save=True, save_dir='./gan/generated/'):
        self.eval()
        device = next(self.parameters()).device
        for i in range(self.label_size):
            label = i * torch.ones(n_samples_per_class, dtype=torch.long, device=device)
            samples = self(torch.randn(
                n_samples_per_class, self.latent_size).to(device), label)
            # imgs = torch.sigmoid(samples).view(n_samples_per_class, 1, 28, 28)
            imgs = torch.clip(samples, 0, 1).view(n_samples_per_class, 1, 28, 28)
            print(f"Standard deviation of number {i}: {torch.std(imgs, dim=0).mean().item()}")
            if save:
                os.makedirs(os.path.join(save_dir, str(i)), exist_ok=True)
                for j in range(n_samples_per_class):
                    torchvision.utils.save_image(imgs[j], os.path.join(save_dir, str(i), "{}_{:>03d}.png".format(i, j)))

## Discriminator

Please implement your own discriminator module, which should also be a fully convolutional network with linear projection heads.

**Hint**: Pay attention to the shape of your output and the shape of label.

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.label_size = label_size = 10
        self.input_size = input_size = 28*28
        self.hidden_size = hidden_size = 512
        ##############################################################################
        #                  TODO: You need to complete the code here                  #
        ##############################################################################
        # YOUR CODE HERE
        ndf = 28
        self.main = nn.Sequential(
            nn.Conv2d(1, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf*4, 2, 4, 2, 1, bias=False),
            nn.Sigmoid()
        )
        self.prep = nn.Linear(input_size+label_size, input_size)
        ##############################################################################
        #                              END OF YOUR CODE                              #
        ##############################################################################

    def forward(self, img, label):
        ##############################################################################
        #                  TODO: You need to complete the code here                  #
        ##############################################################################
        # YOUR CODE HERE
        feature = torch.cat((img.view(img.shape[0], img.shape[1], self.input_size), label.reshape(img.shape[0], img.shape[1], self.label_size)), dim=-1)
        feature = self.prep(feature).view(img.shape[0], img.shape[1], 28, 28)
        out = self.main(feature).view(-1, 2)
        return out
        ##############################################################################
        #                              END OF YOUR CODE                              #
        ##############################################################################

## Classifier

Please implement your own classifier module, which should also be a fully convolutional network with linear projection heads.

In [None]:
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.label_size = label_size = 10
        self.input_size = input_size = 28*28
        self.hidden_size = hidden_size = 512
        ##############################################################################
        #                  TODO: You need to complete the code here                  #
        ##############################################################################
        # YOUR CODE HERE
        ndf = 28
        self.main = nn.Sequential(
            nn.Conv2d(1, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf*4, label_size, 4, 2, 1, bias=False),
            nn.Sigmoid()
        )
        ##############################################################################
        #                              END OF YOUR CODE                              #
        ##############################################################################

    def forward(self, img):
        ##############################################################################
        #                  TODO: You need to complete the code here                  #
        ##############################################################################
        # YOUR CODE HERE
        out = self.main(img).view(-1, self.label_size)
        return out
        ##############################################################################
        #                              END OF YOUR CODE                              #
        ##############################################################################

## Training

We have implemented the skeleton of train function. Please complete the missing loss computation part.

In [None]:
def paint(d_list, g_list, c_list):
  plt.plot(d_list, label='train loss of discriminator')
  plt.plot(g_list, label='train loss of generator')
  plt.plot(c_list, label='train loss of classifier')
  plt.legend()
  plt.xlabel('Number of Samples')
  plt.ylabel('Loss')
  plt.title('Training Curves')
  plt.savefig("./g.png")

In [None]:
def train(n_epochs, generator, discriminator, classifier, train_loader, optimizer_g, optimizer_d, optimizer_c, device=torch.device('cuda'), save_interval=10):
    generator.to(device)
    discriminator.to(device)
    classifier.to(device)
    D_LOSS = []
    G_LOSS = []
    C_LOSS = []
    for epoch in range(n_epochs):
        train_g_loss = train_d_loss = train_c_loss = 0
        n_batches = 0
        pbar = tqdm(total=len(train_loader.dataset))
        for i, (x, y) in enumerate(train_loader):
            # compute loss
            n_batches += x.shape[0]
            x_real = x.to(device)
            y_real = y.to(device)
            ##############################################################################
            #                  TODO: You need to complete the code here                  #
            ##############################################################################
            # YOUR CODE HERE
            z = torch.randn(x.shape[0], generator.latent_size).to(device)
            y_fake = y_real

            criterion = nn.CrossEntropyLoss(reduction='sum')

            classifier.train()
            predict = classifier(x_real)
            c_loss = criterion(predict, y_real)
            
            optimizer_c.zero_grad()
            c_loss.backward()
            optimizer_c.step()
            
            discriminator.train()
            x_fake = torch.sigmoid(generator(z, y_fake))
            d_real = discriminator(x_real, F.one_hot(y_real))
            d_fake = discriminator(x_fake, classifier(x_fake))
            # d_fake = discriminator(x_fake, F.one_hot(y_fake))
            
            d_loss_real = criterion(d_real, torch.ones_like(y_real))
            d_loss_fake = criterion(d_fake, torch.zeros_like(y_fake))
            
            d_loss = d_loss_real + d_loss_fake
            
            optimizer_d.zero_grad()
            d_loss.backward()
            optimizer_d.step()

            generator.train()
            x_fake = torch.sigmoid(generator(z, y_fake))
            pred = classifier(x_fake)
            d_fake = discriminator(x_fake, pred)
            # d_fake = discriminator(x_fake, torch.ones_like(y_fake))
            g_loss_disc = criterion(d_fake, torch.ones_like(y_fake))
            g_loss_class = criterion(pred, y_fake)
            g_loss = g_loss_class + g_loss_disc
            
            optimizer_g.zero_grad()
            g_loss.backward()
            optimizer_g.step()
            ##############################################################################
            #                              END OF YOUR CODE                              #
            ##############################################################################

            D_LOSS.append(d_loss.sum().item())
            G_LOSS.append(g_loss.sum().item())
            C_LOSS.append(c_loss.sum().item())
            
            train_g_loss += g_loss.sum().item()
            train_d_loss += d_loss.sum().item()
            train_c_loss += c_loss.sum().item()

            pbar.update(x.size(0))
            pbar.set_description('Train Epoch {}, Generator Loss: {:.6f}, Discriminator Loss: {:.6f}, Classifier Loss: {:.6f}'.format(
                epoch + 1, train_g_loss / n_batches, train_d_loss / n_batches, train_c_loss / n_batches))
        pbar.close()

        if (epoch + 1) % save_interval == 0:
            os.makedirs(f'./gan/{epoch + 1}', exist_ok=True)
            save_model(f'./gan/{epoch + 1}/gan.pth', generator, optimizer_g,
                       discriminator=discriminator, optimizer_d=optimizer_d)

            # sample and save images
            label = torch.arange(10).repeat(10).to(device)
            generator.sample_images(
                label, save=True, save_dir=f"./gan/{epoch + 1}/")
            save_model('./gan/gan_best.pth', generator, optimizer_g, discriminator, optimizer_d)
    paint(d_list = D_LOSS, g_list = G_LOSS, c_list = C_LOSS)

## Enjoy

Tune your hyperparameters and make your conditional DCGAN work. Good luck!

In [None]:
train_loader = DataLoader(train_set, batch_size=128, pin_memory=True,
                          drop_last=False, shuffle=True, num_workers=8)

g = Generator(100)
d = Discriminator()
c = Classifier()

optimizer_g = torch.optim.Adam(g.parameters(), lr = 0.0002)
optimizer_d = torch.optim.Adam(d.parameters(), lr = 0.0002)
optimizer_c = torch.optim.Adam(c.parameters(), lr = 0.0002)

train(30, g, d, c, train_loader, optimizer_g, optimizer_d, optimizer_c, device=device)

state_dict = load_model('./gan/gan_best.pth')[0]
g.load_state_dict(state_dict)
g.make_dataset(n_samples_per_class=100)

## Evaluation

Make sure your code runs fine with the following cell!

In [None]:
!python evaluate_cgen.py --gan