In [None]:
import os, random, json, PIL, shutil, re
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from kaggle_datasets import KaggleDatasets
import tensorflow as tf
import tensorflow.keras.layers as L
import tensorflow_addons as tfa
from tensorflow.keras import Model, losses, optimizers
from tensorflow import keras
from tensorflow.keras import layers

## Setting up TPU

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print(f'Running on TPU {tpu.master()}')
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()


REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')
AUTO = tf.data.experimental.AUTOTUNE

## Model Parameters

In [None]:
HEIGHT = 256
WIDTH = 256
CHANNELS = 3
EPOCHS = 50
BATCH_SIZE = 1

# Generates tensors with a normal distribution
initializer = tf.random_normal_initializer(0., 0.02)
gamma_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

In [None]:
PATH = KaggleDatasets().get_gcs_path()

MONET_FILENAMES = tf.io.gfile.glob(str(PATH + '/monet_tfrec/*.tfrec'))
PHOTO_FILENAMES = tf.io.gfile.glob(str(PATH + '/photo_tfrec/*.tfrec'))

def count_files(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

n_monet_samples = count_files(MONET_FILENAMES)
n_photo_samples = count_files(PHOTO_FILENAMES)

print(f'Monet TFRecord files: {len(MONET_FILENAMES)}')
print(f'Monet image files: {n_monet_samples}')
print(f'Photo TFRecord files: {len(PHOTO_FILENAMES)}')
print(f'Photo image files: {n_photo_samples}')

In [None]:
def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=CHANNELS)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [HEIGHT, WIDTH, CHANNELS])
    return image


def read_tfrecord(example):
    tfrecord_format = {
        'image_name': tf.io.FixedLenFeature([], tf.string),
        'image':      tf.io.FixedLenFeature([], tf.string),
        'target':     tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

def load_dataset(filenames):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTO)
    return dataset

def get_gan_dataset(monet_files, photo_files, augment=None, repeat=True, shuffle=True, batch_size=1):

    monet_ds = load_dataset(monet_files)
    photo_ds = load_dataset(photo_files)

    if repeat:
        monet_ds = monet_ds.repeat()
        photo_ds = photo_ds.repeat()
    if shuffle:
        monet_ds = monet_ds.shuffle(2048)
        photo_ds = photo_ds.shuffle(2048)
        
    monet_ds = monet_ds.batch(batch_size, drop_remainder=True)
    photo_ds = photo_ds.batch(batch_size, drop_remainder=True)
    monet_ds = monet_ds.cache()
    photo_ds = photo_ds.cache()
    monet_ds = monet_ds.prefetch(AUTO)
    photo_ds = photo_ds.prefetch(AUTO)
    
    gan_ds = tf.data.Dataset.zip((monet_ds, photo_ds))
    
    return gan_ds



In [None]:
def display_samples(dataset, row, col):
    ds_iter = iter(dataset)
    plt.figure(figsize=(15, int(15*row/col)))
    for j in range(row*col):
        example_sample = next(ds_iter)
        plt.subplot(row,col,j+1)
        plt.axis('off')
        plt.imshow(example_sample[0] * 0.5 + 0.5)
    plt.show()
        
def display_generated_samples(dataset, model, no_of_images):
    # The iter() function returns an iterator object for the given object.
    ds_iter = iter(dataset)
    
    for image in range(no_of_images):
        plt.subplots(figsize=(15, 10))
        
        random_image = next(ds_iter)
        generated_image = model.predict(random_image)
        
        
        plt.subplot(121)
        plt.title("Input image")
        plt.imshow(random_image[0] * 0.5 + 0.5)
        plt.axis('off')
        
        plt.subplot(122)
        plt.title("Generated image")
        plt.imshow(generated_image[0] * 0.5 + 0.5)
        plt.axis('off')
        plt.show()
        
def predict_and_save(input_ds, generator_model, output_path):
    i = 1
    for img in input_ds:
        prediction = generator_model(img, training=False)[0].numpy() # make predition
        prediction = (prediction * 127.5 + 127.5).astype(np.uint8)   # re-scale
        im = PIL.Image.fromarray(prediction)
        im.save(f'{output_path}{str(i)}.jpg')
        i += 1

In [None]:
# Model functions
def downsample(filters, size, apply_instancenorm = True, strides=2):
    
    result = keras.Sequential()
    result.add(L.Conv2D(filters, size, strides=strides, padding='same', kernel_initializer=initializer, use_bias=False))
    
    if apply_instancenorm == True:
        result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))
    result.add(L.LeakyReLU())

    return result


In [None]:
def upsample(filters, size, dropout=False, strides=2):
    
    result = keras.Sequential()
    result.add(L.Conv2DTranspose(filters, size, strides=strides, padding='same', kernel_initializer=initializer, use_bias=False))

    result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    if apply_dropout == True:
        result.add(L.Dropout(0.5))
    result.add(L.ReLU())

    return result

In [None]:
OUTPUT_CHANNELS = 3

def create_generator():
    
    # First layer - using keras.layers.inputLayer -- specifying explicitly.
    inputs = L.Input(shape=[HEIGHT, WIDTH, CHANNELS])
    
    
    model_1 = downsample(64,4,apply_instancenorm = False)
    down_stack = []
    down_stack.append(model_1)

    down_stack.append(downsample(128, 4))
    down_stack.append(downsample(256, 4))
    down_stack.append(downsample(512, 4))
    down_stack.append(downsample(512, 4))
    down_stack.append(downsample(512, 4))
    down_stack.append(downsample(512, 4))
    down_stack.append(downsample(512, 4))

    model_2 = upsample(512, 4, apply_dropout = True)
    up_stack = []
    up_stack.append(model_2)
    
    up_stack.append(upsample(512,4 ,dropout = True))
    up_stack.append(upsample(512,4,dropout = True))
    up_stack.append(upsample(512,4,dropout = True))
    up_stack.append(upsample(512,4))
    up_stack.append(upsample(256,4))
    up_stack.append(upsample(128,4))
    up_stack.append(upsample(64,4))


    #Last conv block with `tanh` activation. ( as per documentation)
    final_layer = L.Conv2DTranspose(OUTPUT_CHANNELS, 4, strides=2,padding='same',kernel_initializer=initializer, activation='tanh') # (bs, 256, 256, 3)

    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        # First layer 'down' is added to the second layer 'x'
        x = down(x)
        skips.append(x)
 
    
    # everything except the last downsample layer
    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    # https://www.kaggle.com/jesperdramsch/understanding-and-improving-cyclegans-tutorial -- Highly recommended to understand skip connections
    
    # This is for concatenating the first layer of upsampling and last layer of downsampling and so on.
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = L.Concatenate()([x, skip])

    x = final_layer(x)

    return Model(inputs=inputs, outputs=x)

In [None]:
def create_discriminator():

    input_layer = L.Input(shape=[HEIGHT, WIDTH, CHANNELS], name='input_image')

    model = keras.Sequential()
    model.add(input_layer)
    
    down_1 = downsample(64,4, False)
    down_2 = downsample(128,4)
    down_3 = downsample(256, 4)
    
    model.add(down_1)
    model.add(down_2)
    model.add(down_3)
    
    zero_pad_1 = L.ZeroPadding2D()
    conv = L.Conv2D(512, 4, strides=1,kernel_initializer=initializer, use_bias=False)
    norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)
    leaky_relu = L.LeakyReLU()
    zero_pad_2 = L.ZeroPadding2D()
    last = L.Conv2D(1, 4, strides=1,kernel_initializer=initializer)
    
    model.add(zero_pad_1)
    model.add(conv)
    model.add(norm1)
    model.add(leaky_relu)
    model.add(zero_pad_2)
    model.add(last)

#     return Model(inputs=input_layer, outputs=last)
    return model

In [None]:
with strategy.scope():
    monet_G = create_generator() # transforms photos to Monet-esque paintings
    photo_G = create_generator() # transforms Monet paintings to be more like photos

    monet_D = create_discriminator() # differentiates real Monet paintings and generated Monet paintings
    photo_D = create_discriminator() # differentiates real photos and generated photos


class CycleGan(Model):
    def __init__(self, monet_G, photo_G, monet_D, photo_D, lambda_cycle = 10 ,lambda_identity=0.5):
        super(CycleGan, self).__init__()
        self.m_generator = monet_G
        self.p_generator = photo_G
        self.m_disc = monet_D
        self.p_disc = photo_D
        self.lambda_cycle = lambda_cycle
        self.lambda_identity = lambda_identity
        
    def compile(self, m_gen_optimizer, p_gen_optimizer, m_disc_optimizer, p_disc_optimizer, generator_loss_function, discriminator_loss_function,
                cycle_loss_function,
                identity_loss_function
               ):
        super(CycleGan, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.generator_loss_function = generator_loss_function
        self.discriminator_loss_function = discriminator_loss_function
        self.cycle_loss_function = cycle_loss_function
        self.identity_loss_function = identity_loss_function
        
    def train_step(self, batch_data):
        real_monet, real_photo = batch_data
        
        with tf.GradientTape(persistent=True) as tape:
            # photo to monet back to photo
            fake_monet = self.m_generator(real_photo, training=True)
            cycled_photo = self.p_generator(fake_monet, training=True)

            # monet to photo back to monet
            fake_photo = self.p_generator(real_monet, training=True)
            cycled_monet = self.m_generator(fake_photo, training=True)

            # generating itself
            same_monet = self.m_generator(real_monet, training=True)
            same_photo = self.p_generator(real_photo, training=True)

            # discriminator used to check, inputing real images
            disc_real_monet = self.m_disc(real_monet, training=True)
            disc_real_photo = self.p_disc(real_photo, training=True)

            # discriminator used to check, inputing fake images
            disc_fake_monet = self.m_disc(fake_monet, training=True)
            disc_fake_photo = self.p_disc(fake_photo, training=True)

            # evaluates generator loss
            monet_gen_loss = self.generator_loss_function(disc_fake_monet)
            photo_gen_loss = self.generator_loss_function(disc_fake_photo)
            
            
            # evaluates total cycle consistency loss
#             total_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet) * self.lambda_cycle + self.cycle_loss_fn(real_photo, cycled_photo)* self.lambda_cycle
            monet_cycle_loss = self.cycle_loss_function(real_monet, cycled_monet) * self.lambda_cycle
            photo_cycle_loss = self.cycle_loss_function(real_photo, cycled_photo) * self.lambda_cycle
            
            # evaluates total generator loss
            total_monet_generator_loss = monet_gen_loss + monet_cycle_loss + (self.identity_loss_function(real_monet, same_monet) * self.lambda_cycle * self.lambda_identity)
            total_photo_generator_loss = photo_gen_loss + photo_cycle_loss + (self.identity_loss_function(real_photo, same_photo) * self.lambda_cycle * self.lambda_identity)

            # evaluates total discriminator loss
            monet_discriminator_loss = self.discriminator_loss_function(disc_real_monet, disc_fake_monet)
            photo_discriminator_loss = self.discriminator_loss_function(disc_real_photo, disc_fake_photo)
            
            
        # Compute gradients for gen and disc
        monet_generator_gradients = tape.gradient(total_monet_generator_loss, self.m_generator.trainable_variables)
        photo_generator_gradients = tape.gradient(total_photo_generator_loss, self.p_generator.trainable_variables)
        monet_discriminator_gradients = tape.gradient(monet_discriminator_loss, self.m_disc.trainable_variables)
        photo_discriminator_gradients = tape.gradient(photo_discriminator_loss, self.p_disc.trainable_variables)

        # Apply the gradients to the optimizer
        self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients, self.m_generator.trainable_variables))
        self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients, self.p_generator.trainable_variables))
        self.m_disc_optimizer.apply_gradients(zip(monet_discriminator_gradients, self.m_disc.trainable_variables))
        self.p_disc_optimizer.apply_gradients(zip(photo_discriminator_gradients, self.p_disc.trainable_variables))
        
        return {
            'monet_generator_loss': total_monet_generator_loss,
            'photo_generator_loss': total_photo_generator_loss,
            'monet_discriminator_loss': monet_discriminator_loss,
            'photo_discriminator_loss': photo_discriminator_loss
        }

In [None]:
adversarial_loss_function = tf.keras.losses.CategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)   
with strategy.scope():
    
    # Discriminator loss
    def calculate_discriminator_loss(real_image, generated_image):
        real_loss = adversarial_loss_function( tf.ones_like(real_image), real_image )
        generated_loss = adversarial_loss_function( tf.zeros_like(generated_image), generated_image )
        total_discriminator_loss = real_loss + generated_loss
        return total_discriminator_loss * 0.5
    
    # Generator loss
    def calculate_generator_loss(generated_image):
        total_generator_loss = losses.BinaryCrossentropy( from_logits=True, reduction=losses.Reduction.NONE)(tf.ones_like(generated_image), generated_image )
        return total_generator_loss
    
    # Cycle consistency loss (measures if original photo and the twice transformed photo to be similar to one another)
    with strategy.scope():
        def calculate_cycle_loss(real_image, cycled_image):
            cycle_loss = tf.reduce_mean(tf.abs(real_image - cycled_image))
            return cycle_loss

    # Identity loss (compares the image with its generator (i.e. photo with photo generator))
    with strategy.scope():
        def calculate_identity_loss(real_image, same_image):
            total_identity_loss = tf.reduce_mean(tf.abs(real_image - same_image))
            return total_identity_loss

In [None]:
Adam_optimizer = optimizers.Adam(2e-4, beta_1=0.5)
with strategy.scope():

    # Create GAN
    gan_model = CycleGan(monet_G, photo_G, monet_D, photo_D)

    gan_model.compile(m_gen_optimizer=Adam_optimizer,
                      p_gen_optimizer=Adam_optimizer,
                      m_disc_optimizer=Adam_optimizer,
                      p_disc_optimizer=Adam_optimizer,
                      generator_loss_function=calculate_generator_loss,
                      discriminator_loss_function=calculate_discriminator_loss,
                      cycle_loss_function=calculate_cycle_loss,
                      identity_loss_function=calculate_identity_loss
                     )
    
history = gan_model.fit(get_gan_dataset(MONET_FILENAMES, PHOTO_FILENAMES, batch_size=BATCH_SIZE), 
                        steps_per_epoch=(n_monet_samples//BATCH_SIZE),
                        epochs=1,
                        verbose=1).history

In [None]:
display_generated_samples(load_dataset(PHOTO_FILENAMES).batch(1), monet_G, 10)
