In [12]:
%matplotlib inline
#shashikanthgk
import matplotlib.pyplot as plt
import numpy as np

from keras.datasets import mnist
from keras.layers import Activation, BatchNormalization, Dense, Dropout, Flatten, Reshape
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.models import Sequential
from keras.optimizers import Adam

In [None]:
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)
z_dim = 100

In [None]:
def build_generator(z_dim):

    model = Sequential()

    model.add(Dense(256 * 7 * 7, input_dim=z_dim))
    model.add(Reshape((7, 7, 256)))
    model.add(Conv2DTranspose(128, kernel_size=3, strides=2, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.01))

    #  from 14x14x128 to 14x14x64 tensor
    model.add(Conv2DTranspose(64, kernel_size=3, strides=1, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.01))

    # from 14x14x64 to 28x28x1 tensor
    model.add(Conv2DTranspose(1, kernel_size=3, strides=2, padding='same'))
    model.add(Activation('tanh'))

    return model

In [None]:
def build_discriminator(img_shape):
    model = Sequential()
    # from 28x28x1 into 14x14x32 tensor
    model.add(
        Conv2D(32,
               kernel_size=3,
               strides=2,
               input_shape=img_shape,
               padding='same'))

    model.add(LeakyReLU(alpha=0.01))
    #  from 14x14x32 into 7x7x64 tensor
    model.add(
        Conv2D(64,
               kernel_size=3,
               strides=2,
               input_shape=img_shape,
               padding='same'))
    model.add(BatchNormalization())

    model.add(LeakyReLU(alpha=0.01))
    # from 7x7x64 tensor into 3x3x128 tensor
    model.add(
        Conv2D(128,
               kernel_size=3,
               strides=2,
               input_shape=img_shape,
               padding='same'))

    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.01))
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
    return model

In [None]:
def build_gan(generator, discriminator):

    model = Sequential()
    model.add(generator)
    model.add(discriminator)
    return model

In [None]:
discriminator = build_discriminator(img_shape)
discriminator.compile(loss='binary_crossentropy',
                      optimizer=Adam(),
                      metrics=['accuracy'])

generator = build_generator(z_dim)
discriminator.trainable = False
gan = build_gan(generator, discriminator)
gan.compile(loss='binary_crossentropy', optimizer=Adam())

In [None]:
losses = []
accuracies = []
iteration_checkpoints = []


def train(iterations, batch_size, sample_interval):
    (X_train, _), (_, _) = mnist.load_data()
    X_train = X_train / 127.5 - 1.0
    X_train = np.expand_dims(X_train, axis=3)

    # Labels for real images: all ones
    real = np.ones((batch_size, 1))

    # Labels for fake images: all zeros
    fake = np.zeros((batch_size, 1))
    for iteration in range(iterations):

        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs = X_train[idx]

        z = np.random.normal(0, 1, (batch_size, 100))
        gen_imgs = generator.predict(z)

        # Train Discriminator
        d_loss_real = discriminator.train_on_batch(imgs, real)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        d_loss, accuracy = 0.5 * np.add(d_loss_real, d_loss_fake)

        z = np.random.normal(0, 1, (batch_size, 100))
        gen_imgs = generator.predict(z)

        # Train Generator
        g_loss = gan.train_on_batch(z, real)

        if (iteration + 1) % sample_interval == 0:

            losses.append((d_loss, g_loss))
            accuracies.append(100.0 * accuracy)
            iteration_checkpoints.append(iteration + 1)

            print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" %
                  (iteration + 1, d_loss, 100.0 * accuracy, g_loss))

        if (iteration + 1) % 500 == 0:
            sample_images(generator)
            iteration_checkpoints.append(iteration + 1)


In [21]:
def sample_images(generator, image_grid_rows=4, image_grid_columns=4):
    z = np.random.normal(0, 1, (image_grid_rows * image_grid_columns, z_dim))
    image_size = 28 
    gen_imgs = generator.predict(z)

    gen_imgs = 0.5 * gen_imgs + 0.5

    fig,axes = plt.subplots(image_grid_rows,
                            image_grid_columns,
                            figsize=(8, 8))
    for i,ax in enumerate(axes.flatten()):
        img = gen_imgs[i,:]
        img = (img-img.min())/(img.max()-img.min())
        ax.imshow(img.reshape(image_size,image_size),cmap = 'gray')
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
    plt.subplots_adjust(wspace=0,hspace=0)
    plt.show()

In [None]:
iterations = 20000
batch_size = 128
sample_interval = 5
train(iterations, batch_size, sample_interval)

In [2]:
#save the model
from keras.models import model_from_json
from keras.models import load_model

In [None]:
model_dcgan_mnist = gan.to_json()
with open("model_dcgn_mnist_num.json", "w") as json_file:
    json_file.write(model_dcgan_mnist)
gan.save_weights("model_dcgn_mnist_num.json.h5")


In [None]:
model_generator_mnist = generator.to_json()
with open("model_generator_mnist_num.json", "w") as json_file:
    json_file.write(model_generator_mnist)
generator.save_weights("model_generator_mnist_num.json.h5")

In [None]:
model_discriminator_mnist = discriminator.to_json()
with open("model_discriminator_mnist_num.json", "w") as json_file:
    json_file.write(model_discriminator_mnist)
discriminator.save_weights("model_discriminator_mnist_num.json.h5")

In [9]:
#code for loading the model
json_file = open('model_generator_mnist_num.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
loaded_model = model_from_json(loaded_model_json)
loaded_model.load_weights("model_generator_mnist_num.json.h5")
loaded_model.save('model_dcgn_generator_num.hdf5')
loaded_model=load_model('model_dcgn_generator_num.hdf5')

In [None]:
z_dim = 100
sample_images(loaded_model,10,10)