# MNIST Handwritten Digit Generation using GAN

In [1]:
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
    print('Running TF without GPU')
else:
    print(f'Found GPU at {device_name}')

Running TF without GPU


In [2]:
%matplotlib inline

import glob
import imageio
import math
import os
import seaborn as sn
import time

from abc import abstractstaticmethod
from matplotlib import gridspec
from matplotlib import pyplot as plt
from PIL import Image
from tensorflow import keras

sn.set_theme()

In [3]:
NOISE_DIMENSION = 128
TRAIN_SIZE = 60000
TEST_SIZE = 10000
BATCH_SIZE = 32
EPOCHS = 100

N_EXAMPLES = 25
G_LEARNING_RATE = 2e-4
D_LEARNING_RATE = 1e-4
EPSILON = 1e-7

ARCH = 'gan'
METRICS_PATH = f'metrics/{ARCH}/'
OUTPUT_PATH = f'output/{ARCH}/'

!mkdir -p $METRICS_PATH
!mkdir -p $OUTPUT_PATH
!mkdir -p $(OUTPUT_PATH)old/
!mv $(OUTPUT_PATH)* $(OUTPUT_PATH)old/

mv: target 'output/gan//old/' is not a directory


In [4]:
assert NOISE_DIMENSION > 0
assert TRAIN_SIZE <= 600000
assert BATCH_SIZE >= 1
assert EPOCHS >= 1
assert N_EXAMPLES >= 1
assert G_LEARNING_RATE > 0
assert D_LEARNING_RATE > 0
assert EPSILON > 0

## Base Class

In [5]:
class BaseNetwork(tf.Module):
    def __init__(self):
        super().__init__()

    @tf.Module.with_name_scope
    def __call__(self, input_data, training=False) -> tf.Tensor:
        output_data = input_data
        for layer in self.layers:
            output_data = layer(output_data, training=training)
        return output_data

    @abstractstaticmethod
    def loss() -> tf.Tensor:
        raise NotImplementedError

    @abstractstaticmethod
    def optimizer(*args, **kwargs) -> tf.optimizers.Optimizer:
        raise NotImplementedError

## Generator Network
### Layers
<pre>
<b>Input</b>:  input_size=(128, 1)
<b>Dense</b>:  units=128, activation=ReLU
<b>Dense</b>:  units=128, activation=ReLU
<b>Dense</b>:  units=256, activation=ReLU
<b>Dense</b>:  units=512, activation=ReLU
<b>Dense</b>:  units=512, activation=ReLU
<b>Dense</b>:  units=784, activation=tanh, target_shape=(28, 28, 1)
</pre>

### Optimizer
<pre>
<b>Adam</b>:  learning_rate=0.0002
</pre>

### Loss
&#8466;<sub>G</sub>(<i><b>z</b></i>) = <sup>-1</sup>&frasl;<sub>m</sub> &lowast; &sum;<sub><i>i</i></sub> log(<i>D</i>(G(<i><b>z</b><sup>(i)</sup></i>)))

### Goal
Find argmin<sub>G</sub> {&#8466;<sub>G</sub>(<i><b>z</b></i>)}

In [6]:
class Generator(BaseNetwork):
    def __init__(self, noise_dimension: int, output_shape: tuple) -> None:
        super().__init__()
        # Network layers
        self.layers = [
            keras.layers.InputLayer(input_shape=(noise_dimension,)),
            keras.layers.Flatten(),
            keras.layers.Dense(units=128),
            keras.layers.ReLU(),
            keras.layers.Dense(units=128),
            keras.layers.ReLU(),
            keras.layers.Dense(units=256),
            keras.layers.ReLU(),
            keras.layers.Dense(units=512),
            keras.layers.ReLU(),
            keras.layers.Dense(units=512),
            keras.layers.ReLU(),
            keras.layers.Dense(units=tf.reduce_prod(output_shape)),
            keras.layers.Reshape(target_shape=output_shape),
            keras.layers.Activation(tf.nn.tanh),
        ]

    @staticmethod
    def optimizer(learning_rate: float, momentum: float=0.0) -> tf.optimizers.Optimizer:
        return tf.optimizers.Adam(learning_rate=learning_rate)

    @staticmethod
    @tf.function
    def loss(generated_output):
        loss_i = tf.math.log(generated_output+EPSILON)
        loss = -tf.math.reduce_mean(loss_i)
        return loss

## Discriminator Network
### Layers
<pre>
<b>Input</b>:    input_size=(28, 28, 1)
<b>Dense</b>:    units=512, activation=ReLU
<b>Dropout</b>:  rate=0.4
<b>Dense</b>:    units=512, activation=ReLU
<b>Dropout</b>:  rate=0.3
<b>Dense</b>:    units=512, activation=ReLU
<b>Dropout</b>:  rate=0.3
<b>Dense</b>:    units=256, activation=ReLU
<b>Dropout</b>:  rate=0.3
<b>Dense</b>:    units=128, activation=ReLU
<b>Dropout</b>:  rate=0.1
<b>Dense</b>:    units=64, activation=ReLU
<b>Dense</b>:    units=1, activation=sigmoid
</pre>

### Optimizer
<pre>
<b>Adam</b>:         learning_rate=0.0002
</pre>

### Loss
&#8466;<sub>D</sub>(<i><b>x</b>,<b>z</b></i>) = <sup>-1</sup>&frasl;<sub>m</sub> &lowast; &sum;<sub><i>i</i></sub> \[log <i>D</i>(<i><b>x</b><sup>(i)</sup></i>) + log (1-<i>D</i>(G(<i><b>z</b><sup>(i)</sup></i>)))\]

### Goal
Find argmin<sub>D</sub> &#8466;<sub>D</sub>(<i><b>x</b></i>,<i><b>z</b></i>)

In [7]:
class Discriminator(BaseNetwork):
    def __init__(self, input_shape: tuple) -> None:
        super().__init__()
        # Network layers
        self.layers = [
            keras.layers.InputLayer(input_shape=input_shape),
            keras.layers.Flatten(),
            keras.layers.Dense(units=512),
            keras.layers.ReLU(),
            keras.layers.Dropout(rate=0.4),
            keras.layers.Dense(units=512),
            keras.layers.ReLU(),
            keras.layers.Dropout(rate=0.3),
            keras.layers.Dense(units=256),
            keras.layers.ReLU(),
            keras.layers.Dropout(rate=0.3),
            keras.layers.Dense(units=128),
            keras.layers.ReLU(),
            keras.layers.Dropout(rate=0.2),
            keras.layers.Dense(units=64),
            keras.layers.ReLU(),
            keras.layers.Dense(units=1),
            keras.layers.Activation(tf.nn.sigmoid),
        ]

    @staticmethod
    def optimizer(learning_rate: float, momentum: float=0.0):
        return tf.optimizers.Adam(learning_rate=learning_rate)

    @staticmethod
    @tf.function
    def loss(trained_ouput, generated_output) -> tf.Tensor:
        loss_i = tf.math.log(trained_ouput+EPSILON) + tf.math.log1p(-generated_output+EPSILON)
        loss = -tf.math.reduce_mean(loss_i)
        return loss

## Preprocessing
* Import MNIST training images
* Normalize to \[-1, 1\]
* Shuffle and batch dataset

In [8]:
(train_images, _), (test_images, _) = keras.datasets.mnist.load_data()

train_images = tf.dtypes.cast((train_images[:TRAIN_SIZE]-127.5) / 127.5, tf.float32)
train_images = tf.expand_dims(input=train_images, axis=-1)
train_ds = tf.data.Dataset.from_tensor_slices(train_images) \
                          .shuffle(TRAIN_SIZE) \
                          .batch(BATCH_SIZE)

test_images = tf.dtypes.cast((train_images[:TEST_SIZE]-127.5) / 127.5, tf.float32)

### Initialize Generator and Discriminator

In [9]:
generator = Generator(noise_dimension=NOISE_DIMENSION, output_shape=(train_images.shape[1:]))
generator_optimizer = generator.optimizer(G_LEARNING_RATE)

discriminator = Discriminator(input_shape=(train_images.shape[1:]))
discriminator_optimizer = discriminator.optimizer(D_LEARNING_RATE)

In [10]:
fixed_noise = tf.random.normal([N_EXAMPLES, NOISE_DIMENSION])
metrics_names = ['g_loss', 'd_loss', 'acc', 'real_acc', 'fake_acc']
batch_history = { name: [] for name in metrics_names }

In [11]:
def record_sample(generator: Generator, epoch: int, save: bool=True, show=True):
    fixed_predictions = generator(fixed_noise)
    random_predictions = generator(tf.random.normal(fixed_noise.shape))

    f = plt.figure(figsize=(8, 4))
    f.suptitle('GAN Output')
    outer = gridspec.GridSpec(1, 2)
    titles = ['Fixed Samples', 'Random Samples']
    data = [fixed_predictions, random_predictions]
    for i in range(2):
        inner = gridspec.GridSpecFromSubplotSpec(int(math.sqrt(N_EXAMPLES)), int(math.sqrt(N_EXAMPLES)), subplot_spec=outer[i])
        predictions = data[i]
        for j in range(predictions.shape[0]):
            ax = plt.Subplot(f, inner[j])
            ax.imshow(predictions[j]*127.5 + 127.5, cmap=plt.cm.gray)
            ax.axis('off')
            if j == int(math.sqrt(N_EXAMPLES)) // 2:
                ax.set_title(titles[i])
            f.add_subplot(ax)

    if save:
        plt.savefig(os.path.join(OUTPUT_PATH, 'epoch_{:04d}.png'.format(epoch)))
    if show:
        plt.show()
    plt.close()

### Accuracy

In [12]:
def real_accuracy(real_output) -> tf.Tensor:
    # Trained images fed into D have output 1 
    correct_real_output = tf.ones_like(real_output)
    real_acc = tf.math.reduce_mean(tf.dtypes.cast(tf.math.equal(tf.math.round(real_output), correct_real_output), tf.float32))
    return real_acc

def accuracy(real_output, fake_output) -> tuple:
    # Images from G's noisy distribution should have output 0 from D
    correct_fake_output = tf.zeros_like(fake_output)
    fake_acc = tf.math.reduce_mean(tf.dtypes.cast(tf.math.equal(tf.math.round(fake_output), correct_fake_output), tf.float32))

    return real_accuracy(real_output), fake_acc

## GAN Algorithm
<pre>
<b>for</b> number of training iterations <b>do</b>
  <b>for</b> k steps <b>do</b>
     • Sample minibatch of <i>m</i> noise samples {<i><b>z</b><sup>(1)</sup>, ..., <b>z</b><sup>(m)</sup></i>} from noise prior <i>p<sub>g</sub>(<b>z</b>)</i>.
     • Sample minibatch of <i>m</i> examples {<i><b>x</b><sup>(1)</sup>, ..., <b>x</b><sup>(m)</sup></i>} from data generating distribution <i>p<sub>data</sub>(<b>x</b>)</i>.
     • Update the discriminator by <u>ascending</u> its stochastic gradient:
       <center>&Del;<sub>&theta;<sub>d</sub></sub> <sup>1</sup>&frasl;<sub>m</sub> &lowast; &sum;<sub><i>i</i></sub> [log <i>D</i>(<i><b>x</b><sup>(i)</sup></i>) + log(1-<i>D</i>(G(<i><b>z</b><sup>(i)</sup></i>)))]</center>
  <b>end for</b>
  • Sample minibatch of <i>m</i> noise samples {<i><b>z</b><sup>(1)</sup>, ..., <b>z</b><sup>(m)</sup></i>} from noise prior <i>p<sub>g</sub>(<b>z</b>)</i>.
  • Update the generator by <u>descending</u> its stochastic gradient:
  <center>&Del;<sub>&theta;<sub>d</sub></sub> <sup>1</sup>&frasl;<sub>m</sub> &lowast; &sum;<sub><i>i</i></sub> log(1-<i>D</i>(G(<i><b>z</b><sup>(i)</sup></i>)))</center>
<b>end for</b>
</pre>

Source: https://arxiv.org/pdf/1406.2661.pdf

In [None]:
record_sample(generator, 0, show=False)
steps_per_epoch = train_images.shape[0] // BATCH_SIZE

for epoch in range(1, EPOCHS+1):
    print(f'\nepoch {epoch}/{EPOCHS}')
    progress_bar = keras.utils.Progbar(steps_per_epoch, stateful_metrics=metrics_names)

    for i, image_batch in enumerate(train_ds):

        num_samples = image_batch.shape[0]
        noise = tf.random.normal([num_samples, NOISE_DIMENSION])

        with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
            # Train G on noise
            generated_images = generator(noise, training=True)
            # Train D on training images
            trained_output = discriminator(image_batch, training=True)
            # Train D on generated images
            generated_output = discriminator(generated_images, training=True)

            # Calculate loss
            g_loss = generator.loss(generated_output)
            d_loss = discriminator.loss(trained_output, generated_output)

        real_acc, fake_acc = accuracy(trained_output, generated_output)
        acc = 0.5 * (real_acc + fake_acc)

        training_metrics = {
            'g_loss': g_loss,
            'd_loss': d_loss,
            'acc': acc,
            'real_acc': real_acc,
            'fake_acc': fake_acc,
        }

        # Record loss history
        for metric in batch_history:
            batch_history[metric].append(training_metrics[metric])
    
        metric_values = training_metrics.items()
        progress_bar.update(i, values=metric_values)        

        # https://www.tensorflow.org/api_docs/python/tf/GradientTape#gradient
        grad_g = g_tape.gradient(g_loss, generator.trainable_variables)
        grad_d = d_tape.gradient(d_loss, discriminator.trainable_variables)

        # https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Optimizer#apply_gradients
        generator_optimizer.apply_gradients(zip(grad_g, generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(grad_d, discriminator.trainable_variables))

    progress_bar.update(steps_per_epoch, values=metric_values, finalize=True)
    record_sample(generator, epoch, show=False)


epoch 1/100

epoch 2/100

epoch 3/100

epoch 4/100

epoch 5/100

epoch 6/100

epoch 7/100

epoch 8/100

epoch 9/100

epoch 10/100

epoch 11/100

epoch 12/100

epoch 13/100

In [None]:
plt.figure(figsize=(10, 8))
plt.plot(batch_history['g_loss'], label='Generator Loss')
plt.plot(batch_history['d_loss'], label='Discriminator Loss')
plt.xlabel('Batch #')
plt.ylabel('Loss')
plt.legend()
plt.savefig(os.path.join(METRICS_PATH, 'loss.png'))
plt.show()
plt.close()

In [None]:
plt.figure(figsize=(10, 8))
plt.plot(batch_history['acc'], label='Total Accuracy')
plt.plot(batch_history['real_acc'], label='Training Data Accuracy')
plt.plot(batch_history['fake_acc'], label='Generated Data Accuracy')
plt.xlabel('Batch #')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig(os.path.join(METRICS_PATH, 'accuracy.png'))
plt.show()
plt.close()

In [None]:
real_test_preds = tf.dtypes.cast(tf.math.round(discriminator(test_images, training=False)), tf.float32)
fake_test_preds = tf.dtypes.cast(tf.math.round(discriminator(generator(tf.random.normal([test_images.shape[0], NOISE_DIMENSION]), training=False), training=False)), tf.float32)
real_labels = tf.ones_like(real_test_preds)
fake_labels = tf.zeros_like(fake_test_preds)
preds = tf.concat([real_test_preds, fake_test_preds], axis=0)
labels = tf.concat([real_labels, fake_labels], axis=0)
preds = tf.squeeze(preds)
labels = tf.squeeze(labels)
named_labels = ['Fake', 'Real']
confusion_matrix = tf.math.confusion_matrix(labels, preds, num_classes=2).numpy()
plt.figure(figsize=(10, 7))
sn.set(font_scale=1.4)
ax = sn.heatmap(confusion_matrix, annot=True, annot_kws={'size': 16}, fmt='d', cmap=sn.cm.rocket_r, xticklabels=named_labels, yticklabels=named_labels)
ax.xaxis.tick_top()
ax.xaxis.set_label_position('top')
ax.tick_params(axis='both', which='major', labelsize=16, labelbottom=False, bottom=False, top=False, labeltop=True)
plt.savefig(os.path.join(METRICS_PATH, 'confusion_matrix.png'))
plt.show()
plt.close()

In [None]:
with imageio.get_writer(os.path.join(METRICS_PATH, 'gan.gif'), mode='I') as writer:
    filenames = glob.glob(os.path.join(OUTPUT_PATH, 'epoch*.png'))
    filenames = sorted(filenames)
    for filename in filenames:
        image = imageio.imread(filename)
        writer.append_data(image)
        image = imageio.imread(filename)
        writer.append_data(image)
Image.open(os.path.join(OUTPUT_PATH, 'epoch_{:04d}.png'.format(epoch)))

# for _ in range(5):
#     real_test_image = test_images[tf.random.uniform(shape=(), minval=0, maxval=TEST_SIZE+1, dtype=tf.int32)]
#     fake_test_image_1 = generator(tf.random.normal([1, NOISE_DIMENSION]), training=False)
#     fake_test_image_2 = tf.nn.tanh(tf.random.normal(train_images.shape[1:]))

#     ax1 = plt.subplot(1, 3, 1)
#     ax1.imshow(real_test_image, cmap=plt.cm.gray)
#     ax1.axis('off')
#     real_test_image = tf.expand_dims(input=real_test_image, axis=0)
#     ax1.set_title(discriminator(real_test_image, training=False).numpy()[0][0])

#     ax2 = plt.subplot(1, 3, 2)
#     ax2.imshow(fake_test_image_1[0, :, :, 0], cmap=plt.cm.gray)
#     ax2.axis('off')
#     ax2.set_title(discriminator(fake_test_image_1, training=False).numpy()[0][0])

#     ax3 = plt.subplot(1, 3, 3)
#     ax3.imshow(fake_test_image_2, cmap=plt.cm.gray)
#     ax3.axis('off')
#     fake_test_image_2 = tf.expand_dims(input=fake_test_image_2, axis=0)
#     ax3.set_title(discriminator(fake_test_image_1, training=False).numpy()[0][0])

#     plt.show()
#     plt.close()