# TensorFlow 2 GAN -> CoreML Notebook, Part 1

This notebook trains a TensorFlow 2 color image generative adversarial network.

Also included are some fun plotting functions for exploring / animating the latent space.

## Dependencies, Installation and Environment
To run this you'll need a Python 3 environment with **TensorFlow 2**:
- tensorflow or tensorflow-gpu 2+ (GPU highly recommended. I'm currently running tensorflow-gpu version 2.2.0)
- matplotlib
- numpy
- jupyter

If you do use a GPU, you can test it by running:
```python
import tensorflow as tf
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
```

## The Model, Training, and Configuration
This model is going to run on a wide variety of iPhones, so for simplicity we'll start with a deep convolutional network ([DCGAN](https://www.tensorflow.org/tutorials/generative/dcgan)).

There are a lot of known issues that arise training GANs, but in these experiments I've been mostly concerned with (1) image quality and (2) "mode collapse".

Image quality is pretty subjective. I'm looking to make pictures that are *close enough* to input images, and which are free of any weird artifacts or patterns (e.g. [checkerboarding](https://distill.pub/2016/deconv-checkerboard/)).

Mode collapse occurs when the generator overfits and starts producing the same output over and over.

I'm generating artsy images from very small datasets: currently on the order of 500-20000 training samples. For my purposes, I've found that the models defined below tend to reach a happy place around 200-400 training epochs. You'll need to inspect the generated images and see what works for you.

This is a simple notebook, and there's much, much more that could be done to improve training performance. Some next steps might include changing up the loss function, or adding evaluation metrics like Fréchet Inception Distance. For more detailed writeups, see: [Common Problems (with GANs)](https://developers.google.com/machine-learning/gan/problems), [10 Lessons I Learned Training GANs for one Year](https://towardsdatascience.com/10-lessons-i-learned-training-generative-adversarial-networks-gans-for-a-year-c9071159628), and [How to Train a GAN? Tips and tricks to make GANs work](https://github.com/soumith/ganhacks).

If you're looking to incorporate labels or want some control over inputs in the latent space, check out Conditional GAN (CGAN) and Information Maximizing GAN (InfoGAN).

## Getting started
The basic architecture and training sections of this notebook are based on the ["GAN overriding Model.train_step"](https://keras.io/examples/generative/dcgan_overriding_train_step/) tutorial, which trains a deep convolutional GAN with TensorFlow 2 + Keras to generate black and white digits based on the MNIST dataset.

The model has been modified to output 128x128 RGB images. Read along for more details!

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

import matplotlib.pyplot as plt
import os

# Setup

## Pick a model name
Files for this model will be saved to **./models/{model_name}**

## Add training data
Add image paths to "data_paths" to include them in training. If you're looking to include labels, modify the process_path and/or Dataset initializer.

I tested this out on my own pictures of flowers. For something similar, check out the [Oxford Flowers 17](http://www.robots.ox.ac.uk/~vgg/data/flowers/17/), [TensorFlow Flowers](https://www.tensorflow.org/datasets/catalog/tf_flowers), or [Oxford Flowers 102](https://www.tensorflow.org/datasets/catalog/oxford_flowers102).

### Image size, etc.
You can also change the shape of latent input, or the size of output images here. **If you change the image output size, be prepared to modify the generator model to match it.**

In [None]:
# General settings and image preprocessing helpers
model_name = 'MODEL_NAME'
data_paths = ['./PATH_TO_IMAGES/*']
latent_dim = 128
img_shape = (128, 128, 3)

def decode_img(img, img_size=(128, 128)):
    img = tf.image.decode_jpeg(img, channels=3)
    # Convert image to floats in [0,1]
    img = tf.image.convert_image_dtype(img, tf.float32)
    # resize and convert to [-1, 1]
    img = tf.image.resize(img, img_size)
    img = (img - 0.5) * 2.0
    return img

def process_path(file_path):
    # load the raw data from the file as a string
    img = decode_img(tf.io.read_file(file_path))
    return img

filenames = []
for idx, path in enumerate(data_paths):
    batch_ds = tf.data.Dataset.list_files(path)
    batch_filenames = list(batch_ds)
    filenames.extend(batch_filenames)
    print("{}: {} images".format(path, len(batch_filenames)))

file_ds = tf.data.Dataset.from_tensor_slices(filenames)
all_images = file_ds.map(process_path, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = all_images.shuffle(buffer_size=1024).batch(64).prefetch(32).cache()

# Models

The model layers are all specified with Keras. Some comments:
    - Use LeakyReLU in discriminator and generator
    - Initialize transpose convolutions with random noise
    - Make sure that transpose convolutions are formatted with color channels last
    - Use BatchNormalization in generator, and tanh for final activation

As of the time of writing, there was no out-of-the-box CoreML support for a Spectral Normalization layer.

In [None]:
discriminator_model = keras.Sequential(
    [
        keras.Input(shape=img_shape),
        layers.Conv2D(32, (5, 5), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.5),
        layers.Conv2D(64, (5, 5), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.5),
        layers.Conv2D(128, (5, 5), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.5),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1, activation='sigmoid')
    ],
    name="discriminator",
)

discriminator_model.summary()

In [None]:
generator_model = keras.Sequential(
    [
        keras.Input(shape=(latent_dim,), name='noise_in'),

        layers.Dense(16 * 16 * 256, use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        layers.Reshape((16, 16, 256)),
        layers.UpSampling2D(),
        
        layers.Conv2DTranspose(128, (5, 5), strides=(1, 1),
                               padding="same",
                               use_bias=False,
                               kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02),
                               data_format="channels_last"),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        
        layers.Conv2DTranspose(64, (5, 5), strides=(2, 2),
                               padding="same",
                               use_bias=False,
                               kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02),
                               data_format="channels_last"),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        
        layers.Conv2DTranspose(32, (5, 5), strides=(2, 2),
                               padding="same",
                               use_bias=False,
                               kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02),
                               data_format="channels_last"),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        
        layers.Conv2D(3, (5, 5), strides=(1, 1), activation='tanh', padding='same')
    ],
    name="generator",
)

generator_model.summary()

# Make sure the generator's output shape matches the image shape we're expecting
assert(generator_model.layers[-1].output_shape[1:] == img_shape)

# Training

This is pretty much unmodified from the source notebook.

The GANMonitor class is a callback that saves a grid of generated images at the end of each training epoch.

In [None]:
class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]
        
        # Sample random points in the latent space
        batch_size = tf.shape(real_images)[0]
        
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Decode them to fake images
        generated_images = self.generator(random_latent_vectors)

        # Combine them with real images
        combined_images = tf.concat([generated_images, real_images], axis=0)

        # Assemble labels discriminating real from fake images
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )
        # Add random noise to the labels - important trick!
        labels += 0.05 * tf.random.uniform(tf.shape(labels))

        # Train the discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )

        # Sample random points in the latent space
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Assemble labels that say "all real images"
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
        return {"d_loss": d_loss, "g_loss": g_loss}

In [None]:
# This callback plots a grid of images generated from a random sampling of points in the latent space

class GANMonitor(keras.callbacks.Callback):
    def __init__(self, img_path, img_rows=4, img_cols=4, latent_dim=128):
        self.img_rows = img_rows
        self.img_cols = img_cols
        self.latent_dim = latent_dim
        self.img_path = img_path

    def on_epoch_end(self, epoch, logs=None):
        random_latent_vectors = tf.random.normal(shape=(self.img_rows * self.img_cols, self.latent_dim))

        generated_images = self.model.generator(random_latent_vectors)
        generated_images = 0.5 * generated_images + 0.5
        generated_images = generated_images.numpy()
        
        raw_vectors = random_latent_vectors.numpy()
        
        fig, axs = plt.subplots(self.img_rows, self.img_cols, figsize=(15,15), gridspec_kw={'wspace': 0.025, 'hspace': 0.025})
        fig.suptitle('Epoch {}'.format(epoch), fontsize=16)
        cnt = 0
        for i in range(self.img_rows):
            for j in range(self.img_cols):
                axs[i,j].imshow(generated_images[cnt])
                axs[i,j].axis('off')
                cnt += 1
        outfile = os.path.join(self.img_path, '{}.png'.format(epoch))
        fig.savefig(outfile)
        plt.close()

# Training Setup

All model files are saved in **./models/{model_name}**. The particulars are:
 - weights.hdf5: trained weights
 - generator.json: generator layer architecture
 - discriminator.json: discrinimator layer architecture
 - images/{epoch}.png: evaluation images generated at the end of each epoch

**If existing model definition files or weights are found, they're loaded automatically**

In [None]:
# Create a folder for this model
# The folder will contain model architecture definitions, trained weights, and evaluation images
model_path = os.path.join('models', model_name)
weights_path = os.path.join(model_path, 'weights.hdf5') # Combined model weights
generator_weights_path = os.path.join(model_path, 'generator_weights.hdf5')
img_path = os.path.join(model_path, 'images')
mlmodel_path = os.path.join(model_path, "model.mlmodel")
generator_path = os.path.join(model_path, 'generator_architecture.json')
discriminator_path = os.path.join(model_path, 'discriminator_architecture.json')

print("Working in {}".format(model_path))

if not os.path.exists(img_path):
    os.makedirs(img_path)
    
if os.path.exists(generator_path):
    with open(generator_path, 'r') as fp:
        generator = keras.models.model_from_json(fp.read())
        print("Loaded generator architecture from {}".format(generator_path))
        print(generator.summary())
else:
    generator = generator_model
    print("Using new generator")
    with open(generator_path, 'w') as fp:
        fp.write(generator.to_json())
        
if os.path.exists(discriminator_path):
    with open(discriminator_path, 'r') as fp:
        discriminator = keras.models.model_from_json(fp.read())
        print("Loaded discriminator architecture from {}", discriminator_path)
        print(discriminator.summary())
else:
    discriminator = discriminator_model
    print("Using new discriminator")
    with open(discriminator_path, 'w') as fp:
        fp.write(discriminator.to_json())

# Two Time-Scale Update Rule: settings discriminator LR higher than the generator
gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0004, beta_1=0.5),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.5),
    loss_fn=keras.losses.BinaryCrossentropy(),
)

plot_images = GANMonitor(img_path, latent_dim=latent_dim)

# Save weights every epoch. Monitoring Inception Score / FID could be a nice improvement
checkpoint = keras.callbacks.ModelCheckpoint(weights_path,
                                             monitor='g_loss',
                                             verbose=1,
                                             save_best_only=False,
                                             mode='auto',
                                             save_freq=1)

if os.path.exists(weights_path):
    print('Loading weights from {}'.format(weights_path))
    gan.load_weights(weights_path)

# Train the model!

In [None]:
epochs = 225
gan.fit(
    dataset, epochs=epochs, callbacks=[plot_images, checkpoint]
)

# Save the generator weights
These will get loaded elsewhere and included with the final CoreML model.

In [None]:
gan.generator.save_weights(generator_weights_path)
print("Saved generator weights to {}".format(generator_weights_path))

# Notebook Utilities

Use these cells to visualize output in the notebook.

In [None]:
# Spherical linear interpolation for exploring latent space
# Basically, since we're sampling the latent space with a Gaussian distribution, our sample points are closer to a hypersphere than a uniformly distributed hypercube. For more on this see:
# https://github.com/soumith/dcgan.torch/issues/14
# https://machinelearningmastery.com/how-to-interpolate-and-perform-vector-arithmetic-with-faces-using-a-generative-adversarial-network/

def slerp(val, low, high):
    omega = np.arccos(np.clip(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high)), -1, 1))
    so = np.sin(omega)
    if so == 0:
        # Becomes LERP when omega -> 0
        return (1.0-val) * low + val * high
    return np.sin((1.0-val)*omega) / so * low + np.sin(val*omega) / so * high

def interpolate(a, b, n_steps=10):
    steps = np.linspace(0, 1, num=n_steps)
    vectors = []
    return np.array([slerp(t, a, b) for t in steps])

## Grid of random points in latent space

In [None]:
# The same as in GANMonitor

rows, cols = 4, 4
random_latent_vectors = tf.random.normal(shape=(rows*cols, latent_dim))

generated_images = gan.generator(random_latent_vectors)
generated_images = 0.5 * generated_images + 0.5
generated_images = generated_images.numpy()

fig, axs = plt.subplots(rows, cols, figsize=(15,15), gridspec_kw={'wspace': 0.025, 'hspace': 0.025})
cnt = 0
for i in range(rows):
    for j in range(cols):
        axs[i,j].imshow(generated_images[cnt])
        axs[i,j].axis('off')
        cnt += 1
plt.show()

## Animated interpolation between points in latent space

In [None]:
import matplotlib.animation as animation
from IPython.display import HTML

fig = plt.figure()
ax = plt.gca()
ax.axis('off')
imgs = []

imobj = ax.imshow(np.zeros((128, 128)), origin='upper', alpha=1.0, zorder=1, aspect=1)

count = 8

points = tf.random.normal(shape=(count, latent_dim))
for i in range(0, count-1):
    interpolated = interpolate(points[i], points[i+1], n_steps=40)
    generated_images = gan.generator(interpolated)
    generated_images = 0.5 * generated_images + 0.5
    imgs.extend(generated_images)

def animate(i):
    if i < len(imgs):
        img = imgs[i]
        imobj.set_data(img)
    return imobj,

anim = animation.FuncAnimation(fig, animate, frames=len(imgs), repeat=True, interval=135, blit=True, repeat_delay=100)

# Comment these lines in to save the animation as a gif
writer = animation.PillowWriter(fps=12)
anim.save(os.path.join(model_path, "animation.gif"), writer=writer)
HTML(anim.to_jshtml())

In [None]:
# Comment these lines in to save the animation as a gif
writer = animation.PillowWriter(fps=24)
anim.save(os.path.join(model_path, "animation.gif"), writer=writer)
HTML(anim.to_jshtml())

## Grid of interpolated points in latent space

In [None]:
rows = 8
cols = 8

points = tf.random.normal(shape=(rows+1, latent_dim))
imgs = []
for i in range(0, rows):
    interpolated = interpolate(points[i], points[i+1], n_steps=cols)
    generated_images = gan.generator(interpolated)
    generated_images = 0.5 * generated_images + 0.5
    imgs.extend(generated_images)

fig, axs = plt.subplots(rows, cols, figsize=(16, 16), gridspec_kw={'wspace': 0.025, 'hspace': 0.025})
cnt = 0
for i in range(rows):
    for j in range(cols):
        axs[i,j].imshow(imgs[cnt])
        axs[i,j].axis('off')
        cnt += 1
fig.savefig(os.path.join(model_path, 'grid.png'))
plt.show()

# CoreML Conversion

For more about converting the trained generator to CoreML, check out my other notebook.