In [27]:
from tensorflow.keras.layers import Input,LeakyReLU, Activation, Dense, Conv2D, Flatten, Reshape, Conv2DTranspose,BatchNormalization
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model

import numpy as np
import math
import matplotlib.pyplot as plt
import os

# import sys
# sys.path.append("..")
# from lib import gan

In [28]:
def plot_images(generator, noise_input, noise_label=None, noise_codes=None, show=False, step=0, model_name="gan"):
    os.makedirs(model_name, exist_ok=True)
    filename = os.path.join(model_name, "%05d.png" % step)
    rows = int(math.sqrt(noise_input.shape[0]))
    if noise_label is not None:
        noise_input = [noise_input, noise_label]
        if noise_codes is not None:
            noise_input += noise_codes

    images = generator.predict(noise_input)
    plt.figure(figsize=(2.2, 2.2))
    num_images = images.shape[0]
    image_size = images.shape[1]
    for i in range(num_images):
        plt.subplot(rows, rows, i + 1)
        image = np.reshape(images[i], [image_size, image_size])
        plt.imshow(image, cmap='gray')
        plt.axis('off')
    plt.savefig(filename)
    if show:
        plt.show()
    else:
        plt.close('all')

In [29]:
# load MNIST dataset
(x_train, _), (_, _) = mnist.load_data()
# reshape data for CNN as (28, 28, 1) and normalize
image_size = x_train.shape[1]
x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
x_train = x_train.astype('float32') / 255

In [30]:
model_name = "wgan_mnist"
# network parameters
# the latent or z vector is 100-dim
latent_size = 100
# hyper parameters from WGAN paper [2]
n_critic = 5
clip_value = 0.01
batch_size = 64
lr = 5e-5
train_steps = 40000
input_shape = (image_size, image_size, 1)

In [31]:
def wasserstein_loss(y_label, y_pred):
    return -K.mean(y_label * y_pred)

In [32]:
# build discriminator model
inputs = Input(shape=input_shape, name='discriminator_input')

x = LeakyReLU(alpha=0.2)(inputs)
x = Conv2D(filters=32, kernel_size=5, strides=2, padding='same')(x)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(filters=64, kernel_size=5, strides=2, padding='same')(x)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(filters=128, kernel_size=5, strides=2, padding='same')(x)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(filters=256, kernel_size=5, strides=1, padding='same')(x)
x = Flatten()(x)
outputs = Dense(1)(x)
# WGAN uses linear activation in paper [2]
outputs = Activation(activation='linear')(outputs)
discriminator = Model(inputs, outputs, name='discriminator')

optimizer = RMSprop(lr=lr)
# WGAN discriminator uses wassertein loss
discriminator.compile(loss=wasserstein_loss, optimizer=optimizer, metrics=['accuracy'])
discriminator.summary()

Model: "discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
discriminator_input (InputLa [(None, 28, 28, 1)]       0         
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU)    (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 14, 14, 32)        832       
_________________________________________________________________
leaky_re_lu_9 (LeakyReLU)    (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 7, 7, 64)          51264     
_________________________________________________________________
leaky_re_lu_10 (LeakyReLU)   (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 4, 4, 128)       

In [33]:
# build generator model
input_shape = (latent_size, )
image_resize = image_size // 4

inputs = Input(shape=input_shape, name='z_input')
x = Dense(image_resize * image_resize * 1)(inputs)
x = Reshape((image_resize, image_resize, 1))(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2DTranspose(filters=128, kernel_size=5, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2DTranspose(filters=64, kernel_size=5, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2DTranspose(filters=32, kernel_size=5, strides=1, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2DTranspose(filters=1, kernel_size=5, strides=1, padding='same')(x)
x = Activation(activation='sigmoid')(x)

generator = Model(inputs, x, name='generator')
generator.summary()

Model: "generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
z_input (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
dense_5 (Dense)              (None, 49)                4949      
_________________________________________________________________
reshape_2 (Reshape)          (None, 7, 7, 1)           0         
_________________________________________________________________
batch_normalization_4 (Batch (None, 7, 7, 1)           4         
_________________________________________________________________
activation_8 (Activation)    (None, 7, 7, 1)           0         
_________________________________________________________________
conv2d_transpose_4 (Conv2DTr (None, 14, 14, 128)       3328      
_________________________________________________________________
batch_normalization_5 (Batch (None, 14, 14, 128)       51

In [34]:
# build adversarial model = generator + discriminator
# freeze the weights of discriminator during adversarial training
discriminator.trainable = False
adversarial = Model(inputs, discriminator(generator(inputs)), name=model_name)
adversarial.compile(loss=wasserstein_loss, optimizer=optimizer, metrics=['accuracy'])
adversarial.summary()

Model: "wgan_mnist"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
z_input (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
generator (Model)            (None, 28, 28, 1)         266074    
_________________________________________________________________
discriminator (Model)        (None, 1)                 1080577   
Total params: 1,346,651
Trainable params: 265,624
Non-trainable params: 1,081,027
_________________________________________________________________


In [35]:
# the generator image is saved every 500 steps
save_interval = 500

# noise vector to see how the generator output evolves during training
noise_input = np.random.uniform(-1.0, 1.0, size=[16, latent_size])
# number of elements in train dataset
train_size = x_train.shape[0]
# labels for real data
real_labels = np.ones((batch_size, 1))
for i in range(train_steps):
    # train discriminator n_critic times
    loss = 0
    acc = 0
    for _ in range(n_critic):
        # train the discriminator for 1 batch
        # 1 batch of real (label=1.0) and fake images (label=-1.0)
        # randomly pick real images from dataset
        rand_indexes = np.random.randint(0, train_size, size=batch_size)
        real_images = x_train[rand_indexes]
        # generate fake images from noise using generator
        # generate noise using uniform distribution
        noise = np.random.uniform(-1.0, 1.0, size=[batch_size, latent_size])
        fake_images = generator.predict(noise)

        # train the discriminator network
        # real data label=1, fake data label=-1
        # instead of 1 combined batch of real and fake images,
        # train with 1 batch of real data first, then 1 batch
        # of fake images.
        # this tweak prevents the gradient from vanishing due to opposite
        # signs of real and fake data labels (i.e. +1 and -1) and 
        # small magnitude of weights due to clipping.
        real_loss, real_acc = discriminator.train_on_batch(real_images, real_labels)
        fake_loss, fake_acc = discriminator.train_on_batch(fake_images, -real_labels)
        # accumulate average loss and accuracy
        loss += 0.5 * (real_loss + fake_loss)
        acc += 0.5 * (real_acc + fake_acc)

        # clip discriminator weights to satisfy Lipschitz constraint
        for layer in discriminator.layers:
            weights = layer.get_weights()
            weights = [np.clip(weight, -clip_value, clip_value) for weight in weights]
            layer.set_weights(weights)

    # average loss and accuracy per n_critic training iterations
    loss /= n_critic
    acc /= n_critic
    log = "%d: [discriminator loss: %f, acc: %f]" % (i, loss, acc)

    # train the adversarial network for 1 batch
    # 1 batch of fake images with label=1.0
    # since the discriminator weights are frozen in adversarial network
    # only the generator is trained
    # generate noise using uniform distribution
    noise = np.random.uniform(-1.0, 1.0, size=[batch_size, latent_size])
    # train the adversarial network
    # note that unlike in discriminator training,
    # we do not save the fake images in a variable
    # the fake images go to the discriminator input of the adversarial
    # for classification
    # fake images are labelled as real
    # log the loss and accuracy
    loss, acc = adversarial.train_on_batch(noise, real_labels)
    log = "%s [adversarial loss: %f, acc: %f]" % (log, loss, acc)
    print(log)
    if (i + 1) % save_interval == 0:
        if (i + 1) == train_steps:
            show = True
        else:
            show = False

    # plot generator images on a periodic basis
    plot_images(generator, noise_input=noise_input, show=False, step=(i + 1), model_name=model_name)

0: [discriminator loss: 0.015030, acc: 0.000000] [adversarial loss: -0.000273, acc: 0.000000]
1: [discriminator loss: 0.000001, acc: 0.000000] [adversarial loss: -0.000154, acc: 0.000000]
2: [discriminator loss: -0.000145, acc: 0.000000] [adversarial loss: 0.000062, acc: 0.000000]
3: [discriminator loss: -0.000413, acc: 0.000000] [adversarial loss: 0.000465, acc: 0.000000]
4: [discriminator loss: -0.000969, acc: 0.000000] [adversarial loss: 0.001379, acc: 0.000000]
5: [discriminator loss: -0.002326, acc: 0.000000] [adversarial loss: 0.003738, acc: 0.000000]
6: [discriminator loss: -0.005621, acc: 0.000000] [adversarial loss: 0.008925, acc: 0.000000]
7: [discriminator loss: -0.013330, acc: 0.000000] [adversarial loss: 0.019076, acc: 0.000000]
8: [discriminator loss: -0.026463, acc: 0.000000] [adversarial loss: 0.033314, acc: 0.000000]
9: [discriminator loss: -0.049825, acc: 0.000000] [adversarial loss: 0.054909, acc: 0.000000]
10: [discriminator loss: -0.082861, acc: 0.000000] [adversar

85: [discriminator loss: -108.988280, acc: 0.104688] [adversarial loss: -40.464615, acc: 1.000000]
86: [discriminator loss: -113.913955, acc: 0.117188] [adversarial loss: -47.328335, acc: 1.000000]
87: [discriminator loss: -112.178378, acc: 0.109375] [adversarial loss: -50.422173, acc: 1.000000]
88: [discriminator loss: -111.910949, acc: 0.093750] [adversarial loss: -53.114223, acc: 1.000000]
89: [discriminator loss: -118.611156, acc: 0.118750] [adversarial loss: -54.235588, acc: 1.000000]
90: [discriminator loss: -110.760719, acc: 0.090625] [adversarial loss: -56.013092, acc: 1.000000]
91: [discriminator loss: -113.813723, acc: 0.098437] [adversarial loss: -57.177334, acc: 1.000000]
92: [discriminator loss: -115.851353, acc: 0.090625] [adversarial loss: -58.051231, acc: 1.000000]
93: [discriminator loss: -112.377261, acc: 0.087500] [adversarial loss: -59.101784, acc: 1.000000]
94: [discriminator loss: -115.363040, acc: 0.078125] [adversarial loss: -59.945530, acc: 1.000000]
95: [discr

168: [discriminator loss: -9.967733, acc: 0.051562] [adversarial loss: -74.308197, acc: 1.000000]
169: [discriminator loss: -8.532911, acc: 0.059375] [adversarial loss: -73.678543, acc: 1.000000]
170: [discriminator loss: -5.530589, acc: 0.050000] [adversarial loss: -73.042679, acc: 1.000000]
171: [discriminator loss: -5.368523, acc: 0.048438] [adversarial loss: -72.337173, acc: 1.000000]
172: [discriminator loss: -5.559027, acc: 0.068750] [adversarial loss: -71.632957, acc: 1.000000]
173: [discriminator loss: -2.432058, acc: 0.068750] [adversarial loss: -70.894241, acc: 1.000000]
174: [discriminator loss: -2.772424, acc: 0.057813] [adversarial loss: -70.117134, acc: 1.000000]
175: [discriminator loss: -0.450908, acc: 0.075000] [adversarial loss: -69.275848, acc: 1.000000]
176: [discriminator loss: 2.499207, acc: 0.064062] [adversarial loss: -68.328613, acc: 1.000000]
177: [discriminator loss: 1.639951, acc: 0.079687] [adversarial loss: -67.376923, acc: 1.000000]
178: [discriminator lo

252: [discriminator loss: -96.888680, acc: 0.500000] [adversarial loss: -287.584930, acc: 1.000000]
253: [discriminator loss: -100.690091, acc: 0.500000] [adversarial loss: -300.113708, acc: 1.000000]
254: [discriminator loss: -100.538452, acc: 0.500000] [adversarial loss: -312.712280, acc: 1.000000]
255: [discriminator loss: -110.202386, acc: 0.500000] [adversarial loss: -326.060059, acc: 1.000000]
256: [discriminator loss: -101.582703, acc: 0.500000] [adversarial loss: -339.127686, acc: 1.000000]
257: [discriminator loss: -113.433298, acc: 0.500000] [adversarial loss: -353.090637, acc: 1.000000]
258: [discriminator loss: -114.580116, acc: 0.500000] [adversarial loss: -368.022919, acc: 1.000000]
259: [discriminator loss: -116.316345, acc: 0.500000] [adversarial loss: -383.332581, acc: 1.000000]
260: [discriminator loss: -115.524586, acc: 0.500000] [adversarial loss: -400.358276, acc: 1.000000]
261: [discriminator loss: -123.643817, acc: 0.500000] [adversarial loss: -418.221741, acc: 1

333: [discriminator loss: -267.272650, acc: 0.500000] [adversarial loss: -1570.027954, acc: 1.000000]
334: [discriminator loss: -274.026242, acc: 0.500000] [adversarial loss: -1593.861450, acc: 1.000000]
335: [discriminator loss: -264.579298, acc: 0.500000] [adversarial loss: -1618.351074, acc: 1.000000]
336: [discriminator loss: -264.736620, acc: 0.500000] [adversarial loss: -1621.848633, acc: 1.000000]
337: [discriminator loss: -250.194495, acc: 0.500000] [adversarial loss: -1666.749268, acc: 1.000000]
338: [discriminator loss: -259.747586, acc: 0.500000] [adversarial loss: -1670.439941, acc: 1.000000]
339: [discriminator loss: -259.862570, acc: 0.500000] [adversarial loss: -1725.098511, acc: 1.000000]
340: [discriminator loss: -256.532983, acc: 0.500000] [adversarial loss: -1741.707642, acc: 1.000000]
341: [discriminator loss: -251.754651, acc: 0.500000] [adversarial loss: -1769.290649, acc: 1.000000]
342: [discriminator loss: -235.045428, acc: 0.500000] [adversarial loss: -1792.227

KeyboardInterrupt: 

In [None]:
# save the model after training the generator
# the trained generator can be reloaded for future MNIST digit generation
generator.save(model_name + ".h5")