# Generative Adversarial Networks (GANs)

In [1]:
import torch

from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST # Training dataset
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0) # Set for testing purposes, please do not change!

<torch._C.Generator at 0x23bb4de0c18>

In [2]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in a uniform grid.
    '''
    
    try:
        image_unflat = image_tensor.detach().cpu().view(-1, *size)
    
        image_grid = make_grid(image_unflat[:num_images], nrow=5)
    except:
        print('line 1')
    
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()
    

## Generator Network

In [3]:
def gen_block(dim_in, dim_out):
    block=nn.Sequential(nn.Linear(dim_in, dim_out),
                        nn.BatchNorm1d(dim_out),
                        nn.ReLU(inplace=True))
    
    return block

In [4]:
class GeneratorNetwork(nn.Module):

    def __init__(self, noise_dim=10, im_dim=784, hidden_dim=128):
        super(GeneratorNetwork, self).__init__()
        # Build the neural network
        self.gen = nn.Sequential(
            gen_block(noise_dim, hidden_dim),
            gen_block(hidden_dim, hidden_dim * 2),
            gen_block(hidden_dim * 2, hidden_dim * 4),
            gen_block(hidden_dim * 4, hidden_dim * 8),
            
            nn.Linear(hidden_dim * 8, im_dim),
            nn.Sigmoid()

        )
    def forward(self, noise):
        
        return self.gen(noise)
    
    # Needed for grading
    def get_gen(self):
        
        return self.gen

In [5]:
def get_noise(n_samples, z_dim, device='cpu'):
    
    return torch.randn(n_samples, z_dim, device=device)

In [6]:
def get_discriminator_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.LeakyReLU(negative_slope=0.2)
    )

In [7]:
class Discriminator(nn.Module):
    def __init__(self, im_dim=784, hidden_dim=128):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            get_discriminator_block(im_dim, hidden_dim * 4),
            get_discriminator_block(hidden_dim * 4, hidden_dim * 2),
            get_discriminator_block(hidden_dim * 2, hidden_dim),
            # Hint: You want to transform the final output into a single value,
            #       so add one more linear map.
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, image):
        return self.disc(image)
    
    # Needed for grading
    def get_disc(self):
        return self.disc

In [8]:
criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.00001

# Load MNIST dataset as tensors
dataloader = DataLoader(
    MNIST('.', download=True, transform=transforms.ToTensor()),
    batch_size=batch_size,
    shuffle=True)

### DO NOT EDIT ###
#device = 'cuda'
device = 'cpu'

In [9]:
gen = GeneratorNetwork(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator().to(device) 
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

In [10]:
def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):
    noise = get_noise(num_images, z_dim, device=device)
    fake_image = gen(noise)
    pred_disc_fake = disc(fake_image.detach())
    loss_fake = criterion(pred_disc_fake, torch.zeros_like(pred_disc_fake))
    pred_disc_real = disc(real)
    loss_real = criterion(pred_disc_real, torch.ones_like(pred_disc_real))
    disc_loss = 1/2 * (loss_fake + loss_real)
    return disc_loss

In [11]:
def get_gen_loss(gen, disc, criterion, num_images, z_dim, device):
    noise = get_noise(num_images, z_dim, device=device)
    fake_image = gen(noise)
    pred_disc_fake = disc(fake_image)
    gen_loss = criterion(pred_disc_fake, torch.ones_like(pred_disc_fake))
    return gen_loss

In [None]:
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
test_generator = True # Whether the generator should be tested
gen_loss = False
error = False
for epoch in range(n_epochs):
    
    print('Epoch: ', epoch+1, '/',n_epochs)
  
    # Dataloader returns the batches
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)

        # Flatten the batch of real images from the dataset
        real = real.view(cur_batch_size, -1).to(device)

        ### Update discriminator ###
        # Zero out the gradients before backpropagation
        disc_opt.zero_grad()

        # Calculate discriminator loss
        disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)

        # Update gradients
        disc_loss.backward(retain_graph=True)

        # Update optimizer
        disc_opt.step()

        # For testing purposes, to keep track of the generator weights
        if test_generator:
            old_generator_weights = gen.gen[0][0].weight.detach().clone()

        gen_opt.zero_grad()
        gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim, device)
        gen_loss.backward(retain_graph=True)
        gen_opt.step()


        if test_generator:
            try:
                assert lr > 0.0000002 or (gen.gen[0][0].weight.grad.abs().max() < 0.0005 and epoch == 0)
                assert torch.any(gen.gen[0][0].weight.detach().clone() != old_generator_weights)
            except:
                error = True
                print("Runtime tests have failed")


        mean_discriminator_loss += disc_loss.item() / display_step
        mean_generator_loss += gen_loss.item() / display_step
                
        cur_step += 1
        

Epoch:  1 / 200


  0%|          | 0/469 [00:00<?, ?it/s]

Epoch:  2 / 200


  0%|          | 0/469 [00:00<?, ?it/s]