# Hands-On Machine Learning

## Session 9: Generative Adversarial networks
by Justus Schock, Christoph Haarburger, Laxmi Gupta

### Goals of this Session

In this session you will...
* learn how to implement a vanilla "Goodfellow-like" generative adversarial network
* train a GAN
* learn how to implement a convolutional GAN

We'll be working with the MNIST dataset, which you already know from previous sessions.

In [None]:
import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt
import numpy as np
import os
%matplotlib inline

In [None]:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(0)
np.random.seed(0)

## Hyperparameters

In [None]:
latent_size = 100
hidden_size = 256
image_size = 784
num_epochs = 10

**Task:** Set up the necessary transforms that are applied to the input images. You've already done this in the previous sessions.

We'd like to transform our data to `Tensor`s and normalize it with 0.5 mean and 0.5 variance.

After defining the transforms, we can initialize the dataset.

In [None]:
from torchvision import transforms
from torchvision.datasets import MNIST

dset_path = os.path.join(os.path.abspath(os.environ["HOME"]),
                         'datasets')
# We make sure that the dataset is actually available
try:
    torchvision.datasets.MNIST(root=dset_path,
                               download=False)
except RuntimeError:
    if not os.path.isdir(dset_path):
            os.makedirs(dset_path)
    torchvision.datasets.MNIST(root=dset_path,
                               download=True)


transform = # your job
dataset = # your job, dataset is located at dset_path


Now we can set up the `DataLoader`:

In [None]:
batch_size = 128
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=batch_size,
                                          shuffle=True, drop_last=True)

**Task:** Visualize some samples of a single batch using the `plot_batch()` helper function.

In [None]:
from utils import plot_batch



## Build Discriminator

**Task:** Implement the Discriminator as a sequential model consisting of three `nn.Linear`, `nn.ReLU` layers and a binary output with `nn.Sigmoid` activation.

We will feed the images into the network as flattened arrays.

In [None]:
discriminator = nn.Sequential(
)


## Build Generator

**Task:** Implement the Generator as a sequential model of two `nn.Linear` layers with `nn.ReLU()` activation and a final `nn.Linear` layer with `nn.Tanh` activation that maps from a random vector to the original flattened image dimension.

In [None]:
generator = nn.Sequential(
)


### General Setup for Training

By the follwing line we can dynamically determine whether we want to run our model and optimization on CPU or GPU

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

**Task:** Send both `generator` and `discriminator` to the `device`. You can do that simply by calling `{mymodule}.to(device)`

**Task:** Now we can initialize the optimizers for both generator and discriminator. We'll use `Adam` with a learning rate of `0.0005` here.

Now we only need to initialize the cross entropy loss before getting started with the training loop 😊

Compute BCE_Loss using real images where `BCE_Loss` = - y * log(D(x)) - (1-y) * log(1 - D(x)). The second term of the loss is always zero since real_labels == 1.

In [None]:
criterion = nn.BCELoss()

### Training Loop
Finally, we can write the training loop and look at actual results. Since this is quite a bit of code, please perform the implementation in the follwing steps:

**Task:** Train your GAN and watch your fake images get better and better. It might take up to 20 Minutes to produce good results.

In [None]:
total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape((images.size(0), -1)).to(device)
        
        # Create the labels which are later used as input for the BCE loss
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # ================================================================== #
        #                      Train the discriminator                       #
        # ================================================================== #

        outputs = discriminator(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        # Compute BCELoss using fake images
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = generator(z)
        outputs = discriminator(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        # Backprop and optimize
        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # ================================================================== #
        #                        Train the generator                         #
        # ================================================================== #

        # Compute loss with fake images
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = generator(z)
        outputs = discriminator(fake_images)
        
        # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
        g_loss = criterion(outputs, real_labels)
        
        # Backprop and optimize
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                        real_score.mean().item(), fake_score.mean().item()))
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28).cpu().detach().numpy()
    plot_batch(fake_images)

Vanilla GANs as implemented above are difficult to train because of the high number of the many feedforward connections. DCGAN aims to overcome this by using Convolutional instead of Fully-connected layers.

In the next cells we will implement the same Training loop and setup as above, but using a DCGAN. 

## Build DCGAN

**Task:** Implement the generator of a DCGAN. To allow you to train the DCGAN also with a different dataset later, we change the input resolution from `28x28` as before to `64x64`.

### Generator

In [None]:
g_c = torch.nn.Sequential(
)


### Discriminator

**Task:** Implement the discriminator of a DCGAN. To allow you to train the DCGAN also with a different dataset later, we change the input resolution from `28x28` as before to `64x64`.

In [None]:
d_c = torch.nn.Sequential(
)


**Task:** Once again, we have to send the networks to our device and initialize the optimizers. Use the same learning rate as above.

In [None]:
criterion = nn.BCELoss()

# send to device

# optimizer



### Transforms

**Task:** To adapt the MNIST images to the new expected input dimension of `64x64`, we need to adapt our `transforms` accordingly. Find the right transform to resize all images to `64x64` and initialize dataset and data loader with the new transforms. 

In [None]:
transform_c = transforms.Compose([
                    
])
dataset = torchvision.datasets.MNIST(root=dset_path,
                                     train=True,
                                     transform=transform_c,
                                     download=False)
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=batch_size,
                                          shuffle=True, drop_last=True)


### Training Loop Reloaded

Now we can build the training loop for the DCGAN. But don't worry, we can mostly copy our code from above. Take care of the changed tensor dimensions, though.

**Task:** Implement the Training loop from above and visualize a batch of fake data after every epoch. How does the training differ from the vanilla GAN training?

## Bonus Task

Congratulations, you've successfully trained a DCGAN that can generate MNIST numbers. If you are brave, it's time to generate some more complex images.

**Task:** Have a look at the datasets that are included in `torchvision` and train a DCGAN with another dataset.

### Feedback

That's it, we're done 👏🏼🍻

If you have any suggestions on how we could improve this session, please let us know in the following cell. What did you particularly like or dislike? Did you miss any contents?