In [12]:
from tensorflow.keras.layers import Input
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import load_model
from tensorflow.keras.layers import Activation, Dense, Input
from tensorflow.keras.layers import Conv2D, Flatten
from tensorflow.keras.layers import Reshape, Conv2DTranspose
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.layers import concatenate

import math
import matplotlib.pyplot as plt
import os
import numpy as np

In [4]:
def plot_images(generator, noise_input, noise_label=None, noise_codes=None, show=False, step=0, model_name="gan"):
    os.makedirs(model_name, exist_ok=True)
    filename = os.path.join(model_name, "%05d.png" % step)
    rows = int(math.sqrt(noise_input.shape[0]))
    if noise_label is not None:
        noise_input = [noise_input, noise_label]
        if noise_codes is not None:
            noise_input += noise_codes

    images = generator.predict(noise_input)
    plt.figure(figsize=(2.2, 2.2))
    num_images = images.shape[0]
    image_size = images.shape[1]
    for i in range(num_images):
        plt.subplot(rows, rows, i + 1)
        image = np.reshape(images[i], [image_size, image_size])
        plt.imshow(image, cmap='gray')
        plt.axis('off')
    plt.savefig(filename)
    if show:
        plt.show()
    else:
        plt.close('all')

In [5]:
(x_train, _), (_, _) = mnist.load_data()

# reshape data for CNN as (28, 28, 1) and normalize
image_size = x_train.shape[1]
x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
x_train = x_train.astype('float32') / 255

In [6]:
model_name = "lsgan_mnist"
# network parameters
# the latent or z vector is 100-dim
latent_size = 100
input_shape = (image_size, image_size, 1)
batch_size = 64
lr = 2e-4
decay = 6e-8
train_steps = 40000

In [7]:
# build discriminator model
inputs = Input(shape=input_shape, name='discriminator_input')

x = LeakyReLU(alpha=0.2)(inputs)
x = Conv2D(filters=32, kernel_size=5, strides=2, padding='same')(x)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(filters=64, kernel_size=5, strides=2, padding='same')(x)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(filters=128, kernel_size=5, strides=2, padding='same')(x)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(filters=256, kernel_size=5, strides=1, padding='same')(x)
x = Flatten()(x)
outputs = Dense(1)(x)
discriminator = Model(inputs, outputs, name='discriminator')

# [1] uses Adam, but discriminator converges easily with RMSprop
optimizer = RMSprop(lr=lr, decay=decay)
# LSGAN uses MSE loss [2]
discriminator.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
discriminator.summary()

Model: "discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
discriminator_input (InputLa [(None, 28, 28, 1)]       0         
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d (Conv2D)              (None, 14, 14, 32)        832       
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 7, 7, 64)          51264     
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 4, 4, 128)       

In [8]:
# build generator model
input_shape = (latent_size, )
image_resize = image_size // 4

inputs = Input(shape=input_shape, name='z_input')
x = Dense(image_resize * image_resize * 1)(inputs)
x = Reshape((image_resize, image_resize, 1))(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2DTranspose(filters=128, kernel_size=5, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2DTranspose(filters=64, kernel_size=5, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2DTranspose(filters=32, kernel_size=5, strides=1, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2DTranspose(filters=1, kernel_size=5, strides=1, padding='same')(x)
x = Activation(activation='sigmoid')(x)

generator = Model(inputs, x, name='generator')
generator.summary()

Model: "generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
z_input (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
dense_1 (Dense)              (None, 49)                4949      
_________________________________________________________________
reshape (Reshape)            (None, 7, 7, 1)           0         
_________________________________________________________________
batch_normalization (BatchNo (None, 7, 7, 1)           4         
_________________________________________________________________
activation (Activation)      (None, 7, 7, 1)           0         
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 14, 14, 128)       3328      
_________________________________________________________________
batch_normalization_1 (Batch (None, 14, 14, 128)       51

In [9]:
# build adversarial model = generator + discriminator
optimizer = RMSprop(lr=lr*0.5, decay=decay*0.5)
# freeze the weights of discriminator during adversarial training
discriminator.trainable = False
adversarial = Model(inputs, discriminator(generator(inputs)), name=model_name)
# LSGAN uses MSE loss [2]
adversarial.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
adversarial.summary()

Model: "lsgan_mnist"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
z_input (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
generator (Model)            (None, 28, 28, 1)         266074    
_________________________________________________________________
discriminator (Model)        (None, 1)                 1080577   
Total params: 1,346,651
Trainable params: 265,624
Non-trainable params: 1,081,027
_________________________________________________________________


In [13]:
# the generator image is saved every 500 steps
save_interval = 500
# noise vector to see how the generator output evolves during training
noise_input = np.random.uniform(-1.0, 1.0, size=[16, latent_size])
# number of elements in train dataset
train_size = x_train.shape[0]
for i in range(train_steps):
    # train the discriminator for 1 batch
    # 1 batch of real (label=1.0) and fake images (label=0.0)
    # randomly pick real images from dataset
    rand_indexes = np.random.randint(0, train_size, size=batch_size)
    real_images = x_train[rand_indexes]
    # generate fake images from noise using generator 
    # generate noise using uniform distribution
    noise = np.random.uniform(-1.0, 1.0, size=[batch_size, latent_size])
    # generate fake images
    fake_images = generator.predict(noise)
    # real + fake images = 1 batch of train data
    x = np.concatenate((real_images, fake_images))
    # label real and fake images
    # real images label is 1.0
    y = np.ones([2 * batch_size, 1])
    # fake images label is 0.0
    y[batch_size:, :] = 0.0
    # train discriminator network, log the loss and accuracy
    loss, acc = discriminator.train_on_batch(x, y)
    log = "%d: [discriminator loss: %f, acc: %f]" % (i, loss, acc)

    # train the adversarial network for 1 batch
    # 1 batch of fake images with label=1.0
    # since the discriminator weights are frozen in adversarial network
    # only the generator is trained
    # generate noise using uniform distribution
    noise = np.random.uniform(-1.0, 1.0, size=[batch_size, latent_size])
    # label fake images as real or 1.0
    y = np.ones([batch_size, 1])
    # train the adversarial network 
    # note that unlike in discriminator training, 
    # we do not save the fake images in a variable
    # the fake images go to the discriminator input of the adversarial
    # for classification
    # log the loss and accuracy
    loss, acc = adversarial.train_on_batch(noise, y)
    log = "%s [adversarial loss: %f, acc: %f]" % (log, loss, acc)
    print(log)
    if (i + 1) % save_interval == 0:
        if (i + 1) == train_steps:
            show = True
        else:
            show = False
    
        # plot generator images on a periodic basis
        plot_images(generator, noise_input=noise_input, show=show, step=(i + 1), model_name=model_name)

0: [discriminator loss: 0.018497, acc: 1.000000] [adversarial loss: 0.279191, acc: 0.406250]
1: [discriminator loss: 0.017713, acc: 1.000000] [adversarial loss: 0.082497, acc: 0.953125]
2: [discriminator loss: 0.024931, acc: 1.000000] [adversarial loss: 0.232066, acc: 0.593750]
3: [discriminator loss: 0.021450, acc: 1.000000] [adversarial loss: 0.035235, acc: 1.000000]
4: [discriminator loss: 0.022464, acc: 1.000000] [adversarial loss: 0.338295, acc: 0.140625]
5: [discriminator loss: 0.035219, acc: 1.000000] [adversarial loss: 0.022817, acc: 1.000000]
6: [discriminator loss: 0.020089, acc: 1.000000] [adversarial loss: 0.254594, acc: 0.609375]
7: [discriminator loss: 0.032478, acc: 0.992188] [adversarial loss: 0.013560, acc: 1.000000]
8: [discriminator loss: 0.026180, acc: 1.000000] [adversarial loss: 0.277992, acc: 0.390625]
9: [discriminator loss: 0.045784, acc: 0.984375] [adversarial loss: 0.022420, acc: 1.000000]
10: [discriminator loss: 0.015323, acc: 1.000000] [adversarial loss: 0

KeyboardInterrupt: 

In [None]:
# save the model after training the generator
# the trained generator can be reloaded for future MNIST digit generation
generator.save(model_name + ".h5")