## Imports

In [None]:
import tensorflow as tf

import matplotlib.pyplot as plt
import numpy as np
import re
import PIL
import os
import shutil

from kaggle_datasets import KaggleDatasets
from random import random
from tensorflow import keras
from tensorflow.keras import layers, Model, losses, optimizers
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.layers import Activation, Concatenate
from tensorflow_addons.layers import InstanceNormalization

## Model Constants

---


In [None]:
HEIGHT = 256
WIDTH = 256
BATCH_SIZE = 1
CHANNELS = 3
LAMBDA = 10
EPOCHS = 250

GCS_PATH = KaggleDatasets().get_gcs_path('gan-getting-started')

## Optimize for TPU usage

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print(f'Running on TPU {tpu.master()}')
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()
    
REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')

AUTOTUNE = tf.data.AUTOTUNE



## Load data

---



In [None]:
MONET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))
PHOTO_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/photo_tfrec/*.tfrec'))

def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

n_monet_samples = count_data_items(MONET_FILENAMES)
n_photo_samples = count_data_items(PHOTO_FILENAMES)

print(f'Monet TFRecord files: {len(MONET_FILENAMES)}')
print(f'Monet image files: {n_monet_samples}')
print(f'Photo TFRecord files: {len(PHOTO_FILENAMES)}')
print(f'Photo image files: {n_photo_samples}')

#### Decode the image and rescale pizels to [-1 - 1] to use tanh activation function

In [None]:
def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=CHANNELS)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [HEIGHT, WIDTH, CHANNELS])
    
    return image 

#### Read and return the images from the TFRecord files

In [None]:
def read_tfrecord(example):
    tfrecord_format = {
        'image_name': tf.io.FixedLenFeature([], tf.string),
        'image':      tf.io.FixedLenFeature([], tf.string),
        'target':     tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    
    return image

#### Extract the images from the tfrecs

In [None]:
def load_dataset(filenames, labeled=True, ordered=False):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    
    return dataset

#### Random augmentaions to the data - a left or right flip, zoom or rotation

In [None]:
data_augmentation = tf.keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomZoom(0.1)
])

In [None]:
def gan_dataset(monet_files, photo_files, augment=None, repeat=True, shuffle=True, batch_size=1, buffer_size=2048):
    
    monet_dataset = load_dataset(monet_files)
    photo_dataset = load_dataset(photo_files)
    
    if augment:
        monet_dataset = monet_dataset.map(augment, AUTOTUNE)
        photo_dataset = photo_dataset.map(augment, AUTOTUNE)
        
    monet_dataset = monet_dataset.cache()
    photo_dataset = photo_dataset.cache()
    
    if shuffle:
        monet_dataset = monet_dataset.shuffle(buffer_size)
        photo_dataset = photo_dataset.shuffle(buffer_size)
        
    if repeat:
        monet_dataset = monet_dataset.repeat()
        photo_dataset = photo_dataset.repeat()
        
    monet_dataset = monet_dataset.batch(batch_size, drop_remainder=True)
    photo_dataset = photo_dataset.batch(batch_size, drop_remainder=True)

    monet_dataset = monet_dataset.prefetch(AUTOTUNE)
    photo_dataset = photo_dataset.prefetch(AUTOTUNE)
    
    gan_dataset = tf.data.Dataset.zip((monet_dataset, photo_dataset))
    
    return gan_dataset

# Load dataset
data = gan_dataset(MONET_FILENAMES, PHOTO_FILENAMES, augment=data_augmentation)

#### Check some photos from the dataset

In [None]:
def display_samples(dataset, row, col):
    dataset_iter = iter(dataset)
    plt.figure(figsize=(15, int(15*row/col)))
    
    for j in range(row*col):
        example_sample = next(dataset_iter)
        plt.subplot(row, col, j+1)
        plt.axis('off')
        plt.imshow(example_sample[0] * 0.5 + 0.5)
    plt.show()

In [None]:
# Show examples
display_samples(load_dataset(MONET_FILENAMES).shuffle(30).batch(1), 4, 4)

In [None]:
# Show photos
display_samples(load_dataset(PHOTO_FILENAMES).shuffle(30).batch(1), 4, 4)

In [None]:
def display_generated_samples(dataset, model, num_samples):
    dataset_iter = iter(dataset)
    
    for _ in range(num_samples):
        example_sample = next(dataset_iter)
        generated_sample = model.predict(example_sample)
        
        plt.subplot(121)
        plt.title("input image")
        plt.imshow(example_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        
        plt.subplot(122)
        plt.title("Generated image")
        plt.imshow(generated_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        plt.show()

In [None]:
def predict_and_save(input_dataset, generator_model, output_path):
    i = 1
    for image in input_dataset:
        prediction = generator_model(image, training=False)[0].numpy()
        # Re-scale
        prediction = (prediction * 127.5 + 127.5).astype(np.uint8)   
        im = PIL.Image.fromarray(prediction)
        im.save(f'{output_path}{str(i)}.jpg')
        i += 1

## Build Generator Model

---

#### Resnet blocks- [Link to the paper](https://arxiv.org/abs/1707.04881) (note section 2.2. ResGAN model)
* Two 256 filter 3×3 Convolutional 2D layers
* ` InstanceNormalization(axis = -1)` ensures features are normalized per feature map
* ReLU activation function for first layer

In [None]:
def resnet_block(input_layer):
    initializer = RandomNormal(stddev=0.02)
    
    # Layer 1
    out = layers.Conv2D(256, 3, padding='same', kernel_initializer=initializer)(input_layer)
    out = InstanceNormalization(axis=-1)(out)
    out = layers.LeakyReLU(alpha=0.2)(out)
    
    # Layer 2
    out = layers.Conv2D(256, 3, padding='same', kernel_initializer=initializer)(out)
    out = InstanceNormalization(axis=-1)(out)
    
    # Merge with input layer and return
    out = Concatenate()([out, input_layer])
    
    return out

#### Build Generator Model

Unpaired Image-to-Image Translation [Link to paper](https://arxiv.org/abs/1703.10593) (Note section 2 and figure 3)

* Begin with preprocessing we only have 300 images in the training Moet set
* The generator contains 9 residual blocks c7s1-64, d128, d256, R256, R256, R256, R256, R256, R256, R256, R256, R256, u128, u64, c7s1-3
* **c7s1-k** - 7×7 Convolution-InstanceNorm-ReLU layer with k filters and stride 1. 
* **dk** - 3 × 3 Convolution-InstanceNorm-ReLU layer with k filters and stride 2.
* **Rk** - resnet_block (from previous cell)
* **uk** - 3 × 3 fractional-strided-Convolution InstanceNorm-ReLU layer with k filters

    > Fractional striding is achieved by Conv2dTranspose [Here is a nice article on the topic from Beeren Sahu](https://beerensahu.wordpress.com/2018/04/10/pytorch-a-fractionally-strided-convolution-or-a-deconvolution/)


In [None]:
def generator():
    inputs = layers.Input(shape=[HEIGHT, WIDTH, CHANNELS])
    init = RandomNormal(mean=0.0, stddev=0.02)
    
    # c7s1-64
    out = layers.Conv2D(64, 7, padding='same', kernel_initializer=init)(inputs)
    out = InstanceNormalization()(out)
    out = Activation('relu')(out)
    
    # d128
    out = layers.Conv2D(128, 3, strides=2, padding='same', kernel_initializer=init)(out)
    out = InstanceNormalization()(out)
    out = Activation('relu')(out)
    
    # d256
    out = layers.Conv2D(256, 3, strides=2, padding='same', kernel_initializer=init)(out)
    out = InstanceNormalization()(out)
    out = layers.Dropout(0.5)(out)
    out = Activation('relu')(out)
    
    # R256 (9)
    for _ in range(9):
        out = resnet_block(out)
    
    # u128
    out = layers.Conv2DTranspose(128, 3, strides=2, padding='same', kernel_initializer=init)(out)
    out = InstanceNormalization()(out)
    out = Activation('relu')(out)
    
    # u64
    out = layers.Conv2DTranspose(64, 3, strides=2, padding='same', kernel_initializer=init)(out)
    out = InstanceNormalization()(out)
    out = Activation('relu')(out)
    
    # c7s1-3
    out = layers.Conv2D(3, 7, padding='same', kernel_initializer=init)(inputs)
    out = InstanceNormalization()(out)
    out = Activation('tanh')(out)

    return Model(inputs, out)

## Build Discriminator Model

---

#### 70 x 70 patchGAN - [Link to the paper](https://arxiv.org/abs/1703.10593) (note section 7.2. Network architectures - Discriminator architectures)

* The discriminator contains 4 blocks C64 - C128 - C256 - C512
* `InstanceNormalization(axis = -1)` ensures features are normalized per feature map

In [None]:
def discriminator():
    init = RandomNormal(mean=0.0, stddev=0.02)
    inputs = layers.Input(shape=[HEIGHT, WIDTH, CHANNELS])
    
    # C64
    out = layers.Conv2D(64, 4, strides=2, padding='same', kernel_initializer=init)(inputs)
    out = layers.LeakyReLU(alpha=0.2)(out)
    
    # C128
    out = layers.Conv2D(128, 4, strides=2, padding='same', kernel_initializer=init)(out)
    out = InstanceNormalization()(out)
    out = layers.LeakyReLU(alpha=0.2)(out)
    
    # C256
    out = layers.Conv2D(256, 4, strides=2, padding='same', kernel_initializer=init)(out)
    out = InstanceNormalization()(out)
    out = layers.LeakyReLU(alpha=0.2)(out)
    
    # C512, 1x1 stride
    out = layers.Conv2D(256, 4, padding='same', kernel_initializer=init)(out)
    out = InstanceNormalization()(out)
    out = layers.LeakyReLU(alpha=0.2)(out)
    
    # final step 1 filter
    outputs = layers.Conv2D(1, 4, padding='same', kernel_initializer=init)(out)
    
    return Model(inputs, outputs)

## Loss Functions

---

* Discriminator loss {0: fake, 1: real} (The discriminator loss outputs the average of the real and generated loss)
* Cycle consistency loss (measures if original photo and the twice transformed photo to be similar to one another)
* Identity loss (compares the image with its generator (i.e. photo with photo generator))

In [None]:
with strategy.scope():
    
    def discriminator_loss(real, generated):
        real_loss = losses.BinaryCrossentropy(from_logits=True, reduction=losses.Reduction.NONE)(tf.ones_like(real), real)
        generated_loss = losses.BinaryCrossentropy(from_logits=True, reduction=losses.Reduction.NONE)(tf.zeros_like(generated), generated)
        total_disc_loss = real_loss + generated_loss

        return total_disc_loss * 0.5
  

In [None]:
with strategy.scope():
    
    def generator_loss(generated):
        return losses.BinaryCrossentropy(from_logits=True, reduction=losses.Reduction.NONE)(tf.ones_like(generated), generated)

In [None]:
with strategy.scope():
    
    def cycle_loss(real_image, cycled_image, LAMBDA):
        loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

        return LAMBDA * loss1

In [None]:
with strategy.scope():
    
    def identity_loss(real_image, same_image, LAMBDA):
        loss = tf.reduce_mean(tf.abs(real_image - same_image))
        
        return LAMBDA * 0.5 * loss

#### Initialize Optimizers

In [None]:
with strategy.scope(): 
    
    monet_generator_optimizer = optimizers.Adam(2e-4, beta_1=0.5)
    photo_generator_optimizer = optimizers.Adam(2e-4, beta_1=0.5)

    monet_discriminator_optimizer = optimizers.Adam(1e-4, beta_1=0.5, epsilon=0.1, amsgrad=True)
    photo_discriminator_optimizer = optimizers.Adam(1e-4, beta_1=0.5, epsilon=0.1, amsgrad=True)

## Build Composite Model

---


In [None]:
with strategy.scope():
    monet_generator = generator()
    photo_generator = generator()

    monet_discriminator = discriminator()
    photo_discriminator = discriminator()

class CycleGan(Model):
    def __init__(
        self,
        monet_generator,
        photo_generator,
        monet_discriminator,
        photo_discriminator,
        lambda_cycle=10,
    ):
        super(CycleGan, self).__init__()
        self.monet_generator = monet_generator
        self.photo_generator = photo_generator
        self.monet_discriminator = monet_discriminator
        self.photo_discriminator = photo_discriminator
        self.lambda_cycle = lambda_cycle
        
    def compile(
        self,
        monet_generator_optimizer,
        photo_generator_optimizer,
        monet_discriminator_optimizer,
        photo_discriminator_optimizer,
        gen_loss_fn,
        disc_loss_fn,
        cycle_loss_fn,
        identity_loss_fn
    ):
        super(CycleGan, self).compile()
        self.monet_generator_optimizer = monet_generator_optimizer
        self.photo_generator_optimizer = photo_generator_optimizer
        self.monet_discriminator_optimizer = monet_discriminator_optimizer
        self.photo_discriminator_optimizer = photo_discriminator_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
    
    @tf.function
    def train_step(self, batch_data):
        real_monet, real_photo = batch_data
        
        with tf.GradientTape(persistent=True) as tape:
            # photo to monet back to photo
            fake_monet = self.monet_generator(real_photo, training=True)
            cycled_photo = self.photo_generator(fake_monet, training=True)

            # monet to photo back to monet
            fake_photo = self.photo_generator(real_monet, training=True)
            cycled_monet = self.monet_generator(fake_photo, training=True)

            # generating itself
            same_monet = self.monet_generator(real_monet, training=True)
            same_photo = self.photo_generator(real_photo, training=True)

            # discriminator used to check, inputing real images
            disc_real_monet = self.monet_discriminator(real_monet, training=True)
            disc_real_photo = self.photo_discriminator(real_photo, training=True)

            # discriminator used to check, inputing fake images
            disc_fake_monet = self.monet_discriminator(fake_monet, training=True)
            disc_fake_photo = self.photo_discriminator(fake_photo, training=True)

            # evaluates generator loss
            monet_gen_loss = self.gen_loss_fn(disc_fake_monet)
            photo_gen_loss = self.gen_loss_fn(disc_fake_photo)

            # evaluates total cycle consistency loss
            total_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle) + self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle)

            # evaluates total generator loss
            total_monet_gen_loss = monet_gen_loss + total_cycle_loss + self.identity_loss_fn(real_monet, same_monet, self.lambda_cycle)
            total_photo_gen_loss = photo_gen_loss + total_cycle_loss + self.identity_loss_fn(real_photo, same_photo, self.lambda_cycle)

            # evaluates discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)

        # Calculate the gradients for generator and discriminator
        monet_generator_gradients = tape.gradient(total_monet_gen_loss, self.monet_generator.trainable_variables)
        photo_generator_gradients = tape.gradient(total_photo_gen_loss, self.photo_generator.trainable_variables)

        monet_discriminator_gradients = tape.gradient(monet_disc_loss, self.monet_discriminator.trainable_variables)
        photo_discriminator_gradients = tape.gradient(photo_disc_loss, self.photo_discriminator.trainable_variables)

        # Apply the gradients to the optimizer
        self.monet_generator_optimizer.apply_gradients(zip(monet_generator_gradients, self.monet_generator.trainable_variables))
        self.photo_generator_optimizer.apply_gradients(zip(photo_generator_gradients, self.photo_generator.trainable_variables))
        self.monet_discriminator_optimizer.apply_gradients(zip(monet_discriminator_gradients, self.monet_discriminator.trainable_variables))
        self.photo_discriminator_optimizer.apply_gradients(zip(photo_discriminator_gradients, self.photo_discriminator.trainable_variables))
        
        return {
            'Monet generator loss': total_monet_gen_loss,
            'Photo generator loss': total_photo_gen_loss,
            'Monet discriminator loss': monet_disc_loss,
            'Photo discriminator loss': photo_disc_loss
        }

## Helper Functions

---


In [None]:
def predict_and_save(input_dataset, generator_model, output_path):
    i = 1
    for image in input_dataset:
        prediction = generator_model(image, training=False)[0].numpy()
        prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
        im = PIL.Image.fromarray(prediction)
        im.save(f'{output_path}{str(i)}.jpg')
        i += 1

#### Updating:
- image buffer for fake images reduces model oscillation
- discriminators using a history of generated images rather than the ones produced by the latest generators

[Link to the paper](https://arxiv.org/abs/1703.10593) (note section 4. Implementation - Training details)

In [None]:
# def update_image_buffer(buffer_images, images, buffer_max =50):
    
#     output_images = list()
#     for image in images:
        
#         # Fill the buffer
#         if len(buffer) < buffer_max:
            
#             buffer.append(image)
#             output_images.append(image)
            
#         # Use image, don't add to buffer
#         elif random() < 0.5:
#             output_images.append(image)
            
#         # Replace an existing image and use replaced image
#         else:
#             index = randint(0, len(buffer))
#             output_images.append(buffer[index])
#             buffer[index] = image
            
#     return tf.stack(selected, axis=0)

## Compile GAN

---


In [None]:
with strategy.scope():
    
    gan_model = CycleGan(monet_generator, photo_generator, monet_discriminator, photo_discriminator)

    gan_model.compile(
        monet_generator_optimizer = monet_generator_optimizer,
        photo_generator_optimizer = photo_generator_optimizer,
        monet_discriminator_optimizer = monet_discriminator_optimizer,
        photo_discriminator_optimizer = photo_discriminator_optimizer,
        gen_loss_fn = generator_loss,
        disc_loss_fn = discriminator_loss,
        cycle_loss_fn = cycle_loss,
        identity_loss_fn = identity_loss,
    )

## Train GAN

---


In [None]:
gan_model.fit(data, epochs=EPOCHS, steps_per_epoch=(max(n_monet_samples , n_photo_samples )//5), verbose=2)

gan_model.save('./cyclegan_model')

## Visualize predictions

In [None]:
display_generated_samples(load_dataset(PHOTO_FILENAMES).batch(1), monet_generator, 4)

In [None]:
# plot the image, its translation, and the reconstruction
def show_plot(imagesX, imagesY1, imagesY2):
    
    images = vstack((imagesX, imagesY1, imagesY2))
    titles = ['Real', 'Generated', 'Reconstructed']

    # Rescale from [-1,1] to [0,1]
    images = (images + 1) / 2.0
    
    # plot images row by row
    for i in range(len(images)):
        pyplot.subplot(1, len(images), 1 + i)
        pyplot.axis('off')
        pyplot.imshow(images[i])
        pyplot.title(titles[i])
        
    pyplot.show()

## Make predictions

In [None]:
# Create folder to save generated images
os.makedirs('../images/')

with strategy.scope():
    predict_and_save(load_dataset(PHOTO_FILENAMES).batch(1), monet_generator, '../images/')

## Submission file

In [None]:
shutil.make_archive('/kaggle/working/images/', 'zip', '../images')

print(f"Generated samples: {len([name for name in os.listdir('../images/') if os.path.isfile(os.path.join('../images/', name))])}")