# Try on GAN

This notebook is my first try on GAN. We wil use MNIST data to train our GAN. At the end we will use the result to generate the random digit images. 


Assume that you already read: 
* [Get Start Image Classification](https://www.kaggle.com/uysimty/get-start-image-classification)
* [Understand Image Classification](https://www.kaggle.com/uysimty/keras-cnn-dog-or-cat-classification)

These Kernels let's you understand CNN.

    

# GAN over simplify 

Imagine the criminal and police works. The criminal have their method to create the fake money. And the police have their method to detect it. While police keep detecting it, the criminal will try to improve their method which is harder to detect the fake money. Also while the criminal method keep improving, police also keep improving their method to detecting it. In the mean time, both of them keep improving on their work base on the other work. So at the end, both side will have perfect methodology of their result. 

In this notebook:
*   The criminal will be the **generator model** who responsible for generate fake image
*   The police will be the **discriminator model** who responsible for detecting the result of generator
*   They both will keep improving base on each other results. 


# Require library

In [1]:
import numpy as np
from tensorflow.keras.datasets.mnist import load_data
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import Reshape
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import BatchNormalization
from matplotlib import pyplot
import time
import os
from IPython.display import display, clear_output

# Discriminator Model

The model to prodict that image is fake or real. It will take image as input and predict whether it is real or fake. 

The discriminator model 
* Has two convolutional layers with 64 filters each, 
* A small kernel size of 3, and larger than normal stride of 2.
* Has no pooling layers and a single node in the output layer with the sigmoid activation function to predict whether the input sample is real or fake. 
* Is trained to minimize the binary cross entropy loss function, appropriate for binary classification.

In [1]:
def define_discriminator(in_shape=(28,28,1)):
    model = Sequential()
    model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same', input_shape=in_shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.2))
    model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.2))
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
    
    # compile model
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    return model

In [1]:
model = define_discriminator()
model.summary()

# Generator Model
​
Generator model is reponsible for generate our fake image with machine learning algorithm (not randomly).
It will take random noise as input and generate image output. 
​
So we going to rake 100 random element and generate it to output image with 28x28 = 784 pixels
​
* The first is a Dense layer as the first hidden layer that has enough nodes to represent a low-resolution version of the output image. Specifically, an image half the size (one quarter the area) of the output image would be 14×14 or 196 nodes, and an image one quarter the size (one eighth the area) would be 7×7 or 49 nodes.
* Reshape image to 128 different 7×7 feature maps
* Upsampling the low-resolution image to a higher resolution version of the image. The Conv2DTranspose layer can be configured with a stride of (2×2) to double their width and height 
* The output layer of the model is a Conv2D with one filter and a kernel size of 7×7. Designed to create a single feature map and preserve its dimensions at 28×28 
* We will use best practice with BatchNormalization and Dropout to improve our model

In [1]:
def define_generator(input_dim):
    model = Sequential()
    
    # foundation for 7x7 image
    n_nodes = 128 * 7 * 7
    model.add(Dense(n_nodes, input_dim=input_dim))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(Reshape((7, 7, 128)))
    model.add(Dropout(0.2))
    
    # upsample to 14x14
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.2))
    
    # upsample to 28x28
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.2))
    
    model.add(Conv2D(1, (7,7), activation='sigmoid', padding='same'))
    return model

In [1]:
noise_dim = 100
g_model = define_generator(noise_dim)
g_model.summary()

### Generate fake samples

In [1]:
# use for generate random noise to generate random image
def generate_noise(noise_dim, n_samples):
    x_input = np.random.randn(noise_dim * n_samples) # generate random noise 
    x_input = x_input.reshape(n_samples, noise_dim)
    return x_input

In [1]:
def generate_fake_samples(noise_dim, n_samples):
  x_input = generate_noise(noise_dim, n_samples) # generate by random noise
  X = g_model.predict(x_input) # generate image from our model
  y = np.zeros((n_samples, 1)) # mark label to 'fake' as 0
  return X, y

See how our generator perform

In [1]:
fig = pyplot.figure(figsize=(12, 12))
n_samples = 25
X, _ = generate_fake_samples(100, n_samples)
for i in range(n_samples):
    pyplot.subplot(5, 5, 1 + i)
    pyplot.axis('off')
    pyplot.imshow(X[i, :, :, 0])
    pyplot.draw()

The result is not look good yet as our generator is not well trained yet

### Load samples

In [1]:
def load_real_samples():
    (trainX, _), (_, _) = load_data() # load mnist dataset
    X = np.expand_dims(trainX, axis=-1) # add gray scale channel to image
    X = X.astype('float32') # convert pixel from ints to floats
    X = X / 255.0 # pixel to between
    return X

### Get real sample by index

In [1]:
def get_real_samples(dataset, idx):
    n_sample = len(idx)
    X = dataset[idx]
    y = np.ones((n_sample, 1)) # mark label to 'real' as 1 
    return X, y

# Create GAN Model

* GAN model is created from combined of generator model and discriminal model.
* The generator responsible for generate fake image.
* And dicriminator responible to evaluate the output of the generator by doing binary classication how look real is it for the output image. 
* From the prediction of discriminator, GAN will get the errors to update the weight of generator to improve the generator model. 
* The weights in the generator model are updated based on the performance of the discriminator model.

In [1]:
def define_gan(g_model, d_model):
    d_model.trainable = False # don't want to update the decriminator model
  
    # connects discriminator and generator
    model = Sequential()
    model.add(g_model)
    model.add(d_model)
  
    model.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))
    return model

In [1]:
noise_dim = 100
d_model = define_discriminator()
g_model = define_generator(noise_dim)
gan_model = define_gan(g_model, d_model)
gan_model.summary()

# Train model 

What we are doing here are:
* We get the image from dataset and generate the fake images to train our discriminator.
* The discriminator will predict the generator output in probability between 0.0 and 1.0.
* While training the GAN we label all the generation from our generator to be real (1).
* The discriminator can calculate the errors and improve the generator model.
* Every 100 epochs we will push the accuracy to `d_history` to virtualize it later

In [1]:
d_history = []

In [1]:
def train_gan(dataset, noise_dim, epochs, batch_size):
    steps = int(dataset.shape[0] / batch_size)
    half_batch = int(batch_size / 2)

    # generate plot slot for real time plot
    fig = pyplot.figure(figsize=(12, 12))
    axs = []
    for i in range(25):
        axs.append(pyplot.subplot(5, 5, 1 + i))

    for epoch in range(epochs):
        for step in range(steps):
            # train our discriminator base from our generator result
            sample_idx = range(step, step+half_batch)
            X_real, y_real = get_real_samples(dataset, sample_idx)
            X_fake, y_fake = generate_fake_samples(noise_dim, half_batch)
            X, y = np.vstack((X_real, X_fake)), np.vstack((y_real, y_fake))
            d_loss, _ = d_model.train_on_batch(X, y)

            # train our GAN to improve our generator
            x_gan = generate_noise(noise_dim, batch_size)
            y_gan = np.ones((batch_size, 1))
            gan_model.train_on_batch(x_gan, y_gan)
        if epoch % 100 == 0: # evaluate every 100 epochs
            # evaluate model test only with 100 output
            evaluate_sample_idx = np.random.randint(0, dataset.shape[0], 50)
            X_real_test, y_real_test = get_real_samples(dataset, evaluate_sample_idx)
            x_fake_test, y_fake_test = generate_fake_samples(noise_dim, 50)
            x_test, y_test = np.vstack((X_real_test, x_fake_test)), np.vstack((y_real_test, y_fake_test))
            _, acc = d_model.evaluate(x_fake_test, y_fake_test, verbose=0)
            d_history.append([acc, epoch])

            fig.suptitle('Discriminal Accuracy: {} at epoch {}'.format(acc, epoch), fontsize=16) # display accuracy and epoch on title

            # plot result in real time
            for i in range(25):
              ax = axs[i]
              ax.cla()
              ax.axis('off')
              ax.imshow(x_fake_test[i, :, :, 0])
            fig.savefig("result_at_epoch_{}.png".format(epoch))
            display(fig)
            clear_output(wait = True) 

  

In [1]:
dataset = load_real_samples()
train_gan(dataset, noise_dim, epochs=1500, batch_size=256)

Save our generator

In [1]:
g_model.save("model.h5")

# See our descrimator performances

In [1]:
d_history = np.array(d_history)
pyplot.figure(figsize=(12, 6))
pyplot.plot(d_history[:, 1], d_history[:, 0]) # plot history accuracy
pyplot.show()

# Plot the final result

In [1]:
x_fake, _ = generate_fake_samples(noise_dim, 25)
fig = pyplot.figure(figsize=(12, 12))
for i in range(25):
    pyplot.subplot(5, 5, 1 + i)
    pyplot.axis('off')
    pyplot.imshow(x_fake[i, :, :, 0])

By this number of epochs it is not perfect yet. But we start to see shape of the digit