## What Is a Wasserstein GAN?

It is an extension of the GAN that seeks an alternate way of training the generator model to better approximate the distribution of data observed in a given training dataset.

### Wasserstein GAN, or WGAN, is a type of generative adversarial network that minimizes an approximation of the Earth-Mover's distance (EM) rather than the Jensen-Shannon divergence as in the original GAN formulation.

#### Here in WGAN, the discriminator does not actually classify instances. Rather here for each instance the Discriminator outputs a number. This number does not have to be less than one or greater than 0, so we can't use 0.5 as a threshold to decide whether an instance is real or fake. Discriminator training just tries to make the output bigger for real instances than for fake instances.

Instead of using a discriminator to classify or predict the probability of generated images as being real or fake, the WGAN changes or replaces the discriminator model with a critic that scores the realness or fakeness of a given image.

This change is motivated by a mathematical argument that training the generator should seek a minimization of the distance between the distribution of the data observed in the training dataset and the distribution observed in generated examples. The argument contrasts different distribution distance measures, such as Kullback-Leibler (KL) divergence, Jensen-Shannon (JS) divergence, and the Earth-Mover (EM) distance, referred to as Wasserstein distance.

---


The idea for the working of WGANs is to utilize two probability distributions. One is the probability distribution of the generator (Pg), which refers to the distribution from the output of the generator model.

The other is the probability distribution from the real images (Pr).

And the objective of WGAN is to ensure that both these probability distributions are close to each other so that the output generated is highly realistic and high-quality.

For calculating the distance of these probability distributions, mathematical statistics in machine learning proposes three primary methods, namely

- Kullback–Leibler divergence,
- Jensen–Shannon divergence, and
- Wasserstein distance.

The Jensen–Shannon divergence (also a typical GAN loss) is the more utilized mechanism in simple GAN networks.

#### But in WGAN, we use the Wasserstein distance (a.k.a Earth Mover’s Distance) instead of Jensen-Shannon Divergence to compare probability distributions.

**The benefit of the WGAN is that the training process is more stable and less sensitive to model architecture and choice of hyperparameter configurations.**

---

## Compared to the original GAN algorithm, the WGAN undertakes the following changes:

* After every gradient update on the critic function, clamp the weights to a small fixed range, [-c, c].

* Use a new loss function derived from the Wasserstein distance, no logarithm anymore. The “discriminator” model does not play as a direct critic but a helper for estimating the Wasserstein metric between real and generated data distribution.

* Empirically the authors recommended RMSProp optimizer on the critic, rather than a momentum based optimizer such as Adam which could cause instability in the model training.

---


## Key Points in WGAN

![Imgur](https://imgur.com/cWROjs7.png)

### 1. Critic Weight Clipping

The critic F has to be a 1-Lipschitz function. To enforce the constraint, WGAN applies a very simple clipping to restrict the maximum weight value in F,

i.e. the weights of the discriminator must be within a certain range controlled by the hyperparameters c

### 2. Update Critic More Than Generator

In the DCGAN, the generator and the discriminator model must be updated in equal amounts.

Specifically, the discriminator is updated with a half batch of real and a half batch of fake samples each iteration, whereas the generator is updated with a single batch of generated samples.

In the WGAN model, the critic model must be updated more than the generator model.

Specifically, a new hyperparameter is defined to control the number of times that the critic is updated for each update to the generator model, called n_critic.

### 3. Use RMSProp Stochastic Gradient Descent

The DCGAN uses the Adam version of stochastic gradient descent with a small learning rate and modest momentum.

The WGAN recommends the use of Root Mean Square Propagation or RMSProp instead (which is one of the Adaptive Learning Rate Gradient Descent), with a small learning rate of 0.00005.

---

## The loss function for WGAN

#### First, for a Normal GAN (e.g. DCGAN) the Loss definition is ;

Critic Loss: D(x) - D(G(z))

Where,
- D(x) is the discriminator's estimate of the probability that real data instance x is real.
- G(z) is the generator's output when given noise z.
- D(G(z)) is the discriminator's estimate of the probability that a fake instance is real.


#### Now for WGAN the Loss is defined as:

#### Critic Loss = [average critic score on real images] – [average critic score on fake images]

Critic Loss: D(x) - D(G(z))

In WGAN, the Discriminator, does not produce a Probability, rather it produces a pure score.

Where,
    - D(x) is the critic's output for a real instance.
    - G(z) is the generator's output when given noise z.
    - D(G(z)) is the critic's output for a fake instance.

The output of critic D does not have to be between 1 and 0.

- The discriminator tries to maximize this function. In other words, it tries to maximize the difference between its output on real instances and its output on fake instances.

- So, when compared to the Normal GAN's Discriminator, the Discriminator in WGAN, we do NOT classify or predict the probability of generated images as being real or fake. Instead, the WGAN replaces the discriminator model with a critic that scores the realness or fakeness of a given image.

- It does this by removing the last Sigmoid() layer and have a linear layer at the end of the discriminator’s neural network.

#### Generator Loss = -[average critic score on fake images]

Generator Loss: D(G(z))

The generator tries to maximize this function. In other words, It tries to maximize the discriminator's output for its fake instances. In these functions:



---

### Implementing Wasserstein Loss

1.  Use a linear activation function in the output layer of the critic model (instead of sigmoid).

2. Use Wasserstein loss to train the critic and generator models that promote larger difference between scores for real and generated images.

3. Constrain critic model weights to a limited range after each mini batch update (e.g. [-0.01,0.01]).

4. In order to have parameters w lie in a compact space, something simple we can do is clamp the weights to a fixed box (say W = [−0.01, 0.01]l ) after each gradient update.


We can summarize the function as it is described in the paper as follows:


#### Critic Loss = [average critic score on real images] – [average critic score on fake images]

#### Generator Loss = -[average critic score on fake images]

Where the average scores are calculated across a mini-batch of samples.

The calculations are straightforward to interpret once we recall that stochastic gradient descent seeks to minimize loss.

#### In the case of the generator, a larger score from the critic will result in a smaller loss for the generator, encouraging the critic to output larger scores for fake images. For example, an average score of 10 becomes -10, an average score of 50 becomes -50, which is smaller, and so on.

#### In the case of the critic, a larger score for real images results in a larger resulting loss for the critic, penalizing the model. This encourages the critic to output smaller scores for real images. For example, an average score of 20 for real images and 50 for fake images results in a loss of -30; an average score of 10 for real images and 50 for fake images results in a loss of -40, which is better, and so on.

#### The sign of the loss does not matter in this case, as long as loss for real images is a small number and the loss for fake images is a large number. The Wasserstein loss encourages the critic to separate these numbers.

#### We can also reverse the situation and encourage the critic to output a large score for real images and a small score for fake images and achieve the same result.

---

### Main Equation

The network uses Earth Mover’s Distance instead of Jensen-Shannon Divergence to compare probability distributions.

![Imgur](https://imgur.com/EJg4nHM.png)

In the above equation, the max value represents the constraint on the discriminator. In the WGAN architecture, the discriminator is referred to as the critic. One of the reasons for this convention is that there is no sigmoid activation function to limit the values to 0 or 1, which means real or fake. So the discriminator in WGAN, outputs a scalar score rather than a probability.

The first part of the equation represents the real data, while the second half represents the generator data. The discriminator (or the critic) in the above equation aims to maximize the distance between the real data and the generated data, because it wants to be able to successfully distinguish the data accordingly.

The generator network aims to minimize the distance between the real data and generated data because it wants the generated data to be as real as possible.

---

## Jensen Shannon Divergence (JSD)


The objective function of our original GAN is essentially the minimization of something called the Jensen Shannon Divergence (JSD). Specifically it is:

![Imgur](https://imgur.com/kYc2Cfv.png)

---

Sadly, Wasserstein GAN is not perfect. Even the authors of the original WGAN paper mentioned that “Weight clipping is a clearly terrible way to enforce a Lipschitz constraint” (Oops!). WGAN still suffers from unstable training, slow convergence after weight clipping (when clipping window is too large), and vanishing gradients (when clipping window is too small).

Some improvement, precisely replacing weight clipping with gradient penalty is one of the most prominent solution that has been proposed.

## Implementation from scratch

In [6]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [7]:
import os

import torchvision.transforms as transforms
from torchvision.utils import make_grid

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure

from tqdm import tqdm 

plt.ion()
from IPython.display import clear_output

In [8]:
!nvidia-smi

Tue Mar 22 20:37:53 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## HYPERPARAMETERS

In [None]:
class Hyperparameters(object):
      def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

hp = Hyperparameters(n_epochs=200,
                     batch_size=64,
                     lr=0.00005,                     
                     n_cpu=8,
                     latent_dim=100,
                     img_size=32,
                     channels=1,
                     n_critic=25,
                     clip_value=.005,
                     sample_interval=400)

print(hp.lr)

In [10]:
root_path = '/content/drive/MyDrive/All_Datasets/Fashion_MNIST'
''' The Fashion-MNIST dataset contains 60,000 training images (and 10,000 test images) of fashion and clothing items, taken from 10 classes. Each image is a standardized 28×28 size in grayscale (784 total pixels). '''

dataloader = torch.utils.data.DataLoader(
  datasets.FashionMNIST(
    root_path,
    train=True,
    download=True,
    transform=transforms.Compose(
        [transforms.Resize(hp.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
    ),
  ),
  batch_size=hp.batch_size,
  shuffle=True,
)

SETUP

In [11]:
# os.makedirs("images", exist_ok=True)
img_shape = (hp.channels, hp.img_size, hp.img_size)

cuda = True if torch.cuda.is_available() else False

def weights_init_normal(m):
  classname = m.__class__.__name__
  if classname.find("Conv") != -1:
    torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
  elif classname.find("BatchNorm2d") != -1:
    torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
    torch.nn.init.constant_(m.bias.data, 0.0)

def to_img(x):
  x = x.clamp(0, 1)
  return x

def visualise_output(images, x, y):
  with torch.no_grad():  
    images = images.cpu()
    images = to_img(images)
    np_imagegrid = make_grid(images, x, y).numpy()
    figure(figsize=(20,20))
    plt.imshow(np.transpose(np_imagegrid, (1, 2, 0)))
    plt.show()

## GENERATOR

The generator model takes as input a point in the latent space and outputs a single 28×28 grayscale image.

This is achieved by using a fully connected layer to interpret the point in the latent space. 

This is then upsampled couple of more times, doubling the size.

### np.prod()

Return the product of array elements over a given axis.

If the input array is blank, then this method returns the neutral element: 1

By default, the axis is set to None, thereby calculating the product of all the elements in the given array. 

In [12]:
class Generator(nn.Module):
    def __init__(self, img_shape, latent_dim):
        super(Generator, self).__init__()

        def block(in_features, out_features, normalize=True):
            layers = [nn.Linear(in_features, out_features)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_features, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(
                in_features=latent_dim, out_features=128, normalize=False
            ),  # Batch_size, 784 -> Batch_size, 128
            *block(
                in_features=128, out_features=256
            ),  # Batch_size, 128 -> Batch_size, 256
            *block(
                in_features=256, out_features=512
            ),  # Batch_size, 256 -> Batch_size, 512
            *block(
                in_features=512, out_features=1024
            ),  # Batch_size, 512 -> Batch_size, 1024
            nn.Linear(
                in_features=1024, out_features=int(np.prod(img_shape))
            ),  # Batch_size, 1024 -> Batch_size, np.prod(img_shape)
            nn.Tanh()
        )

    def forward(self, img_shape, z):
        img = self.model(z)
        img = img.view(img.shape[0], *img_shape)
        return img

## DISCRIMINATOR

One of the reasons for this convention is that there is no sigmoid activation function to limit the values to 0 or 1, which means real or fake. So the discriminator in WGAN, outputs a scalar score rather than a probability.

In [13]:
class Critic(nn.Module):
    def __init__(self, img_shape):
        super(Critic, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(
                in_features=int(np.prod(img_shape)), out_features=512
            ),  # Batch_size, np.prod(img_shape) -> Batch_size, 512
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(
                in_features=512, out_features=256
            ),  # Batch_size, 512 -> Batch_size, 256
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(
                in_features=256, out_features=1
            ),  # Batch_size, 256 -> Batch_size, 1
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity

LOSS and MODELS

In [14]:
generator = Generator(img_shape, hp.latent_dim)
critic = Critic(img_shape)

if cuda:
  generator.cuda()
  critic.cuda()  

# Initialize weights
generator.apply(weights_init_normal)
critic.apply(weights_init_normal)

Critic(
  (model): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
    (4): Linear(in_features=256, out_features=1, bias=True)
  )
)

OPTIMIZERS and TENSOR SETUP

In [15]:
optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=hp.lr)
optimizer_D = torch.optim.RMSprop(critic.parameters(), lr=hp.lr)

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

## TRAINING STEPS

1. The critic network is first trained on a real batch of data, then trained on a batch of data generated from a noise-prior via the generator. 

2. The critic's loss function is arranged such that it estimates the Wasserstein Distance (maximizes the distance between the two distributions) then clips its own weights to ensure it is 1-Lipschitz-Continuous. 

3. Then, the generator generates a new batch of images from a noise prior, passes these through to the critic who then "informs" the generator of the Wasserstein-1 distance between the true distribution and the distribution of the images the Generator just created. 

4. It does this via the loss function of the critic. The critic's weights are frozen and the error propagates all the way back through to the generator who then updates its parameters to minimize the Wasserstein distance. 

5. This repeats until the loss (hopefully) converges to near zero and the distributions are approximately equal.

6. The discriminator loss is (an approximation of) the negative Wasserstein distance between the generator distribution and the data distribution.

In [None]:
for epoch in range(hp.n_epochs):
  for i, (imgs, _) in enumerate(dataloader):

      # Adversarial ground truths
      valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
      fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

      # Configure input
      real_imgs = Variable(imgs.type(Tensor))

      # -----------------
      #  Train Critic
      # -----------------

      optimizer_G.zero_grad()

      # Sample noise as generator input
      # Draw random samples from a normal (Gaussian) distribution.
      # np.random.normal(mean, sd, Output shape)
      z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], hp.latent_dim))))

      # Generate a batch of images
      fake_imgs = generator(z).detach()

      ''' The math for the loss functions for the critic and generator is:
        Critic Loss: D(x) - D(G(z))
        Generator Loss: D(G(z))
        Now for the Critic Loss, as per the Paper, we have to maximize the expression.
        So, arithmetically, maximizing an expression, means minimizing the -ve of that expression
        i.e. -(D(x) - D(G(z))) which is -D(x) + D(G(z)) i.e. -D(real_imgs) + D(G(real_imgs))
     '''
      d_loss = -torch.mean(critic(real_imgs)) + torch.mean(critic(fake_imgs)) 

      d_loss.backward()
      optimizer_D.step()

      
      ''' Clip weights of critic to avoid vanishing/exploding gradients in the 
      critic/critic. 
      In order to have parameters w lie in a compact space, something simple we can do is clamp the weights to a fixed box (say W = [-0.005, 0.005]l ) after each gradient update.
      
      torch.clamp() is used to clamp all the elements in an input into the range [min, max]. It takes three parameters: the input tensor, min, and max values. The values less than the min are replaced by the min and the values greater than the max are replaced by the max. If min is not given, then there is no lower bound. If max is not given, then there is no upper bound.  '''
      for p in critic.parameters():
        p.data.clamp_(-hp.clip_value, hp.clip_value)


      ''' Train the generator every n_critic iterations 
      we need to increase training iterations of the critic so that it works to 
      approximate the real distribution sooner.
      '''
      if i % hp.n_critic == 0:
        # ---------------------
        #  Train Generator
        # ---------------------
        optimizer_G.zero_grad()

        # Generate a batch of images
        fake_images_from_generator = generator(z)
        # Adversarial loss
        g_loss = -torch.mean(critic(fake_images_from_generator))

        g_loss.backward()
        optimizer_G.step()    

      batches_done = epoch * len(dataloader) + i
      if batches_done % hp.sample_interval == 0:
        clear_output()
        print(f"Epoch:{epoch}:It{i}:DLoss{d_loss.item()}:GLoss{g_loss.item()}")          
        visualise_output(fake_images_from_generator.data[:50],10, 10)