In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# Import Libraries

In [None]:
# https://arxiv.org/pdf/1703.10593.pdf
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
from tensorflow.keras import activations
from tensorflow.keras.losses import BinaryCrossentropy as BCE

from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt
import numpy as np

# Set up TPU

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path()

# Load Data

In [None]:
MONET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))
print('Monet TFRecord Files:', len(MONET_FILENAMES))

PHOTO_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/photo_tfrec/*.tfrec'))
print('Photo TFRecord Files:', len(PHOTO_FILENAMES))

# Create Datasets

In [None]:
IMAGE_SIZE = [256, 256]

def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    #image= normalize(tf.cast(image, tf.float32))
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_tfrecord(example):
    tfrecord_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

In [None]:
def load_dataset(filenames, labeled=True, ordered=False):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    return dataset

In [None]:
monet_ds = load_dataset(MONET_FILENAMES, labeled=True).batch(1)
photo_ds = load_dataset(PHOTO_FILENAMES, labeled=True).batch(1)

In [None]:
example_monet = next(iter(monet_ds))
example_photo = next(iter(photo_ds))

In [None]:
plt.imshow(example_monet[0]* 0.5 + 0.5)

# Build downsample and upsample blocks

In [None]:
def downsample(num_filters, kernel_size, apply_instancenorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
    
    encode_block = keras.Sequential()
    #if pool:
    #    encode_block.add(layers.MaxPooling2D(pool_size=(2, 2),strides=(2, 2), padding='valid'))
    # keeps same image size in convolution but increases filters
    encode_block.add(layers.Conv2D(num_filters, kernel_size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))
    # max pool to downsample by factor of 2
    
    # instance norm, better for this type of problem
    if apply_instancenorm:
        encode_block.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))
    encode_block.add(layers.LeakyReLU())
    
    return encode_block
    
    

In [None]:
def upsample(num_filters, kernel_size, dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
    
    decode_block = keras.Sequential()
    #upsample by doubling size
    #decode_block.add(layers.UpSampling2D())
    # should i use bias here?
    decode_block.add(layers.Conv2DTranspose(num_filters, kernel_size, strides=2,padding='same',
                                            kernel_initializer=initializer, use_bias=False))
    decode_block.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))
    if dropout:
        decode_block.add(layers.Dropout(0.5))
    decode_block.add(layers.LeakyReLU())
    
    return decode_block
    

# Build Generator UNet

In [None]:
# https://www.researchgate.net/figure/The-architecture-of-the-U-Net-model-used-in-this-study-with-an-input-sample-size-of-256_fig2_341034897
def Generator():
    # initialize tensor size
    inputs = layers.Input(shape=[256,256,3])
    
    
    down_stack = [          # shape after convolutional layer in downsample
        downsample(64,4, apply_instancenorm=False),   # (bs, 128, 128, 64)
        downsample(128,4), # (bs, 64, 64, 128)
        downsample(256,4), # (bs, 32, 32, 256)
        downsample(512,4), # (bs, 16, 16, 512)
        # only convolutional layer for bottom of UNet
        downsample(1024,4), # (bs, 8, 8, 1024)
        downsample(1024,4), # (bs, 4, 4, 1024)
    ]
    
    up_stack = [                        # shape after convolutional transpose layer
        upsample(1024,4, dropout=True), # (bs, 8, 8, 1024)
        upsample(512, 4), # (bs, 16, 16, 512)
        upsample(256, 4), # (bs, 32, 32, 256)
        upsample(128, 4), # (bs, 64, 64, 128)
        upsample(64, 4)                 # (bs, 128, 128, 64)
    ]
    
    initializer = tf.random_normal_initializer(0., 0.02)
    # final layer to get back to (256,256,3)
    last_layer = layers.Conv2DTranspose(3, 4,strides=2, padding='same', 
                                        kernel_initializer=initializer, activation='tanh') # (bs, 256, 256, 3)
    
    x = inputs
    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)
        #print(x)
    # last 2 downsamples are bottom of unet
    skips = reversed(skips[:-1])
    #for i in skips:
        #print(i)
    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        #print(x)
        x = layers.Concatenate()([x, skip])

    x = last_layer(x)
    
    return keras.Model(inputs=inputs, outputs=x)
    
    

In [None]:
test = Generator()
test.summary()

# Build PatchGAN Discriminator 

In [None]:
# PatchGAN architecture
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    inp = layers.Input(shape=[256, 256, 3], name='input_image')
    x = inp
    x = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
    x = downsample(128, 4)(x) # (bs, 64, 64, 128)
    x = downsample(256, 4)(x) # (bs, 32, 32, 256)
    #x = downsample(512, 2)(x) # (bs, 16, 16, 512)
    x = layers.Conv2D(512, 4, strides=1, padding='same',
                         kernel_initializer=initializer,
                         use_bias=False)(x) # (bs, 32, 32, 512)
    #print(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(x)
    x = layers.LeakyReLU()(x)
    
    out = layers.Conv2D(1, 4, strides=1, padding='same',kernel_initializer=initializer, use_bias=False)(x) # (bs, 29, 29, 1)
    #out = activations.sigmoid(out)
    
    return tf.keras.Model(inputs=inp, outputs=out)
    
    
    
    

In [None]:
test = Discriminator()
tf.keras.utils.plot_model(test, show_shapes=True, dpi=70)

In [None]:
# Initialize models
with strategy.scope():
    monet_generator = Generator() # transforms photos to Monet-esque paintings
    photo_generator = Generator() # transforms Monet paintings to be more like photos

    monet_discriminator = Discriminator() # differentiates real Monet paintings and generated Monet paintings
    photo_discriminator = Discriminator() # differentiates real photos and generated photos

In [None]:
to_monet = monet_generator(example_photo)

plt.subplot(1, 2, 1)
plt.title("Original Photo")
plt.imshow(example_photo[0] * 0.5 + 0.5)

plt.subplot(1, 2, 2)
plt.title("Monet-esque Photo")
plt.imshow(to_monet[0] * 0.5 + 0.5)
plt.show()

In [None]:
# define cycleGAN model
class CycleGan(keras.Model):
    def __init__(
        self,
        monet_generator,
        photo_generator,
        monet_discriminator,
        photo_discriminator,
        lambda_cycle=10,
    ):
        super(CycleGan, self).__init__()
        self.m_gen = monet_generator
        self.p_gen = photo_generator
        self.m_disc = monet_discriminator
        self.p_disc = photo_discriminator
        self.lambda_cycle = lambda_cycle
    
    def compile(
        self,
        m_gen_optimizer,
        p_gen_optimizer,
        m_disc_optimizer,
        p_disc_optimizer,
        gen_loss_fn,
        disc_loss_fn,
        cycle_loss_fn,
        identity_loss_fn
    ):
        super(CycleGan, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
    def train_step(self, batch_data):
        real_monet, real_photo = batch_data
        
        with tf.GradientTape(persistent=True) as tape:
            # photo to monet back to photo
            # forward cycle
            fake_monet = self.m_gen(real_photo, training=True)
            cycled_photo = self.p_gen(fake_monet, training=True)

            # monet to photo back to monet
            # backward cycle
            fake_photo = self.p_gen(real_monet, training=True)
            cycled_monet = self.m_gen(fake_photo, training=True)

            # generating itself
            same_monet = self.m_gen(real_monet, training=True)
            same_photo = self.p_gen(real_photo, training=True)

            # discriminator used to check, inputing real images
            disc_real_monet = self.m_disc(real_monet, training=True)
            disc_real_photo = self.p_disc(real_photo, training=True)
            
            # discriminator used to check, inputing fake images
            disc_fake_monet = self.m_disc(fake_monet, training=True)
            disc_fake_photo = self.p_disc(fake_photo, training=True)

            # evaluates generator loss
            monet_gen_loss = self.gen_loss_fn(disc_fake_monet)
            photo_gen_loss = self.gen_loss_fn(disc_fake_photo)

            # evaluates total cycle consistency loss
            total_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle) + self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle)

            # evaluates total generator loss
            total_monet_gen_loss = monet_gen_loss + total_cycle_loss + self.identity_loss_fn(real_monet, same_monet, self.lambda_cycle)
            total_photo_gen_loss = photo_gen_loss + total_cycle_loss + self.identity_loss_fn(real_photo, same_photo, self.lambda_cycle)

            # evaluates discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)
            
            # Calculate the gradients for generator and discriminator
        monet_generator_gradients = tape.gradient(total_monet_gen_loss,
                                                  self.m_gen.trainable_variables)
        photo_generator_gradients = tape.gradient(total_photo_gen_loss,
                                                  self.p_gen.trainable_variables)

        monet_discriminator_gradients = tape.gradient(monet_disc_loss,
                                                      self.m_disc.trainable_variables)
        photo_discriminator_gradients = tape.gradient(photo_disc_loss,
                                                      self.p_disc.trainable_variables)

        # Apply the gradients to the optimizer
        self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients,
                                                 self.m_gen.trainable_variables))

        self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients,
                                                 self.p_gen.trainable_variables))

        self.m_disc_optimizer.apply_gradients(zip(monet_discriminator_gradients,
                                                  self.m_disc.trainable_variables))
        self.p_disc_optimizer.apply_gradients(zip(photo_discriminator_gradients,
                                                  self.p_disc.trainable_variables))
        
        return {
            "monet_gen_loss": total_monet_gen_loss,
            "photo_gen_loss": total_photo_gen_loss,
            "monet_disc_loss": monet_disc_loss,
            "photo_disc_loss": photo_disc_loss
        }

# Loss functions

In [None]:
# take average of fake and real discriminator loss bc disc is comparing real vs fake
with strategy.scope():
    def discriminator_loss(real, generated):
        # want real to be predicted 1
        real_loss = BCE(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)
        # want generated to be predicted 0
        generated_loss = BCE(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)

        total_disc_loss = real_loss + generated_loss

        return total_disc_loss * 0.5

In [None]:
with strategy.scope():
    def generator_loss(generated):
        return BCE(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(generated), generated)

In [None]:
with strategy.scope():
    def calc_cycle_loss(real_image, cycled_image, LAMBDA):
        loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

        return LAMBDA * loss1

In [None]:
# add in identity loss for better performance
with strategy.scope():
    def identity_loss(real_image, same_image, LAMBDA):
        loss = tf.reduce_mean(tf.abs(real_image - same_image))
        return LAMBDA * 0.5 * loss

In [None]:
# optimizers
with strategy.scope():
    monet_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    monet_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
# instantiate model
with strategy.scope():
    cycle_gan_model = CycleGan(
        monet_generator, photo_generator, monet_discriminator, photo_discriminator
    )

    cycle_gan_model.compile(
        m_gen_optimizer = monet_generator_optimizer,
        p_gen_optimizer = photo_generator_optimizer,
        m_disc_optimizer = monet_discriminator_optimizer,
        p_disc_optimizer = photo_discriminator_optimizer,
        gen_loss_fn = generator_loss,
        disc_loss_fn = discriminator_loss,
        cycle_loss_fn = calc_cycle_loss,
        identity_loss_fn = identity_loss
    )

In [None]:
checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=monet_generator,
                           generator_f=photo_generator,
                           discriminator_x=monet_discriminator,
                           discriminator_y=monet_discriminator,
                           generator_g_optimizer=monet_generator_optimizer,
                           generator_f_optimizer=photo_generator_optimizer,
                           discriminator_x_optimizer=monet_discriminator_optimizer,
                           discriminator_y_optimizer=photo_discriminator_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

In [None]:
# fit model
cycle_gan_model.fit(
    tf.data.Dataset.zip((monet_ds, photo_ds)),
    epochs=100
)

# Visualize cycled images

In [None]:
_, ax = plt.subplots(10, 3, figsize=(30, 30))
for i, img in enumerate(photo_ds.take(10)):
    #print(img.shape)
    prediction = monet_generator(img, training=False)[0].numpy()
    #print(prediction.shape, tf.reshape(tf.convert_to_tensor(prediction),shape=(1,256,256,3)).shape)
    cycled = photo_generator(tf.reshape(tf.convert_to_tensor(prediction),shape=(1,256,256,3)), training=False)[0].numpy()
    cycled = (cycled * 127.5 + 127.5).astype(np.uint8)
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

    ax[i, 0].imshow(img)
    ax[i, 1].imshow(prediction)
    ax[i, 2].imshow(cycled)
    ax[i, 0].set_title("Input Photo")
    ax[i, 1].set_title("Monet-esque")
    ax[i, 2].set_title("Cycled")
    ax[i, 0].axis("off")
    ax[i, 1].axis("off")
    ax[i, 2].axis("off")
plt.show()