In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

## 1. Import libraries 

In [None]:
import tensorflow as tf 
import tensorflow_addons as tfa 

import matplotlib.pyplot as plt
import matplotlib.image as mpimg 
import cv2 
import os

import numpy as np
import pandas as pd

## 2. Set up the input pipeline

In [None]:
try:
  tpu = tf.distribute_cluster_resolver.TPUClusterResolver()
  print('Device:', tpu.master())
  tf.config.experimental_connect_to_cluster(tpu)
  tf.tpu.experimental.initialize_tpu_system(tpu)
  strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
  strategy = tf.distribute.get_strategy()

print('Number of replicas:', strategy.num_replicas_in_sync)

print(tf.__version__)

## 3. Hyperparameter config

In [None]:
IMG_WIDTH = 256
IMG_HEIGTH = 256
CHANNELS = 3

BATCH_SIZE = 1
BUFFER_SIZE = 1000


AUTOTUNE = tf.data.experimental.AUTOTUNE

## 4. Helper functions

In [None]:
def resize(image):
  image = tf.image.resize(image, [IMG_HEIGTH, IMG_WIDTH])
  return image

In [None]:
# Random Crop 
def random_crop(image):
  cropped_image = tf.image.random_crop(image, size = [IMG_HEIGTH, IMG_WIDTH, CHANNELS])
  return cropped_image

In [None]:
# Normalize the pictures to [-1, 1]
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image 

In [None]:
def denormalize(image):
  image = tf.cast(image, tf.float32)
  image = (image * 0.5 + 0.5) * 255 # Range 0 to 1 and then to range 0..255
  image = tf.cast(image, tf.int32)
  return image

In [None]:
def random_jitter(image):
  # resize the image to 286 x 286 x 3
  image = tf.image.resize(image, [286,286], method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # randomly cropping 256 x 256 x 3
  image = random_crop(image)

  # random mirroring
  image = tf.image.random_flip_left_right(image)

  return image

In [None]:
def preprocess_image_train(image):
  """
  For preprocessing training images 
  """
  image = resize(image)
  image = random_jitter(image)
  image = normalize(image)

  return image 

In [None]:
def preprocess_image_test(image):
  """
  For preprocessing test images 
  """
  image = resize(image)
  image = normalize(image)

  return image 

In [None]:
def count_data_items(filenames):
  """
  Getting the number of files inside your image folder
  """
  n = [filename for filename in filenames]
  return len(n)

In [None]:
def read_tfrecord(example):
  tfrecord_format = {
      "image_name": tf.io.FixedLenFeature([], tf.string),
      "image": tf.io.FixedLenFeature([],tf.string),
      "target": tf.io.FixedLenFeature([], tf.string)
  }

  example = tf.io.parse_single_example(example, tfrecord_format)
  image = tf.image.decode_jpeg(example['image'], channels = 3)
  return image 

In [None]:
def load_dataset(filenames, labeled=True, ordered=False):
  dataset = tf.data.TFRecordDataset(filenames)
  dataset = dataset.map(read_tfrecord, num_parallel_calls = AUTOTUNE)
  return dataset

## 5. Input Pipeline

In [None]:
MONET_TF_PATH = '../input/gan-getting-started/monet_tfrec'
PHOTO_TF_PATH = '../input/gan-getting-started/photo_tfrec'

In [None]:
MONET_FILENAMES = tf.io.gfile.glob(MONET_TF_PATH + '/*.tfrec')
PHOTO_FILENAMES = tf.io.gfile.glob(PHOTO_TF_PATH + '/*.tfrec')

In [None]:
monet_ds = load_dataset(MONET_FILENAMES, labeled=True)
photo_ds = load_dataset(PHOTO_FILENAMES, labeled=True)

In [None]:
plt.figure(figsize=(10,10))

for i, img in enumerate(monet_ds.take(4)):
  plt.subplot(2,4,i+1)
  plt.imshow(img) 

for i, img in enumerate(photo_ds.take(4)):
  plt.subplot(2,4,i+5)
  plt.imshow(img) 
    
plt.show()

In [None]:
MONET_JPG_PATH = '../input/gan-getting-started/monet_jpg'
PHOTO_JPG_PATH = '../input/gan-getting-started/photo_jpg'

In [None]:
MONET_DATA_SIZE = count_data_items(tf.io.gfile.glob(MONET_JPG_PATH + '/*.jpg'))
PHOTO_DATA_SIZE = count_data_items(tf.io.gfile.glob(PHOTO_JPG_PATH + '/*.jpg'))

In [None]:
train_size = int(0.7 * PHOTO_DATA_SIZE)
test_size = int(0.15 * PHOTO_DATA_SIZE)
val_size = int(0.15 * PHOTO_DATA_SIZE)
BUFFER_SIZE = 1000

photo_ds = photo_ds.shuffle(BUFFER_SIZE)

train_photo = photo_ds.take(train_size)
test_photo = photo_ds.skip(train_size)
test_photo = photo_ds.take(test_size)

val_photo = photo_ds.skip(train_size+test_size)

In [None]:
train_size = int(0.7 * MONET_DATA_SIZE)
test_size = int(0.15 * MONET_DATA_SIZE)
val_size = int(0.15 * MONET_DATA_SIZE)
BUFFER_SIZE = 1000

monet_ds = monet_ds.shuffle(BUFFER_SIZE)

train_monet = monet_ds

train_monet = monet_ds.take(train_size)

test_monet = monet_ds.skip(train_size)
test_monet = monet_ds.take(test_size)

val_monet = monet_ds.skip(train_size+test_size)

In [None]:
train_monet = train_monet.map(preprocess_image_train, 
                              num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)

test_monet = test_monet.map(preprocess_image_test, 
                              num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)

val_monet = val_monet.map(preprocess_image_test, 
                              num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)

train_photo = train_photo.map(preprocess_image_train, 
                              num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)

test_photo = test_photo.map(preprocess_image_test, 
                              num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)

val_photo = val_photo.map(preprocess_image_test, 
                              num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)

In [None]:
plt.figure(figsize=(10,10))

for i, img in enumerate(train_monet.take(4)):
  plt.subplot(2,4,i+1)
  plt.imshow(denormalize(img[0,...]), vmin=0, vmax=255) # first dimension (batch) is eliminated and we denormalize the image 
    
for i, img in enumerate(train_photo.take(4)):
  plt.subplot(2,4,i+5)
  plt.imshow(denormalize(img[0,...]), vmin=0, vmax=255) # first dimension (batch) is eliminated and we denormalize the image
    
plt.show()

## 6. Loss function

In [None]:
LAMBDA = 10

In [None]:
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
# Discriminator_loss
def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss

In [None]:
def generated_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

In [None]:
def cal_cycle_loss(real_image, cycled_image):

    loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

    return LAMBDA * loss1

In [None]:
def identity_loss(real_image, same_image):

    loss = tf.reduce_mean(tf.abs(real_image - same_image))
    
    return LAMBDA * 0.5 * loss

## 7. Build the Generator

In [None]:
def downsample(filters, apply_norm = True):

  result = tf.keras.Sequential()

  initializer = tf.random_normal_initializer(0, 0.02)

  # Convolutional layer 
  result.add(tf.keras.layers.Conv2D(filters, 
                                    kernel_size = 4, 
                                    strides = 2, 
                                    padding = 'same', 
                                    kernel_initializer = initializer, 
                                    use_bias = not apply_norm)) # When applying Normalization you have already have the bias implicit
  # Normalization layer 
  if apply_norm:
    gamma_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev = 0.02)
    result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

  # Activation layer 
  result.add(tf.keras.layers.LeakyReLU())

  return result 

In [None]:
def upsample(filters, apply_dropout=True):

  result = tf.keras.Sequential()

  initializer = tf.random_normal_initializer(0, 0.02)

  # Transpose convolutional layer 
  result.add(tf.keras.layers.Conv2DTranspose(filters, 
                                             kernel_size = 4, 
                                             strides = 2,
                                             padding = 'same',
                                             kernel_initializer = initializer,
                                             use_bias = False))

  # Normalization layer 
  gamma_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
  result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

  # Dropout layer
  if apply_dropout:
    result.add(tf.keras.layers.Dropout(0.5))

  # Activation layer 
  result.add(tf.keras.layers.ReLU())

  return result 


In [None]:
class Generator(tf.keras.Model):

  def __init__(self):
    super(Generator, self).__init__()
    self.downstack = [downsample(64, apply_norm=False),
                      downsample(128),
                      downsample(256),
                      downsample(512),
                      downsample(512),
                      downsample(512),
                      downsample(512),
                      downsample(512)]

    self.upstack = [upsample(512),
                    upsample(512),
                    upsample(512),
                    upsample(512, apply_dropout=False),
                    upsample(256, apply_dropout=False),
                    upsample(128, apply_dropout=False),
                    upsample(64, apply_dropout=False)]

    self.last = tf.keras.layers.Conv2DTranspose(filters = 3,
                                                kernel_size = 4,
                                                strides = 2,
                                                padding = 'same',
                                                kernel_initializer = tf.random_normal_initializer(0,0.02),
                                                activation = 'tanh')

  def call(self, inputs, training=False):
    x = inputs
    skips = []

    # Add the Encoder blocks to the model and save the outputs to later perform the Skip Connections
    for down in self.downstack:
      x = down(x)
      skips.append(x)

    # We eliminate the last layer of the Skips connections, since it will be a direct input in the first block of Encoder,
    # and we turn the Skips around since the second layer must connect 
    skips = reversed(skips[:-1])

    # We add the Decoder blocks to the model and the Skips connection
    for up, skip in zip(self.upstack, skips):
      x = up(x)
      x = tf.keras.layers.Concatenate()([x, skip])

    return self.last(x)

In [None]:
generator = Generator()
generator.build((None, 256, 256, 3))
generator.summary()

## 8. Build the discriminator

In [None]:
class Discriminator(tf.keras.Model):

  def __init__(self):
    super(Discriminator, self).__init__()
    self.down1 = downsample(64, apply_norm=False)
    self.down2 = downsample(128)
    self.down3 = downsample(256)
    self.zero_pad1 = tf.keras.layers.ZeroPadding2D()
    self.conv = tf.keras.layers.Conv2D(filters = 512,
                                       kernel_size = 4,
                                       strides = 1,
                                       kernel_initializer = tf.random_normal_initializer(0., 0.02),
                                       use_bias = False)
    self.norm1 = tfa.layers.InstanceNormalization(gamma_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02))
    self.leaky_relu = tf.keras.layers.LeakyReLU()
    self.zero_pad2 = tf.keras.layers.ZeroPadding2D()
    self.last = tf.keras.layers.Conv2D(filters = 1,
                                       kernel_size = 4,
                                       strides = 1,
                                       kernel_initializer = tf.random_normal_initializer(0., 0.02))

  def call(self, inputs, training=False):
    x = inputs
    down_x = self.down3(self.down2(self.down1(x)))
    zero_pad1 = self.zero_pad1(down_x)
    conv = self.conv(zero_pad1)
    norm1 = self.norm1(conv)

    return self.last(self.zero_pad2(self.leaky_relu(norm1)))

In [None]:
discriminator = Discriminator()
discriminator.build((None, 256, 256, 3))
discriminator.summary()

## 9. Build the CycleGAN model

In [None]:
class CycleGAN(tf.keras.Model):
  
  def __init__(self, generator_monet, generator_photo, discrimator_monet, discriminator_photo):
    super(CycleGAN, self).__init__()
    self.generator_monet = generator_monet
    self.generator_photo = generator_photo
    self.discriminator_monet = discriminator_monet
    self.discriminator_photo = discriminator_photo


  def compile(self, generator_monet_optimizer, generator_photo_optimizer, discriminator_monet_optimizer, discriminator_photo_optimizer, discriminator_loss, adversarial_loss, calc_cycle_loss, identity_loss):
    super(CycleGAN, self).compile()
    self.generator_monet_optimizer = generator_monet_optimizer
    self.generator_photo_optimizer = generator_photo_optimizer
    self.discriminator_monet_optimizer = discriminator_monet_optimizer
    self.discriminator_photo_optimizer = discriminator_photo_optimizer
    self.discriminator_loss = discriminator_loss
    self.adversarial_loss = adversarial_loss
    self.calc_cycle_loss = cal_cycle_loss
    self.identity_loss = identity_loss

  def train_step(self, batch_data):
    monet, photo =  batch_data

    with tf.GradientTape(persistent=True) as tape:

      # monet generator
      fake_monet = self.generator_monet(photo)
      cycled_photo = self.generator_photo(fake_monet)

      # photo generator
      fake_photo = self.generator_photo(monet)
      cycled_monet = self.generator_monet(fake_photo)

      # monet discriminator 
      fake_monet_disc = self.discriminator_monet(fake_monet)
      real_monet_disc = self.discriminator_monet(monet)

      # photo discriminator 
      fake_photo_disc = self.discriminator_photo(fake_photo)
      real_photo_disc = self.discriminator_photo(photo)

      # generating itself are used for identity loss
      same_monet = self.generator_monet(monet)
      same_photo = self.generator_photo(photo)

      # calculate the loss
      gen_monet_adversarial_loss = self.adversarial_loss(fake_monet_disc)
      gen_photo_adversarial_loss = self.adversarial_loss(fake_photo_disc)

      # evaluates total cycle consistency loss
      total_cycle_loss = self.calc_cycle_loss(monet, same_monet) + self.calc_cycle_loss(photo, same_photo)

      # identity loss
      gen_monet_identity_loss = self.identity_loss(monet, same_monet)
      gen_photo_identity_loss = self.identity_loss(photo, same_photo)

      # Total_loss = adversarial_loss + cycle_loss + identity_loss
      total_gen_monet_loss = gen_monet_adversarial_loss + total_cycle_loss + gen_monet_identity_loss
      total_gen_photo_loss = gen_photo_adversarial_loss + total_cycle_loss + gen_photo_identity_loss

      # discriminator loss 
      disc_monet_loss = self.discriminator_loss(real_monet_disc, fake_monet_disc)
      disc_photo_loss = self.discriminator_loss(real_photo_disc, fake_photo_disc)

    # Calculate the gradients for generator and discriminator
    gen_monet_gradients = tape.gradient(total_gen_monet_loss,
                                        self.generator_monet.trainable_variables)
    gen_photo_gradients = tape.gradient(total_gen_photo_loss,
                                        self.generator_photo.trainable_variables)
    
    disc_monet_gradients = tape.gradient(disc_monet_loss,
                                         self.discriminator_monet.trainable_variables)
    disc_photo_gradients = tape.gradient(disc_photo_loss,
                                         self.discriminator_photo.trainable_variables)
    
    # Apply the gradients to the optimizer 
    self.generator_monet_optimizer.apply_gradients(zip(gen_monet_gradients,
                                                   self.generator_monet.trainable_variables))
    
    self.generator_photo_optimizer.apply_gradients(zip(gen_photo_gradients,
                                                   self.generator_photo.trainable_variables))
    
    self.discriminator_monet_optimizer.apply_gradients(zip(disc_monet_gradients,
                                                       self.discriminator_monet.trainable_variables))
    self.discriminator_photo_optimizer.apply_gradients(zip(disc_photo_gradients,
                                                       self.discriminator_photo.trainable_variables))
    
    return {
        "total_gen_monet_loss": total_gen_monet_loss,
        "total_gen_photo_loss": total_gen_photo_loss,
        "disc_monet_loss": disc_monet_loss,
        "disc_photo_loss": disc_photo_loss
    }

In [None]:
generator_monet = Generator()
generator_monet.build((None, 256, 256, 3))
generator_photo = Generator()
generator_photo.build((None, 256, 256, 3))
discriminator_monet = Discriminator()
discriminator_monet.build((None, 256, 256, 3))
discriminator_photo = Discriminator()
discriminator_photo.build((None, 256, 256, 3))

In [None]:
generator_monet_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_photo_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_monet_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_photo_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
cycle_gan_model = CycleGAN(generator_monet, generator_photo, discriminator_monet, discriminator_photo)

In [None]:
cycle_gan_model.compile(generator_monet_optimizer,
                        generator_photo_optimizer,
                        discriminator_monet_optimizer,
                        discriminator_photo_optimizer,
                        discriminator_loss,
                        generated_loss,
                        cal_cycle_loss,
                        identity_loss)

In [None]:
import time

t1 = time.process_time()

cycle_gan_model.fit(
    tf.data.Dataset.zip((train_photo, train_monet)),
    epochs=5
)

t2 = time.process_time()

print ("Accelerator =  ----- Computation time = " + str(1000*(t2 - t1)) + "ms")

## 10. Test the model

In [None]:
def generate_images(model, test_input):
  prediction = model(test_input)

  plt.figure(figsize=(12, 12))

  plt.subplot(1, 2, 1)
  plt.imshow(test_input[0] * 0.5 + 0.5)
  plt.title('Input Image')
  plt.axis('off')
    
  plt.subplot(1, 2, 2)
  plt.imshow(prediction[0] * 0.5 + 0.5)
  plt.title('Predicted Image')
  plt.axis('off')

  plt.show()