# Generating Pokémon with Generative Adversial Networks (GAN)

We will generate new Pokémon based on images of the existing 800+ Pokémonn with DCGAN.

GANs pair a generator model, which learns to produce the target output, with a discriminator model, which learns to distinguish true data from the output of the generator. The generator tries to fool the discriminator, and the discriminator tries to keep from being fooled. Using this architecture GANs can create images that resemble the training set. (modified from: https://developers.google.com/machine-learning/gan)

This kernel is based on https://www.tensorflow.org/tutorials/generative/dcgan
**Run this kernel with GPU acceleration.**

In [None]:
!rm -rf ../output  # Clear previous output

Let's start by loading up the Pokémon images. We will use the dataset provided by https://www.kaggle.com/vishalsubbiah/pokemon-images-and-types. 

We will create a `tf.data.Dataset` as this appears to be the fastest solution. I have previously tried `tf.keras.preprocessing.image_dataset_from_directory` and although this makes things simpler and has built in augementation it became a bottleneck for the GAN. It is not necessarly important to understand what is happening here but I will explain anyway. We load up images from the input folder, allow caching (for GPU), prefetching (so we do not have to loadup data just in time), tell the dataset to use batching and apply preprocessing and filtering.

### Preprocessing
Images are loaded up as RGB (or grayscale if `IMAGE_CHANNELS = 1`), scaled to `[IMAGE_SIZE, IMAGE_SIZE]` and each image channel is transformed from $[0, 255]$ to $[-1, 1]$. The images are now stored as a matrix of shape `[IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNELS]`.

### Filtering
We kick out all images that have a white background (ca. 10%). This decreases our already quite small dataset to 721 images but also prevents color bleeding and improves convergence. We detect images with a white background by simply checking the red channel of the top left pixel - this will be a bad estimate in most cases but works well for this dataset.

In [None]:
import tensorflow as tf
from tensorflow.data.experimental import AUTOTUNE
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time
from IPython import display

BATCH_SIZE = 128
IMAGE_SIZE = 120  # reduce this to increase performance
IMAGE_CHANNELS = 3  # can be 3 (RGB) or 1 (Grayscale)
LATENT_SPACE_DIM = 100  # dimensions of the latent space that is used to generate the images

assert IMAGE_SIZE % 4 == 0


def preprocess(file_path):
    # load the raw data from the file as a string
    img = tf.io.read_file(file_path)
    # load the image as uint8 array and transform to grayscale
    img = tf.image.decode_jpeg(img, channels=IMAGE_CHANNELS)
    # resize the image to the desired size
    img = tf.image.resize(img, [IMAGE_SIZE, IMAGE_SIZE])
    # transform the color values from [0, 255] to [-1, 1]. The division changes the datatype to float32
    img = (img - 127.5) / 127.5
    return img


def filter(img):
    return img[0, 0, 0] == -1  # discard white bg images (estimate by the R channel of the top left pixel)


def configure_for_performance(ds):
    ds = ds.cache()
    ds = ds.filter(filter)
    ds = ds.shuffle(buffer_size=1000)
    ds = ds.batch(BATCH_SIZE)
    ds = ds.prefetch(buffer_size=AUTOTUNE)
    return ds


list_ds = tf.data.Dataset.list_files(str('../input/pokemon-images-and-types/images/*/*'), shuffle=True)  # Get all images from subfolders
train_dataset = list_ds.take(-1)
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
train_dataset = train_dataset.map(preprocess, num_parallel_calls=AUTOTUNE)
train_dataset = configure_for_performance(train_dataset)

## Generator model

The goal of the generator is to create an image from some input tensor. This input tensor is a representation of the image in latent space. Think of this latent space as some kind of super-compressed zip-file, that ones extracted returns an image. This "extraction" is done by this model. Note that the model converts the input into a image, that will be upscaled while traversing through the model resulting in quadrupling in size compared to the input after the `Reshape` layer. 

In [None]:
def make_generator_model():
    model = tf.keras.Sequential()
    
    n = IMAGE_SIZE // 4
    
    model.add(layers.Dense(n * n * 256, use_bias=False, input_shape=(LATENT_SPACE_DIM,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((n, n, 256)))

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(IMAGE_CHANNELS, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))

    return model

## Discriminator model

The task of the discriminator is to classify whether a given input image is a real image (i.e. similar to the ones trained on) or fake. It will return a single output: the probability $p$ of the input beeing real ($p\in[0,1]$ where 1 would mean real).

In [None]:
def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=(IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNELS)))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

Now, we generate the models and set Adam as optimizer for both models.

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

discriminator = make_discriminator_model()
generator = make_generator_model()

## Loss functions

To assess the quality of our models we need to provide a loss function for each. We will later use these functions to update our model weights durign training.


### Generator Loss
The goal of the generator is to create an image that fools the discriminator i.e. an image that is classified as real. 

Therefore, the loss is simply the difference between the scores of the discriminator for each fake image that the generator created and a tensor full of 1's (i.e. saying all input images are real). 

The loss would be 0 if all fake images are labeled as real by the discriminator.

### Discriminator Loss
The discriminators goal on the other hand is to distinguish real from fake images.

It computes two sub-losses and adds them. One loss represents how good real images are detected (outputting 1) and the other how good fake data is rejected ("called out" by outputting 0). The real loss is the difference between the scores of the discriminator for each training image and a tensor full of 1's. While, the fake loss is the difference between the scores of the discriminator for each fake image that the generator created and a tensor full of 0's. 

The loss would be 0 if images are classified correctly.

In [None]:
# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

Now we need to define our training steps. In each step we will generate images from random noise using the generator and supply a batch of real training images. We will then compute the loss and the corresponding gradient for both models. Since both models are Keras models we can then simply update the weights by using `Model.apply_gradients()`.

Please note that we annotate the function as `@tf.function` which significantly increases performance.

In [None]:
# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, LATENT_SPACE_DIM])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)  # training=True is important, sicne Dropout and BatchNorm behave differently during inference

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

        gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

        generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

The training itself is pretty straight forward. Given our dataset we call `training_set` for each batch in the dataset. We repeat this for the number of epochs provided. Additionally, we generate and save the current predictions based on a gloabl seed after `save_after` epochs. This helps us to understand the progress of the model. Since this adds a significant computational load it is best to keep `save_after` as high as possible.

In [None]:
num_examples_to_generate = 16
# We will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.normal([num_examples_to_generate, LATENT_SPACE_DIM])

def train(dataset, epochs, save_after):
    
    generate_and_save_images(generator,
                       0,
                       seed)
    
    for epoch in range(epochs):
        for image_batch in dataset:
            train_step(image_batch)

        if (epoch + 1) % save_after == 0:
            # Produce images for the GIF as we go
            display.clear_output(wait=True)
            generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Generate after the final epoch
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                       epochs,
                       seed)

Add a little helper function that makes predictions from a given model, plots them in a grid and saves into the output folder.

In [None]:
def generate_and_save_images(model, epoch, test_input):
    # Notice `training` is set to False.
    # This is so all layers run in inference mode (batchnorm).
    predictions = model(test_input, training=False)

    fig = plt.figure(figsize=(10, 10))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        if predictions.shape[-1] == 3:
            plt.imshow(predictions[i] * 0.5 + .5)  # scale image to [0, 1] floats (or you could also scale to [0, 255] ints) 
        else: 
            plt.imshow(predictions[i, :, :, 0] * 0.5 + .5, cmap='gray')  # scale image to [0, 1] floats (or you could also scale to [0, 255] ints) 
        plt.axis('off')
    plt.suptitle(f'Epoch {epoch}')
    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()


# Fun Time

**Warning: training these models takes a long time. Prepare for multiple hours even with GPU acceleration. The model will output and save generations every `save_after` epochs.**

How to improve performance:
* Increase `save_after` or completely discard it
* Decrease the `IMAGE_SIZE` (will generate smaller images)
* Changes from color to grayscale by setting `IMAGE_CHANNELS = 1`
* Run on TPU (untested)

In [None]:
train(train_dataset, epochs=10000, save_after=100)