## Version Log

1. Initial NB
2. Added [monet-tfrecords-extdata](https://www.kaggle.com/doanquanvietnamca/monet-tfrecords-extdata) & trained for full 31 epochs

### Necessary Imports

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers 
import tensorflow_addons as tfa

from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt
import numpy as np

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print("Device :",tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
    
except:
    strategy = tf.distribute.get_strategy()

print("Number of replicas : ", strategy.num_replicas_in_sync)


AUTO = tf.data.experimental.AUTOTUNE
print(tf.__version__)

In [None]:
from kaggle_datasets import KaggleDatasets
GCS_PATH = KaggleDatasets().get_gcs_path('gan-getting-started')
GCS_PATH

Version 2 update....

Adding extra Monet TFRecords Data

In [None]:
GCS_PATH_MONET = KaggleDatasets().get_gcs_path('monet-tfrecords-extdata')
GCS_PATH_MONET

In [None]:
MONET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))
print("Monet TFRecord Files : ", len(MONET_FILENAMES))
print(MONET_FILENAMES)

PHOTO_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/photo_tfrec/*.tfrec'))
print("Photo TFRecord Files : ", len(PHOTO_FILENAMES))

In [None]:
# Version 2

NEW_MONET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH_MONET + '/monet*.tfrec'))
print("New Monet TFRecord Files : ", len(NEW_MONET_FILENAMES))
#print(NEW_MONET_FILENAMES)

TOTAL_MONET_FILES = []
TOTAL_MONET_FILES.extend(MONET_FILENAMES)
TOTAL_MONET_FILES.extend(NEW_MONET_FILENAMES)
print(len(TOTAL_MONET_FILES))

In [None]:
TOTAL_MONET_FILES

In [None]:
IMAGE_SIZE = [256,256]

def decode_image(image):
    
    image = tf.image.decode_jpeg(image, channels = 3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    
    return image


def read_tfrecord(example):
    tfrecord_format = {
        "image_name" : tf.io.FixedLenFeature([], tf.string),
        "image" : tf.io.FixedLenFeature([], tf.string),
        "target" : tf.io.FixedLenFeature([], tf.string)
    }
    
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

In [None]:
def load_dataset(filenames, labeled = True, ordered = False):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls = tf.data.experimental.AUTOTUNE)
    
    return dataset

## Data Augment

In [None]:
import random
import tensorflow as tf
import math
import tensorflow.keras.backend as K 


def transform_rotation(image, height, rotation):
    
    dim = height
    xdim = height % 2
    rotation = rotation * tf.random.uniform([1], dtype = 'float32')
    rotation = math.pi/180 * rotation
    
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1], dtype='float32')
    zero = tf.constant([0], dtype='float32')
    
    rotation_matrix = tf.reshape(tf.concat([c1,s1,zero,-s1,c1,zero,zero,zero,one],axis=0),shape=[3,3])

    
    x = tf.repeat(tf.range(dim//2,-dim//2,-1), dim) # shape like : [111222333]
    y = tf.tile(tf.range(-dim//2,dim//2), [dim]) #shape like: [123123123]
    z = tf.ones([dim*dim], dtype='int32')
    
    idx = tf.stack([x,y,z])
    
    idx2 = K.dot(rotation_matrix, tf.cast(idx, dtype='float32'))
    idx2 = K.cast(idx2, dtype='int32')
    idx2 = K.clip(idx2,-dim//2+1+xdim, dim//2)
    
    idx3 = tf.stack([dim//2 -idx2[0], dim//2 -1 + idx2[1,]])
    
    
    d = tf.gather_nd(image, tf.transpose(idx3))
    return tf.reshape(d, shape = [dim, dim, 3])
    
    # rotation matrix
    
def data_augment(image):
    
    p_rotate = tf.random.uniform([], 0, 1.0, dtype = tf.float32)
    p_spatial = tf.random.uniform([], 0, 1.0, dtype = tf.float32)
    p_crop = tf.random.uniform([], 0, 1.0, dtype = tf.float32)
    
    if p_crop > 0.5 : 
        temp_size = random.randint(260, 290)
        image = tf.image.resize(image, size= [temp_size,temp_size])
        image = tf.image.random_crop(image, size = [256,256,3])
        
        if p_crop > 0.9:
            temp_size2 = random.randint(290, 310)
            image = tf.image.resize(image, [temp_size2, temp_size2])
            image = tf.image.random_crop(image, size = [256,256,3])
            
        
    if p_rotate > 0.75:
        image = tf.image.rot90(image,k=3)
    elif p_rotate > .5:
        image = tf.image.rot90(image,k=2)
    elif p_rotate > 0.25:
        image = tf.image.rot90(image,k=1)

    if p_rotate >=0.3:
        image = transform_rotation(image, height=256,rotation=45.)
        
        
    if p_spatial >0.6:
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_flip_up_down(image)
        if p_spatial >0.85:
            image = tf.image.transpose(image)

    return image
    
    

### GAN Dataset

In [None]:
AUTO = tf.data.experimental.AUTOTUNE

def get_gan_dataset(monet_files, photo_files, augment = None, repeat = True, shuffle = True, batch_size = 1):
    monet_ds = load_dataset(monet_files)
    photo_ds = load_dataset(photo_files)
    
    if augment: 
        monet_ds = monet_ds.map(augment, num_parallel_calls = AUTO)
        photo_ds = photo_ds.map(augment, num_parallel_calls = AUTO)
        
    
    if repeat:
        monet_ds = monet_ds.repeat()
        photo_ds= photo_ds.repeat()
        
    if shuffle:
        monet_ds = monet_ds.shuffle(2048)
        photo_ds = photo_ds.shuffle(2048)
        
    
    monet_ds = monet_ds.batch(batch_size, drop_remainder = True)
    photo_ds = photo_ds.batch(batch_size, drop_remainder = True)
    monet_ds = monet_ds.cache()
    photo_ds = photo_ds.cache()
    monet_ds = monet_ds.prefetch(AUTO)
    photo_ds = photo_ds.prefetch(AUTO)
    
    gan_ds = tf.data.Dataset.zip((monet_ds, photo_ds))
    
    
    return gan_ds

In [None]:
#gan_dataset = get_gan_dataset(MONET_FILENAMES, PHOTO_FILENAMES, augment = data_augment, repeat = True, shuffle = True, batch_size = 1)

In [None]:
# Version 2 update

gan_dataset = get_gan_dataset(TOTAL_MONET_FILES, PHOTO_FILENAMES, augment = data_augment, repeat = True, shuffle = True, batch_size = 1)


### Visualizing the Dataset

Version 2 Update

In [None]:
import matplotlib.pyplot as plt

In [None]:
a_tuple = None
for data in gan_dataset.take(1):
    a_tuple = data

In [None]:
# PAINTING

plt.imshow(np.array(a_tuple[0]).squeeze())

In [None]:
# PHOTO IMAGE

plt.imshow(np.array(a_tuple[1]).squeeze())

### Developing the GAN Structure

In [None]:
OUTPUT_CHANNELS = 3

def downsample(filters, size, apply_instancenorm = True):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean = 0.0, stddev = 0.02)
    
    result = keras.Sequential()
    result.add(layers.Conv2D(filters, size, strides = 2, padding= 'same', kernel_initializer = initializer, use_bias = False))
    
    if apply_instancenorm:
        result.add(tfa.layers.InstanceNormalization(gamma_initializer = gamma_init))
        
    result.add(layers.LeakyReLU())
    
    return result

In [None]:
def upsample(filters, size, apply_dropout = False):
    
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean = 0.0, stddev = 0.02)
    
    result = keras.Sequential()
    result.add(layers.Conv2DTranspose(filters, size, strides = 2, padding = 'same', kernel_initializer = initializer, use_bias = False))
    result.add(tfa.layers.InstanceNormalization(gamma_initializer = gamma_init))
    
    if apply_dropout:
        result.add(layers.Dropout(0.5))
        
    result.add(layers.ReLU())
    
    return result

In [None]:
def Generator():
    
    inputs = layers.Input(shape = [256,256,3])
    
    down_stack = [
        downsample(64, 4, apply_instancenorm = False),
        downsample(128,4),
        downsample(256,4),
        downsample(512,4),
        downsample(512,4),
        downsample(512,4),
        downsample(512,4),
        downsample(512,4),
    ]
    
    
    up_stack = [
        upsample(512, 4, apply_dropout=True),
        upsample(512, 4, apply_dropout=True),  
        upsample(512, 4, apply_dropout=True),
        upsample(512, 4),
        upsample(256, 4),
        upsample(128, 4),
        upsample(64, 4)
    ]
    
    initializer = tf.random_normal_initializer(0., 0.02)
    last = layers.Conv2DTranspose(OUTPUT_CHANNELS, 4, strides = 2, padding = "same",kernel_initializer = initializer, 
                                 activation = 'tanh')
    
    x = inputs
    
    # Downsampling
    
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)
    
    skips = reversed(skips[:-1])
    
    for up, skip in zip(up_stack, skips):
        
        x = up(x)
        x = layers.Concatenate()([x, skip])
        
    x = last(x)
        
        
    return keras.Model(inputs = inputs, outputs = x)

In [None]:
def Discriminator():
    
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean = 0.0, stddev = 0.02)
    
    inp = layers.Input(shape = [256, 256, 3], name = 'input_image')
    
    x = inp
    
    down1 = downsample(64, 4, False)(x)
    down2 = downsample(128, 4)(down1)
    down3 = downsample(256, 4)(down2)
    
    zero_pad1 = layers.ZeroPadding2D()(down3)
    conv = layers.Conv2D(512, 4, strides = 1, kernel_initializer = initializer, use_bias = False)(zero_pad1)
    norm1 = tfa.layers.InstanceNormalization(gamma_initializer = gamma_init)(conv)
    
    leaky_relu = layers.LeakyReLU()(norm1)
    
    zero_pad2 = layers.ZeroPadding2D()(leaky_relu)
    
    last = layers.Conv2D(1, 4, strides = 1, kernel_initializer = initializer)(zero_pad2)
    
    return tf.keras.Model(inputs = inp, outputs = last)
    
    

In [None]:
with strategy.scope():
    
    monet_generator = Generator()
    photo_generator = Generator()
    
    monet_discriminator = Discriminator()
    photo_discriminator = Discriminator()

### Train the CycleGAN tutorial

In [None]:
with strategy.scope():
    
    monet_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1= 0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1 = 0.5)
    
    monet_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1 = 0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1 = 0.5)

In [None]:
class CycleGAN(keras.Model):
    
    
    def __init__(self, monet_generator, photo_generator, monet_discriminator, photo_discriminator, lambda_cycle = 10):
        
        super(CycleGAN, self).__init__()
        
        self.m_gen = monet_generator
        self.p_gen = photo_generator
        self.m_disc = monet_discriminator
        self.p_disc = photo_discriminator
        self.lambda_cycle = lambda_cycle
        
        
    def compile(self, m_gen_optimizer, p_gen_optimizer, m_disc_optimizer, p_disc_optimizer, gen_loss_fn, disc_loss_fn, cycle_loss_fn, identity_loss_fn):
        super(CycleGAN, self).compile()
        
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
        
    def train_step(self, batch_data):
        
        
        real_monet, real_photo = batch_data
        
        with tf.GradientTape(persistent = True) as tape:
            
            
            fake_monet = self.m_gen(real_photo, training = True)
            cycled_photo = self.p_gen(fake_monet, training = True)
            
            fake_photo = self.p_gen(real_monet, training = True)
            cycled_monet = self.m_gen(fake_photo, training = True)
            
            # generating itself
            same_monet = self.m_gen(real_monet, training = True)
            same_photo = self.p_gen(real_photo, training = True)
            
            # discriminator used to check, inputing real images
            disc_real_monet = self.m_disc(real_monet, training = True)
            disc_real_photo = self.p_disc(real_photo, training = True)
            
            # discriminator used to check, inputing fake images
            disc_fake_monet = self.m_disc(fake_monet, training = True)
            disc_fake_photo = self.p_disc(fake_photo, training = True)
            
            # evaluates generator loss
            monet_gen_loss = self.gen_loss_fn(disc_fake_monet)
            photo_gen_loss = self.gen_loss_fn(disc_fake_photo)
            
            # total cycle consistency loss
            total_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle) + self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle)
            
            
            # total generator loss
            
            total_monet_gen_loss = monet_gen_loss + total_cycle_loss + self.identity_loss_fn(real_monet, same_monet, self.lambda_cycle)
            
            
            total_photo_gen_loss = photo_gen_loss + total_cycle_loss + self.identity_loss_fn(real_photo, same_photo, self.lambda_cycle)
            
            
            # total discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)
            
            # gradients for generator & discriminator
            monet_generator_gradients = tape.gradient(total_monet_gen_loss,
                                                  self.m_gen.trainable_variables)
            photo_generator_gradients = tape.gradient(total_photo_gen_loss,
                                                  self.p_gen.trainable_variables)
            monet_discriminator_gradients = tape.gradient(monet_disc_loss,
                                                      self.m_disc.trainable_variables)
            photo_discriminator_gradients = tape.gradient(photo_disc_loss,
                                                      self.p_disc.trainable_variables)
            # gradients to the optimizer
            self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients, self.m_gen.trainable_variables))
            self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients, self.p_gen.trainable_variables))
            self.m_disc_optimizer.apply_gradients(zip(monet_discriminator_gradients, self.m_disc.trainable_variables))
            self.p_disc_optimizer.apply_gradients(zip(photo_discriminator_gradients, self.p_disc.trainable_variables))
            
            
        
        return {
            "monet_gen_loss" : total_monet_gen_loss,
            "photo_gen_loss" : total_photo_gen_loss,
            "monet_disc_loss" : monet_disc_loss,
            "photo_disc_loss" : photo_disc_loss
        }

In [None]:
"""class EarlyStopping_custom(tf.keras.callbacks.EarlyStopping):
    
    def __init__(self, monitor, mode, patience, restore_best_weights = True):
        
        super(EarlyStopping_custom, self).__init__()
        self.restore_best_weights = restore_best_weights
        self.monitor = monitor
        self.mode = mode
        self."""

In [None]:
with strategy.scope():
    
    def discriminator_loss(real, generated):
        
        real_loss = tf.keras.losses.BinaryCrossentropy(from_logits = True, reduction = tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)
        
        generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits = True, reduction = tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)
        
        total_disc_loss = real_loss + generated_loss
        
        return total_disc_loss * 0.5

In [None]:
with strategy.scope():
    
    def generator_loss(generated):
        
        return tf.keras.losses.BinaryCrossentropy(from_logits = True, reduction = tf.keras.losses.Reduction.NONE)(tf.ones_like(generated), generated)

In [None]:
with strategy.scope():
    
    def calc_cycle_loss(real_image, cycled_image, LAMBDA):
        loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
        
        return LAMBDA * loss1

In [None]:
with strategy.scope():
    
    def identity_loss(real_image, same_image, LAMBDA):
        
        loss = tf.reduce_mean(tf.abs(real_image - same_image))
        return LAMBDA * 0.5 * loss

In [None]:
with strategy.scope():
    
    monet_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1 = 0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1 = 0.5)
    
    monet_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1 = 0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1 = 0.5)
    
    

In [None]:
with strategy.scope():
    
    cycle_gan_model = CycleGAN(
        monet_generator, photo_generator, monet_discriminator, photo_discriminator
    )
    
    cycle_gan_model.compile(
        m_gen_optimizer = monet_generator_optimizer, 
        p_gen_optimizer = photo_generator_optimizer, 
        m_disc_optimizer = monet_discriminator_optimizer, 
        p_disc_optimizer = photo_discriminator_optimizer, 
        gen_loss_fn = generator_loss,
        disc_loss_fn = discriminator_loss,
        cycle_loss_fn = calc_cycle_loss,
        identity_loss_fn = identity_loss
    )

In [None]:
EPOCHS_NUM = 31

In [None]:
BATCH_SIZE = 2
steps_per_epoch = (7038//BATCH_SIZE)
#gan_dataset = get_gan_dataset(MONET_FILENAMES,PHOTO_FILENAMES,augment=data_augument,repeat=True,shuffle=True,batch_size=BATCH_SIZE)

In [None]:
cycle_gan_model.fit(
    #full_dataset,
    gan_dataset,
    epochs = EPOCHS_NUM,
    #steps_per_epoch = (max(n_monet_samples, n_photo_samples) // BATCH_SIZE)
    steps_per_epoch = steps_per_epoch
)

# Notebook in making...

### Possible Hyperparameters to tune

* IMAGE_SIZE
* NUMBER OF EPOCHS
* Optimizer Selected