# Generative Adversarial Network
## Credit to: https://heartbeat.fritz.ai/introduction-to-generative-adversarial-networks-gans-35ef44f21193

In [1]:
import keras
import matplotlib.pyplot as plt
from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers.core import Dense, Dropout
from keras.layers.advanced_activations import LeakyReLU
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import initializers
from tqdm import tqdm

Using TensorFlow backend.


In [2]:
import numpy as np
np.random.seed(1000)
#Next we set the dimension of a random noise vector.
random_dim = 100

In [3]:
def load_minst_data():
  (x_train, y_train), (x_test, y_test) = mnist.load_data()
  x_train = x_train.reshape(-1, 784) / 255
  x_test = x_test.reshape(-1, 784) / 255
  return (x_train, y_train, x_test, y_test)

x_train, y_train, x_test, y_test = load_minst_data()

Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz


In [4]:
print(x_train.shape,x_test.shape)

(60000, 784) (10000, 784)


In [5]:
optimizer = Adam(lr=0.0002, beta_1=0.5)

Instructions for updating:
Colocations handled automatically by placer.


In [6]:
generator = Sequential()
generator.add(Dense(256, input_dim=random_dim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
generator.add(LeakyReLU(0.2))
generator.add(Dense(512))
generator.add(LeakyReLU(0.2))
generator.add(Dense(1024))
generator.add(LeakyReLU(0.2))
generator.add(Dense(784, activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=optimizer)

In [7]:
discriminator = Sequential()
discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)

Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.


In [8]:
discriminator.trainable = False
ganInput = Input(shape=(random_dim,))

In [9]:
x = generator(ganInput)
ganOutput = discriminator(x)
gan = Model(inputs=ganInput, outputs=ganOutput)
gan.compile(loss='binary_crossentropy', optimizer=optimizer)

In [10]:
20%10

0

In [11]:
def plot_generated_images(epoch, generator, examples=100, dim=(10, 10), figsize=(10, 10)):
  noise = np.random.normal(0, 1, size=[examples, random_dim])
  generated_images = generator.predict(noise)
  generated_images = generated_images.reshape(examples, 28, 28)
  plt.figure(figsize=figsize)
  for i in range(generated_images.shape[0]):
    plt.subplot(dim[0], dim[1], i+1)
    plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')
    plt.axis('off')
  plt.tight_layout()
  
  if epoch in [1,20,40,60,80,100]:
    plt.savefig('image_generated_%d.png' % epoch)

In [12]:
def train(epochs=1, batch_size=128):
  x_train, y_train, x_test, y_test = load_minst_data()
  batch_count = x_train.shape[0] / batch_size
  for e in range(1, epochs+1):
    print('-'*10, 'Epoch %d' % e, '-'*10)
    for _ in tqdm(range(int(batch_count))):
        
      noise = np.random.normal(0, 1, size=[batch_size, random_dim])
      image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]
      generated_images = generator.predict(noise)
      X = np.concatenate([image_batch, generated_images])
      y_dis = np.zeros(2*batch_size)
      y_dis[:batch_size] = 0.9
      discriminator.trainable = True
      discriminator.train_on_batch(X, y_dis)
        
      noise = np.random.normal(0, 1, size=[batch_size, random_dim])
      y_gen = np.ones(batch_size)
      discriminator.trainable = False
      gan.train_on_batch(noise, y_gen)
                            
    plot_generated_images(e, generator)

In [None]:
train(100, 128)

---------- Epoch 1 ----------


  0%|                                                                               | 0/468 [00:00<?, ?it/s]

Instructions for updating:
Use tf.cast instead.


100%|█████████████████████████████████████████████████████████████████████| 468/468 [00:15<00:00, 30.03it/s]


---------- Epoch 2 ----------


100%|█████████████████████████████████████████████████████████████████████| 468/468 [00:09<00:00, 47.08it/s]


---------- Epoch 3 ----------


100%|█████████████████████████████████████████████████████████████████████| 468/468 [00:09<00:00, 46.81it/s]


---------- Epoch 4 ----------


100%|█████████████████████████████████████████████████████████████████████| 468/468 [00:09<00:00, 46.84it/s]


---------- Epoch 5 ----------


100%|█████████████████████████████████████████████████████████████████████| 468/468 [00:10<00:00, 46.24it/s]


---------- Epoch 6 ----------


100%|█████████████████████████████████████████████████████████████████████| 468/468 [00:09<00:00, 47.01it/s]


---------- Epoch 7 ----------


 85%|██████████████████████████████████████████████████████████▊          | 399/468 [00:08<00:01, 47.75it/s]