In [None]:
%tensorflow_version 1.x

%cd /content
! rm -rf gan-tools
!git clone --single-branch --depth=1 --branch master https://github.com/hannesdm/gan-tools.git
%cd gan-tools
from keras.datasets import mnist
from keras import initializers
from keras import Sequential
from keras.layers import  Dense
from core.gan import GAN, WGAN
from core import constraint
import matplotlib.pyplot as plt
plt.rcParams['image.cmap'] = 'gray'
plt.rcParams['axes.grid'] = False


## Wasserstein GAN
We will train a standard GAN and a Wasserstein GAN on the mnist data. <br/>
Both variants have a relatively simple fully connected architecture to allow for fast training. This will inevitable produce worse results than larger or specialized models (cfr. DCGAN). <br/>
The Wasserstein GAN implementation follows the paper of Arjovsky et al. <br/>
You are encouraged to change the parameters and architecture of the model. If you do, do **not** change the **input_dim**,  **final layer** and **wasserstein_params** for the weight clipping version. <br/>
**Exercise** Compare the performance of the different GAN formulations over
the different iterations, do you see an improvement in stability and quality of the generated samples? <br/>Elaborate based on
the knowledge you have gained about optimal transport and the Wasserstein distance.

In [None]:
(X_train_mnist, Y_train_mnist), (_, _) = mnist.load_data()
X_train_mnist = X_train_mnist.reshape((-1, 28*28))
X_train_mnist = X_train_mnist.astype('float32') / 127.5 - 1

In [None]:
kernel_initializer = initializers.RandomNormal(stddev=0.02)
weight_clipping = constraint.WeightClipping(c1=-0.01, c2=0.01)
wasserstein_params = {
    'kernel_initializer': kernel_initializer,
    'kernel_constraint': weight_clipping, 'bias_constraint': weight_clipping}

def mnist_generator_model():
  generator = Sequential()
  generator.add(Dense(256, input_dim=100, activation='relu'))
  generator.add(Dense(256, activation='relu'))
  generator.add(Dense(784, activation='tanh'))
  return generator

def mnist_discriminator_model():
  discriminator = Sequential()
  discriminator.add(Dense(256, input_dim=784, activation='relu'))
  discriminator.add(Dense(256, activation='relu'))
  discriminator.add(Dense(1, activation='sigmoid'))
  return discriminator

def mnist_wgan_discriminator_model():
  discriminator = Sequential()
  discriminator.add(Dense(256, input_dim=784, activation='relu'))
  discriminator.add(Dense(256, activation='relu'))
  discriminator.add(Dense(1, activation='linear'))
  return discriminator

def mnist_weight_clipping_discriminator_model():
  discriminator = Sequential()
  discriminator.add(Dense(256, input_dim=784, activation='relu', **wasserstein_params))
  discriminator.add(Dense(256, activation='relu', **wasserstein_params))
  discriminator.add(Dense(1, activation='linear', **wasserstein_params))
  return discriminator

## Train the standard GAN
The parameters **batches**, **batch_size** and **plot_interval** may be changed if wanted. <br/>
Remember that the execution may be interrupted at any time by clicking the stop button or by selecting the 'interrupt execution' option in the runtime menu.

In [None]:
mnist_gan = GAN(discriminator=mnist_discriminator_model(), generator=mnist_generator_model())
mnist_gan.train_random_batches(X_train_mnist, batches = 20000, batch_size=64, plot_interval = 500, image_shape=(28,28))

## Train the Wasserstein GAN with Weight Clipping
The Discriminator in the GAN framework now performs the role of a critic, instead of a classifier.
The original way to enforce the Lipschitz constraint on the critic is by making sure the weights lie in a compact space. This can be done by clipping the weights after each gradient update. This implementation follows the work by Arjovsky et al. See https://arxiv.org/pdf/1701.07875.pdf  <br/>

The parameters **batches**, **batch_size** and **plot_interval** may be changed if wanted. <br/>
Remember that the execution may be interrupted at any time by clicking the stop button or by selecting the 'interrupt execution' option in the runtime menu.

In [None]:
mnist_wgan = GAN(discriminator=mnist_weight_clipping_discriminator_model(), generator=mnist_generator_model(), gen_loss='wasserstein', dis_loss='wasserstein')
mnist_wgan.train_random_batches(X_train_mnist, batches=20000, batch_size=64,
                                plot_interval = 500, image_shape=(28,28), nr_train_discriminator=5)

## Wasserstein GAN with Gradient Penalty
A more natural way of enforcing the Lipschitz constraint in the Wasserstein GAN formulation is by penalizing the norm of the gradient of the critic. This implementation follows the work of Gulrajani et al. See https://arxiv.org/pdf/1704.00028.pdf <br/>
The parameters **batches**, **batch_size** and **plot_interval** may be changed if wanted. <br/>
Remember that the execution may be interrupted at any time by clicking the stop button or by selecting the 'interrupt execution' option in the runtime menu.

In [None]:
mnist_wgan_gp = WGAN(discriminator=mnist_wgan_discriminator_model(), generator=mnist_generator_model(), gen_loss='wasserstein', dis_loss = 'wasserstein')
mnist_wgan_gp.train_random_batches(X_train_mnist, batches=20000, batch_size=64,
                                plot_interval = 500, image_shape=(28,28), nr_train_discriminator=5)