In [1]:
import tensorflow as tf
from tensorflow import keras


  from ._conv import register_converters as _register_converters


In [2]:
# example of loading the mnist dataset
from keras.datasets.mnist import load_data
# load the 28x28 images into memory
(trainX, trainy), (testX, testy) = load_data()
# summarize the shape of the dataset
print('Train', trainX.shape, trainy.shape)
print('Test', testX.shape, testy.shape)

Using TensorFlow backend.


Train (60000, 28, 28) (60000,)
Test (10000, 28, 28) (10000,)


In [4]:
import os
import numpy as np
from numpy import expand_dims
from numpy import ones
from numpy import zeros
from numpy.random import rand
from numpy.random import randint
from keras.models import Sequential
from keras.layers import Dense, Activation, Conv2D, Flatten, LeakyReLU, Dropout
from keras.optimizers import Adam
from keras.utils.vis_utils import plot_model
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.utils.vis_utils import plot_model
from numpy import zeros
from numpy.random import randn
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from matplotlib import pyplot

# load and prepare mnist training images
def load_real_samples():
    # load mnist dataset
    (trainX, _), (_, _) = load_data()
    # expand to 3d, e.g. add channels dimension
    X = expand_dims(trainX, axis=3)
    # convert from unsigned ints to floats
    X = X.astype(float)
    # scale from [0,255] to [0,1]
    X = X / 255.0
    return X

def real_sampler(data, num_samples):
    #random samples
    ran = randint(num_samples, data.shape[0])
    X = data[ran-num_samples:ran]
    y = zeros((num_samples, 1))
    return(X,y)

def fake_sampler(num_samples):
    # generate uniform random numbers in [0,1]
    X = rand(28 * 28 * num_samples)
    # reshape into a batch of grayscale images
    X = X.reshape((num_samples, 28, 28, 1))
    # generate 'fake' class labels (0)
    y = zeros((num_samples, 1))
    return X, y

#create the descriminator model

img_shape = (28,28,1)
def descriminator():
    model= Sequential()
    model.add(Conv2D(64, (3,3), strides=(2,2), padding= 'same', input_shape=img_shape))
    model.add(LeakyReLU(alpha=.2))
    model.add(Dropout(.2))
    model.add(Conv2D(32, (3,3), strides=(2,2), padding= 'same'))
    model.add(LeakyReLU(alpha=.2))
    model.add(Dropout(.2))
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
    model.compile(loss='binary_crossentropy', optimizer=Adam(lr=.0002), metrics=['accuracy'])
    return model 

#train the descrinator with fake samples from a random distribution not a generator
def train_discriminator(model, data, num_iter, num_batch):
    half_batch = int(num_batch/2)
    for i in range(num_iter):
        # get real samples
        X_real, y_real = real_sampler(data, half_batch)
        #train discriminator on real samples
        _, real_acc = model.train_on_batch(X_real, y_real)
        # get fake
        X_fake, y_fake = fake_sampler(half_batch)
        # train discrimanor on fake
        _, fake_acc = model.train_on_batch(X_fake, y_fake)
        #perfomrance
        print('>%d real=%.0f%% fake=%.0f%%' % (i+1, real_acc*100, fake_acc*100))
        
        
# define the standalone generator model
def generator(latent_dim):
    model = Sequential()
    # foundation for 7x7 image
    n_nodes = 256 * 7 * 7
    model.add(Dense(n_nodes, input_dim=latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Reshape((7, 7, 256)))
    # upsample to 14x14
    model.add(Conv2DTranspose(256, (4,4), strides=(2,2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    # upsample to 28x28
    model.add(Conv2DTranspose(256, (4,4), strides=(2,2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2D(1, (7,7), activation='sigmoid', padding='same'))
    return model

# Generatr some points in the latent space to seed the fake samples

def gen_latent(latent_dim, num_samples):
    # generate points in the latent space
    x_input = randn(latent_dim * num_samples)
    # reshape into a batch of inputs for the network
    x_input = x_input.reshape(num_samples, latent_dim)
    return x_input


## generate fake samples using the generator model
def generate_fake_samples(gen_model, latent_dim, num_samples):
    # generate points in latent space
    x_input = gen_latent(latent_dim, num_samples)
    # predict outputs
    X = gen_model.predict(x_input)
    # create 'fake' class labels (0)
    y = zeros((num_samples, 1))
    return X, y


# define the combined generator and discriminator model, for updating the generator
def gan(g_model, d_model):
    # make weights in the discriminator not trainable
    d_model.trainable = False
    # connect them
    model = Sequential()
    # add generator
    model.add(g_model)
    # add the discriminator
    model.add(d_model)
    # compile model
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model

        
# train the GAN
def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=100, n_batch=128):
    bat_per_epo = int(dataset.shape[0] / n_batch)
    half_batch = int(n_batch / 2)
    #training loop of epochs
    for i in range(n_epochs):
        # enumerate batches over the training set
        for j in range(bat_per_epo):
            # get randomly selected 'real' samples
            X_real, y_real = real_sampler(dataset, half_batch)
            # fake samples using the generator model
            X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
            # merge samples to create training set for the discriminator
            X, y = np.vstack((X_real, X_fake)), np.vstack((y_real, y_fake))
            # train the  discriminator model weights on batch
            d_loss, _ = d_model.train_on_batch(X, y)
            # randomnly choose points in latent space as input for the generator
            X_gan = gen_latent(latent_dim, n_batch)
            # create false positives for the fake samples to fool the generator
            y_gan = ones((n_batch, 1))
            # train the gan model with latent inputs
            g_loss = gan_model.train_on_batch(X_gan, y_gan)
            # summarize loss on this batch
            print('>%d, %d/%d, dis=%.3f, gen=%.3f' % (i+1, j+1, bat_per_epo, d_loss, g_loss))
            # evaluate the model performance, sometimes
            if (i+1) % 10 == 0:
                model_summary(i, g_model, d_model, dataset, latent_dim)
                
def model_summary(epoc, g_model, d_model, dataset, latent_dim, n_samples=100):
    #prepare real
    X_real, y_real = real_sampler(dataset, n_samples)
    #evaluate discriminator on real samples
    _, acc_real = d_model.evaluate(X_real, y_real) 
    #prepare fake examples
    X_fake, y_fake = fake_sampler(g_model, latent_dim, num_samples)
    #evaluate discriminator on fake samples
    _, acc_fake = d_model.evaluate(X_fake, y_fake, verbose=0)
    # summarize discriminator performance
    print('>Accuracy real: %.0f%%, fake: %.0f%%' % (acc_real*100, acc_fake*100))
    # save plot
    save_plot(x_fake, epoch)
    # save the generator model tile file
    cwd = os.getcwd()
    file = 'saved/generator_model_%03d.h5' % (epoch + 1)
    filename = os.path.join(cwd,filename)
    g_model.save(filename)

def save_plot(examples, epoch, n=10):
    # plot images
    for i in range(n * n):
        # define subplot
        pyplot.subplot(n, n, 1 + i)
        # turn off axis
        pyplot.axis('off')
        # plot raw pixel data
        pyplot.imshow(examples[i, :, :, 0], cmap='gray_r')
        # save plot to file
        cwd = os.getcwd()
        filename = 'saved/generated_plot_e%03d.png' % (epoch+1)
        file = os.path.join(cwd,filename)
        pyplot.savefig(file)
        pyplot.close()

# size of the latent space
latent_dim = 100
# create the discriminator
d_model = descriminator()
# create the generator
g_model = generator(latent_dim)
# create the gan
gan_model = gan(g_model, d_model)
# load image data
dataset = load_real_samples()
# train model
train(g_model, d_model, gan_model, dataset, latent_dim)

  'Discrepancy between trainable weights and collected trainable'


>1, 1/468, d=0.673, g=0.754
>1, 2/468, d=0.665, g=0.759
>1, 3/468, d=0.657, g=0.771
>1, 4/468, d=0.650, g=0.781
>1, 5/468, d=0.646, g=0.781
>1, 6/468, d=0.645, g=0.774
>1, 7/468, d=0.648, g=0.761
>1, 8/468, d=0.655, g=0.743
>1, 9/468, d=0.651, g=0.728
>1, 10/468, d=0.654, g=0.714
>1, 11/468, d=0.660, g=0.710
>1, 12/468, d=0.654, g=0.708
>1, 13/468, d=0.651, g=0.708
>1, 14/468, d=0.654, g=0.709
>1, 15/468, d=0.646, g=0.710
>1, 16/468, d=0.639, g=0.711
>1, 17/468, d=0.634, g=0.712
>1, 18/468, d=0.631, g=0.713
>1, 19/468, d=0.638, g=0.715
>1, 20/468, d=0.620, g=0.717
>1, 21/468, d=0.617, g=0.718
>1, 22/468, d=0.601, g=0.720
>1, 23/468, d=0.600, g=0.723
>1, 24/468, d=0.604, g=0.724
>1, 25/468, d=0.598, g=0.727
>1, 26/468, d=0.588, g=0.730
>1, 27/468, d=0.580, g=0.735
>1, 28/468, d=0.572, g=0.740
>1, 29/468, d=0.575, g=0.744
>1, 30/468, d=0.559, g=0.750
>1, 31/468, d=0.552, g=0.756
>1, 32/468, d=0.543, g=0.762
>1, 33/468, d=0.534, g=0.771
>1, 34/468, d=0.521, g=0.775
>1, 35/468, d=0.512, g=

>1, 278/468, d=0.001, g=6.420
>1, 279/468, d=0.001, g=6.416
>1, 280/468, d=0.001, g=6.464
>1, 281/468, d=0.001, g=6.444
>1, 282/468, d=0.001, g=6.426
>1, 283/468, d=0.001, g=6.454
>1, 284/468, d=0.001, g=6.474
>1, 285/468, d=0.001, g=6.472
>1, 286/468, d=0.001, g=6.468
>1, 287/468, d=0.001, g=6.510
>1, 288/468, d=0.001, g=6.507
>1, 289/468, d=0.001, g=6.509
>1, 290/468, d=0.001, g=6.520
>1, 291/468, d=0.001, g=6.521
>1, 292/468, d=0.001, g=6.529
>1, 293/468, d=0.001, g=6.541
>1, 294/468, d=0.001, g=6.545
>1, 295/468, d=0.001, g=6.556
>1, 296/468, d=0.001, g=6.557
>1, 297/468, d=0.001, g=6.573
>1, 298/468, d=0.001, g=6.594
>1, 299/468, d=0.001, g=6.578
>1, 300/468, d=0.001, g=6.594
>1, 301/468, d=0.001, g=6.615
>1, 302/468, d=0.001, g=6.605
>1, 303/468, d=0.001, g=6.617
>1, 304/468, d=0.001, g=6.630
>1, 305/468, d=0.001, g=6.677
>1, 306/468, d=0.001, g=6.621
>1, 307/468, d=0.001, g=6.641
>1, 308/468, d=0.001, g=6.683
>1, 309/468, d=0.001, g=6.663
>1, 310/468, d=0.001, g=6.640
>1, 311/46

KeyboardInterrupt: 