# **Competetion Objective**

**Using GANs for generating Monet Style Art**

The Task for generating Monet style Art Falls in the domain of **Image to Image Translation**

Image-to-Image Translation is a framework of conditional generation that transforms images into different styles.
Taking in an image and transforming it to get a different image of a different style, but maintaining the content of that Image is what we try to achieve when working with Image to Image Translation. 

Because GANs are really good at realistic generation, they are really well-suited for this image-to-image translation task. 



## **Unpaired Image to Image Translation**

Since we have a limited number of Images, 300 monet style paintings without any corresponding pair with real Image, Unpaired Image to Image Translation is where out task categorises to.

**Unpaired image to image translation** is an Unsupervised method, uses piles of different styled images instead of paired images.
The model learns that mapping between those two piles by ***keeping the contents*** that are present in both, while ***changing the style*** which is different or unique to each of those piles. 

## Dependencies 

In [None]:
import os, random, json, PIL, shutil, re
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from kaggle_datasets import KaggleDatasets
import tensorflow as tf
import tensorflow.keras.layers as L
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import Model, losses, optimizers

## Configuring TPU

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print(f'Running on TPU {tpu.master()}')
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()


REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')
AUTO = tf.data.experimental.AUTOTUNE


print(tf.__version__)

## Model Parameters

In [None]:
HEIGHT = 256
WIDTH = 256
CHANNELS = 3
EPOCHS = 50
BATCH_SIZE = 1

## Load in the Data

In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path('gan-getting-started')

MONET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))
PHOTO_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/photo_tfrec/*.tfrec'))

def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

n_monet_samples = count_data_items(MONET_FILENAMES)
n_photo_samples = count_data_items(PHOTO_FILENAMES)

print(f'Monet TFRecord files: {len(MONET_FILENAMES)}')
print(f'Monet image files: {n_monet_samples}')
print(f'Photo TFRecord files: {len(PHOTO_FILENAMES)}')
print(f'Photo image files: {n_photo_samples}')

## Loading TFRecord Dataset and Visualization functions

In [None]:
def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=CHANNELS)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [HEIGHT, WIDTH, CHANNELS])
    return image

def read_tfrecord(example):
    tfrecord_format = {
        'image_name': tf.io.FixedLenFeature([], tf.string),
        'image':      tf.io.FixedLenFeature([], tf.string),
        'target':     tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

def load_dataset(filenames):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTO)
    return dataset

def get_gan_dataset(monet_files, photo_files, augment=None, repeat=True, shuffle=True, batch_size=1):

    monet_ds = load_dataset(monet_files)
    photo_ds = load_dataset(photo_files)

    if repeat:
        monet_ds = monet_ds.repeat()
        photo_ds = photo_ds.repeat()
    if shuffle:
        monet_ds = monet_ds.shuffle(2048)
        photo_ds = photo_ds.shuffle(2048)
        
    monet_ds = monet_ds.batch(batch_size, drop_remainder=True)
    photo_ds = photo_ds.batch(batch_size, drop_remainder=True)
    monet_ds = monet_ds.cache()
    photo_ds = photo_ds.cache()
    monet_ds = monet_ds.prefetch(AUTO)
    photo_ds = photo_ds.prefetch(AUTO)
    
    gan_ds = tf.data.Dataset.zip((monet_ds, photo_ds))
    
    return gan_ds

def display_samples(ds, row, col):
    ds_iter = iter(ds)
    plt.figure(figsize=(15, int(15*row/col)))
    for j in range(row*col):
        example_sample = next(ds_iter)
        plt.subplot(row,col,j+1)
        plt.axis('off')
        plt.imshow(example_sample[0] * 0.5 + 0.5)
    plt.show()
        
def display_generated_samples(ds, model, n_samples):
    ds_iter = iter(ds)
    for n_sample in range(n_samples):
        example_sample = next(ds_iter)
        generated_sample = model.predict(example_sample)
        
        plt.subplot(121)
        plt.title("input image")
        plt.imshow(example_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        
        plt.subplot(122)
        plt.title("Generated image")
        plt.imshow(generated_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        plt.show()
        

In [None]:
display_samples(load_dataset(MONET_FILENAMES).batch(1), 4, 6)

In [None]:
display_samples(load_dataset(PHOTO_FILENAMES).batch(1), 4, 6)

In [None]:
monet_ds = load_dataset(MONET_FILENAMES).batch(1)
photo_ds = load_dataset(PHOTO_FILENAMES).batch(1)


fast_photo_ds = load_dataset(PHOTO_FILENAMES).batch(32*strategy.num_replicas_in_sync).prefetch(32)
fid_photo_ds = load_dataset(PHOTO_FILENAMES).take(1024).batch(32*strategy.num_replicas_in_sync).prefetch(32)
fid_monet_ds = load_dataset(MONET_FILENAMES).batch(32*strategy.num_replicas_in_sync).prefetch(32)

In [None]:
with strategy.scope():

    inception_model = tf.keras.applications.InceptionV3(input_shape=(256,256,3),pooling="avg",include_top=False)

    mix3  = inception_model.get_layer("mixed9").output
    f0 = tf.keras.layers.GlobalAveragePooling2D()(mix3)

    inception_model = tf.keras.Model(inputs=inception_model.input, outputs=f0)
    inception_model.trainable = False

    
    
    def calculate_activation_statistics_mod(images,fid_model):

            act=tf.cast(fid_model.predict(images), tf.float32)

            mu = tf.reduce_mean(act, axis=0)
            mean_x = tf.reduce_mean(act, axis=0, keepdims=True)
            mx = tf.matmul(tf.transpose(mean_x), mean_x)
            vx = tf.matmul(tf.transpose(act), act)/tf.cast(tf.shape(act)[0], tf.float32)
            sigma = vx - mx
            return mu, sigma
    myFID_mu2, myFID_sigma2 = calculate_activation_statistics_mod(fid_monet_ds,inception_model)        
    fids=[]

In [None]:
with strategy.scope():
    def calculate_frechet_distance(mu1,sigma1,mu2,sigma2):
        fid_epsilon = 1e-14
       
        covmean = tf.linalg.sqrtm(tf.cast(tf.matmul(sigma1,sigma2),tf.complex64))
#         isgood=tf.cast(tf.math.is_finite(covmean), tf.int32)
#         if tf.size(isgood)!=tf.math.reduce_sum(isgood):
#             return 0

        covmean = tf.cast(tf.math.real(covmean),tf.float32)
  
        tr_covmean = tf.linalg.trace(covmean)


        return tf.matmul(tf.expand_dims(mu1 - mu2, axis=0),tf.expand_dims(mu1 - mu2, axis=1)) + tf.linalg.trace(sigma1) + tf.linalg.trace(sigma2) - 2 * tr_covmean


    
    
    def FID(images,gen_model,inception_model=inception_model,myFID_mu2=myFID_mu2, myFID_sigma2=myFID_sigma2):
                inp = keras.Input(shape=[256, 256, 3], name='input_image')
                x  = gen_model(inp)
                x=inception_model(x)
                fid_model = tf.keras.Model(inputs=inp, outputs=x)
                
                mu1, sigma1= calculate_activation_statistics_mod(images,fid_model)

                fid_value = calculate_frechet_distance(mu1, sigma1,myFID_mu2, myFID_sigma2)


                return fid_value


In [None]:
example_monet = next(iter(monet_ds))
example_photo = next(iter(photo_ds))



plt.subplot(121)
plt.title('Photo')
plt.imshow(example_photo[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Monet')
plt.imshow(example_monet[0] * 0.5 + 0.5)

## **GAN Architecture: CycleGAN**

**CycleGAN** is made up of two GANs, then make a cycle and they rely on each other to try to compute all different types of loss terms.
In fact, the generators have six loss terms in total:
* The least squares adversarial (main loss term).
* The cycle consistency loss, 
* The optional identity loss for each of the generators.

And the discriminators are a bit simpler with just least squares adversarial loss
using a ***PatchGAN*** that you learn from pix2pix. 

**Cycle consistency** is important in transferring of common style elements while maintaining common content across those images, and it is a really, ***really important loss term***.
This can be done by adding that pixel distance loss to the or adversarial loss to encourage cycle consistency in both directions.
Looking at fake zebra to real zebra and fake horse to real horse. The ablation studies show that the cycle consistency loss term in both directions help prevent mode collapse and help with this uwieldy, unpaired image to image translation task. 

## Model Functions

In [None]:
def ContractingBlock(inputs, filters, kernel_size, strides = 2, apply_instancenorm=True):
    initializer = tf.keras.initializers.RandomNormal(0., 0.02)
    gamma_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = L.Conv2D(filters, kernel_size, strides=strides, padding='same',
                             kernel_initializer=initializer, use_bias=False)(inputs)

    if apply_instancenorm:
        result = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(result)

    result = L.LeakyReLU(alpha=0.2)(result)

    return result



def ResnetBlock(input_layer, filters, kernel_size):
    
    out_res_1 = ContractingBlock(input_layer, filters, kernel_size, strides = 1)
    out_res_2 = ContractingBlock(out_res_1, filters, kernel_size, strides = 1)
    
    return out_res_2 + input_layer



def ExpandingBlock(inputs, filters, kernel_size, strides = 2, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = L.Conv2DTranspose(filters, kernel_size, strides=strides, padding='same',
                             kernel_initializer=initializer, use_bias=False)(inputs)

    result = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(result)
    
    if apply_dropout:
        result = L.Dropout(0.5)(result)

    result = L.ReLU()(result)

    return result

## Building the Generator

In [None]:
def Generator():
    """
     Generator Module
    A series of 3 contracting blocks, 9 residual blocks, and 3 expanding blocks to 
    transform an input image into an translated image from the other class, with an upfeature
    layer at the start and a downfeature layer at the end.
    """
    
    inputs = keras.Input(shape=[HEIGHT, WIDTH, CHANNELS])
    
    gen_img = ContractingBlock(inputs, 64, 7, strides=1, apply_instancenorm=False) # (bs, 256, 256, 64)
    gen_img = ContractingBlock(gen_img, 128, 3) # (bs, 128, 128, 128)
    gen_img = ContractingBlock(gen_img, 256, 3) # (bs, 64, 64, 256)
#     gen_img = ContractingBlock(gen_img, 256, 3) # (bs, 32, 32, 256)
    
    
    res_img = ResnetBlock(gen_img, 256, 3)
    res_img = ResnetBlock(res_img, 256, 3)
    res_img = ResnetBlock(res_img, 256, 3)
    res_img = ResnetBlock(res_img, 256, 3)
    res_img = ResnetBlock(res_img, 256, 3)
    res_img = ResnetBlock(res_img, 256, 3)
    res_img = ResnetBlock(res_img, 256, 3)
    res_img = ResnetBlock(res_img, 256, 3)
    res_img = ResnetBlock(res_img, 256, 3)
    
    
    exp_img = ExpandingBlock(res_img, 128, 3)
    exp_img = ExpandingBlock(exp_img, 64, 3)
#     exp_img = ExpandingBlock(exp_img, 3, 7, strides=1)
    initializer = tf.random_normal_initializer(0., 0.02)
    last = L.Conv2DTranspose(3, 7,
                                  strides=1,
                                  padding='same',
                                  kernel_initializer=initializer,
                                  activation='tanh')
    out  = last(exp_img)
#     exp_img = ExpandingBlock(exp_img, 3, 7)
    
    return Model(inputs=inputs, outputs=out)

In [None]:
gen = Generator()
inp = keras.Input(shape=[HEIGHT, WIDTH, CHANNELS])
x = gen(inp)

In [None]:
tf.keras.utils.plot_model(gen, show_shapes=True, dpi=64)

## Build The Discriminator

In [None]:
def Discriminator():
    '''
    Discriminator Module
    Structured like the contracting path of the Generator, the discriminator will
    output a matrix of values classifying corresponding portions of the image as real or fake. 
    Parameters:
        input_channels: the number of image input channels
        hidden_channels: the initial number of discriminator convolutional filters
    '''
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    inp = keras.Input(shape=[256, 256, 3], name='input_image')
    
    x = inp
    
    des_img = ContractingBlock(x, 64, 7)
    des_img = ContractingBlock(x, 128, 4)    
    des_img = ContractingBlock(x, 256, 4)    
    
    zero_pad1 = L.ZeroPadding2D()(des_img) 
    conv = L.Conv2D(512, 4, strides=1,
                         kernel_initializer=initializer,
                         use_bias=False)(zero_pad1) 

    norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(conv)

    leaky_relu = L.LeakyReLU()(norm1)

    zero_pad2 = L.ZeroPadding2D()(leaky_relu) 

    last = L.Conv2D(1, 4, strides=1,
                         kernel_initializer=initializer)(zero_pad2) 

    return tf.keras.Model(inputs=inp, outputs=last)


In [None]:
dis = Discriminator()
inp = keras.Input(shape=[HEIGHT, WIDTH, CHANNELS])
x = dis(inp)

In [None]:
tf.keras.utils.plot_model(dis, show_shapes=True, dpi=64)

## Building the CycleGAN Model

**The CycleGAN architecture works as follows**:

The model architecture is comprised of two generator models: one generator (Generator-A) for generating images for the first domain (Domain-A) and the second generator (Generator-B) for generating images for the second domain (Domain-B).

*   **Generator-A -> Domain-A**
*   **Generator-B -> Domain-B**

The generator models perform image translation, meaning that the image generation process is conditional on an input image, specifically an image from the other domain. Generator-A takes an image from Domain-B as input and Generator-B takes an image from Domain-A as input.

 *   **Domain-B -> Generator-A -> Domain-A**
 *  **Domain-A -> Generator-B -> Domain-B**

Each generator has a corresponding discriminator model.

The first discriminator model (Discriminator-A) takes real images from Domain-A and generated images from Generator-A and predicts whether they are real or fake. The second discriminator model (Discriminator-B) takes real images from Domain-B and generated images from Generator-B and predicts whether they are real or fake.

*    **Domain-A -> Discriminator-A -> [Real/Fake]**
*    **Domain-B -> Generator-A -> Discriminator-A -> [Real/Fake]**
*    **Domain-B -> Discriminator-B -> [Real/Fake]**
*   **Domain-A -> Generator-B -> Discriminator-B -> [Real/Fake]**

The discriminator and generator models are trained in an adversarial zero-sum process, like normal GAN models.

The generated image must retain the property of original image, so if we generate a fake image using a generator say GeneratorA→B then we must be able to get back to original image using the another generator GeneratorB→A - it must satisfy cyclic-consistency.

[Reference](https://machinelearningmastery.com/how-to-develop-cyclegan-models-from-scratch-with-keras/)

In [None]:
with strategy.scope():
    monet_generator = Generator() # transforms photos to Monet-esque paintings
    photo_generator = Generator() # transforms Monet paintings to be more like photos

    monet_discriminator = Discriminator() # differentiates real Monet paintings and generated Monet paintings
    photo_discriminator = Discriminator() # differentiates real photos and generated photos

**Since our generators are not trained yet, the generated Monet-esque photo does not show what is expected at this point.**

In [None]:
to_monet = monet_generator(example_photo)

plt.subplot(1, 2, 1)
plt.title("Original Photo")
plt.imshow(example_photo[0] * 0.5 + 0.5)

plt.subplot(1, 2, 2)
plt.title("Monet-esque Photo")
plt.imshow(to_monet[0] * 0.5 + 0.5)
plt.show()

We will subclass a ***Model*** so that we can run ***fit()*** later to train our model. During the training step, the model transforms a photo to a Monet painting and then back to a photo. The difference between the original photo and the twice-transformed photo is the cycle-consistency loss. We want the original photo and the twice-transformed photo to be similar to one another.


In [None]:
class CycleGan(Model):
    def __init__(
        self,
        monet_generator,
        photo_generator,
        monet_discriminator,
        photo_discriminator,
        lambda_cycle=10,
    ):
        super(CycleGan, self).__init__()
        self.m_gen = monet_generator
        self.p_gen = photo_generator
        self.m_disc = monet_discriminator
        self.p_disc = photo_discriminator
        self.lambda_cycle = lambda_cycle
        
    def compile(
        self,
        m_gen_optimizer,
        p_gen_optimizer,
        m_disc_optimizer,
        p_disc_optimizer,
        gen_loss_fn,
        disc_loss_fn,
        cycle_loss_fn,
        identity_loss_fn
    ):
        super(CycleGan, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
    def train_step(self, batch_data):
        real_monet, real_photo = batch_data
        
        with tf.GradientTape(persistent=True) as tape:
            # photo to monet back to photo
            fake_monet = self.m_gen(real_photo, training=True)
            cycled_photo = self.p_gen(fake_monet, training=True)

            # monet to photo back to monet
            fake_photo = self.p_gen(real_monet, training=True)
            cycled_monet = self.m_gen(fake_photo, training=True)

            # generating itself
            same_monet = self.m_gen(real_monet, training=True)
            same_photo = self.p_gen(real_photo, training=True)

            # discriminator used to check, inputing real images
            disc_real_monet = self.m_disc(real_monet, training=True)
            disc_real_photo = self.p_disc(real_photo, training=True)

            # discriminator used to check, inputing fake images
            disc_fake_monet = self.m_disc(fake_monet, training=True)
            disc_fake_photo = self.p_disc(fake_photo, training=True)

            # evaluates generator loss
            monet_gen_loss = self.gen_loss_fn(disc_fake_monet)
            photo_gen_loss = self.gen_loss_fn(disc_fake_photo)

            # evaluates total cycle consistency loss
            total_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle) + self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle)

            # evaluates total generator loss
            total_monet_gen_loss = monet_gen_loss + total_cycle_loss + self.identity_loss_fn(real_monet, same_monet, self.lambda_cycle)
            total_photo_gen_loss = photo_gen_loss + total_cycle_loss + self.identity_loss_fn(real_photo, same_photo, self.lambda_cycle)

            # evaluates discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)

        # Calculate the gradients for generator and discriminator
        monet_generator_gradients = tape.gradient(total_monet_gen_loss,
                                                  self.m_gen.trainable_variables)
        photo_generator_gradients = tape.gradient(total_photo_gen_loss,
                                                  self.p_gen.trainable_variables)

        monet_discriminator_gradients = tape.gradient(monet_disc_loss,
                                                      self.m_disc.trainable_variables)
        photo_discriminator_gradients = tape.gradient(photo_disc_loss,
                                                      self.p_disc.trainable_variables)

        # Apply the gradients to the optimizer
        self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients,
                                                 self.m_gen.trainable_variables))

        self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients,
                                                 self.p_gen.trainable_variables))

        self.m_disc_optimizer.apply_gradients(zip(monet_discriminator_gradients,
                                                  self.m_disc.trainable_variables))

        self.p_disc_optimizer.apply_gradients(zip(photo_discriminator_gradients,
                                                  self.p_disc.trainable_variables))
        
        return {
            "monet_gen_loss": total_monet_gen_loss,
            "photo_gen_loss": total_photo_gen_loss,
            "monet_disc_loss": monet_disc_loss,
            "photo_disc_loss": photo_disc_loss
        }


## Discriminator loss
### Adversarial loss Part 1

Discriminator must be trained such that recommendation for images from category A must be as close to 1, and vice versa for discriminator B. So Discriminator A would like to minimize ***(DiscriminatorA(a)−1)^2*** and same goes for B as well. 


### Adversarial loss Part 2

Since, discriniator should be able to distinguish between generated and original images, it should also be predicting 0 for images produced by the generator, i.e. Discriminator A wwould like to minimize ***(DiscriminatorA(GeneratorB→A(b)))2***. 

The above two losses can be implemented as follows:


In [None]:
with strategy.scope():
    def discriminator_loss(real, generated):
        real_loss = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)

        generated_loss = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)

        total_disc_loss = real_loss + generated_loss

        return total_disc_loss * 0.5

## Genrator Loss

Generator should eventually be able to fool the discriminator about the authencity of it's generated images. This can done if the recommendation by discriminator for the generated images is as close to 1 as possible. So generator would like to minimize ***(DiscriminatorB(GeneratorA→B(a))−1)2***
So the loss is:

In [None]:
with strategy.scope():
    def generator_loss(generated):
        return tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(generated), generated)

## Cyclic loss

One of the most important one is the cyclic loss that captures that we are able to get the image back using another generator and thus the difference between the original image and the cyclic image should be as small as possible.


In [None]:
with strategy.scope():
    def calc_cycle_loss(real_image, cycled_image, LAMBDA):
        loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

        return LAMBDA * loss1

## Identity Loss

We'll want to measure the change in an image when you pass the generator an example from the target domain instead of the input domain it's expecting. **The output should be the same as the input since it is already of the target domain class.** For example, if you put a horse through a zebra -> horse generator, you'd expect the output to be the same horse because nothing needed to be transformed. It's already a horse! You don't want your generator to be transforming it into any other thing, so you want to encourage this behavior. In encouraging this identity mapping, the authors of CycleGAN found that for some tasks, this helped properly preserve the colors of an image, even when the expected input (here, a zebra) was put in. This is particularly useful for **the photos <-> paintings mapping** and, while an optional aesthetic component, you might find it useful for your applications down the line.

In [None]:
with strategy.scope():
    def identity_loss(real_image, same_image, LAMBDA):
        loss = tf.reduce_mean(tf.abs(real_image - same_image))
        return LAMBDA * 0.5 * loss

## Train the CycleGAN

Let's compile our model. Since we used ***tf.keras.Model*** to build our CycleGAN, we can just use the fit function to train our model.


In [None]:
with strategy.scope():
    monet_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    monet_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
with strategy.scope():
    cycle_gan_model = CycleGan(
        monet_generator, photo_generator, monet_discriminator, photo_discriminator
    )
    cycle_gan_model.compile(
        m_gen_optimizer = monet_generator_optimizer,
        p_gen_optimizer = photo_generator_optimizer,
        m_disc_optimizer = monet_discriminator_optimizer,
        p_disc_optimizer = photo_discriminator_optimizer,
        gen_loss_fn = generator_loss,
        disc_loss_fn = discriminator_loss,
        cycle_loss_fn = calc_cycle_loss,
        identity_loss_fn = identity_loss
    )


In [None]:
cycle_gan_model.load_weights("../input/cyclegan-monet/my_checkpoint.ckpt")    

In [None]:
cycle_gan_model.fit(
    tf.data.Dataset.zip((monet_ds, photo_ds)),
    epochs=50
)



In [None]:
FID(fid_photo_ds,monet_generator) 

In [None]:
_, ax = plt.subplots(6, 2, figsize=(12, 12))
for i, img in enumerate(photo_ds.take(6)):
    prediction = monet_generator(img, training=False)[0].numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

    ax[i, 0].imshow(img)
    ax[i, 1].imshow(prediction)
    ax[i, 0].set_title("Input Photo")
    ax[i, 1].set_title("Monet-esque")
    ax[i, 0].axis("off")
    ax[i, 1].axis("off")
plt.show()

In [None]:
import PIL
! mkdir ./images

In [None]:
i = 1
for img in photo_ds:
    prediction = monet_generator(img, training=False)[0].numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    im = PIL.Image.fromarray(prediction)
    im.save("./images/" + str(i) + ".jpg")
    i += 1

In [None]:
import shutil
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/working/images")

In [None]:
cycle_gan_model.save_weights('./my_checkpoint.ckpt')

In [None]:
print(len(os.listdir("./images")))