In [2]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from keras.layers import Input, Dense, Dropout, LeakyReLU
from keras.models import Model, Sequential
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import initializers

In [3]:
np.random.seed(10)
random_din=100

In [4]:
def load_mnist_data():
  (x_train, y_train),(x_test,y_test)=mnist.load_data()
  x_train=(x_train.astype(np.float32)-127.5)/127.5
  x_train=x_train.reshape(60000,784)
  return (x_train,y_train,x_test,y_test)

In [6]:
def get_optimizer():
  return Adam(learning_rate=0.0002,beta_1=0.5)

In [9]:
def get_generator(optimizer):
  generator=Sequential()
  generator.add(Dense(256,input_dim=random_din,\
                      kernel_initializer=initializers.RandomNormal(stddev=0.02)))
  generator.add(LeakyReLU(0.2))

  generator.add(Dense(512))
  generator.add(LeakyReLU(0.2))

  generator.add(Dense(1024))
  generator.add(LeakyReLU(0.2))

  generator.add(Dense(784,activation='tanh'))
  generator.compile(loss='binary_crossentropy',optimizer=optimizer)
  return generator

In [12]:
def get_discriminator(optimizer):
  discriminator=Sequential()
  discriminator.add(Dense(1024,input_dim=784,\
                      kernel_initializer=initializers.RandomNormal(stddev=0.02)))
  discriminator.add(LeakyReLU(0.2))
  discriminator.add(Dropout(0.3))

  discriminator.add(Dense(512))
  discriminator.add(LeakyReLU(0.2))
  discriminator.add(Dropout(0.3))

  discriminator.add(Dense(256))
  discriminator.add(LeakyReLU(0.2))
  discriminator.add(Dropout(0.3))

  discriminator.add(Dense(1,activation='sigmoid'))
  discriminator.compile(loss='binary_crossentropy',optimizer=optimizer)
  return discriminator

In [13]:
def get_gan_network(discriminator,random_din,generator,optimizer):
  discriminator.trainable=False
  gan_input=Input(shape=(random_din,))
  x=generator(gan_input)
  gan_output=discriminator(x)
  gan=Model(inputs=gan_input,outputs=gan_output)
  gan.compile(loss='binary_crossentropy',optimizer=optimizer)
  return gan

In [15]:
def plot_generated_images(epoch,generator,examples=100,dim=(10,10),figsize=(10,10)):
  noise=np.random.normal(0,1,size=[examples,random_din])
  generated_images=generator.predict(noise)
  generated_images=generated_images.reshape(100,28,28)
  plt.figure(figsize=figsize)

  for i in range(generated_images.shape[0]):
    plt.subplot(dim[0],dim[1],i+1)
    plt.imshow(generated_images[i],interpolation='nearest',\
               cmap='gray_r')
    plt.axis('off')
  plt.tight_layout()
  plt.savefig('gan_generated_image_epoch_%d.png'%epoch)


In [None]:
def train(epochs=1, batch_size=128):
  x_train, y_train, x_test, y_test=load_mnist_data()
  batch_count=x_train.shape[0] / batch_size
  generator=get_generator(get_optimizer())
  discriminator=get_discriminator(get_optimizer())