In [1]:
import numpy as np
import matplotlib.pyplot as plt
#plt.switch_backend('agg')
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.datasets import mnist
from tensorflow.keras.optimizers import Adam

  from ._conv import register_converters as _register_converters


In [2]:
# The results are a little better when the dimensionality of the random vector is only 10.
# The dimensionality has been left at 100 for consistency with other GAN implementations.
randomDim = 100

# Load MNIST data
(X_train, y_train), (X_test, y_test) = mnist.load_data('/home/jenno/Desktop/data/mnist/mnist.npz')
X_train = (X_train.astype(np.float32) - 127.5)/127.5
X_train = X_train.reshape(60000, 784)

# Optimizer
adam = Adam(lr=0.0001)

In [3]:
generator = Sequential()
generator.add(Dense(256, input_dim=randomDim))
generator.add(LeakyReLU(0.2))
generator.add(Dense(512))
generator.add(LeakyReLU(0.2))
generator.add(Dense(1024))
generator.add(LeakyReLU(0.2))
generator.add(Dense(784, activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=adam)

discriminator = Sequential()
discriminator.add(Dense(1024, input_dim=784))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=adam)
discriminator.trainable = False
# Combined network
ganInput = Input(shape=(randomDim,))
x = generator(ganInput)
ganOutput = discriminator(x)
gan = Model(inputs=ganInput, outputs=ganOutput)
gan.compile(loss='binary_crossentropy', optimizer=adam)

dLosses = []
gLosses = []

In [4]:
# Plot the loss from each batch
def plotLoss(epoch):
    plt.figure(figsize=(10, 8))
    plt.plot(dLosses, label='Discriminitive loss')
    plt.plot(gLosses, label='Generative loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('reference/images/gan_loss_epoch_%d.png' % epoch)

# Create a wall of generated MNIST images
def plotGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)):
    noise = np.random.normal(0, 1, size=[examples, randomDim])
    generatedImages = generator.predict(noise)
    generatedImages = generatedImages.reshape(examples, 28, 28)

    plt.figure(figsize=figsize)
    for i in range(generatedImages.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generatedImages[i], interpolation='nearest', cmap='gray_r')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('reference/images/gan_generated_image_epoch_%d.png' % epoch)

# Save the generator and discriminator networks (and weights) for later use
def saveModels(epoch):
    generator.save('reference/models/gan_generator_epoch_%d.h5' % epoch)
    discriminator.save('reference/models/gan_discriminator_epoch_%d.h5' % epoch)

In [5]:
def train(epochs=1, batchSize=128):
    batchCount = X_train.shape[0] // batchSize
    for e in range(epochs):
        print('Epoch %d' % e)
        epoch_dLoss = []
        epoch_gLoss = []
        for i in range(batchCount):
            # Get a random set of input noise and images
            noise = np.random.normal(0, 1, size=[batchSize, randomDim])
            imageBatch = X_train[np.random.randint(0, X_train.shape[0], size=batchSize)]

            # Generate fake MNIST images
            generatedImages = generator.predict(noise)
            # print np.shape(imageBatch), np.shape(generatedImages)
            X = np.concatenate([imageBatch, generatedImages])

            # Labels for generated and real data
            yDis = np.zeros(2*batchSize)
            # One-sided label smoothing
            yDis[:batchSize] = 0.9

            # Train discriminator
            discriminator.trainable = True
            dloss = discriminator.train_on_batch(X, yDis)

            # Train generator
            noise = np.random.normal(0, 1, size=[batchSize, randomDim])
            yGen = np.ones(batchSize)
            discriminator.trainable = False
            gloss = gan.train_on_batch(noise, yGen)
            epoch_dLoss.append(dloss)
            epoch_gLoss.append(gloss)
            print(gloss)
        average_dloss = np.mean(epoch_dLoss)
        average_gloss = np.mean(epoch_gLoss)
        print('generator loss: ' + str(average_gloss))
        print('discriminator loss: ' + str(average_dloss))
        dLosses.append(average_dloss)
        gLosses.append(average_gloss)

        if e % 50 == 0 and e != 0:
            plotGeneratedImages(e)
            #saveModels(e)

    # Plot losses from every epoch
    plotLoss(e)
    print('training finished')

In [6]:
discriminator.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_4 (Dense)              (None, 1024)              803840    
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 1024)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1024)              0         
_________________________________________________________________
dense_5 (Dense)              (None, 512)               524800    
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_6 (Dense)              (None, 256)               131328    
__________

In [7]:
generator.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 256)               25856     
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 256)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 512)               131584    
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 1024)              525312    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 1024)              0         
_________________________________________________________________
dense_3 (Dense)              (None, 784)               803600    
Total para

In [6]:
train(100, 128)

Epoch 0
0.72320676
0.687928
0.6367873
0.6373629
0.60917264
0.5891676
0.58692336
0.57869315
0.56545186
0.5843009
0.5677943
0.5963724
0.58628345
0.62779343
0.65469533
0.7047914
0.7354696
0.79410154
0.8483566
0.8992345
0.9431421
1.0078995
1.0443285
1.0699167
1.0961827
1.1136625
1.1668453
1.1454762
1.1911232
1.2213593
1.1964948
1.2214944
1.2705653
1.3129435
1.3011277
1.417846
1.3889898
1.3414229
1.358686
1.3265721
1.3743792
1.3133378
1.2784407
1.4106313
1.4418548
1.5404128
1.586438
1.6454166
1.6384704
1.7532828
1.800461
1.8311273
1.8289931
1.7472836
1.6929898
1.7014805
1.7985859
1.7180855
1.786087
1.9060388
1.8484626
1.8900132
1.9892137
1.9140078
1.9377034
2.0177548
1.9559952
2.0953033
2.191064
2.2512534
2.4410396
2.4647865
2.6205301
2.6535451
2.8307905
2.970058
3.1039038
3.26375
3.4225955
3.5885348
3.7471168
3.8536074
3.7080767
3.650783
3.4495795
3.3362417
3.2622042
2.9105792
3.1295025
3.0554533
2.9903398
3.0031004
2.9657657
3.0118566
3.2216997
3.1792932
3.292562
3.3657894
3.3687973
3.339

KeyboardInterrupt: 