In [2]:
# example of training a stable gan for generating a handwritten digit
from os import makedirs
from numpy import expand_dims
from numpy import zeros
from numpy import ones
from numpy.random import randn
from numpy.random import randint
from keras.datasets.mnist import load_data
from keras.optimizers import Adam
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import BatchNormalization
from keras.initializers import RandomNormal
from matplotlib import pyplot

# define the standalone discriminator model
def define_discriminator(in_shape=(28,28,1)):
    # weight initialization
    init = RandomNormal(stddev=0.02)
    # define model
    model = Sequential()
    # downsample to 14x14
    model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init,
    input_shape=in_shape))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    # downsample to 7x7
    model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    # classifier
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
    # compile model
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    return model

# define the standalone generator model
def define_generator(latent_dim):
    # weight initialization
    init = RandomNormal(stddev=0.02)
    # define model
    model = Sequential()
    # foundation for 7x7 image
    n_nodes = 128 * 7 * 7
    model.add(Dense(n_nodes, kernel_initializer=init, input_dim=latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Reshape((7, 7, 128)))
    # upsample to 14x14
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same',
                    kernel_initializer=init))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    # upsample to 28x28
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same',
                    kernel_initializer=init))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    # output 28x28x1
    model.add(Conv2D(1, (7,7), activation='tanh', padding='same', kernel_initializer=init))
    return model

# define the combined generator and discriminator model, for updating the generator
def define_gan(generator, discriminator):
    # make weights in the discriminator not trainable
    discriminator.trainable = False
    # connect them
    model = Sequential()
    # add generator
    model.add(generator)
    # add the discriminator
    model.add(discriminator)
    # compile model
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model

# load mnist images
def load_real_samples():
    # load dataset
    (trainX, trainy), (_, _) = load_data()
    # expand to 3d, e.g. add channels
    X = expand_dims(trainX, axis=-1)
    # select all of the examples for a given class
    selected_ix = trainy == 8
    X = X[selected_ix]
    # convert from ints to floats
    X = X.astype('float32')
    # scale from [0,255] to [-1,1]
    X = (X - 127.5) / 127.5
    return X

# select real samples
def generate_real_samples(dataset, n_samples):
    # choose random instances
    ix = randint(0, dataset.shape[0], n_samples)
    # select images
    X = dataset[ix]
    # generate class labels
    y = ones((n_samples, 1))
    return X, y

# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):
    # generate points in the latent space
    x_input = randn(latent_dim * n_samples)
    # reshape into a batch of inputs for the network
    x_input = x_input.reshape(n_samples, latent_dim)
    return x_input

# use the generator to generate n fake examples, with class labels
def generate_fake_samples(generator, latent_dim, n_samples):
    # generate points in latent space
    x_input = generate_latent_points(latent_dim, n_samples)
    # predict outputs
    X = generator.predict(x_input)
    # create class labels
    y = zeros((n_samples, 1))
    return X, y

# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, latent_dim, n_samples=100):
    # prepare fake examples
    X, _ = generate_fake_samples(g_model, latent_dim, n_samples)
    # scale from [-1,1] to [0,1]
    X = (X + 1) / 2.0
    # plot images
    for i in range(10 * 10):
        # define subplot
        pyplot.subplot(10, 10, 1 + i)
        # turn off axis
        pyplot.axis('off')
        # plot raw pixel data
        pyplot.imshow(X[i, :, :, 0], cmap='gray_r')
    # save plot to file
    pyplot.savefig('results_baseline/generated_plot_%03d.png' % (step+1))
    pyplot.close()
    # save the generator model
    g_model.save('results_baseline/model_%03d.h5' % (step+1))

# create a line plot of loss for the gan and save to file
def plot_history(d1_hist, d2_hist, g_hist, a1_hist, a2_hist):
    # plot loss
    pyplot.subplot(2, 1, 1)
    pyplot.plot(d1_hist, label='d-real')
    pyplot.plot(d2_hist, label='d-fake')
    pyplot.plot(g_hist, label='gen')
    pyplot.legend()
    # plot discriminator accuracy
    pyplot.subplot(2, 1, 2)
    pyplot.plot(a1_hist, label='acc-real')
    pyplot.plot(a2_hist, label='acc-fake')
    pyplot.legend()
    # save plot to file
    pyplot.savefig('results_baseline/plot_line_plot_loss.png')
    pyplot.close()

# train the generator and discriminator
def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=10, n_batch=128): # 10 128
    # calculate the number of batches per epoch
    bat_per_epo = int(dataset.shape[0] / n_batch)
    # calculate the total iterations based on batch and epoch
    n_steps = bat_per_epo * n_epochs
    # calculate the number of samples in half a batch
    half_batch = int(n_batch / 2)
    # prepare lists for storing stats each iteration
    d1_hist, d2_hist, g_hist, a1_hist, a2_hist = list(), list(), list(), list(), list()
    # manually enumerate epochs
    for i in range(n_steps):
        # get randomly selected ✬real✬ samples
        X_real, y_real = generate_real_samples(dataset, half_batch)
        # update discriminator model weights
        d_loss1, d_acc1 = d_model.train_on_batch(X_real, y_real)
        # generate ✬fake✬ examples
        X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
        # update discriminator model weights
        d_loss2, d_acc2 = d_model.train_on_batch(X_fake, y_fake)
        # prepare points in latent space as input for the generator
        X_gan = generate_latent_points(latent_dim, n_batch)
        # create inverted labels for the fake samples
        y_gan = ones((n_batch, 1))
        # update the generator via the discriminator✬s error
        for j in range(2):
            g_loss = gan_model.train_on_batch(X_gan, y_gan)
        # summarize loss on this batch
        print('>%d, d1=%.3f, d2=%.3f g=%.3f, a1=%d, a2=%d' %
                (i+1, d_loss1, d_loss2, g_loss, int(100*d_acc1), int(100*d_acc2)))
        # record history
        d1_hist.append(d_loss1)
        d2_hist.append(d_loss2)
        g_hist.append(g_loss)
        a1_hist.append(d_acc1)
        a2_hist.append(d_acc2)
        # evaluate the model performance every ✬epoch✬
        if (i+1) % bat_per_epo == 0:
            summarize_performance(i, g_model, latent_dim)
    plot_history(d1_hist, d2_hist, g_hist, a1_hist, a2_hist)
    
# make folder for results
makedirs('results_baseline', exist_ok=True)
# size of the latent space
latent_dim = 50
# create the discriminator
discriminator = define_discriminator()
# create the generator
generator = define_generator(latent_dim)
# create the gan
gan_model = define_gan(generator, discriminator)
# load image data
dataset = load_real_samples()
print(dataset.shape)
# train model
train(generator, discriminator, gan_model, dataset, latent_dim)


(5851, 28, 28, 1)


  'Discrepancy between trainable weights and collected trainable'


>1, d1=0.749, d2=0.697 g=0.171, a1=56, a2=53


  'Discrepancy between trainable weights and collected trainable'


>2, d1=0.182, d2=1.142 g=0.074, a1=98, a2=0
>3, d1=0.091, d2=1.419 g=0.062, a1=100, a2=0
>4, d1=0.078, d2=1.337 g=0.062, a1=100, a2=0
>5, d1=0.077, d2=1.259 g=0.077, a1=100, a2=0
>6, d1=0.055, d2=1.111 g=0.083, a1=100, a2=1
>7, d1=0.065, d2=1.097 g=0.072, a1=100, a2=4
>8, d1=0.068, d2=1.182 g=0.060, a1=100, a2=0
>9, d1=0.107, d2=1.266 g=0.059, a1=98, a2=0
>10, d1=0.112, d2=1.220 g=0.059, a1=98, a2=0
>11, d1=0.109, d2=1.187 g=0.063, a1=100, a2=0
>12, d1=0.135, d2=1.074 g=0.064, a1=96, a2=3
>13, d1=0.141, d2=0.849 g=0.061, a1=96, a2=23
>14, d1=0.136, d2=0.762 g=0.061, a1=98, a2=31
>15, d1=0.160, d2=0.704 g=0.060, a1=96, a2=50
>16, d1=0.097, d2=0.639 g=0.055, a1=100, a2=64
>17, d1=0.123, d2=0.628 g=0.058, a1=100, a2=67
>18, d1=0.190, d2=0.576 g=0.062, a1=95, a2=78
>19, d1=0.134, d2=0.509 g=0.069, a1=98, a2=93
>20, d1=0.107, d2=0.510 g=0.081, a1=100, a2=92
>21, d1=0.167, d2=0.554 g=0.093, a1=96, a2=84
>22, d1=0.104, d2=0.439 g=0.117, a1=98, a2=96
>23, d1=0.145, d2=0.407 g=0.170, a1=98, a2=

>173, d1=0.004, d2=0.003 g=0.001, a1=100, a2=100
>174, d1=0.004, d2=0.002 g=0.001, a1=100, a2=100
>175, d1=0.007, d2=0.004 g=0.001, a1=100, a2=100
>176, d1=0.003, d2=0.005 g=0.002, a1=100, a2=100
>177, d1=0.005, d2=0.002 g=0.002, a1=100, a2=100
>178, d1=0.007, d2=0.005 g=0.002, a1=100, a2=100
>179, d1=0.005, d2=0.008 g=0.002, a1=100, a2=100
>180, d1=0.004, d2=0.004 g=0.002, a1=100, a2=100
>181, d1=0.005, d2=0.004 g=0.002, a1=100, a2=100
>182, d1=0.007, d2=0.005 g=0.002, a1=100, a2=100
>183, d1=0.006, d2=0.007 g=0.002, a1=100, a2=100
>184, d1=0.005, d2=0.007 g=0.003, a1=100, a2=100
>185, d1=0.005, d2=0.004 g=0.003, a1=100, a2=100
>186, d1=0.014, d2=0.007 g=0.004, a1=100, a2=100
>187, d1=0.008, d2=0.004 g=0.003, a1=100, a2=100
>188, d1=0.007, d2=0.005 g=0.003, a1=100, a2=100
>189, d1=0.015, d2=0.005 g=0.003, a1=100, a2=100
>190, d1=0.006, d2=0.014 g=0.005, a1=100, a2=100
>191, d1=0.007, d2=0.003 g=0.007, a1=100, a2=100
>192, d1=0.009, d2=0.009 g=0.008, a1=100, a2=100
>193, d1=0.005, d2=0

>345, d1=0.872, d2=0.788 g=0.761, a1=29, a2=37
>346, d1=0.869, d2=0.880 g=0.765, a1=26, a2=28
>347, d1=0.916, d2=0.858 g=0.727, a1=23, a2=17
>348, d1=0.905, d2=0.859 g=0.726, a1=20, a2=15
>349, d1=0.835, d2=0.849 g=0.754, a1=26, a2=18
>350, d1=0.874, d2=0.816 g=0.739, a1=18, a2=26
>351, d1=0.927, d2=0.830 g=0.748, a1=14, a2=25
>352, d1=0.862, d2=0.848 g=0.718, a1=25, a2=18
>353, d1=0.954, d2=0.855 g=0.725, a1=12, a2=29
>354, d1=0.854, d2=0.878 g=0.716, a1=25, a2=26
>355, d1=0.807, d2=0.823 g=0.753, a1=26, a2=21
>356, d1=0.841, d2=0.815 g=0.758, a1=25, a2=29
>357, d1=0.861, d2=0.842 g=0.785, a1=21, a2=18
>358, d1=0.880, d2=0.852 g=0.731, a1=17, a2=20
>359, d1=0.854, d2=0.841 g=0.737, a1=29, a2=25
>360, d1=0.844, d2=0.801 g=0.711, a1=26, a2=25
>361, d1=0.873, d2=0.838 g=0.708, a1=21, a2=14
>362, d1=0.892, d2=0.831 g=0.707, a1=17, a2=25
>363, d1=0.814, d2=0.810 g=0.693, a1=26, a2=18
>364, d1=0.824, d2=0.869 g=0.697, a1=23, a2=15
>365, d1=0.841, d2=0.872 g=0.716, a1=15, a2=14
>366, d1=0.82