# **Imports**

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras

from keras import Model, layers, losses, optimizers, initializers
from keras.layers import Conv2D, Conv2DTranspose
from keras.models import Sequential

# **Load Data**

In [None]:
MONET_FILENAMES = tf.io.gfile.glob('/kaggle/input/gan-getting-started/monet_tfrec/monet*.tfrec')
PHOTO_FILENAMES = tf.io.gfile.glob('/kaggle/input/gan-getting-started/photo_tfrec/photo*.tfrec')

print('Monet Number Of Files:', len(MONET_FILENAMES))
print('Photo Number Of Files:', len(PHOTO_FILENAMES))

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
IMAGE_SIZE = [256, 256]

def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_tfrecord(example):
    tfrecord_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

def load_dataset(filenames):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    return dataset

In [None]:
monet_ds = load_dataset(MONET_FILENAMES).batch(1)
photo_ds = load_dataset(PHOTO_FILENAMES).batch(1)

# **Visualize sample**

In [None]:
example_monet = next(iter(monet_ds))
example_photo = next(iter(photo_ds))

plt.subplot(121)
plt.title('Photo')
plt.imshow(example_photo[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Monet')
plt.imshow(example_monet[0] * 0.5 + 0.5)

# **Downsample & Upsample**

In [None]:
OUTPUT_CHANNELS = 3

In [None]:
def downsample(filters, size):
    initializer = tf.random_normal_initializer(0., 0.03)
    
    model = Sequential()
    model.add(Conv2D(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))
    model.add(tf.keras.layers.LeakyReLU())

    return model

In [None]:
def upsample(filters, size):
    initializer = tf.random_normal_initializer(0., 0.03)

    model = Sequential()
    model.add(Conv2DTranspose(filters, size, activation='relu', strides=2, padding='same', kernel_initializer=initializer, use_bias=False))
    
    return model

# **Create generator & discriminator**

In [None]:
def Generator():
    inputs = layers.Input(shape=[256,256,3])

    down_stack = [
        downsample(64, 4), # (128, 128, 64)
        downsample(128, 4), # (64, 64, 128)
        downsample(256, 4), # (32, 32, 256)
        downsample(512, 4), # (16, 16, 512)
        downsample(512, 4), # (8, 8, 512)
        downsample(512, 4), # (4, 4, 512)
        downsample(512, 4), # (2, 2, 512)
        downsample(1024, 4), # (1, 1, 512)
    ]

    up_stack = [
        upsample(1024, 4), # (2, 2, 1024)
        upsample(512, 4), # (4, 4, 1024)
        upsample(512, 4), # (8, 8, 1024)
        upsample(512, 4), # (16, 16, 1024)
        upsample(256, 4), # (32, 32, 512)
        upsample(128, 4), # (64, 64, 256)
        upsample(64, 4), # (128, 128, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.03)
    last = Conv2DTranspose(OUTPUT_CHANNELS, 4, strides=2, padding='same', kernel_initializer=initializer, activation='tanh') # (256, 256, 3)
    
    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip])

    x = last(x)

    return Model(inputs=inputs, outputs=x)

In [None]:
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.03)

    discInputs = layers.Input(shape=[256, 256, 3], name='input_image')
    x = discInputs

    firstLayer = downsample(64, 4)(x) # (128, 128, 64)
    secondLayer = downsample(128, 4)(firstLayer) # (64, 64, 128)
    thirdLayer = downsample(256, 4)(secondLayer) # (32, 32, 256)

    firstZeroPadLayer = layers.ZeroPadding2D()(thirdLayer) # (34, 34, 256)
    convLayer = Conv2D(512, 4, strides=1, kernel_initializer=initializer, use_bias=False)(firstZeroPadLayer) # (31, 31, 512)

    leaky_relu = layers.LeakyReLU()(convLayer)

    secondZeroPadLayer = layers.ZeroPadding2D()(leaky_relu) # (33, 33, 512)

    lastLayer = layers.Conv2D(1, 4, strides=1, kernel_initializer=initializer)(secondZeroPadLayer) # (30, 30, 1)

    return Model(inputs=discInputs, outputs=lastLayer)

In [None]:
monet_generator = Generator() # transforms photos to Monet-esque paintings
photo_generator = Generator() # transforms Monet paintings to be more like photos

monet_discriminator = Discriminator() # differentiates real Monet paintings and generated Monet paintings
photo_discriminator = Discriminator() # differentiates real photos and generated photos

In [None]:
to_monet = monet_generator(example_photo)

plt.subplot(1, 2, 1)
plt.title("Original Photo")
plt.imshow(example_photo[0] * 0.5 + 0.5)

plt.subplot(1, 2, 2)
plt.title("Monet-esque Photo")
plt.imshow(to_monet[0] * 0.5 + 0.5)
plt.show()

# **Define the CycleGan**

In [None]:
class CycleGan(Model):
    def __init__(self, monet_generator, photo_generator, monet_discriminator, photo_discriminator, lambda_cycle=10):
        super(CycleGan, self).__init__()
        self.m_gen = monet_generator
        self.p_gen = photo_generator
        self.m_disc = monet_discriminator
        self.p_disc = photo_discriminator
        self.lambda_cycle = lambda_cycle
        
    def compile(self, m_gen_opt, p_gen_opt, m_disc_opt, p_disc_opt, gen_loss_fn, disc_loss_fn, cycle_loss_fn, identity_loss_fn):
        super(CycleGan, self).compile()
        self.m_gen_opt = m_gen_opt
        self.p_gen_opt = p_gen_opt
        self.m_disc_opt = m_disc_opt
        self.p_disc_opt = p_disc_opt
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
    def train_step(self, batch_data):
        real_monet, real_photo = batch_data
        
        with tf.GradientTape(persistent=True) as tape:
            # Generate photo to monet and then generate back to photo
            fake_monet = self.m_gen(real_photo, training=True)
            cycled_photo = self.p_gen(fake_monet, training=True)

            # Generate monet to photo and then generate back to monet
            fake_photo = self.p_gen(real_monet, training=True)
            cycled_monet = self.m_gen(fake_photo, training=True)

            # generating itself
            same_monet = self.m_gen(real_monet, training=True)
            same_photo = self.p_gen(real_photo, training=True)

            # Send the real images to the discriminator
            disc_real_monet = self.m_disc(real_monet, training=True)
            disc_real_photo = self.p_disc(real_photo, training=True)

            # Send the fake images to the discriminator
            disc_fake_monet = self.m_disc(fake_monet, training=True)
            disc_fake_photo = self.p_disc(fake_photo, training=True)

            # evaluates generator loss
            monet_gen_loss = self.gen_loss_fn(disc_fake_monet)
            photo_gen_loss = self.gen_loss_fn(disc_fake_photo)

            # evaluates total cycle consistency loss
            total_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet) + self.cycle_loss_fn(real_photo, cycled_photo)

            # evaluates total generator loss
            total_monet_gen_loss = monet_gen_loss + self.lambda_cycle * total_cycle_loss + self.identity_loss_fn(real_monet, same_monet)
            total_photo_gen_loss = photo_gen_loss + self.lambda_cycle * total_cycle_loss + self.identity_loss_fn(real_photo, same_photo)

            # evaluates discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)

        # Calculate the gradients for generator and discriminator
        monet_generator_gradients = tape.gradient(total_monet_gen_loss, self.m_gen.trainable_variables)
        photo_generator_gradients = tape.gradient(total_photo_gen_loss, self.p_gen.trainable_variables)
        monet_discriminator_gradients = tape.gradient(monet_disc_loss, self.m_disc.trainable_variables)
        photo_discriminator_gradients = tape.gradient(photo_disc_loss, self.p_disc.trainable_variables)

        # Apply the gradients to the optimizer
        self.m_gen_opt.apply_gradients(zip(monet_generator_gradients, self.m_gen.trainable_variables))
        self.p_gen_opt.apply_gradients(zip(photo_generator_gradients, self.p_gen.trainable_variables))
        self.m_disc_opt.apply_gradients(zip(monet_discriminator_gradients, self.m_disc.trainable_variables))
        self.p_disc_opt.apply_gradients(zip(photo_discriminator_gradients, self.p_disc.trainable_variables))
        
        return {
            "monet_gen_loss": total_monet_gen_loss,
            "photo_gen_loss": total_photo_gen_loss,
            "monet_disc_loss": monet_disc_loss,
            "photo_disc_loss": photo_disc_loss
        }

# **Loss Functions**

In [None]:
def discriminator_loss(real, generated):
    real_loss = losses.BinaryCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)(tf.ones_like(real), real)
    generated_loss = losses.BinaryCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)

    total_disc_loss = real_loss + generated_loss

    return total_disc_loss * 0.5

In [None]:
def generator_loss(generated):
    return losses.BinaryCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)(tf.ones_like(generated), generated)

In [None]:
def calc_cycle_loss(real_image, cycled_image):
    return tf.reduce_mean(tf.abs(cycled_image - real_image))

In [None]:
def identity_loss(real_image, same_image):
    loss = tf.reduce_mean(tf.abs(real_image - same_image))
    
    return 0.5174 * loss

# **Optimization**

In [None]:
monet_generator_optimizer = optimizers.Adam(2e-4, beta_1=0.5)
photo_generator_optimizer = optimizers.Adam(2e-4, beta_1=0.5)

monet_discriminator_optimizer = optimizers.Adam(2e-4, beta_1=0.5)
photo_discriminator_optimizer = optimizers.Adam(2e-4, beta_1=0.5)

# **Compilation**

In [None]:
cycle_gan_model = CycleGan(monet_generator, photo_generator, monet_discriminator, photo_discriminator)

cycle_gan_model.compile(
    m_gen_opt = monet_generator_optimizer,
    p_gen_opt = photo_generator_optimizer,
    m_disc_opt = monet_discriminator_optimizer,
    p_disc_opt = photo_discriminator_optimizer,
    gen_loss_fn = generator_loss,
    disc_loss_fn = discriminator_loss,
    cycle_loss_fn = calc_cycle_loss,
    identity_loss_fn = identity_loss
)

# **Output**

In [None]:
def printResults():
  _, ax = plt.subplots(5, 2, figsize=(15, 15))
  for i, img in enumerate(photo_ds.shuffle(300).take(5)):
      prediction = monet_generator(img, training=False)[0].numpy()
      prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
      img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

      ax[i, 0].imshow(img)
      ax[i, 1].imshow(prediction)
      ax[i, 0].set_title("Input Photo")
      ax[i, 1].set_title("Monet-esque")
      ax[i, 0].axis("off")
      ax[i, 1].axis("off")
  plt.show()

# **Training**

In [None]:
history = cycle_gan_model.fit(tf.data.Dataset.zip((monet_ds, photo_ds)), epochs=30)

In [None]:
printResults()