# <center> Generative Adversarial Networks </center>

<center>This notebook is a part of teaching material for CS-EJ3311 - Deep Learning with Python</center>
<center>Aalto University (Espoo, Finland)</center>
<center>fitech.io (Finland)</center>

# <center>Data</center>

[MNIST](https://www.tensorflow.org/datasets/catalog/mnist) dataset consists of data points representing handwritten digits. Each data point is characterized by a $28 \times 28$ pixels grayscale image. Each data point is associated with a label $y$ that indicates to which of $10$ classes (0,..,9) this article belongs. 


In [None]:
#@title  Import Python libraries

import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt 
import IPython.display as ipd
from IPython.display import IFrame

In [None]:
#@title Load Data

# load dataset
(X_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()  # load training set
X_train = X_train.reshape(-1, 28, 28, 1)/255 * 2. - 1.  # reshape and rescale
X_train = tf.cast(X_train, tf.float32)  # change data type

# create tf.Dataset
batch_size = 32
dataset = tf.data.Dataset.from_tensor_slices(X_train)
dataset = dataset.shuffle(1000)
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)


<h2>DCGAN architecture:</h2>
<h3>Generator.</h3>
    
Your task is to build a generator. The generator is a sequential model with:
    
- Dense layer with 7*7*128 units and `input_shape=[codings_size]`. 
- Reshape layer `tf.keras.layers.Reshape([7, 7, 128])`
- Batch Normalization layer
- Conv2DTranspose layer with 64 kernels, kernel size 5, strides 2, padding "same" and selu activation
- Batch Normalization layer
- Conv2DTranspose layer with 1 output kernel, kernel size 5, strides 2, padding "same" and tanh activation
    
Add name to the model as `cv_generator = tf.keras.models.Sequential([...], name="Generator")`


In [None]:
#@title Build Generator Network

codings_size = 100

cv_generator = tf.keras.models.Sequential([
    tf.keras.layers.Dense(7 * 7 * 128, input_shape=[codings_size]),
    tf.keras.layers.Reshape([7, 7, 128]),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Conv2DTranspose(64, kernel_size=5, strides=2, padding="same",
                                 activation="selu"),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Conv2DTranspose(1, kernel_size=5, strides=2, padding="same",
                                 activation="tanh")
], name="Generator")

<h3>Discriminator.</h3>
    
Your task is to build a discriminator. Discriminator should take as input image generated by generator and output probability of the image being fake or real.
    
Use:
    
- Conv2DTranspose layer with 64 kernels, kernel size 5, strides 2, padding "SAME" and leaky relu activation (alpha=0.2) 
- Conv2DTranspose layer with 128 kernels, kernel size 5, strides 2, padding "SAME" and activation leaky relu activation (alpha=0.2)  
- Flatten layer    
- Dense output layer      
    
Add name to the model as discriminator = tf.keras.models.Sequential([...], name="Discriminator")


In [None]:
#@title Build Discriminator Network

cv_discriminator = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(64, kernel_size=5, strides=2, padding="same",
                        activation=keras.layers.LeakyReLU(0.2),
                        input_shape=[28, 28, 1]),
    tf.keras.layers.Conv2D(128, kernel_size=5, strides=2, padding="same",
                        activation=keras.layers.LeakyReLU(0.2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(1, activation="sigmoid")
], name="Discriminator")

In [None]:
#@title Build a GAN model
cv_gan = tf.keras.models.Sequential([cv_generator, cv_discriminator])

tf.keras.utils.plot_model(
    cv_gan,
    show_shapes=True, 
    show_layer_names=True
)

In [None]:
#@title Compile and train GAN

# train loop
def train_gan(gan, dataset, batch_size, codings_size, n_epochs=3):
    saved_samples = np.zeros((int(n_epochs/10),2,batch_size,28,28,1))
    generator, discriminator = gan.layers
    itr=0
    for epoch in range(n_epochs):
        print("Epoch {}/{}".format(epoch + 1, n_epochs)) 
        for X_batch in dataset:
            # Phase 1 - training the discriminator
            noise = tf.random.normal(shape=[batch_size, codings_size])
            generated_images = generator(noise)
            X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)
            y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
            discriminator.trainable = True
            discriminator.train_on_batch(X_fake_and_real, y1)
            
            # Phase 2 - training the generator
            noise = tf.random.normal(shape=[batch_size, codings_size])
            y2 = tf.constant([[1.]] * batch_size)
            discriminator.trainable = False
            gan.train_on_batch(noise, y2)
            
            if (itr%100 == 0):  
                gen_images = generator.predict(test_random_vector)
                plot_multiple_images(gen_images, 8)
                plt.show()
                         
            itr+=1

# function for plotting images outputted by generator
# code source https://github.com/ageron/handson-ml2/blob/master/17_autoencoders_and_gans.ipynb
def plot_multiple_images(images, n_cols=None):
    n_cols = n_cols or len(images)
    n_rows = (len(images) - 1) // n_cols + 1
    if images.shape[-1] == 1:
        images = np.squeeze(images, axis=-1)
    plt.figure(figsize=(n_cols, n_rows))
    for index, image in enumerate(images):
        plt.subplot(n_rows, n_cols, index + 1)
        plt.imshow(image, cmap="binary")
        plt.axis("off")  

# compile nets
cv_discriminator.compile(loss="binary_crossentropy", optimizer="rmsprop")
cv_discriminator.trainable = False
cv_gan.compile(loss="binary_crossentropy", optimizer="rmsprop")

#random vector for plotting
test_random_vector = tf.random.normal(shape=[batch_size, codings_size])
# train GAN
train_gan(cv_gan, dataset, batch_size, codings_size)
