# Wasserstein Loss and Gradient penalty

In the course, we discussed the limitations of the binary cross entropy loss (BCE Loss). It can lead to:
* [mode collapse](https://developers.google.com/machine-learning/gan/problems): when the generator is stuck in a single mode of the distribution it is trying to replicate. For example, a generator training on the MNIST dataset would be stuck in generating images of certain digits only, as shown below.

<img src='assets/collapse.png' width=50% />

* vanishing gradient: it's a common problem with many neural networks architectures but is very common when training GANs. Because the discriminator's task is much easier than the generator's, the discriminator tends to converge faster and reach a high accuracy. The discriminator loss gets close to zero and the gradients become very small, leading to that vanishing gradient problem. 

In this notebook, you will:
* implement the Wasserstein Loss
* implement two types of gradient penalties



## Wasserstein Loss

The [Wasserstein GAN paper](https://arxiv.org/pdf/1701.07875.pdf) introduced a new type of loss function: the [Wasserstein Distance](https://en.wikipedia.org/wiki/Wasserstein_metric). We are now reshaping the problem GANs are solving: instead of having a loss function that classifies a distribution as being real or not, we have a loss function that tries to minimize the distance between the real and the fake distribution. The difference is subtle but plays a big role in the stability of GANs

<img src='assets/gradient_replace.png' width=80% />

The discriminator is now called a **critic** because it's job is not really to distinguish between real and fake anymore but to maximize the distance between the two distributions. However, we will be using both terms interchangeably for the sake of clarity. 

The Wasserstein loss can be calculated using the formula below:

<center>$\min_{g} \max_{c} E(c(x)) - E(c(g(z)))$</center>

You are now familiar with the minimax function. The main difference with the BCE Loss is the disapperance of the logs!

### First exercise: implement the Wasserstein Loss

The Wasserstein Loss (W-Loss) is taking the vector of logits outputed by the discriminator as input. In comparison, the BCE Loss was taking the probabilities (logits after a softmax layer) as inputs. The discriminator W-Loss is trying to maximize the mean value of the logits of real images and minize the mean value of the logits of fake images. The generator W-Loss is trying to maximize the mean value of the logits of fake images.


In [1]:
import torch

import tests

In [2]:
def disc_w_loss(real_logits: torch.Tensor, fake_logits: torch.Tensor):
    """
    Wasserstein Discriminator Loss
    
    args:
    - real_logits: vector of logits outputed by the discriminator with a real input image
    - fake_logits: vector of logits outputed by the discriminator with a fake input image 
    """
    real_loss = -real_logits.mean()
    fake_loss = fake_logits.mean()
    return real_loss + fake_loss

In [3]:
def disc_g_loss(fake_logits: torch.Tensor):
    """
    Wasserstein Generator Loss
    
    args:
    - fake_logits: vector of logits outputed by the discriminator with a fake input image 
    """
    fake_loss = -fake_logits.mean()
    return fake_loss

In [4]:
tests.check_disc_w_loss(disc_w_loss)

Congratulations, you successfully implemented the W-Loss for the discriminator


In [5]:
tests.check_gen_w_loss(disc_g_loss)

Congratulations, you successfully implemented the W-Loss for the generator


## Gradient penalty

To train a GAN with the Wasserstein Loss, the discriminator (or critic) must be [1-Lipschitz continuous](https://en.wikipedia.org/wiki/Lipschitz_continuity). 

The 1-Lipschitz continuity constraint implies that the norm of the gradient of the function must be below 1. In other words, for a function $f(x)$:

</br>
<center>$|| \frac{df}{dx} || < 1$</center>

Because the W-Loss is not bounded between 0 and 1 like the BCE loss, the above constraint makes sure that the loss does not grow too much. 

In the original paper, the authors enforced this condition by using weight clipping. However, per their own words:

```
Weight clipping is a clearly terrible way to enforce a Lipschitz constraint. If the
clipping parameter is large, then it can take a long time for any weights to reach
their limit, thereby making it harder to train the critic till optimality. If the clipping
is small, this can easily lead to vanishing gradients when the number of layers is
big, or batch normalization is not used (such as in RNNs). We experimented with
simple variants (such as projecting the weights to a sphere) with little difference, and
we stuck with weight clipping due to its simplicity and already good performance.
However, we do leave the topic of enforcing Lipschitz constraints in a neural network
setting for further investigation, and we actively encourage interested researchers
to improve on this method.
```

## WGAN-GP
Introducing Wasserstein Gan with Gradient Penalty, or [WGAN-GP](https://arxiv.org/pdf/1704.00028.pdf). In this paper, the author introduce a more robust way to enforce the 1-Lipschitz constaint of the critic: a **gradient penalty term** in the loss function. The new loss function is described below:

<img src='assets/wgan_gp.png' width=80% />


The gradient penalty is calculated as follow:
* sample a random point $\hat{x}$ between the generated distribution and the real distribution. 
* run this sample through the discriminator and calculate the gradient $\nabla_{\hat{x}} D(\hat{x})$
* calculate the L2 norm of the gradient $|| \nabla_{\hat{x}} D(\hat{x}) ||_{2}$
* remove 1, square the result and calculate the mean $(|| \nabla_{\hat{x}} D(\hat{x}) ||_{2} - 1) ^{2}$

### Second exercise: implement the gradient penalty

In the second exercise of this notebook, you will implement the above gradient penalty. To help you, I have created a dummy critic module.

**Tips**:
* to calculate the gradients, you first have to set the attribute of a tensor `requires_grad` to True.
* you can use the following code to calculate the gradients:
```
torch.autograd.grad(critic(x), x, grad_outputs=torch.ones_like(critic(x)), create_graph=True)[0]
```

In [None]:
import torch.nn as nn

In [None]:
class Critic(nn.Module):
    """ 
    Dummy critic class 
    """
    def __init__(self):
        super(Critic, self).__init__()
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.pow(x, 2)

In [None]:
def gradient_penalty(real_sample: torch.Tensor, 
                     fake_sample: torch.Tensor,
                     critic: nn.Module) -> torch.Tensor:
    """
    Gradient penalty of the WGAN-GP model
    args:
    - real_sample: sample from the real dataset
    - fake_sample: generated sample
    
    returns:
    - gradient penalty
    """
    # sample a random point between both distributions
    alpha = torch.rand(real_sample.shape)
    x_hat = alpha * real_sample + (1 - alpha) * fake_sample
    
    # calculate the gradient
    x_hat.requires_grad = True
    pred = critic(x_hat)
    grad = torch.autograd.grad(pred, 
                               x_hat, 
                               grad_outputs=torch.ones_like(pred), 
                               create_graph=True)[0]
    
    # calculate the norm and the final penalty
    norm = torch.norm(grad.view(-1), 2)
    gp = ((norm - 1)**2).mean()    
    return gp

In [None]:
real_sample = torch.randn(3, 32, 32)
fake_sample = torch.randn(3, 32, 32)
critic = Critic()

gradient_penalty = gradient_penalty(real_sample, fake_sample, critic)

## DRAGAN

The [DRAGAN paper](https://arxiv.org/pdf/1705.07215.pdf) offered a different approach to calculate the gradient penalty and enforce the 1-Lipschitz constraint on the critic.

<img src='assets/dragan_gp.png' width=60% />

As you can see, the formula is very similar, especially since the authors use $k = 1$ for their experiments. The main difference with the WGAN-GP gradient penalty is the $\delta$ term, which is a noise term. In their implementation, the authors calculate $X_{p} = X + \delta $ as follow:

<center>
    $X_{p} = X + 0.5 * \sigma({X}) * N$ 
</center>

where $\sigma$ is the standard deviation and $N$ a noise term sampled from the uniform distribution.

The gradient penalty is then calculated as follow:
* sample a random point $\hat{x}$ between the real distribution $X$ and $X_{p}$ . 
* run this sample through the discriminator and calculate the gradient $\nabla_{\hat{x}} D(\hat{x})$
* calculate the L2 norm of the gradient $|| \nabla_{\hat{x}} D(\hat{x}) ||_{2}$
* remove 1, square the result and calculate the mean $(|| \nabla_{\hat{x}} D(\hat{x}) ||_{2} - 1) ^{2}$


### BCE Loss
Interestingly, using this gradient penalty lifts some of the constraint on the BCE Loss and the author use the above gradient penalty with the vanilla GAN losses (BCE Loss).

### Third exercise: implement the DRAGAN gradient penalty

In the third exercise of this notebook, you will implement the DRAGAN gradient penalty. This is a one liner difference with the above implementation!

In [None]:
def gradient_penalty_dragan(real_sample: torch.Tensor, critic: nn.Module) -> torch.Tensor:
    """
    Gradient penalty of the WGAN-GP model
    args:
    - real_sample: sample from the real dataset
    
    returns:
    - gradient penalty
    """
    # sample a random point between both distributions
    X_p = real_sample + 0.5 * real_sample.std() * torch.rand_like(real_sample)
    
    alpha = torch.rand(real_sample.shape)
    x_hat = alpha * real_sample + (1 - alpha) * X_p
    
    # calculate the gradient
    x_hat.requires_grad = True
    pred = critic(x_hat)
    grad = torch.autograd.grad(pred, 
                               x_hat, 
                               grad_outputs=torch.ones_like(pred), 
                               create_graph=True)[0]
    
    # calculate the norm and the final penalty
    norm = torch.norm(grad.view(-1), 2)
    gp = ((norm - 1)**2).mean()
    return gp

In [None]:
dragan_gp = gradient_penalty_dragan(real_sample, critic)

# WARNING

The gradient penalty terms penalize each input to the critic individually. Therefore, the critic should a single input to a single output. However, we use some layers in the discriminator that remove this property: the BatchNormalization layers. The authors of the WGAN-GP paper explain the following:

```
No critic batch normalization Most prior GAN implementations use batch normalization in both the generator and the discriminator to help stabilize training, but batch normalization
changes the form of the discriminator’s problem from mapping a single input to a single output to
mapping from an entire batch of inputs to a batch of outputs . Our penalized training objective
is no longer valid in this setting, since we penalize the norm of the critic’s gradient with respect
to each input independently, and not the entire batch. To resolve this, we simply omit batch normalization in the critic in our models, finding that they perform well without it. Our method works
with normalization schemes which don’t introduce correlations between examples. In particular, we
recommend layer normalization as a drop-in replacement for batch normalization.
```

Keep this in mind if you decide to use the gradient penalty in your project! 