In [1]:
from numpy import expand_dims
from numpy import zeros
from numpy import ones
from numpy import vstack
from numpy.random import randn
from numpy.random import randint
from keras.datasets.mnist import load_data
from keras.optimizers import Adam
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Dropout
from matplotlib import pyplot

Using TensorFlow backend.


In [2]:
def define_discriminator(in_shape=(28,28,1)):
    model = Sequential()
    model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same', input_shape=in_shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    return model

In [3]:
def define_generator(latent_dim):
    model = Sequential()
    n_nodes = 128 * 7 * 7
    model.add(Dense(n_nodes, input_dim=latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Reshape((7, 7, 128)))
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2D(1, (7,7), activation='sigmoid', padding='same'))
    return model

In [4]:
def define_gan(g_model, d_model):
    # make weights in the discriminator not trainable
    d_model.trainable = False
    model = Sequential()
    model.add(g_model)
    model.add(d_model)
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model

In [5]:
def load_real_samples():
    (trainX, _), (_, _) = load_data()
    X = expand_dims(trainX, axis=-1)
    X = X.astype('float32')
    X = X / 255.0
    return X

In [6]:
def generate_real_samples(dataset, n_samples):
    ix = randint(0, dataset.shape[0], n_samples)
    X = dataset[ix]
    # generate 'real' class labels (1)
    y = ones((n_samples, 1))
    return X, y

In [7]:
def generate_latent_points(latent_dim, n_samples):
    x_input = randn(latent_dim * n_samples)
    x_input = x_input.reshape(n_samples, latent_dim)
    return x_input

In [8]:
def generate_fake_samples(g_model, latent_dim, n_samples):
    x_input = generate_latent_points(latent_dim, n_samples)
    X = g_model.predict(x_input)
    y = zeros((n_samples, 1))
    return X, y

In [9]:
def summarize_performance(epoch, g_model, d_model, dataset, latent_dim, n_samples=100):
    X_real, y_real = generate_real_samples(dataset, n_samples)
    _, acc_real = d_model.evaluate(X_real, y_real, verbose=0)
    x_fake, y_fake = generate_fake_samples(g_model, latent_dim, n_samples)
    _, acc_fake = d_model.evaluate(x_fake, y_fake, verbose=0)
    print('>Accuracy real: %.0f%%, fake: %.0f%%' % (acc_real*100, acc_fake*100))
    filename = 'generator_model_%03d.h5' % (epoch + 1)
    g_model.save(filename)

In [10]:
def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=10, n_batch=256):
    bat_per_epo = int(dataset.shape[0] / n_batch)
    half_batch = int(n_batch / 2)
    for i in range(n_epochs):
        for j in range(bat_per_epo):
            X_real, y_real = generate_real_samples(dataset, half_batch)
            X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
            X, y = vstack((X_real, X_fake)), vstack((y_real, y_fake))
            d_loss, _ = d_model.train_on_batch(X, y)
            X_gan = generate_latent_points(latent_dim, n_batch)
            y_gan = ones((n_batch, 1))
            g_loss = gan_model.train_on_batch(X_gan, y_gan)
            print('>%d, %d/%d, d=%.3f, g=%.3f' % (i+1, j+1, bat_per_epo, d_loss, g_loss))
        if (i+1) % 10 == 0:
            summarize_performance(i, g_model, d_model, dataset, latent_dim)

In [None]:
latent_dim = 100
d_model = define_discriminator()
g_model = define_generator(latent_dim)
gan_model = define_gan(g_model, d_model)
dataset = load_real_samples()
train(g_model, d_model, gan_model, dataset, latent_dim)

W1014 08:36:06.033453 140132469212928 deprecation.py:237] From /usr/local/lib/python3.5/dist-packages/keras/backend/tensorflow_backend.py:4139: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

W1014 08:36:06.057747 140132469212928 deprecation.py:506] From /usr/local/lib/python3.5/dist-packages/keras/backend/tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
W1014 08:36:06.204531 140132469212928 deprecation.py:237] From /usr/local/lib/python3.5/dist-packages/keras/backend/tensorflow_backend.py:3376: The name tf.log is deprecated. Please use tf.math.log instead.

  'Discrepancy between trainable weights and collected trainable'


>1, 1/234, d=0.689, g=0.741


  'Discrepancy between trainable weights and collected trainable'


>1, 2/234, d=0.680, g=0.763
>1, 3/234, d=0.667, g=0.785
>1, 4/234, d=0.660, g=0.813
>1, 5/234, d=0.657, g=0.837
>1, 6/234, d=0.642, g=0.851
>1, 7/234, d=0.638, g=0.860
>1, 8/234, d=0.633, g=0.875
>1, 9/234, d=0.629, g=0.870
>1, 10/234, d=0.635, g=0.852
>1, 11/234, d=0.640, g=0.828
>1, 12/234, d=0.645, g=0.799
>1, 13/234, d=0.649, g=0.770
>1, 14/234, d=0.657, g=0.745
>1, 15/234, d=0.657, g=0.727
>1, 16/234, d=0.650, g=0.716
>1, 17/234, d=0.649, g=0.710
>1, 18/234, d=0.639, g=0.704
>1, 19/234, d=0.632, g=0.702
>1, 20/234, d=0.626, g=0.699
>1, 21/234, d=0.620, g=0.699
>1, 22/234, d=0.609, g=0.698
>1, 23/234, d=0.602, g=0.698
>1, 24/234, d=0.598, g=0.698
>1, 25/234, d=0.585, g=0.699
>1, 26/234, d=0.584, g=0.699
>1, 27/234, d=0.561, g=0.700
>1, 28/234, d=0.549, g=0.701
>1, 29/234, d=0.551, g=0.701
>1, 30/234, d=0.534, g=0.702
>1, 31/234, d=0.517, g=0.703
>1, 32/234, d=0.514, g=0.704
>1, 33/234, d=0.500, g=0.704
>1, 34/234, d=0.492, g=0.706
>1, 35/234, d=0.481, g=0.707
>1, 36/234, d=0.474, g

>2, 47/234, d=0.761, g=0.613
>2, 48/234, d=0.752, g=0.634
>2, 49/234, d=0.749, g=0.645
>2, 50/234, d=0.743, g=0.659
>2, 51/234, d=0.743, g=0.667
>2, 52/234, d=0.755, g=0.673
>2, 53/234, d=0.743, g=0.704
>2, 54/234, d=0.747, g=0.703
>2, 55/234, d=0.747, g=0.713
>2, 56/234, d=0.739, g=0.705
>2, 57/234, d=0.746, g=0.731
>2, 58/234, d=0.744, g=0.705
>2, 59/234, d=0.747, g=0.699
>2, 60/234, d=0.736, g=0.679
>2, 61/234, d=0.727, g=0.681
>2, 62/234, d=0.736, g=0.683
>2, 63/234, d=0.716, g=0.678
>2, 64/234, d=0.728, g=0.693
>2, 65/234, d=0.711, g=0.694
>2, 66/234, d=0.711, g=0.692
>2, 67/234, d=0.719, g=0.694
>2, 68/234, d=0.709, g=0.693
>2, 69/234, d=0.694, g=0.710
>2, 70/234, d=0.698, g=0.700
>2, 71/234, d=0.701, g=0.709
>2, 72/234, d=0.690, g=0.717
>2, 73/234, d=0.670, g=0.721
>2, 74/234, d=0.680, g=0.721
>2, 75/234, d=0.681, g=0.741
>2, 76/234, d=0.688, g=0.755
>2, 77/234, d=0.676, g=0.756
>2, 78/234, d=0.662, g=0.734
>2, 79/234, d=0.672, g=0.742
>2, 80/234, d=0.661, g=0.757
>2, 81/234, d=

>3, 92/234, d=0.685, g=0.754
>3, 93/234, d=0.694, g=0.738
>3, 94/234, d=0.688, g=0.733
>3, 95/234, d=0.688, g=0.724
>3, 96/234, d=0.691, g=0.711
>3, 97/234, d=0.689, g=0.695
>3, 98/234, d=0.686, g=0.715
>3, 99/234, d=0.684, g=0.709
>3, 100/234, d=0.691, g=0.711
>3, 101/234, d=0.684, g=0.716
>3, 102/234, d=0.699, g=0.707
>3, 103/234, d=0.700, g=0.705
>3, 104/234, d=0.681, g=0.701
>3, 105/234, d=0.693, g=0.690
>3, 106/234, d=0.683, g=0.709
>3, 107/234, d=0.689, g=0.714
>3, 108/234, d=0.690, g=0.726
>3, 109/234, d=0.689, g=0.726
>3, 110/234, d=0.685, g=0.731
>3, 111/234, d=0.691, g=0.743
>3, 112/234, d=0.683, g=0.729
>3, 113/234, d=0.691, g=0.734
>3, 114/234, d=0.683, g=0.732
>3, 115/234, d=0.695, g=0.725
>3, 116/234, d=0.683, g=0.730
>3, 117/234, d=0.684, g=0.721
>3, 118/234, d=0.685, g=0.710
>3, 119/234, d=0.696, g=0.694
>3, 120/234, d=0.675, g=0.697
>3, 121/234, d=0.685, g=0.693
>3, 122/234, d=0.674, g=0.685
>3, 123/234, d=0.685, g=0.693
>3, 124/234, d=0.686, g=0.684
>3, 125/234, d=0.6