# CycleGAN - Style Transfer (Photos to Monet Paintings)


This Notebook intends to follow the development of the CycleGAN architecture to capturing special characteristics of Monet paintings and figuring out how these characteristics could be translated into the other image collection, all in the absence of any paired training examples.

This is a Kaggle competition to generate images in the style of Oscar-Claude Monet using generative adversarial networks (GANs). A Kaggle competition provides datasets contain Monet paintings "monet_tfrec", and real photos "photo_tfrec". The images are provided in TFRecord format as well as in JPEG format. In this competition, we are asked to use Oscar-Claude Monet paintings images to train our model and adding Monet-style to the real images, and submit the generated jpeg images as a zip file.

This problem is an Image-to-image translation and iThis Notebook intends to follow the development of the CycleGAN architecture to capturing special characteristics of Monet paintings and figuring out how these characteristics could be translated into the other image collection, all in the absence of any paired training examples.n order to tackle this problem, in general, we have two approaches Paired approach and an Unpaired approach. In Paired approach, it is necessary to have paired representations of the data in both domains. In In the Unpaired approach, both domains can be completely different as no resemblance between them such as in our case in this project (Converting real images to Monet's masterpiece). There are several methods that work on the principle of the unpaired approach, and one of them that performs very well and has shown impressive results is a CycleGAN.  



Sources:
* https://arxiv.org/pdf/1703.10593.pdf (paper)*

* https://www.kaggle.com/amyjang/monet-cyclegan-tutorial (baseline competition)

* https://www.kaggle.com/dimitreoliveira/introduction-to-cyclegan-monet-paintings

* https://junyanz.github.io/CycleGAN/

* https://hardikbansal.github.io/CycleGANBlog/

* https://www.tensorflow.org/tutorials/generative/cyclegan

* https://towardsdatascience.com/cyclegan-learning-to-translate-images-without-paired-training-data-5b4e93862c8d

# CycleGAN 
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks

It uses a generator and a discriminator,The generator has to generate images that are accepted by the discriminator, and the discriminator tries to discover the images that are not real and reject the images generated by the generator. CycleGAN uses a loss of cycle consistency to allow training without the need for paired data.



# Implementation

Below, we are setting up the input pipeline and importing all the dependencies.

It is necessary to enable the TPUs during the implementation, to carry out the implementation.

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt
import numpy as np

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

AUTO = tf.data.experimental.AUTOTUNE
print(tf.__version__)


Load the dataset

In [None]:
GCS_PATH_MONET = KaggleDatasets().get_gcs_path('monet-tfrecords-extdata')
GCS_PATH = KaggleDatasets().get_gcs_path('gan-getting-started')
GCS_PATH_MONET, GCS_PATH

In [None]:
#GCS_PATH_tareq = KaggleDatasets().get_gcs_path('tareq-picture')

In [None]:
import matplotlib.image as mpimg 
import matplotlib.pyplot as plt 
  
# Read Images 
# img = mpimg.imread() 

In [None]:
# To read external data (my picture )
# from PIL import Image  
  
# # Opens a image in RGB mode  
# im = Image.open('../input/tareq-picture/image_6487327.JPG') 

In [None]:
# To resize my picture and make it (256,256,3)
# im2 = im.resize((256,256))
# im2 = np.asarray(im2)
# im2 = im2[:,:, 0:3]

# plt.imshow(im2)
# plt.show()

# print(im2.shape)

Reading the data

In [None]:
import re
MONET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))
print('Monet TFRecord Files:', len(MONET_FILENAMES))

PHOTO_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/photo_tfrec/*.tfrec'))
print('Photo TFRecord Files:', len(PHOTO_FILENAMES))

def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

n_monet_samples = count_data_items(MONET_FILENAMES)
n_photo_samples = count_data_items(PHOTO_FILENAMES)

BATCH_SIZE =  4
EPOCHS_NUM = 30

print(f'Monet TFRecord files: {len(MONET_FILENAMES)}')
print(f'Monet image files: {n_monet_samples}')
print(f'Photo TFRecord files: {len(PHOTO_FILENAMES)}')
print(f'Photo image files: {n_photo_samples}')
print(f"Batch_size: {BATCH_SIZE}")
print(f"Epochs number: {EPOCHS_NUM}")

IMAGE PRE-PROCESSING

* Resizing image
* Normalizing the images to [-1, 1]

In [None]:
IMAGE_SIZE = [256, 256]

def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_tfrecord(example):
    tfrecord_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

Random jittering and mirroring to the training dataset. These are some of the image augmentation techniques that avoids overfitting, Random jittering performs:

* Resize an image to bigger height and width
* Randomly crop to the target size
* Randomly rotate the image

In [None]:
def data_augment(image):
    p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    if p_crop > .5:
        image = tf.image.resize(image, [286, 286]) #resizing to 286 x 286 x 3
        image = tf.image.random_crop(image, size=[256, 256, 3]) # randomly cropping to 256 x 256 x 3
        if p_crop > .9:
            image = tf.image.resize(image, [300, 300])
            image = tf.image.random_crop(image, size=[256, 256, 3])
    
    if p_rotate > .9:
        image = tf.image.rot90(image, k=3) # rotate 270ยบ
    elif p_rotate > .7:
        image = tf.image.rot90(image, k=2) # rotate 180ยบ
    elif p_rotate > .5:
        image = tf.image.rot90(image, k=1) # rotate 90ยบ
        
        ## random mirroring
    if p_spatial > .6:
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_flip_up_down(image)
        if p_spatial > .9:
            image = tf.image.transpose(image)
    
    return image

def load_dataset(filenames):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTO)
    return dataset

def get_gan_dataset(monet_files, photo_files, augment=None, repeat=True, shuffle=True, batch_size=1):

    monet_ds = load_dataset(monet_files)
    photo_ds = load_dataset(photo_files)
    
    if augment:
        monet_ds = monet_ds.map(augment, num_parallel_calls=AUTO)
        photo_ds = photo_ds.map(augment, num_parallel_calls=AUTO)
        
    if repeat:
        monet_ds = monet_ds.repeat()
        photo_ds = photo_ds.repeat()
    if shuffle:
        monet_ds = monet_ds.shuffle(2048)
        photo_ds = photo_ds.shuffle(2048)
        
    monet_ds = monet_ds.batch(batch_size, drop_remainder=True)
    photo_ds = photo_ds.batch(batch_size, drop_remainder=True)
    monet_ds = monet_ds.cache()
    photo_ds = photo_ds.cache()
    monet_ds = monet_ds.prefetch(AUTO)
    photo_ds = photo_ds.prefetch(AUTO)
    
    gan_ds = tf.data.Dataset.zip((monet_ds, photo_ds))
    
    return gan_ds

Adding the image preprocessing to the pipeline.

In [None]:
full_dataset = get_gan_dataset(MONET_FILENAMES, PHOTO_FILENAMES, augment=data_augment, repeat=True, shuffle=True, batch_size=BATCH_SIZE)


Visualizing the data and checking that the upload is successful.

In [None]:
example_monet , example_photo = next(iter(full_dataset))

In [None]:
# Visualizing the real photo
plt.subplot(121)
plt.title('Real photo')
plt.imshow(example_photo[2] * 0.5 + 0.5)

# Visualizing the Monet painting
plt.subplot(122)
plt.title('Monet painting')
plt.imshow(example_monet[2]* 0.5 + 0.5)

MODEL

To build the model, we will follow the following steps:

* Build the Generator
* Build the Discriminador
* loss functions

> Discriminator loss

> Generator loss

> Adversary loss

> Cycle loss

> Identity loss

* Define the optimizers

Build the Generator:
The architecture of the generator is a modified U-Net, consisting of an encoder block and a decoder block, each of them is made up of simpler blocks of layers:
Each block of the encoder, we call it downsample-k where k denotes the number of filters, consisting of the following layers:
* Convolution
* Instance Normalization (not apply to the first block)
* Leaky ReLU

In [None]:
OUTPUT_CHANNELS = 3

def downsample(filters, size, apply_instancenorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

     
    result = keras.Sequential()
    # Convolutional layer
    result.add(layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))
 # Normalization layer
    if apply_instancenorm:
        result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))
 # Activation layer
    result.add(layers.LeakyReLU())

    return result

The decoder(upsampling) is made up of:
* Transposed Convolution
* Instance Normalization
* Dropout (applied to the first 3 blocks)
* ReLU

Skip connections exist between encoder and decoder.

In [None]:
def upsample(filters, size, apply_dropout=False):
     # Normalization layer
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

     # Transpose convolutional layer
    result = keras.Sequential()
    result.add(layers.Conv2DTranspose(filters, size, strides=2,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))
#Instance Normalization
    result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))
# Dropout layer
    if apply_dropout:
        result.add(layers.Dropout(0.5))
# Activation layer
    result.add(layers.ReLU())

    return result

Generator:
* The execution of a downsample-k block
* Then, the execution of an upsampleD-k block (with Dropout)
* Then, the execution of an upsample-k block (without Dropout)
* The last thing is the Transposed Convolution of 3 filters to convert the output into a 256X256 image by 3 channels


All convolutional layers of downsample, have the parameter strides = 2, which causes the dimensions to be reduced by half, likewise, the Transposed Convolution layers of upsample also have the parameter strides = 2 so the dimensions are doubled. In the first two dimensions, not counting the Batch size dimension.

In [None]:
def Generator():
    inputs = layers.Input(shape=[256,256,3])

    # bs = batch size
    down_stack = [
        downsample(64, 4, apply_instancenorm=False), # (bs, 128, 128, 64)
        downsample(128, 4), # (bs, 64, 64, 128)
        downsample(256, 4), # (bs, 32, 32, 256)
        downsample(512, 4), # (bs, 16, 16, 512)
        downsample(512, 4), # (bs, 8, 8, 512)
        downsample(512, 4), # (bs, 4, 4, 512)
        downsample(512, 4), # (bs, 2, 2, 512)
        downsample(512, 4), # (bs, 1, 1, 512)
    ]

    up_stack = [
        upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
        upsample(512, 4), # (bs, 16, 16, 1024)
        upsample(256, 4), # (bs, 32, 32, 512)
        upsample(128, 4), # (bs, 64, 64, 256)
        upsample(64, 4), # (bs, 128, 128, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                  strides=2,
                                  padding='same',
                                  kernel_initializer=initializer,
                                  activation='tanh') # (bs, 256, 256, 3)

    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip])

    x = last(x)

    return keras.Model(inputs=inputs, outputs=x)

Testing the model and show the architecture of generator 

In [None]:
generator_g = Generator()
tf.keras.utils.plot_model(generator_g, show_shapes=True, dpi=64)

In [None]:
photo = (example_photo[0,...] * 0.5 + 0.5)
plt.imshow(photo, vmin=0, vmax=255) 

In [None]:
example_gen_output_y = generator_g(photo[tf.newaxis,...], training=False)
plt.imshow(example_gen_output_y[0]) 

In [None]:
# We pass the denormalized photo so that some result can be seen, since the model is not trained
#photo = example_photo[0,...] (example_photo[2] * 0.5 + 0.5)
photo = (example_photo[0,...] * 0.5 + 0.5)
example_gen_output_y = generator_g(photo[tf.newaxis,...], training=False)

plt.subplot(1,2,1)
plt.imshow(photo, vmin=0, vmax=255) 

plt.subplot(1,2,2)
plt.imshow(example_gen_output_y[0]) 

plt.show()

Build the Discriminator

The task of the discriminator is whether an input image, which is (the output of a generator), is original or fake!

It can be seen that the architecture of the discriminator is a convolution network of the PatchGAN type, instead of returning whether the image is real or not, this architecture returns whether pieces of the image can be considered real or false. 
As we mentioned in the generator, the encoder is made up of downsample-k blocks, the block performs an image compression operation (downsample). It consists of the following layers:
* Convolution
* Instance Normalization (not apply to the first block)
* Leaky ReLU

All convolutional layers of downsampling, have the parameter strides = 2, which causes the dimensions to be reduced by half.
The shape of the discriminator output layer is (batch_size, 30, 30, 1), each 30x30 patch of the output sorts a 70x70 portion of the input image

In [None]:
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    inp = layers.Input(shape=[256, 256, 3], name='input_image')

    x = inp

    down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
    down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
    down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)

    zero_pad1 = layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
    conv = layers.Conv2D(512, 4, strides=1,
                         kernel_initializer=initializer,
                         use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

    norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(conv)

    leaky_relu = layers.LeakyReLU()(norm1)

    zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    last = layers.Conv2D(1, 4, strides=1,
                         kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

    return tf.keras.Model(inputs=inp, outputs=last)

Testing the model and show the architecture of the discriminator



In [None]:
discriminator_y = Discriminator()
tf.keras.utils.plot_model(discriminator_y, show_shapes=True, dpi=64)

In [None]:
photo = example_photo[0,...]* 0.5 + 0.5
photo_1=photo 
plt.imshow(photo_1, vmin=0, vmax=255) 

In [None]:

example_gen_output_y = generator_g(photo_1[tf.newaxis,...], training=False)
plt.imshow(example_gen_output_y[0,...])

In [None]:
# example_disc_out = discriminator_y([example_photo, example_gen_output_y], training=False)

In [None]:
# example_disc_out = discriminator_y([example_photo, example_gen_output_y], training=False)
# m = example_disc_out[0,...,-1].numpy()*1000
# im = plt.imshow(m, vmin=-20, vmax=20, cmap='RdBu_r')
# plt.colorbar(im,fraction=0.046, pad=0.04)

In [None]:
# print(example_disc_out.shape)

In [None]:
# # We pass the denormalized photo so that some result can be seen, since the model is not trained
# photo = example_photo[0,...]* 0.5 + 0.5
# example_gen_output_y = generator_g(photo[tf.newaxis,...], training=False)
# example_disc_out = discriminator_y([example_photo, example_gen_output_y], training=False)

# print(example_disc_out.shape)

# plt.figure(figsize=(10,10))

# plt.subplot(1,3,1)
# plt.imshow(photo, vmin=0, vmax=255) 

# plt.subplot(1,3,2)
# plt.imshow(example_gen_output_y[0,...]) 

# plt.subplot(1,3,3)
# m = example_disc_out[0,...,-1].numpy()*1000
# im = plt.imshow(m, vmin=-20, vmax=20, cmap='RdBu_r')
# plt.colorbar(im,fraction=0.046, pad=0.04)

# plt.show()

Build the complete Model
The complete model consists of:  
* 2 generators (G_MONET and G_PHOTO)
The G_MONET generator learns how to transform a photograph into a Monet painting.

The G_PHOTO generator learns how to transform a Monet painting into a photograph.

As the images are not paired, it is necessary to use two cycles
Input Photo -> G_MONET -> Fake Monet -> G_PHOTO -> Cycle Photo
Input Monet -> G_PHOTO -> Fake Photo -> G_MONET -> Cycle Monet  

* 2 discriminators (D_MONET AND D_PHOTO)
The D_MONET discriminator learns to differentiate if a Monet painting is real or fake, serving to calculate the adversary loss and improve the G_MONET generator.

The D_PHOTO discriminator learns to differentiate if a photo is real or fake, serving to calculate the adversary loss and improve the G_PHOTO generator.


In [None]:
with strategy.scope():
    monet_generator = Generator() # transforms photos to Monet-esque paintings
    photo_generator = Generator() # transforms Monet paintings to be more like photos

    monet_discriminator = Discriminator() # differentiates real Monet paintings and generated Monet paintings
    photo_discriminator = Discriminator() # differentiates real photos and generated photos


In [None]:
to_monet = monet_generator(example_photo)

plt.subplot(1, 2, 1)
plt.title("Original Photo")
plt.imshow(example_photo[0] * 0.5 + 0.5)

plt.subplot(1, 2, 2)
plt.title("Monet-esque Photo")
plt.imshow(to_monet[0] * 0.5 + 0.5)
plt.show()

In [None]:
# to_photo = photo_generator(example_monet)

# plt.subplot(1, 2, 1)
# plt.title("Original Monet")
# plt.imshow(example_monet[0] * 0.5 + 0.5)

# plt.subplot(1, 2, 2)
# plt.title("Monet-esque Photo")
# plt.imshow(to_photo [0] * 0.5 + 0.5)
# plt.show()

In [None]:
#example_disc_out = discriminator_y([example_photo, example_gen_output_y], training=False)

In [None]:
# plt.subplot(1,3,3)
# m = example_disc_out[0,...,-1].numpy()*1000
# im = plt.imshow(m, vmin=-20, vmax=20, cmap='RdBu_r')
# plt.colorbar(im,fraction=0.046, pad=0.04)

Example

Generator part:
* Starting with the photo, a simulation of a Monet painting is generated and then from this simulation, an attempt is made to generate the original photo

* Starting from the Monet, a photo simulation is generated and then from this simulation, an attempt is made to generate the original Monet

In [None]:
photo = (example_photo[0,...] * 0.5 + 0.5)
monet = (example_monet[0,...] * 0.5 + 0.5)

# From photo we generate Monet (fake) and regenerate the photo (cycle) again
example_gen_output_monet_fake = monet_generator(photo[tf.newaxis,...], training=False)
example_gen_output_photo_cycle =  photo_generator(example_gen_output_monet_fake, training=False)

# We run the discriminator for Monet (fake)
example_disc_out_monet = monet_discriminator(example_gen_output_monet_fake, training=False)


# From Monet we generate photo (fake) and regenerate Monet (cycle) again
example_gen_output_photo_fake =  photo_generator(monet[tf.newaxis,...], training=False)
example_gen_output_monet_cycle = monet_generator(example_gen_output_photo_fake, training=False)

# We execute the discriminator for Photo (fake)
example_disc_out_photo = photo_discriminator(example_gen_output_photo_fake, training=False)


# We present results, as the network is not trained, the outputs are not good, 
# but we modify the scala to be able to have some example images

plt.figure(figsize=(10,10))

# Input Photo
plt.subplot(2,4,1)
plt.imshow(photo, vmin=0, vmax=255) 

# Fake Monet
plt.subplot(2,4,2)
#m = example_gen_output_monet_fake[0,...].numpy()
#print(np.min(m), np.max(m))
contrast = 100 
plt.imshow(example_gen_output_monet_fake[0,...]*contrast) 

# Photo Cycle
plt.subplot(2,4,3)
#m = example_gen_output_photo_cycle[0,...].numpy()
#print(np.min(m), np.max(m))
contrast = 100
plt.imshow(example_gen_output_photo_cycle[0,...]*contrast) 

# Monet discriminator result
#plt.subplot(2,4,4)
#m = example_disc_out_monet[0,...,-1].numpy()
#print(np.min(m), np.max(m))
#contrast = 1000
#im = plt.imshow(m*contrast, vmin=-20, vmax=20, cmap='RdBu_r')
#plt.colorbar(im,fraction=0.046, pad=0.04)



# Input Monet
plt.subplot(2,4,5)
plt.imshow(monet, vmin=0, vmax=255) 

# Fake Photo
plt.subplot(2,4,6)
#m = example_gen_output_photo_fake[0,...].numpy()
#print(np.min(m), np.max(m))
contrast = 100
plt.imshow(example_gen_output_photo_fake[0,...]*contrast) 

# Monet Cycle
plt.subplot(2,4,7)
#m = example_gen_output_monet_cycle[0,...].numpy()
#print(np.min(m), np.max(m))
contrast = 100
plt.imshow(example_gen_output_monet_cycle[0,...]*contrast) 

# Photo discriminator result  
#plt.subplot(2,4,8)
#m = example_disc_out_photo[0,...,-1].numpy()
#print(np.min(m), np.max(m))
#contrast = 1000
#im = plt.imshow(contrast, vmin=-20, vmax=20, cmap='RdBu_r')
#plt.colorbar(im,fraction=0.046, pad=0.04)

In [None]:
class CycleGan(keras.Model):
    def __init__(
        self,
        monet_generator,
        photo_generator,
        monet_discriminator,
        photo_discriminator,
        lambda_cycle=15,
    ):
        super(CycleGan, self).__init__()
        self.m_gen = monet_generator
        self.p_gen = photo_generator
        self.m_disc = monet_discriminator
        self.p_disc = photo_discriminator
        self.lambda_cycle = lambda_cycle
        
    def compile(
        self,
        m_gen_optimizer,
        p_gen_optimizer,
        m_disc_optimizer,
        p_disc_optimizer,
        gen_loss_fn,
        disc_loss_fn,
        cycle_loss_fn,
        identity_loss_fn
    ):
        super(CycleGan, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
    def train_step(self, batch_data):
        real_monet, real_photo = batch_data
        
        with tf.GradientTape(persistent=True) as tape:
            # photo to monet back to photo
            fake_monet = self.m_gen(real_photo, training=True)
            cycled_photo = self.p_gen(fake_monet, training=True)

            # monet to photo back to monet
            fake_photo = self.p_gen(real_monet, training=True)
            cycled_monet = self.m_gen(fake_photo, training=True)

            # generating itself
            same_monet = self.m_gen(real_monet, training=True)
            same_photo = self.p_gen(real_photo, training=True)

            # discriminator used to check, inputing real images
            disc_real_monet = self.m_disc(real_monet, training=True)
            disc_real_photo = self.p_disc(real_photo, training=True)

            # discriminator used to check, inputing fake images
            disc_fake_monet = self.m_disc(fake_monet, training=True)
            disc_fake_photo = self.p_disc(fake_photo, training=True)

            # evaluates generator loss
            monet_gen_loss = self.gen_loss_fn(disc_fake_monet)
            photo_gen_loss = self.gen_loss_fn(disc_fake_photo)

            # evaluates total cycle consistency loss
            total_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle) + self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle)

            # evaluates total generator loss
            total_monet_gen_loss = monet_gen_loss + total_cycle_loss + self.identity_loss_fn(real_monet, same_monet, self.lambda_cycle)
            total_photo_gen_loss = photo_gen_loss + total_cycle_loss + self.identity_loss_fn(real_photo, same_photo, self.lambda_cycle)

            # evaluates discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)

        # Calculate the gradients for generator and discriminator
        monet_generator_gradients = tape.gradient(total_monet_gen_loss,
                                                  self.m_gen.trainable_variables)
        photo_generator_gradients = tape.gradient(total_photo_gen_loss,
                                                  self.p_gen.trainable_variables)

        monet_discriminator_gradients = tape.gradient(monet_disc_loss,
                                                      self.m_disc.trainable_variables)
        photo_discriminator_gradients = tape.gradient(photo_disc_loss,
                                                      self.p_disc.trainable_variables)

        # Apply the gradients to the optimizer
        self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients,
                                                 self.m_gen.trainable_variables))

        self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients,
                                                 self.p_gen.trainable_variables))

        self.m_disc_optimizer.apply_gradients(zip(monet_discriminator_gradients,
                                                  self.m_disc.trainable_variables))

        self.p_disc_optimizer.apply_gradients(zip(photo_discriminator_gradients,
                                                  self.p_disc.trainable_variables))
        
        return {
            "monet_gen_loss": total_monet_gen_loss,
            "photo_gen_loss": total_photo_gen_loss,
            "monet_disc_loss": monet_disc_loss,
            "photo_disc_loss": photo_disc_loss
        }


Sigmoid cross entropy is used to calculate the adversary losses in the discriminator and generator.

**Discriminator loss**
The discriminator loss function takes 2 inputs:

For the discriminator_monet will take as input:
The output of the discriminator_monet whose input is the real Monet of the training set
The output of the discriminator_monet whose input is the fake Monet generated by the generator_monet

For the discriminator_photo will take as input:
The output of the discriminator_photo whose input is the real photo of the training set
The output of the discriminator_photo whose input is the fake photo generated by the generator_photo

The calculation of the loss has two components:
real_loss compare the real image with a matrix of 1. (Real)
generate_loss compare the fake image with a matrix of 0 (Fake)
So the total_loss is the sum of the real_loss and the generate_loss times 0.5

In [None]:
with strategy.scope():
    def discriminator_loss(real, generated):
        real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)

        generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)

        total_disc_loss = real_loss + generated_loss

        return total_disc_loss * 0.4
    


**Generator loss**
The generator loss has 3 terms:

* Adversary loss
* Cycle loss
* Identity loss  

**Generator adversary loss**

The output of the discriminator will be as input 
* For the loss of the generator_monet, the function will take the output of the discriminator_monet executed with fake Monet
* For the loss of the generator_photo, the function will take the output of the discriminator_photo executed with a fake photo

The perfect generator will have the discriminator output only ones(REAL) Therefore, compare the generated image with a matrix of 1 to find the loss.

In [None]:
with strategy.scope():
    def generator_loss(generated):
        return tf.keras.losses.BinaryCrossentropy(from_logits=True,
                                                  reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(generated), generated)

**Generator cycle consistency loss**
To make the network learn the correct mapping and the result is similar to the original input.


Input Photo -> G_MONET -> Fake Monet -> G_PHOTO -> Cycle Photo
Generate a Monet style image from a photo, this generated image is passed as input to the second generator, which should generate a photo from a fake Monet style image.

Input Monet -> G_PHOTO -> Fake Photo -> G_MONET -> Cycle Monet
On the other hand, an image is generated that aims to imitate a real photo from a Monet painting, this generated image is passed as an input to the second generator, which should generate the Monet painting again from the fake photo.


To calculate the cycle consistency loss:
The average absolute error for a photo is calculated between Input Photo and Cycle Photo
The mean absolute error for Monet is calculated between Input Monet and Cycle Monet
The cycle error will be the sum of both terms.



In [None]:
with strategy.scope():
    def calc_cycle_loss(real_image, cycled_image, LAMBDA):
        loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

        return LAMBDA * loss1

**Identity loss**
The loss of identity forces what the generator generates to resemble the input.

For the loss of the generator_monet, the function will take the training monet image and the output of the generator_monet with the same input (same_monet)

For the loss of the generator_photo, the function will take the training photo image and the output of the generator_photo with the same input (same_photo)

The loss will be the mean absolute error between the real image and the generated one.

In [None]:
with strategy.scope():
    def identity_loss(real_image, same_image, LAMBDA):
        loss = tf.reduce_mean(tf.abs(real_image - same_image))
        return LAMBDA * 0.5 * loss

Initializing the optimizers for all the generators and the discriminators.

In [None]:
with strategy.scope():
    monet_generator_optimizer = tf.keras.optimizers.Adam(0.002, decay = 0, beta_1=0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam(0.002, beta_1=0.5)

    monet_discriminator_optimizer = tf.keras.optimizers.Adam(0.002, beta_1=0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam(0.002, beta_1=0.5)

In [None]:
with strategy.scope():
    cycle_gan_model = CycleGan(
        monet_generator, photo_generator, monet_discriminator, photo_discriminator
    )

    cycle_gan_model.compile(
        m_gen_optimizer = monet_generator_optimizer,
        p_gen_optimizer = photo_generator_optimizer,
        m_disc_optimizer = monet_discriminator_optimizer,
        p_disc_optimizer = photo_discriminator_optimizer,
        gen_loss_fn = generator_loss,
        disc_loss_fn = discriminator_loss,
        cycle_loss_fn = calc_cycle_loss,
        identity_loss_fn = identity_loss
    )


In [None]:
cycle_gan_model.fit(
    full_dataset,
    epochs=EPOCHS_NUM,
    steps_per_epoch=(max(n_monet_samples, n_photo_samples)//BATCH_SIZE),
)


In [None]:
example_gen_output_monet_fake = monet_generator(example_photo, training=False)
example_gen_output_photo_cycle = photo_generator(example_gen_output_monet_fake, training=False)
example_gen_output_photo_same = photo_generator(example_photo, training=False)


# We execute the discriminator for Photo (real)
example_disc_out_photo_real = photo_discriminator(example_photo, training=False)

# We run the discriminator for Monet (fake)
example_disc_out_monet_fake = monet_discriminator(example_gen_output_monet_fake, training=False)


# from Monet we generate photo (fake) and regenerate Monet (cycle) again
example_gen_output_photo_fake = photo_generator(example_monet, training=False)
example_gen_output_monet_cycle = monet_generator(example_gen_output_photo_fake, training=False)
example_gen_output_monet_same = monet_generator(example_monet, training=False)

# We run the discriminator for Monet (real)
example_disc_out_monet_real = monet_discriminator(example_monet, training=False)

# We execute the discriminator for Photo (fake)
example_disc_out_photo_fake = photo_discriminator(example_gen_output_photo_fake, training=False)

In [None]:
example_photo.shape

In [None]:
plt.figure(figsize=(10,10))

# photo
plt.subplot(4,4,1)
plt.imshow(example_photo[0] * 0.5 + 0.5) 

# Monet 
plt.subplot(4,4,2)
plt.imshow(example_gen_output_monet_fake[0] * 0.5 + 0.5) 

# Photo Cycle
plt.subplot(4,4,3)
plt.imshow(example_gen_output_photo_cycle[0] * 0.5 + 0.5) 

# Photo Same
plt.subplot(4,4,4)
plt.imshow(example_gen_output_photo_same[0] * 0.5 + 0.5) 

# Discriminador Photo (real)
plt.subplot(4,4,5)
m = example_disc_out_photo_real[0,...,-1].numpy()
plt.imshow(m, vmin=np.min(m), vmax=np.max(m), cmap='RdBu_r')
#plt.colorbar(im,fraction=0.046, pad=0.04)

# Discriminador Monet (fake)
plt.subplot(4,4,7)
m = example_disc_out_monet_fake[0,...,-1].numpy()
plt.imshow(m, vmin=np.min(m), vmax=np.max(m), cmap='RdBu_r')
#plt.colorbar(im,fraction=0.046, pad=0.04)

# Monet 
plt.subplot(4,4,9)
plt.imshow(example_monet[0] * 0.5 + 0.5) 

# Foto generado
plt.subplot(4,4,10)
plt.imshow(example_gen_output_photo_fake[0] * 0.5 + 0.5) 

# Monet Cycle
plt.subplot(4,4,11)
plt.imshow(example_gen_output_monet_cycle[0] * 0.5 + 0.5) 

# Monet Same
plt.subplot(4,4,12)
plt.imshow(example_gen_output_monet_same[0] * 0.5 + 0.5) 

# Discriminador Monet (real)
plt.subplot(4,4,13)
m = example_disc_out_monet_real[0,...,-1].numpy()
plt.imshow(m, vmin=np.min(m), vmax=np.max(m), cmap='RdBu_r')
#plt.colorbar(im,fraction=0.046, pad=0.04)

# Discriminador photo(fake)
plt.subplot(4,4,15)
m = example_disc_out_photo_fake[0,...,-1].numpy()
plt.imshow(m, vmin=np.min(m), vmax=np.max(m), cmap='RdBu_r')
#plt.colorbar(im,fraction=0.046, pad=0.04)


plt.show()

In [None]:
# AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
# _, ax = plt.subplots(2, 2, figsize=(12, 12))
# for i, img in enumerate(example_photo.take(2)):
#     prediction = monet_generator(img, training=False)[0].numpy()
#     prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
#     img = (img[1] * 127.5 + 127.5).numpy().astype(np.uint8)

#     ax[i, 0].imshow(img)
#     ax[i, 1].imshow(prediction)
#     ax[i, 0].set_title("Input Photo")
#     ax[i, 1].set_title("Monet-esque")
#     ax[i, 0].axis("off")
#     ax[i, 1].axis("off")
# plt.show()

In [None]:
import PIL
def predict_and_save(input_ds, generator_model, output_path):
    i = 1
    for img in input_ds:
        prediction = generator_model(img, training=False)[0].numpy() # make predition
        prediction = (prediction * 127.5 + 127.5).astype(np.uint8)   # re-scale
        im = PIL.Image.fromarray(prediction)
        im.save(f'{output_path}{str(i)}.jpg')
        i += 1

In [None]:
# im2 = im2.reshape(1,256,256,3)

In [None]:
# scaled = (im2.astype(np.float32) - 127.5) / 127.5
# scaled = (im2.astype(np.float32) / 127.5) - 1

In [None]:
# plt.imshow(scaled[0,:,:,:])

In [None]:
# prediction = monet_generator(scaled, training=False)[0].numpy() # make predition
# prediction = (prediction * 127.5 + 127.5).astype(np.uint8)   # re-scale
# im = PIL.Image.fromarray(prediction)

In [None]:
# im

In [None]:
# for img in load_dataset(PHOTO_FILENAMES).batch(1):
#     img = img
#     break;

In [None]:
# s = np.asarray(img[0])

In [None]:
# s.shape

In [None]:
# plt.imshow(s)

In [None]:
# s = s.reshape(1,256,256,3)

In [None]:
# prediction = monet_generator(s, training=False)[0].numpy() # make predition
# prediction = (prediction * 127.5 + 127.5).astype(np.uint8)   # re-scale
# im = PIL.Image.fromarray(prediction)

In [None]:
# im

In [None]:
# im = PIL.Image.fromarray(img)

In [None]:
import os
os.makedirs('../images/') # Create folder to save generated images

predict_and_save(load_dataset(PHOTO_FILENAMES).batch(1), monet_generator, '../images/')


In [None]:
import shutil
shutil.make_archive('/kaggle/working/images/', 'zip', '../images')

print(f"Generated samples: {len([name for name in os.listdir('../images/') if os.path.isfile(os.path.join('../images/', name))])}")