<a href="https://colab.research.google.com/github/xu-hao/colab/blob/master/gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Based on https://github.com/Zackory/Keras-MNIST-GAN/blob/master/mnist_gan.py

In [0]:
!mkdir images
!mkdir models

In [5]:
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers.core import Reshape, Dense, Dropout, Flatten
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Convolution2D, UpSampling2D
from keras.layers.normalization import BatchNormalization
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import backend as K
from keras import initializers

Using TensorFlow backend.


In [0]:
K.set_image_dim_ordering('th')

# Deterministic output.
# Tired of seeing the same results every time? Remove the line below.
np.random.seed(1000)

# The results are a little better when the dimensionality of the random vector is only 10.
# The dimensionality has been left at 100 for consistency with other GAN implementations.
randomDim = 100

# Load MNIST data
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5)/127.5
X_train = X_train.reshape(X_train.shape[0], 784)

In [0]:
# Optimizer
adam = Adam(lr=0.0002, beta_1=0.5)

generator_input = Input(shape=(randomDim,))
x = Dense(256, kernel_initializer=initializers.RandomNormal(stddev=0.02))(generator_input)
x = LeakyReLU(0.2)(x)
x = Dense(512)(x)
x = LeakyReLU(0.2)(x)
x = Dense(1024)(x)
x = LeakyReLU(0.2)(x)
generator_output = Dense(784, activation='tanh')(x)
generator = Model(inputs=generator_input, outputs=generator_output)
generator.compile(loss='binary_crossentropy', optimizer=adam)

discriminator_input = Input(shape=(784,))
x = Dense(1024, kernel_initializer=initializers.RandomNormal(stddev=0.02))(discriminator_input)
x = LeakyReLU(0.2)(x)
x = Dropout(0.3)(x)
x = Dense(512)(x)
x = LeakyReLU(0.2)(x)
x = Dropout(0.3)(x)
x = Dense(256)(x)
x = LeakyReLU(0.2)(x)
x = Dropout(0.3)(x)
discriminator_output = Dense(1, activation='sigmoid')(x)
discriminator = Model(inputs=discriminator_input, outputs=discriminator_output)
discriminator.compile(loss='binary_crossentropy', optimizer=adam)

# Combined network
discriminator.trainable = False
ganInput = Input(shape=(randomDim,))
x = generator(ganInput)
ganOutput = discriminator(x)
gan = Model(inputs=ganInput, outputs=ganOutput)
gan.compile(loss='binary_crossentropy', optimizer=adam)

In [0]:

# Plot the loss from each batch
def plotLoss(epoch):
    plt.figure(figsize=(10, 8))
    plt.plot(dLosses, label='Discriminitive loss')
    plt.plot(gLosses, label='Generative loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('images/gan_loss_epoch_%d.png' % epoch)

# Create a wall of generated MNIST images
def plotGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)):
    noise = np.random.normal(0, 1, size=[examples, randomDim])
    generatedImages = generator.predict(noise)
    generatedImages = generatedImages.reshape(examples, 28, 28)

    plt.figure(figsize=figsize)
    for i in range(generatedImages.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generatedImages[i], interpolation='nearest', cmap='gray_r')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('images/gan_generated_image_epoch_%d.png' % epoch)

# Save the generator and discriminator networks (and weights) for later use
def saveModels(epoch):
    generator.save('models/gan_generator_epoch_%d.h5' % epoch)
    discriminator.save('models/gan_discriminator_epoch_%d.h5' % epoch)

In [0]:
def train(epochs=1, batchSize=128):
    batchCount = X_train.shape[0] / batchSize
    print ('Epochs:', epochs)
    print ('Batch size:', batchSize)
    print ('Batches per epoch:', batchCount)

    for e in range(1, epochs+1):
        print ('-'*15, 'Epoch %d' % e, '-'*15)
        for _ in tqdm(range(int(batchCount))):
            # Get a random set of input noise and images
            noise = np.random.normal(0, 1, size=[batchSize, randomDim])
            imageBatch = X_train[np.random.randint(0, X_train.shape[0], size=batchSize)]

            # Generate fake MNIST images
            generatedImages = generator.predict(noise)
            # print np.shape(imageBatch), np.shape(generatedImages)
            X = np.concatenate([imageBatch, generatedImages])

            # Labels for generated and real data
            yDis = np.zeros(2*batchSize)
            # One-sided label smoothing
            yDis[:batchSize] = 0.9

            # Train discriminator
            discriminator.trainable = True
            dloss = discriminator.train_on_batch(X, yDis)

            # Train generator
            noise = np.random.normal(0, 1, size=[batchSize, randomDim])
            yGen = np.ones(batchSize)
            discriminator.trainable = False
            gloss = gan.train_on_batch(noise, yGen)

        # Store loss of most recent batch from this epoch
        dLosses.append(dloss)
        gLosses.append(gloss)

        if e == 1 or e % 20 == 0:
            plotGeneratedImages(e)
            saveModels(e)

    # Plot losses from every epoch
    plotLoss(e)

In [0]:
dLosses = []
gLosses = []


In [0]:
train(200, 128)

  0%|          | 0/468 [00:00<?, ?it/s]

Epochs: 200
Batch size: 128
Batches per epoch: 468.75
--------------- Epoch 1 ---------------
Instructions for updating:
Use tf.cast instead.


100%|██████████| 468/468 [00:14<00:00, 32.48it/s]
  1%|          | 5/468 [00:00<00:11, 40.99it/s]

--------------- Epoch 2 ---------------


100%|██████████| 468/468 [00:11<00:00, 40.86it/s]
  1%|          | 5/468 [00:00<00:11, 40.80it/s]

--------------- Epoch 3 ---------------


100%|██████████| 468/468 [00:11<00:00, 41.34it/s]
  1%|          | 5/468 [00:00<00:11, 40.75it/s]

--------------- Epoch 4 ---------------


100%|██████████| 468/468 [00:11<00:00, 41.79it/s]
  1%|          | 5/468 [00:00<00:11, 40.62it/s]

--------------- Epoch 5 ---------------


100%|██████████| 468/468 [00:11<00:00, 41.46it/s]
  1%|          | 4/468 [00:00<00:11, 39.85it/s]

--------------- Epoch 6 ---------------


100%|██████████| 468/468 [00:11<00:00, 41.42it/s]
  1%|          | 4/468 [00:00<00:12, 37.81it/s]

--------------- Epoch 7 ---------------


100%|██████████| 468/468 [00:11<00:00, 40.84it/s]
  1%|          | 5/468 [00:00<00:11, 42.00it/s]

--------------- Epoch 8 ---------------


 38%|███▊      | 177/468 [00:04<00:07, 40.62it/s]