# WGAN on MNIST

In [1]:
import os
import numpy as np
import h5py
from tqdm import tqdm
from PIL import Image
from keras.models import Model, Sequential
from keras.layers import Input, Dense, Reshape, Flatten
from keras.layers.merge import _Merge
from keras.layers.convolutional import Convolution2D, Conv2DTranspose
from keras.layers.normalization import BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import Adam
from keras.datasets import mnist
from keras import backend as K
from functools import partial

Using TensorFlow backend.


## Define Constants

In [2]:
BATCH_SIZE = 64
TRAINING_RATIO = 5  # The training ratio is the number of discriminator updates per generator update. The paper uses 5.
GRADIENT_PENALTY_WEIGHT = 10  # As per the paper
output_dir = "../Data/output/MNIST_WGAN/"

## Define Loss

In [3]:
def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true * y_pred)


def gradient_penalty_loss(y_true, y_pred, averaged_samples, gradient_penalty_weight):
    gradients = K.gradients(K.sum(y_pred), averaged_samples)
    gradient_l2_norm = K.sqrt(K.sum(K.square(gradients)))
    gradient_penalty = gradient_penalty_weight * K.square(1 - gradient_l2_norm)
    return gradient_penalty


class RandomWeightedAverage(_Merge):
    """Takes a randomly-weighted average of two tensors. In geometric terms, this outputs a random point on the line
    between each pair of input points.

    Inheriting from _Merge is a little messy but it was the quickest solution I could think of.
    Improvements appreciated."""

    def _merge_function(self, inputs):
        weights = K.random_uniform((BATCH_SIZE, 1, 1, 1))
        return (weights * inputs[0]) + ((1 - weights) * inputs[1])

## Generator Architecture

In [4]:
def make_generator():
    model = Sequential()
    model.add(Dense(1024, input_dim=100))
    model.add(LeakyReLU())
    model.add(Dense(128 * 7 * 7))
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    model.add(Reshape((7, 7, 128), input_shape=(128 * 7 * 7,)))
    bn_axis = -1
    model.add(Conv2DTranspose(128, (5, 5), strides=2, padding='same'))
    model.add(BatchNormalization(axis=bn_axis))
    model.add(LeakyReLU())
    model.add(Convolution2D(64, (5, 5), padding='same'))
    model.add(BatchNormalization(axis=bn_axis))
    model.add(LeakyReLU())
    model.add(Conv2DTranspose(64, (5, 5), strides=2, padding='same'))
    model.add(BatchNormalization(axis=bn_axis))
    model.add(LeakyReLU())
    model.add(Convolution2D(1, (5, 5), padding='same', activation='tanh'))
    return model

generator = make_generator()
generator.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_1 (Dense)              (None, 1024)              103424    
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 1024)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 6272)              6428800   
_________________________________________________________________
batch_normalization_1 (Batch (None, 6272)              25088     
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 6272)              0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 7, 7, 128)         0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 14, 14, 128)       409728    
__________

## Discriminator Architecture

In [5]:
def make_discriminator():
    model = Sequential()
    model.add(Convolution2D(64, (5, 5), padding='same', input_shape=(28, 28, 1)))
    model.add(LeakyReLU())
    model.add(Convolution2D(128, (5, 5), kernel_initializer='he_normal', strides=[2, 2]))
    model.add(LeakyReLU())
    model.add(Convolution2D(128, (5, 5), kernel_initializer='he_normal', padding='same', strides=[2, 2]))
    model.add(LeakyReLU())
    model.add(Flatten())
    model.add(Dense(1024, kernel_initializer='he_normal'))
    model.add(LeakyReLU())
    model.add(Dense(1, kernel_initializer='he_normal'))
    return model

discriminator = make_discriminator()
discriminator.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_3 (Conv2D)            (None, 28, 28, 64)        1664      
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 12, 12, 128)       204928    
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 12, 12, 128)       0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 6, 6, 128)         409728    
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU)    (None, 6, 6, 128)         0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 4608)              0         
__________

## Helper functions for generating images

In [6]:
def tile_images(image_stack):
    """Given a stacked tensor of images, reshapes them into a horizontal tiling for display."""
    assert len(image_stack.shape) == 3
    image_list = [image_stack[i, :, :] for i in range(image_stack.shape[0])]
    tiled_images = np.concatenate(image_list, axis=1)
    return tiled_images

def generate_images(generator_model, output_dir, epoch):
    """Feeds random seeds into the generator and tiles and saves the output to a PNG file."""
    test_image_stack = generator_model.predict(np.random.rand(10, 100))
    test_image_stack = (test_image_stack * 127.5) + 127.5
    test_image_stack = np.squeeze(np.round(test_image_stack).astype(np.uint8))
    tiled_output = tile_images(test_image_stack)
    tiled_output = Image.fromarray(tiled_output, mode='L')  # L specifies greyscale
    outfile = os.path.join(output_dir, 'epoch_{}.png'.format(epoch))
    tiled_output.save(outfile)

## Load Data

In [7]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = np.concatenate((X_train, X_test), axis=0)
X_train = X_train.reshape((X_train.shape[0], X_train.shape[1], X_train.shape[2], 1))
X_train = (X_train.astype(np.float32) - 127.5) / 127.5

## Compiling the Generator Model

In [8]:
for layer in discriminator.layers:
    layer.trainable = False
discriminator.trainable = False
generator_input = Input(shape=(100,))
generator_layers = generator(generator_input)
discriminator_layers_for_generator = discriminator(generator_layers)
generator_model = Model(inputs=[generator_input], outputs=[discriminator_layers_for_generator])
# We use the Adam paramaters from Gulrajani et al.
generator_model.compile(optimizer=Adam(0.0001, beta_1=0.5, beta_2=0.9), loss=wasserstein_loss)

## Compiling the Discriminator Model

In [None]:
for layer in discriminator.layers:
    layer.trainable = True
for layer in generator.layers:
    layer.trainable = False
discriminator.trainable = True
generator.trainable = False

real_samples = Input(shape=X_train.shape[1:])
generator_input_for_discriminator = Input(shape=(100,))
generated_samples_for_discriminator = generator(generator_input_for_discriminator)
discriminator_output_from_generator = discriminator(generated_samples_for_discriminator)
discriminator_output_from_real_samples = discriminator(real_samples)

averaged_samples = RandomWeightedAverage()([real_samples, generated_samples_for_discriminator])
averaged_samples_out = discriminator(averaged_samples)
partial_gp_loss = partial(gradient_penalty_loss,averaged_samples=averaged_samples,gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT)
partial_gp_loss.__name__ = 'gradient_penalty'

discriminator_model = Model(inputs=[real_samples, generator_input_for_discriminator],
                            outputs=[discriminator_output_from_real_samples,
                                     discriminator_output_from_generator,
                                     averaged_samples_out])
discriminator_model.compile(optimizer=Adam(0.0001, beta_1=0.5, beta_2=0.9),loss=[wasserstein_loss,wasserstein_loss,partial_gp_loss])

positive_y = np.ones((BATCH_SIZE, 1), dtype=np.float32)
negative_y = -positive_y
dummy_y = np.zeros((BATCH_SIZE, 1), dtype=np.float32)

## Training the model

In [None]:
for epoch in range(100):
    np.random.shuffle(X_train)
    discriminator_loss = []
    generator_loss = []
    minibatches_size = BATCH_SIZE * TRAINING_RATIO
    num_minibatches = int(X_train.shape[0] // (BATCH_SIZE * TRAINING_RATIO))
    print("Epoch {}:".format(epoch+1))
    for i in tqdm(range(num_minibatches),total=num_minibatches):
        discriminator_minibatches = X_train[i * minibatches_size:(i + 1) * minibatches_size]
        for j in range(TRAINING_RATIO):
            image_batch = discriminator_minibatches[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
            noise = np.random.rand(BATCH_SIZE, 100).astype(np.float32)
            discriminator_loss.append(discriminator_model.train_on_batch([image_batch, noise],[positive_y, negative_y, dummy_y]))
        generator_loss.append(generator_model.train_on_batch(np.random.rand(BATCH_SIZE, 100), positive_y))
    generate_images(generator, output_dir, epoch+1)
    generator_loss_epoch = np.average(generator_loss[-(num_minibatches):],axis=0)
    discriminator_loss_epoch = np.average(discriminator_loss[-(num_minibatches*TRAINING_RATIO):],axis=0)
    print("Generator Loss: {}".format(generator_loss_epoch))
    print("Discriminator Loss: {}".format(discriminator_loss_epoch))
    generator_model.save("../Models/MNIST_WGAN/G-epoch-{:3d}.hdf5".format(epoch+1))
    discriminator_model.save("../Models/MNIST_WGAN/D-epoch-{:3d}.hdf5".format(epoch+1))

Epoch 1:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.44it/s]


Generator Loss: -0.5396284461021423
Discriminator Loss: [-0.61788917 -1.24658453  0.56703389  0.06166336]
Epoch 2:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.44it/s]


Generator Loss: -0.5013314485549927
Discriminator Loss: [-0.38489708 -0.91737449  0.50903988  0.02343691]
Epoch 3:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.44it/s]


Generator Loss: -0.450283408164978
Discriminator Loss: [-0.36606601 -0.86022884  0.47420239  0.01996152]
Epoch 4:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:32<00:00,  1.43it/s]


Generator Loss: -0.4235396981239319
Discriminator Loss: [-0.34667036 -0.79395664  0.43136528  0.01592029]
Epoch 5:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.43it/s]


Generator Loss: -0.33615216612815857
Discriminator Loss: [-0.32206663 -0.69238251  0.35576516  0.01455056]
Epoch 6:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:32<00:00,  1.43it/s]


Generator Loss: -0.2569199502468109
Discriminator Loss: [-0.30700073 -0.59622639  0.27568665  0.01353928]
Epoch 7:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.44it/s]


Generator Loss: -0.20881876349449158
Discriminator Loss: [-0.29184848 -0.52956587  0.22491166  0.01280598]
Epoch 8:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.44it/s]


Generator Loss: -0.16668200492858887
Discriminator Loss: [-0.28049749 -0.47307259  0.18063389  0.01194128]
Epoch 9:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:32<00:00,  1.43it/s]


Generator Loss: -0.14165599644184113
Discriminator Loss: [-0.26968783 -0.43878701  0.15737662  0.01172269]
Epoch 10:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.43it/s]


Generator Loss: -0.12847304344177246
Discriminator Loss: [-0.26067263 -0.40689969  0.1352742   0.01095229]
Epoch 11:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:32<00:00,  1.43it/s]


Generator Loss: -0.0698762834072113
Discriminator Loss: [-0.25224736 -0.35004592  0.08615926  0.01163901]
Epoch 12:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:32<00:00,  1.43it/s]


Generator Loss: -0.03977135196328163
Discriminator Loss: [-0.24826716 -0.31474966  0.05638522  0.01009737]
Epoch 13:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:32<00:00,  1.43it/s]


Generator Loss: 0.009271916002035141
Discriminator Loss: [-0.24133934 -0.26355374  0.01106608  0.01114838]
Epoch 14:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:32<00:00,  1.43it/s]


Generator Loss: 0.049115102738142014
Discriminator Loss: [-0.23893595 -0.22008865 -0.02934551  0.01049826]
Epoch 15:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.44it/s]


Generator Loss: 0.0842994973063469
Discriminator Loss: [-0.23285307 -0.17116615 -0.07244975  0.01076256]
Epoch 16:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:32<00:00,  1.43it/s]


Generator Loss: 0.10577556490898132
Discriminator Loss: [-0.23002678 -0.14402629 -0.09731837  0.01131793]
Epoch 17:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:32<00:00,  1.43it/s]


Generator Loss: 0.13532654941082
Discriminator Loss: [-0.22711679 -0.12030657 -0.11831786  0.01150752]
Epoch 18:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:32<00:00,  1.43it/s]


Generator Loss: 0.1462278813123703
Discriminator Loss: [-0.2220265  -0.10230811 -0.13151738  0.01179936]
Epoch 19:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:33<00:00,  1.42it/s]


Generator Loss: 0.14251169562339783
Discriminator Loss: [-0.2173945  -0.10061971 -0.12759754  0.01082269]
Epoch 20:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:33<00:00,  1.42it/s]


Generator Loss: 0.14812885224819183
Discriminator Loss: [-0.21335985 -0.08958244 -0.13474038  0.01096285]
Epoch 21:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.44it/s]


Generator Loss: 0.14004650712013245
Discriminator Loss: [-0.2108441  -0.09935175 -0.12212736  0.01063487]
Epoch 22:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.44it/s]


Generator Loss: 0.12277690321207047
Discriminator Loss: [-0.20729339 -0.11193161 -0.10619096  0.01082935]
Epoch 23:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.44it/s]


Generator Loss: 0.07428920269012451
Discriminator Loss: [-0.20465165 -0.15689531 -0.05838472  0.01062831]
Epoch 24:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.44it/s]


Generator Loss: 0.08835750818252563
Discriminator Loss: [-0.20351806 -0.14452568 -0.06976949  0.01077726]
Epoch 25:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.44it/s]


Generator Loss: 0.04587677866220474
Discriminator Loss: [-0.19840193 -0.17679669 -0.032639    0.01103365]
Epoch 26:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.44it/s]


Generator Loss: 0.013236289843916893
Discriminator Loss: [-0.19739275 -0.21633761  0.00891085  0.0100339 ]
Epoch 27:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.44it/s]


Generator Loss: 0.022061537951231003
Discriminator Loss: [-0.19500129 -0.19780962 -0.00744497  0.0102533 ]
Epoch 28:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.44it/s]


Generator Loss: 0.005772297736257315
Discriminator Loss: [-0.19203815 -0.21765736  0.01562775  0.00999145]
Epoch 29:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.44it/s]


Generator Loss: -0.008473298512399197
Discriminator Loss: [-0.19202666 -0.22651914  0.02500597  0.00948649]
Epoch 30:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.44it/s]


Generator Loss: -0.030697491019964218
Discriminator Loss: [-0.18922736 -0.23751694  0.03912489  0.0091647 ]
Epoch 31:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.44it/s]


Generator Loss: -0.051591865718364716
Discriminator Loss: [-0.18705903 -0.26593331  0.06974408  0.00912992]
Epoch 32:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.44it/s]


Generator Loss: -0.042930833995342255
Discriminator Loss: [-0.18410754 -0.2547217   0.0616603   0.0089539 ]
Epoch 33:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:30<00:00,  1.44it/s]


Generator Loss: -0.048885539174079895
Discriminator Loss: [-0.18427294 -0.25177819  0.05816985  0.00933552]
Epoch 34:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.44it/s]


Generator Loss: -0.03502267599105835
Discriminator Loss: [-0.17994045 -0.239508    0.05073221  0.00883544]
Epoch 35:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.44it/s]


Generator Loss: -0.049883149564266205
Discriminator Loss: [-0.18229382 -0.25195825  0.06087526  0.00878913]
Epoch 36:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.44it/s]


Generator Loss: -0.07688833773136139
Discriminator Loss: [-0.17856875 -0.27954924  0.0924435   0.00853705]
Epoch 37:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.44it/s]


Generator Loss: -0.10050281882286072
Discriminator Loss: [-0.17953968 -0.30631423  0.11822975  0.00854441]
Epoch 38:


100%|████████████████████████████████████████████████████████████████████████████████| 218/218 [02:31<00:00,  1.44it/s]


Generator Loss: -0.10748814791440964
Discriminator Loss: [-0.17787424 -0.3175067   0.13115342  0.00847915]
Epoch 39:


 34%|███████████████████████████▍                                                     | 74/218 [00:51<01:39,  1.44it/s]