## Deep Convolutional Generative Adversarial Network (DCGAN) implementation on MNIST data set using Keras

Here is an implementation of DCGAN as described in [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/abs/1511.06434) on [MNIST](http://yann.lecun.com/exdb/mnist/) data set using [Keras](https://keras.io/) library with [Tensorflow](https://www.tensorflow.org/) backend.

In [1]:
%matplotlib inline

import numpy as np
import keras
import matplotlib.pyplot as plt

from keras.models import Sequential
from keras.layers import Dense, Activation, Reshape, Input, Flatten
from keras.layers.convolutional import Conv2DTranspose, Conv2D
from keras.layers.normalization import BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import Adam
from keras.datasets import mnist
import keras.models as km

from PIL import Image
import math
import os

Using TensorFlow backend.


Set the format of the image data, channels last, to avoid any discrepancies.

In [2]:
keras.backend.set_image_data_format('channels_last')

## Defining the Adversarial Network

### Generator

The generator takes a vector of random numbers and transforms it into a 32x32 image. Each layer in the network involves a strided transpose convolution, batch normalization, and rectified nonlinearity.

In [3]:
weight_initializer = 'truncated_normal'

def generator_model(z_size):
    gen_input = Input(shape=(z_size,))
    
    gen_hidden = Dense(units=7*7*128, 
                       kernel_initializer=weight_initializer)(gen_input)
    gen_hidden = BatchNormalization()(gen_hidden)
    gen_hidden = Activation('relu')(gen_hidden)
    gen_hidden = Reshape([7, 7, 128])(gen_hidden)
    # now we have 7x7x128 image
    
    gen_hidden = Conv2DTranspose(filters=64,
                                 kernel_size=[5, 5],
                                 strides=[2, 2],
                                 padding='same',
                                 kernel_initializer=weight_initializer)(gen_hidden)
    gen_hidden = BatchNormalization()(gen_hidden)
    gen_hidden = Activation('relu')(gen_hidden)
    # now we have 14x14x64 image
    
    gen_hidden = Conv2DTranspose(filters=1,
                                 kernel_size=[5, 5],
                                 strides=[2, 2],
                                 padding='same',
                                 kernel_initializer=weight_initializer)(gen_hidden)
    gen_output = Activation('tanh')(gen_hidden)
    # now we have 28x28x1 image
    
    gen_model = km.Model(inputs=gen_input, outputs=gen_output)
    
    return gen_model

### Discriminator

The discriminator network takes as input a 28x28 image and transforms it into a single valued probability of being generated from real-world data.

In [4]:
def discriminator_model():
    discr_input = Input(shape=(28,28,1))
    
    discr_l1 = Conv2D(filters=64,
                      kernel_size=[5, 5],
                      strides=[2, 2],
                      padding='same',
                      kernel_initializer=weight_initializer)
    discr_hidden1 = discr_l1(discr_input)
    discr_l2 = LeakyReLU(alpha=0.2)
    discr_hidden2 = discr_l2(discr_hidden1)
    
    discr_l3 = Conv2D(filters=128,
                      kernel_size=[5, 5],
                      strides=[2, 2],
                      padding='same',
                      kernel_initializer=weight_initializer)
    discr_hidden3 = discr_l3(discr_hidden2)
    discr_l4 = BatchNormalization()
    #discr_hidden4 = discr_l4(discr_hidden3)
    discr_l5 = LeakyReLU(alpha=0.2)
    discr_hidden5 = discr_l5(discr_hidden3)
    
    discr_l6 = Flatten()
    discr_hidden6 = discr_l6(discr_hidden5)
    discr_l7 = Dense(units=1, 
                     kernel_initializer=weight_initializer, 
                     activation='sigmoid')
    discr_output = discr_l7(discr_hidden6)
    
    discr_model = km.Model(inputs=discr_input, outputs=discr_output)
    
    return (discr_model, discr_l1, discr_l2, discr_l3, discr_l4, discr_l5, discr_l6, discr_l7)

### Adam optimizer

We tune hyperparameters of the Adam optimizer as suggested in the [paper](https://arxiv.org/abs/1511.06434).

In [5]:
adam_optimizer = Adam(lr=0.0002, beta_1=0.5)

### Output images

Here is a helper function to produce a combine image consisting of several generated images (in matrix form). We want just see, how the generator produces better results after several epochs.

In [6]:
def combine_images(generated_images):
    #print(generated_images.shape)
    num = generated_images.shape[0]
    
    num_cols = int(math.sqrt(num))
    num_rows = int(math.ceil(float(num)/num_cols))
    
    single_width = generated_images.shape[1]
    single_height = generated_images.shape[2]
    
    image = np.zeros((num_rows*single_height,
                      num_cols*single_width),
                      dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index / num_cols)
        j = index % num_cols
        image[i*single_height:(i+1)*single_height, j*single_width:(j+1)*single_width] = img[:, :, 0]
        
    return image

### Connect everything together

Now we define inputs, outputs and compile the model. Pay attention that we need to combine generator and discriminator models but for this combination the discriminator must not be trainable.

In [7]:
def make_trainable(model, trainable):
    model.trainable = trainable
    for l in model.layers:
        l.trainable = trainable

In [8]:
z_size = 100

# compile generator
generator = generator_model(z_size)
generator.compile(optimizer=adam_optimizer, loss='binary_crossentropy')

In [9]:
# create discriminator model
discriminator, discr_l1, discr_l2, discr_l3, discr_l4, discr_l5, discr_l6, discr_l7 = discriminator_model()

In [10]:
make_trainable(discriminator, False) # !!! here discriminator must not be trainable

# create models that consist of generator and discriminator and compile it
gan_input = Input(shape=(z_size,))
gan_hidden = generator(gan_input)
gan_hidden = discr_l1(gan_hidden)
gan_hidden = discr_l2(gan_hidden)
gan_hidden = discr_l3(gan_hidden)
#gan_hidden = discr_l4(gan_hidden)
gan_hidden = discr_l5(gan_hidden)
gan_hidden = discr_l6(gan_hidden)
gan_output = discr_l7(gan_hidden)

gan = km.Model(inputs=gan_input,outputs=gan_output)
gan.compile(optimizer=adam_optimizer, loss='binary_crossentropy')

In [11]:
# compile discriminator
make_trainable(discriminator, True)
discriminator.compile(optimizer=adam_optimizer, loss='binary_crossentropy')

### Training

Load MNIST data set

In [12]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

Prepare train images - we need to squash them to have values (-1, 1) and reshape into 28x28x1 array

In [13]:
x_train = (x_train.astype(np.float32) - 127.5)/127.5
x_train = x_train.reshape((x_train.shape[0], x_train.shape[1], x_train.shape[2], 1))

Create directory to save generated images

In [14]:
GEN_IMAGES_DIR = 'gen_images'
if not os.path.exists(GEN_IMAGES_DIR):
    os.makedirs(GEN_IMAGES_DIR)

Training loop

In [15]:
batch_size = 128
epochs = 25
num_batches = int(x_train.shape[0] / batch_size)

for e in range(epochs):
    for b in range(num_batches):
        # noise and train images
        noise = np.random.uniform(-1.0, 1.0, size=[batch_size, z_size]).astype(np.float32)
        images = x_train[b*batch_size:(b + 1)*batch_size]
        
        # generate image
        gen_images = generator.predict(noise, verbose=0)
        
        # save generated images periodically
        if ((b+1) % 50) == 0:
            combined_image = combine_images(gen_images)
            combined_image = combined_image*127.5+127.5
            Image.fromarray(combined_image.astype(np.uint8)) \
                .save("{}/{}_{}.png".format(GEN_IMAGES_DIR, e+1, b+1))
        
        # calcualate discriminator loss
        discr_input = np.concatenate((images, gen_images))
        discr_labels = [0.9]*batch_size + [0.0]*batch_size 
        d_loss = discriminator.train_on_batch(discr_input, discr_labels)
        
        # calculate generator loss
        make_trainable(discriminator, False)
        gen_labels = [1]*batch_size
        g_loss = gan.train_on_batch(noise, gen_labels)
        make_trainable(discriminator, True)
        
        # print statistic
        print("Epoch {}/{}...".format(e+1, epochs),
              "Batch {}/{}...".format(b+1, num_batches),
              "Discriminator Loss: {:.4f}...".format(d_loss),
              "Generator Loss: {:.4f}".format(g_loss))
        
        # save weights periodically
        if ((b+1) % 10) == 0:
            generator.save_weights('generator', True)
            discriminator.save_weights('discriminator', True) 

Epoch 1/25... Batch 1/468... Discriminator Loss: 0.7413... Generator Loss: 0.7678
Epoch 1/25... Batch 2/468... Discriminator Loss: 0.6440... Generator Loss: 0.6379
Epoch 1/25... Batch 3/468... Discriminator Loss: 0.5732... Generator Loss: 0.5488
Epoch 1/25... Batch 4/468... Discriminator Loss: 0.5593... Generator Loss: 0.5040
Epoch 1/25... Batch 5/468... Discriminator Loss: 0.5559... Generator Loss: 0.5147
Epoch 1/25... Batch 6/468... Discriminator Loss: 0.5543... Generator Loss: 0.5569
Epoch 1/25... Batch 7/468... Discriminator Loss: 0.5454... Generator Loss: 0.6263
Epoch 1/25... Batch 8/468... Discriminator Loss: 0.5392... Generator Loss: 0.6733
Epoch 1/25... Batch 9/468... Discriminator Loss: 0.5334... Generator Loss: 0.7097
Epoch 1/25... Batch 10/468... Discriminator Loss: 0.5267... Generator Loss: 0.7527
Epoch 1/25... Batch 11/468... Discriminator Loss: 0.5174... Generator Loss: 0.8209
Epoch 1/25... Batch 12/468... Discriminator Loss: 0.5123... Generator Loss: 0.9093
Epoch 1/25...

KeyboardInterrupt: 

### Generate images

In [None]:
# create directory for generated images
GEN_IMAGES_DIR_FINAL = 'gen_images_final'
if not os.path.exists(GEN_IMAGES_DIR_FINAL):
    os.makedirs(GEN_IMAGES_DIR_FINAL)
    
# create generator and load its weights
generator = generator_model(z_size)
generator.compile(optimizer=adam_optimizer, loss='binary_crossentropy')
generator.load_weights('generator')
    
# generate images
NUM_GEN_IMAGES = 64
noise = np.random.uniform(-1.0, 1.0, size=[NUM_GEN_IMAGES, z_size]).astype(np.float32)
gen_images = generator.predict(noise, verbose=0)

# save images
combined_image = combine_images(gen_images)
combined_image = combined_image*127.5+127.5
Image.fromarray(combined_image.astype(np.uint8)) \
    .save("{}/gen_images.png".format(GEN_IMAGES_DIR_FINAL))

In [None]:
# show generated image
plt.imshow(combined_image, cmap='Greys_r')