In [None]:
%matplotlib inline

import os

# Generate images with condition labels
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch import nn

from tqdm import trange
import numpy as np
import time as t

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 30
batch_size = 200

### Tips

1. If you are trying to debug the model, train the model for 1 epoch and visualize the generator and discriminator loss.
2. First try to train the GAN on a smaller dataset (use digits 0 and 1), if the result looks resonable then train the model on the entire dataset.

Note: A single epoch of training should take 30-45 seconds for 1 epoch on GPU when bath_size=200. While tuning the `batch_size` and `epochs` may lead to better results, we do not expect you to tune these hyperparameters.

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

# transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

# Put your code here
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False) 

In [None]:
class Generator(torch.nn.Module):
    
    def __init__(self):
        
        super(Generator, self).__init__()
        
        ##################################################
        #                                                #
        #            ---- YOUR CODE HERE ----            #
        #                                                #
        ##################################################

    def forward(self, x):
        """The forward function should return batch of images."""
        
        ##################################################
        #                                                #
        #            ---- YOUR CODE HERE ----            #
        #                                                #
        ##################################################
        return


class Discriminator(torch.nn.Module):
    
    def __init__(self):
        super(Discriminator, self).__init__()

        ##################################################
        #                                                #
        #            ---- YOUR CODE HERE ----            #
        #                                                #
        ##################################################

    def forward(self, x):
        """The forward function should return the logits."""
        
        ##################################################
        #                                                #
        #            ---- YOUR CODE HERE ----            #
        #                                                #
        ##################################################
        
        return

class DCGAN(object):
    
    def __init__(self, epochs, batch_size):
        
        ##### ---- YOUR CODE HERE ---- #####
        self.G = 
        self.D = 
        self.loss = 
        ##### ----                ---- #####

        self.d_optimizer = torch.optim.Adam(self.D.parameters(), lr=0.0002, betas=(0.5, 0.999))
        self.g_optimizer = torch.optim.Adam(self.G.parameters(), lr=0.0002, betas=(0.5, 0.999))

        self.epochs = epochs
        self.batch_size = batch_size

        self.number_of_images = 10
        
    def train(self, train_loader):
        
        disc_loss = []
        genr_loss = []
        
        generator_iter = 0
        
        for epoch in trange(self.epochs):

            for i, (images, _) in enumerate(train_loader):
                
                # Step 1: Train discriminator
                z = torch.rand((self.batch_size, 100, 1, 1))
                
                real_labels = torch.ones(self.batch_size)
                fake_labels = torch.zeros(self.batch_size)

                images, z = images.to(device), z.to(device)
                real_labels, fake_labels = real_labels.to(device), fake_labels.to(device)

                # Compute the BCE Loss using real images
                real_logits = self.D(images)
                real_logits = torch.squeeze(real_logits)
                d_loss_real = self.loss(real_logits, real_labels)

                # Compute the BCE Loss using fake images
                fake_images = self.G(z)
                fake_logits = self.D(fake_images)
                fake_logits = torch.squeeze(fake_logits)
                d_loss_fake = self.loss(fake_logits, fake_labels)

                # Optimize discriminator
                d_loss = d_loss_real + d_loss_fake
                self.D.zero_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Step 2: Train Generator
                z = torch.randn(self.batch_size, 100, 1, 1).to(device)
                
                fake_images = self.G(z)
                fake_logits = self.D(fake_images)
                fake_logits = torch.squeeze(fake_logits)
                g_loss = self.loss(fake_logits, real_labels)

                self.D.zero_grad()
                self.G.zero_grad()
                g_loss.backward()
                self.g_optimizer.step()
                generator_iter += 1

                disc_loss.append(d_loss.item())
                genr_loss.append(g_loss.item())

        return disc_loss, genr_loss

    def generate_img(self, z, number_of_images):
        samples = self.G(z).data.cpu().numpy()[:number_of_images]
        generated_images = []
        for sample in samples:
            generated_images.append(sample.reshape(28, 28))
        return generated_images

In [None]:
model = DCGAN(epochs, batch_size)

In [None]:
disc_loss, genr_loss = model.train(train_loader)

### Plot the generated images 

In [None]:
z = torch.randn(100, 100, 1, 1).to(device)
images = model.generate_img(z, 100)
fig = plt.figure(figsize=(10, 10))
gs = gridspec.GridSpec(10, 10)
gs.update(wspace=0.05, hspace=0.05)

for i in range(10):
    for j in range(10):
        sample = images[i*10+j]
        ax = plt.subplot(gs[i, j])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

### Plot the discriminator and generator loss.

In [None]:
def plot_gan_losses(disc_loss, genr_loss):
    fig = plt.figure(figsize=(20,8))
    fig.add_subplot(121)
    plt.title('Discriminator Loss', fontsize=16)
    plt.xlabel('Iteration', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt1 = plt.plot(disc_loss)
    fig.add_subplot(122)
    plt.title('Generator Loss', fontsize=16)
    plt.xlabel('Iteration', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt2 = plt.plot(genr_loss)
    
    plt.show()

In [None]:
plot_gan_losses(disc_loss, genr_loss)