<a href="https://colab.research.google.com/github/siddheshsathe/google-colab-repos/blob/master/GAN_with_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Using MNIST data for learning working of GANs

Below are the steps that we'll be following
<br>
* Load data using `keras`'s built in `load_data`
* For demo/learning purpose we'll just use any one data class (anyone from 0 through 9), let's use 1
* Create `generator` model 
> * This model will have any latent (primary input noise) shape 
> * Eventually we'll make this with reshaping and/or upsampling the 2D images to the desired output image shape
> * While building this model, we'll not compile this as it will be compiled directly in the complete GAN model
* Create `discriminator` model
> * This model will just be a `binary classification` model which will output that the input to this model is a real or a fake image
> * Thus the output layer will consist of a `Dense` layer with single neuron.
> * Input shape of this model is of input image shape
> * Compile the model with `binary_crossentropy` and any suitable optimizer
* The training phase of this GAN model is not just usual with `fit` method
* We'll need to train the `discriminator` with below labels
> * `fake = 0`
> * `real = 1`


#1. Loading Data

In [0]:
%tensorflow_version 2.x
from tensorflow.keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()

import numpy as np
X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], X_train.shape[2], 1))
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], X_test.shape[2], 1))

#2. Defining the models

In [0]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, UpSampling2D, Flatten, LeakyReLU, BatchNormalization, Reshape, Conv2DTranspose, Dropout
from tensorflow.keras.optimizers import Adam

In [0]:
coding_size = 100 # Initial latent size

In [0]:
generator = Sequential()

generator.add(Dense(7 * 7 * 128, input_shape=[coding_size]))
generator.add(Reshape([7, 7, 128]))
generator.add(BatchNormalization())
generator.add(Conv2DTranspose(64, kernel_size=5, strides=2, padding="same",
                                 activation="relu"))
generator.add(BatchNormalization())
generator.add(Conv2DTranspose(1, kernel_size=5, strides=2, padding="same",
                                 activation="tanh"))

In [0]:
discriminator = Sequential()
discriminator.add(Conv2D(64, kernel_size=5, strides=2, padding="same",
                        activation=LeakyReLU(0.3),
                        input_shape=[28, 28, 1]))
discriminator.add(Dropout(0.5))
discriminator.add(Conv2D(128, kernel_size=5, strides=2, padding="same",
                        activation=LeakyReLU(0.3)))
discriminator.add(Dropout(0.5))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation="sigmoid"))
discriminator.compile(loss="binary_crossentropy", optimizer="adam")

In [0]:
GAN = Sequential([generator, discriminator])

In [0]:
discriminator.trainable = False

In [0]:
GAN.compile(optimizer='adam', loss='binary_crossentropy')

In [0]:
GAN.summary()

#3. Setting up training batches

In [0]:
dataset = tf.data.Dataset.from_tensor_slices(X_train).shuffle(buffer_size=coding_size)

In [0]:
batch_size = 32
dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)

#4. Training

In [0]:
epochs = 10
# Grab the seprate components
generator, discriminator = GAN.layers

# For every epcoh
for epoch in range(epochs):
    print(f"Currently on Epoch {epoch+1}")
    i = 0
    # For every batch in the dataset
    for X_batch in dataset:
        i=i+1
        if i%10 == 0:
            print(f"Batch: {i}")
        ## Training discriminator
        
        # Create Noise
        noise = tf.random.normal(shape=[batch_size, coding_size])
        
        # Generate numbers based just on noise input
        gen_images = generator(noise)
        
        # Concatenate Generated Images against the Real Ones
        # TO use tf.concat, the data types must match!
        X_fake_vs_real = tf.concat([gen_images, tf.dtypes.cast(X_batch,tf.float32)], axis=0)
        
        # Targets set to zero for fake images and 1 for real images
        yLabel = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
        
        # This gets rid of a Keras warning
        discriminator.trainable = True
        
        # Train the discriminator on this batch
        discriminator.train_on_batch(X_fake_vs_real, y1)
        
        
        ## Training generator
        # Using same noise
        
        # We want discriminator to belive that fake images are real
        yLabel = tf.constant([[1.]] * batch_size)
        
        # Avois a warning
        discriminator.trainable = False
        
        GAN.train_on_batch(noise, y2)

# 5. Generating data with trained model

In [0]:
noise = tf.random.normal([1, coding_size])
noise.shape

In [0]:
generated_image = generator(noise)

In [0]:
generated_image.shape

In [0]:
import matplotlib.pyplot as plt

In [0]:
plt.imshow(generated_image[0].numpy().reshape(28, 28), cmap='gray')