<div class="alert alert-block alert-info">
<b>Number of points for this notebook:</b> 2
<br>
<b>Deadline:</b> May 23, 2020 (Saturday) 23:00
</div>

# Exercise 11.2. Generative adversarial networks (GANs). WGAN-GP: Wasserstein GAN with gradient penalty

The goal of this exercise is to get familiar with WGAN-GP: one of the most popular versions of GANs, which is relatively easy to train.

The algorithm was introduced in the paper [Improved Training of Wasserstein GANs](https://arxiv.org/pdf/1704.00028.pdf).

In [0]:
skip_training = True  # Set this flag to True before validation and submission

In [0]:
# During evaluation, this cell sets skip_training to True
# skip_training = True

In [0]:
import os
import numpy as np
import matplotlib.pyplot as plt
from IPython import display

import torch
import torchvision
import torch.nn as nn
from torch.nn import functional as F
import torchvision.transforms as transforms
import torchvision.utils as utils

import tools
import tests

In [6]:
# When running on your own computer, you can specify the data directory by:
# data_dir = tools.select_data_dir('/your/local/data/directory')
data_dir = tools.select_data_dir()

The data directory is ../data


In [0]:
device = torch.device('cuda:0')
# device = torch.device('cpu')

In [0]:
if skip_training:
    # The models are always evaluated on CPU
    device = torch.device("cpu")

# Data

We will use MNIST data in this exercise. Note that we re-scale images so that the pixel intensities are in the range [-1, 1].

In [0]:
transform = transforms.Compose([
    transforms.ToTensor(),  # Transform to tensor
    transforms.Normalize((0.5,), (0.5,))  # Scale to [-1, 1]
])

trainset = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)

# Wasserstein GAN (WGAN)

The WGAN value function is constructed as
$$
  \min_G \max_{D \in \mathcal{D}} E_{x∼P_r}[D(x)] − E_{\tilde x∼P_g}[D(\tilde x)]
$$
where
* the dicriminator $D$ (called critic in WGAN) is constrained to be from the set $\mathcal{D}$ of 1-Lipschitz functions
* $P_r$ is the data distribution
* $P_g$ is the model distribution. Samples from the model distribution are produced as follows:
\begin{align}
z &\sim N(0, I)
\\
\tilde x &= G(z)
\end{align}

## Generator

Implement the generator in the cell below. We recommend you to use the same architecture of the generator as in Exercise 11.1.

In [0]:
class Generator(nn.Module):
    def __init__(self, nz, ngf, nc):
        """WGAN generator.
        
        Args:
          nz:  Number of elements in the latent code.
          ngf: Base size (number of channels) of the generator layers.
          nc:  Number of channels in the generated images.
        """
        super(Generator, self).__init__()
        # YOUR CODE HERE
        self.conv1 = nn.ConvTranspose2d(in_channels=nz, out_channels=4*ngf, kernel_size=4, stride=2,padding=1, bias=False)
        self.conv2 = nn.ConvTranspose2d(in_channels=4*ngf, out_channels=2*ngf, kernel_size=4, stride=2, bias=False)
        self.conv3 = nn.ConvTranspose2d(in_channels=2*ngf, out_channels=ngf, kernel_size=4, stride=2, bias=False)
        self.conv4 = nn.ConvTranspose2d(in_channels=ngf, out_channels=nc, kernel_size=4, stride=2, padding=1, bias=False)
        
        self.bn1 = nn.BatchNorm2d(4*ngf)
        self.bn2 = nn.BatchNorm2d(2*ngf)
        self.bn3 = nn.BatchNorm2d(ngf)
        #raise NotImplementedError()

    def forward(self, z, verbose=False):
        """Generate images by transforming the given noise tensor.
        
        Args:
          z of shape (batch_size, nz, 1, 1): Tensor of noise samples. We use the last two singleton dimensions
              so that we can feed z to the generator without reshaping.
          verbose (bool): Whether to print intermediate shapes (True) or not (False).
        
        Returns:
          out of shape (batch_size, nc, 28, 28): Generated images.
        """
        # YOUR CODE HERE
        z = F.relu(self.bn1(self.conv1(z))) #b,nz,1,1 -> b,4*ngf,2,2
#         print(z.shape)
        z = F.relu(self.bn2(self.conv2(z))) #b,4*ngf,2,2 -> b,2*ngf,6,6
#         print(z.shape)
        z = F.relu(self.bn3(self.conv3(z))) #b,2*ngf,6,6 -> b,ngf,14,14
#         print(z.shape)
        z = torch.tanh(self.conv4(z)) #b,ngf,14,14 -> b,nc,28,28
#         print(z.shape)
             
        return z
        #raise NotImplementedError()

In [13]:
def test_Generator_shapes():
    batch_size = 32
    nz = 10
    netG = Generator(nz, ngf=64, nc=1)

    noise = torch.randn(batch_size, nz, 1, 1)
    out = netG(noise, verbose=True)

    assert out.shape == torch.Size([batch_size, 1, 28, 28]), f"Bad out.shape: {out.shape}"
    print('Success')

test_Generator_shapes()

Success


### Loss for training the generator

The generator is trained to minimize the relevant part of the value function using a fixed critic $D$:
$$
  \min_G − E_{\tilde{x} \sim P_g}[D( \tilde x)]
$$

In [0]:
def generator_loss(D, fake_images):
    """Loss computed to train the WGAN generator.

    Args:
      D: The critic whose forward function takes inputs of shape (batch_size, nc, 28, 28)
         and produces outputs of shape (batch_size, 1).
      fake_images of shape (batch_size, nc, 28, 28): Fake images produces by the generator.

    Returns:
      loss: The relevant part of the WGAN value function.
    """
    # YOUR CODE HERE
    loss = -torch.mean(D(fake_images))
    return loss
    #raise NotImplementedError()

In [0]:
# This cell tests generator_loss()

## Critic

In WGAN, the discriminator is called a critic because it is not trained to classify.

Implement the WGAN critic in the cell below. You can use almost the same architecture as the architecture of the discriminator in Exercise 11.1. The difference is that there is no need to use `sigmoid` nonlinearity in the output layer because the output of the critic does not have to be between 0 and 1.

In [0]:
class Critic(nn.Module):
    def __init__(self, nc=1, ndf=64):
        """
        Args:
          nc:  Number of channels in the images.
          ndf: Base size (number of channels) of the critic layers.
        """
        # YOUR CODE HERE
        super(Critic, self).__init__()
        self.conv1 = nn.Conv2d(nc, ndf, 4, stride=2, padding=1, bias=False)
        self.conv2 = nn.Conv2d(ndf, 2*ndf, 4, stride=2, bias=False)
        self.conv3 = nn.Conv2d(2*ndf, 4*ndf, 4, stride=2, bias=False)
        self.conv4 = nn.Conv2d(4*ndf, nc, 4, stride=2, padding=1, bias=False)
        
        self.bn1 = nn.BatchNorm2d(ndf)
        self.bn2 = nn.BatchNorm2d(2*ndf)
        self.bn3 = nn.BatchNorm2d(4*ndf)
        
        self.l_relu = nn.LeakyReLU(0.2)
        #raise NotImplementedError()

    def forward(self, x, verbose=False):
        """
        Args:
          x of shape (batch_size, 1, 28, 28): Images to be evaluated.
        
        Returns:
          out of shape (batch_size,): Critic outputs for images x.
        """
        # YOUR CODE HERE
        x = self.l_relu(self.bn1(self.conv1(x))) #b,nc,28,28 -> b,ndf,14,14
        x = self.l_relu(self.bn2(self.conv2(x))) #b,ndf,14,14 -> b,2*ndf,6,6
        x = self.l_relu(self.bn3(self.conv3(x))) #b,2*ndf,6,6 -> b,4*ndf,2,2
        x = self.conv4(x) #b,4*ndf,2,2 -> b,1,1,1
        return x.squeeze()
        #raise NotImplementedError()

In [17]:
def test_Critic_shapes():
    nz = 10  # size of the latent z vector
    netD = Critic(nc=1, ndf=64)

    batch_size = 32
    images = torch.ones(batch_size, 1, 28, 28)
    out = netD(images, verbose=True)
    assert out.shape == torch.Size([batch_size]), f"Bad out.shape: {out.shape}"
    print('Success')

test_Critic_shapes()

Success


### Loss for training the WGAN critic

Recall the value function of WGAN:
$$
  \min_G \max_{D \in \mathcal{D}} E_{x∼P_r}[D(x)] − E_{\tilde x∼P_g}[D(\tilde x)]
$$
To tune the critic, we need to minimize the following function:
$$
  \min_{D \in \mathcal{D}} - E_{x∼P_r}[D(x)] + E_{\tilde x∼P_g}[D(\tilde x)]
$$
You need to implement this loss function *assuming no constraints on D* in the function below.

In [0]:
def critic_loss(critic, real_images, fake_images):
    """
    Args:
      critic: The critic.
      real_images of shape (batch_size, nc, 28, 28): Real images.
      fake_images of shape (batch_size, nc, 28, 28): Fake images.

    Returns:
      loss (scalar tensor): Loss for training the WGAN critic.
    """
    # YOUR CODE HERE
    d_fake = critic(fake_images)
    d_real = critic(real_images)
    loss = torch.mean(d_fake) - torch.mean(d_real) #maximize reward for real, min for fake
    return loss
    #raise NotImplementedError()

In [0]:
# This cell tests critic_loss()

Without constraints on $D$, the WGAN value function can be made infinitely large. WGAN constrains the derivative of $D$ using a gradient penalty. The penalty is computed at random points between real images and generated ones using the following procedure:
* Given a real image $x$ and a fake image $\tilde x$, draw a random number $\epsilon \sim U[0,1]$
* $\hat{x} \leftarrow \epsilon x + (1−\epsilon) \tilde x$
* Compute the gradient penalty $(‖\nabla_{\hat{x}} D(\hat{x})‖_2−1)^2$
where $\nabla_{\hat{x}} D(\hat{x})$ is the gradient of $D$ computed at $\hat{x}$.

Your task is to implement the gradient penalty in the cell below. Note that we need to compute the gradient $\nabla D$ so that we can differentiate through the gradient when computing the derivatives wrt the parameters of the critic. This can be achieved by using function [torch.autograd.grad](https://pytorch.org/docs/stable/autograd.html#torch.autograd.grad) which can create a computational graph with the gradient computations.

In [20]:
torch.rand(10,2,3).norm()

tensor(4.6859)

In [0]:
def gradient_penalty(critic, real, fake_detached):
    """
    Args:
      critic: The critic.
      real of shape (batch_size, nc, 28, 28): Real images.
      fake_detached of shape (batch_size, nc, 28, 28): Fake images (detached from the computational graph).

    Returns:
      grad_penalty (scalar tensor): Gradient penalty.
      x of shape (batch_size, nc, 28, 28): Points x-hat in which the gradient penalty is computed.
    """
    # YOUR CODE HERE
    #print("iput shape: ", real.shape)
    batch_size = real.size(0)
    eps = torch.rand(1, device=real.device)
    x_cap = eps*real + (1-eps)*fake_detached
    x_cap.requires_grad = True
    d_out = critic(x_cap)
    grads = torch.autograd.grad(outputs=d_out, inputs=x_cap, grad_outputs=torch.ones(d_out.size(), device=real.device), create_graph=True, retain_graph=True)
    panelty = ((grads[0].norm(2)-1)**2)
    return panelty, x_cap
    #raise NotImplementedError()

In [30]:
tests.test_gradient_penalty(gradient_penalty)

loss: tensor(729., grad_fn=<PowBackward0>)
expected: tensor(729.)
Success


# Training WGAN-GP

We will now train WGAN-GP. To assess the quality of the generated samples, we will use a simple scorer loaded in the cell below.

In [31]:
from scorer import Scorer
scorer = Scorer()
scorer.to(device)

Sequential(
  (fc1): Linear(in_features=784, out_features=256, bias=True)
  (relu1): ReLU()
  (drop1): Dropout(p=0.2, inplace=False)
  (fc2): Linear(in_features=256, out_features=256, bias=True)
  (relu2): ReLU()
  (drop2): Dropout(p=0.2, inplace=False)
  (out): Linear(in_features=256, out_features=10, bias=True)
)


Scorer(
  (model): MLP(
    (model): Sequential(
      (fc1): Linear(in_features=784, out_features=256, bias=True)
      (relu1): ReLU()
      (drop1): Dropout(p=0.2, inplace=False)
      (fc2): Linear(in_features=256, out_features=256, bias=True)
      (relu2): ReLU()
      (drop2): Dropout(p=0.2, inplace=False)
      (out): Linear(in_features=256, out_features=10, bias=True)
    )
  )
)

In [0]:
# Create the network
nz = 10
netG = Generator(nz=nz, ngf=128, nc=1).to(device)
netD = Critic(nc=1, ndf=128).to(device)

### Training loop

Implement the training loop in the cell below. The recommended hyperparameters:
* Optimizer of the critic:    Adam with learning rate 0.0001
* Optimizer of the generator: Adam with learning rate 0.0001
* Weight $\lambda=10$ of the gradient penalty term in the discriminator loss:
$$
  \min_{D} - E_{x∼P_r}[D(x)] + E_{\tilde x∼P_g}[D(\tilde x)]
  + \lambda (‖\nabla_{\hat{x}} D(\hat{x})‖_2−1)^2
$$

Hints:
- We will use the scorer defined above to assess the quality of the generated samples. The desired level of 0.66 should be reached within 15-20 epochs.
- You can use the following code to track the training progress. The code plots some generated images and computes the score that we use to evaluate the trained model. Note that the images fed to the scorer need to be normalized to be in the range [0, 1].
```
with torch.no_grad():
    # Plot generated images
    z = torch.randn(144, nz, 1, 1, device=device)
    samples = netG(z)
    tools.plot_generated_samples(samples)
    
    # Compute score
    z = torch.randn(1000, nz, 1, 1, device=device)
    samples = netG(z)
    samples = (samples + 1) / 2  # Re-normalize to [0, 1]
    score = scorer(samples)
```
- The quality of the images is slightly worse than with the DCGAN.

In [37]:
import time
if not skip_training:
    # YOUR CODE HERE
    batch_size = iter(trainloader).next()[0].shape[0]
    d_optim = torch.optim.Adam(params=netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
    g_optim = torch.optim.Adam(params=netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    for epoch in range(17):
        start = time.time()
        running_d_loss = []
        running_g_loss = []
        for i, (real_images, labels) in enumerate(trainloader):
            real_images, labels = real_images.to(device), labels.to(device)
            netD.zero_grad()
            netD.train()
            netG.train()
            
            #generate fake images
            z = torch.randn(batch_size, nz, 1, 1, device=device)
            fake_images =netG(z).detach()
            
            #calculate loss
            d_loss = critic_loss(netD, real_images, fake_images)
            panelty = gradient_penalty(netD, real_images, fake_images)
            loss = d_loss + panelty[0]
            
            #critic step
            loss.backward()
            d_optim.step()
            
            #train generator
            g_optim.zero_grad()
            z = torch.randn(batch_size, nz, 1, 1, device=device)
            gen_images = netG(z)
            g_loss = generator_loss(netD, gen_images)
            g_loss.backward()
            g_optim.step()

            running_d_loss.append(d_loss.item())
            running_g_loss.append(g_loss.item())
            if i%100==0: print(i, end=" ")

        end = time.time()
        print(f"{epoch} d_loss:{np.mean(running_d_loss)} g_loss:{np.mean(running_g_loss)} time:{end-start}")
        with torch.no_grad():
            # Plot generated images
            z = torch.randn(144, nz, 1, 1, device=device)
            samples = netG(z)
            tools.plot_generated_samples(samples)

            # Compute score
            z = torch.randn(1000, nz, 1, 1, device=device)
            samples = netG(z)
            samples = (samples + 1) / 2  # Re-normalize to [0, 1]
            score = scorer(samples)
            print(score)
            
    #raise NotImplementedError()

Output hidden; open in https://colab.research.google.com to view.

In [38]:
# Save the model to disk (the pth-files will be submitted automatically together with your notebook)
if not skip_training:
    tools.save_model(netG, '11_wgan_g.pth')
    tools.save_model(netD, '11_wgan_d.pth')
else:
    nz = 10
    netG = Generator(nz=nz, ngf=128, nc=1)
    netD = Critic(nc=1, ndf=128)
    
    tools.load_model(netG, '11_wgan_g.pth', device)
    tools.load_model(netD, '11_wgan_d.pth', device)

Do you want to save the model (type yes to confirm)? yes
Model saved to 11_wgan_g.pth.
Do you want to save the model (type yes to confirm)? yes
Model saved to 11_wgan_d.pth.


In [39]:
# Evaluate generated samples
with torch.no_grad():
    z = torch.randn(2000, nz, 1, 1, device=device)
    samples = (netG(z) + 1) / 2
    score = scorer(samples)

print(f'The trained WGAN-GP achieves a score of {score:.5f}')
assert score >= 0.66, "Poor GAN score! Check your architecture and training."
print('Success')

The trained WGAN-GP achieves a score of 0.79366
Success
