In [1]:
import numpy as np
import time
from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten, Reshape
from keras.layers import Conv2D, Conv2DTranspose
from keras.layers import LeakyReLU, Dropout
from keras.layers import BatchNormalization
from keras.optimizers import Adam, RMSprop
from keras.datasets import mnist
import matplotlib.pyplot as plt

kernel_size = 5
image_size = 28

def generator_model():
    dropout = 0.4
    dim = image_size // 4
    depth = 128
    # In: 100
    # Out: 7 x 7 x 256
    model = Sequential()
    model.add(Dense(dim*dim*depth, input_dim=100))
    model.add(Reshape((dim, dim, depth)))
    model.add(BatchNormalization(momentum=0.9))
    model.add(Activation("relu"))
#     model.add(LeakyReLU())
    model.add(Dropout(dropout))
    # In: 7 x 7 x 256
    # Out: 14 x 14 x 128
    model.add(Conv2DTranspose(depth, kernel_size, strides = 2, padding='same'))
    model.add(BatchNormalization(momentum=0.9))
    model.add(Activation("relu"))
#     model.add(LeakyReLU())
    # In: 14 x 14 x 128
    # Out: 28 x 28 x 64
    model.add(Conv2DTranspose(int(depth/2), kernel_size, strides = 2, padding='same'))
    model.add(BatchNormalization(momentum=0.9))
    model.add(Activation("relu"))
#     model.add(LeakyReLU())
    # In: 28 x 28 x 64
    # Out: 28 x 28 x 1
    model.add(Conv2DTranspose(int(depth/4), kernel_size,strides = 1,  padding='same'))
    model.add(BatchNormalization(momentum=0.9))
    model.add(Activation("relu"))
#     model.add(LeakyReLU())
    model.add(Conv2DTranspose(1, kernel_size,strides = 1, padding='same'))
    model.add(Activation('sigmoid'))
    model.summary()
    return model


def discriminator_model():
    depth = 64
    dropout = 0.4
    model = Sequential()
    # In: 28 x 28 x 1, depth = 1
    # Out: 14 x 14 x 1, depth = 32
    model.add(Conv2D(depth, kernel_size, strides = 2, padding ='same', input_shape=(28, 28, 1)))
    model.add(LeakyReLU(alpha=0.2))
#     model.add(Dropout(dropout))
    
    model.add(Conv2D(depth*2, kernel_size, strides = 2, padding ='same'))
    model.add(LeakyReLU(alpha=0.2))
#     model.add(Dropout(dropout))
    
    model.add(Conv2D(depth*4, kernel_size, strides = 2, padding ='same'))
    model.add(LeakyReLU(alpha=0.2))
#     model.add(Dropout(dropout))
    
    model.add(Conv2D(depth*8, kernel_size, strides = 1, padding ='same'))
    model.add(LeakyReLU(alpha=0.2))
#     model.add(Dropout(dropout))
    
    # Out: 1-dim probability
    model.add(Flatten())
    model.add(Dense(1))
    model.add(Activation('sigmoid'))
    model.summary()
    return model

def adversial_model(generator, discriminator):
    model = Sequential()
    model.add(generator)
    discriminator.trainable = False
    model.add(discriminator)
    return model

Using TensorFlow backend.


In [None]:
def train(epochs = 1000, batch_size = 256, save_interval = 50, lr = 0.0002, decay = 6e-8):
    latent_dim = 100
    img_cols, img_rows = 28, 28
    # Initialize data and models
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    X_train = np.reshape(X_train, [-1, image_size, image_size, 1])
    X_train = X_train.astype('float32') / 255
    discriminator = discriminator_model()
    generator = generator_model()
    adversial = adversial_model(generator, discriminator) # discriminator is frozen
    d_optim = RMSprop(lr=lr, decay=decay)
    g_optim = RMSprop(lr=lr*0.5, decay=decay*0.5)
#     generator.compile(loss='binary_crossentropy', optimizer="RMSprop")
    adversial.compile(loss='binary_crossentropy', optimizer=g_optim, metrics=['accuracy'])
#     discriminator.trainable = True # unfreeze discriminator
    discriminator.compile(loss='binary_crossentropy', optimizer=d_optim, metrics=['accuracy'])
    
    # Stuff for saving images
    filename = 'mnist.png'
    noise_input = np.random.uniform(-1, 1, size=[16, latent_dim])
    
    # The training loop
    for epoch in range(epochs):
        real_images = X_train[np.random.randint(0, X_train.shape[0], size=batch_size)]
        noise = np.random.uniform(-1, 1, size=[batch_size, latent_dim])
        fake_images = generator.predict(noise)
        combined_images = np.concatenate((real_images, fake_images))
        labels = np.ones([2*batch_size, 1])
        labels[batch_size:, :] = 0
#         labels = np.concatenate([np.ones((batch_size, 1)), np.zeros((batch_size, 1))])
#         labels += 0.05 * np.random.random(labels.shape)
        d_loss = discriminator.train_on_batch(combined_images, labels)
                                        
        misleading_targets = np.ones([batch_size, 1])
        random_latent_vectors = np.random.uniform(-1, 1, size=[batch_size, latent_dim])
#         discriminator.trainable = False # freeze the discriminator
        a_loss = adversial.train_on_batch(random_latent_vectors, misleading_targets)
#         discriminator.trainable = True # unfreeze the discrminator
                                        
        log_mesg = "%d: [D loss: %f, acc: %f]" % (epoch, d_loss[0], d_loss[1])
        log_mesg = "%s  [A loss: %f, acc: %f]" % (log_mesg, a_loss[0], a_loss[1])
        print(log_mesg)
                       
        # Saving sample generated images
        if save_interval>0:
            step = epoch+1
            if step%save_interval==0:
                filename = "mnist_%d.png" % step
                images = generator.predict(noise_input)
                plt.figure(figsize=(10,10))
                for i in range(images.shape[0]):
                    plt.subplot(4, 4, i+1)
                    image = images[i, :, :, :]
                    image = np.reshape(image, [img_cols, img_rows])
                    plt.imshow(image, cmap='gray')
                    plt.axis('off')
                plt.tight_layout()
                plt.savefig(filename)
                plt.close('all')
train(epochs = 1000, batch_size = 256, save_interval = 50)

Instructions for updating:
Colocations handled automatically by placer.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_1 (Conv2D)            (None, 14, 14, 64)        1664      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 7, 7, 128)         204928    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 7, 7, 128)         0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 4, 4, 256)         819456    
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 4, 4, 256)         0         
_________________________________________________________________
conv