# Wasserstein GANs

## Motivation


## Modes

Any peak on the densitiy function of the distribution of features is a mode. This is the case for mose real-world dataset. For example hand written digits having two features, we see modes for each digit (see VAE notebook).

**Mode collapse** happens when the generator learns to fool the discriminator by producing examples from a single class of **one mode** from the whole training dataset. This is unfortunate because, while the generator is optimizing to fool the discriminator, that's not what you ultimately want your generator to do.

### Example:

Take a discriminator that has learned to be good at identifying which handwritten digits are fakes, except for cases where the generated images look like ones and sevens. This could mean the **discriminator** is at of **local minima** of its cost function. 

The discriminator classifies most of the digits correctly, except for the ones that resembled those ones and sevens, then this information is passed on to the generator. The generator gets feedback from the discriminator and gets a good idea of how to fool the discriminator in the next round. It learns that all images of a one or a seven were misclassified by the discriminator, so it generates a lot of pictures that resemble either of those numbers. 

Then these generated images are then passed on to the discriminator in the next round who then misclassifies every picture except for maybe the one felt looks more like a seven. Generator gets that feedback and sees that the discriminator's weakness is with the pictures that resembled a handwritten one, so this time all the pictures it produces resembled that digit, **collapsing to a single mode** for the whole distribution of possible handwritten digits.

## Wasserstein Loss

### Problem with BCE Loss

The discriminator does not output useful gradients (feedback) for the generator when the real/fake distributions are far apart. This is also called the vanishing gradient problem because the gradients approach 0 when the distributions are far apart.

A loss which mitigates this is **Wasserstein Loss** (W-Loss) which uses the **Earth mover’s distance** (EMD). Earth mover’s distance is a measure of how different two distributions are by estimating the effort it takes to make the generated distribution equal to the real one.

The **discriminator** outputs between 0 and 1:

$$
\underset{d}{\min} \underset{g}{\max} - \big( \mathbb{E}(\log(d(x))) + \mathbb{E}(1 - \log(d(g(z)))) \big)
$$

### W-Loss

When using W-Loss the term discriminator is replaces by the term **critic** as the model does not discriminate between $0$ and $1$, but outputs any number.

W-Loss approximates the Earth Mover's Distance.

The **critic** outputs **any number**:

$$
\underset{d}{\min} \underset{c}{\max} \mathbb{E}(c\,(x)) + \mathbb{E}(c\,(g(z)))
$$

So, **W-Loss** looks similar to BCS-Loss, but prevents **mode collapse** and the **vanishing gradients** problem.

### Condition on Wasserstein Critic

The **critics** network needs to be **1-Lipschitz Continuous** when using W-Loss. 1-L Continuous demands that the slope of every point in a function is less than or equal to one, i.e. when the **gradients L2 norm** is is less then or equal to 1 at all points:

$$
|| \, \nabla c(x) \, ||_2 \leq 1
$$

In case of a neural network this requires that its gradient is always less than or equal to one. This conditions ensures that W-Loss is validly approximating the EMD.

### 1-Lipschitz Continuity Enforcement

There are two ways to enforce 1-L Continuity: **Weight Clipping** and **Gradient Penalty**.

**Weight Clipping** forces the weights of the critic to a fixed interval by clipping the values outside this interval.

A more effective method is **Gradient Penalty** which adds a regularization term:

$$
\underset{d}{\min} \underset{c}{\max} \mathbb{E}(c\,(x)) + \mathbb{E}(c\,(g(z))) + \lambda \, \mathbb{E}(|| \, \nabla c(\hat{x}) \, ||_2 - 1)^2
$$

with

$$
\hat{x} = \epsilon x + (1 - \epsilon) g(z)
$$

being an **interpolation** of the fake and the real image using a hyper parameter $\epsilon$. Using the gradients on an intermediate image with respect to the critic is an approximation for enforcing the gradient norm to be 1 almost everywhere. Since checking the critic’s gradient at each possible point of the feature space is virtually impossible, you can approximate this by using interpolated images.

The first term makes the GAN less prone to **mode collapsing** and **vanishing gradient**. The second term tries to make the **critic** be 1-L Continuous, for that the loss function to be continuous and differntiable.