# Generative Adversarial Neural Netowrk for the MNIST Dataset

See the original GAN [paper](https://arxiv.org/pdf/1406.2661.pdf) and this Open AI [blog post](https://openai.com/blog/generative-models/) related to this [article](https://arxiv.org/abs/1606.03498).

In implementing this, I have used some of the recommendations given in the [DCGAN paper](https://arxiv.org/pdf/1511.06434.pdf).

Finally, TensorFlow has recently added a [blog post](https://www.tensorflow.org/tutorials/generative/dcgan) on a simple DCGAN for the MNIST dataset. Interestingly, they are asppraoching this in a similar way that I have here. Admittedly, their `training_step` implementation is much better and cleaner than mine. 

In [None]:
import tensorflow as tf

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import numpy as np
from scipy import interpolate
from scipy.special import softmax

import utilities as utils
import models 

<br><br><br>

# Load MNIST Dataset

In [None]:
train_set, test_set = tf.keras.datasets.mnist.load_data()
print(f"\tTrain set: {train_set[0].shape}, {train_set[1].shape}")
print(f"\tTest set:  {test_set[0].shape}, {test_set[1].shape}")

<br><br><br>

# Generative Adversarial Networks (GANs)

Now that we have confidence in the generative and the discriminative models, we can proceed with our GAN setup. I will use these pretrained models to reduce the training time. For more details on GANs, see this [blog post](https://openai.com/blog/generative-models/) from OpenAI.

First, I setup a new model that contains both the generative and the discriminative models. This new model is derived from `tf.keras.Model`, with a custom `call` method.

## Instantiate GAN Model

In [None]:
hidden_rep_size = 40

#tf.keras.backend.clear_session()
gan_model = models.GANModel(
    discriminator_model=models.construct_discriminator_model(),
    generator_model=models.construct_generator_model(
        input_size=hidden_rep_size, 
        output_activation="linear",
        with_batchnorm=True))

# To build the model:
gan_model((
    tf.constant([True, False, True], dtype=tf.bool), 
    tf.zeros((3, 28, 28, 1), dtype=tf.float32), 
    tf.zeros((3, hidden_rep_size), dtype=tf.int32)
))

gan_model.summary(print_fn=(lambda *args: print("\t", *args)))

### Sample Output of the Generator Model:

In [None]:
generator_input = softmax(
    np.random.randn(10, hidden_rep_size), axis=1)
generated_images = gan_model.generator_model.predict(generator_input)

n = 1
fig = plt.figure(figsize=(18., 2.))
for image in generated_images:
    ax = plt.subplot(1, 10, n)
    normalized_image = np.clip(image, 0.0, 255.0).astype(np.uint8)
    ax.imshow(normalized_image, cmap="gray")
    ax.axis("off")
    n += 1
plt.show()

## Compile the GAN Model

In [None]:
use_sgd_optimizer = False #True
if use_sgd_optimizer:
    optimizer = tf.keras.optimizers.SGD(
        learning_rate=0.0001, momentum=0.0, nesterov=False, name="SGD")
else:
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=0.001, beta_1=0.9, beta_2=0.999, 
        epsilon=1e-07, amsgrad=False, name="Adam")

loss = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=False, name="crossentropy")

metrics = [
    tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")]

gan_model.compile(loss=loss, optimizer=optimizer, metrics=metrics)

## Train the GAN Model

In [None]:
num_epochs = 20
batch_size = 256

# Early stoppying callback:
early_stopping_callback = tf.keras.callbacks.EarlyStopping(
    monitor="loss", 
    min_delta=0.0005, 
    patience=20, 
    verbose=1,
    mode="min", 
    baseline=None, 
    restore_best_weights=True)

gan_seq_gen = utils.MNIST_GAN_Sequence_Generator(
    train_set[0][..., np.newaxis].astype(np.float32), 
    batch_size=batch_size, 
    hidden_rep_size = hidden_rep_size,
    temperature=None)

# Fit model
fit_history = gan_model.fit(
    gan_seq_gen,
    epochs=num_epochs,
    steps_per_epoch=None,
    verbose=1,
    validation_data=None,
    shuffle=True,
    class_weight=None,
    workers=8,
    callbacks=[
        early_stopping_callback
    ])

# epochs = np.arange(1, 1 + num_epochs)
# plt.plot(epochs, fit_history.history["loss"], )

In [None]:
generator_input = softmax(
    np.random.randn(10, hidden_rep_size), axis=1)
generated_images = gan_model.generator_model.predict(generator_input)

n = 1
fig = plt.figure(figsize=(18., 2.))
for image in generated_images:
    ax = plt.subplot(1, 10, n)
    normalized_image = np.clip(image, 0.0, 255.0).astype(np.uint8)
    ax.imshow(normalized_image, cmap="gray")
    ax.axis("off")
    n += 1
plt.show()

In [None]:
for idx in range(len(gan_seq_gen) // 50):
    x, y = gan_seq_gen[idx]
    res = gan_model.predict(x)
    print(f"\t[{idx}]:\t{np.sum(y)}\t{np.sum(np.argmax(res, axis=1))}")