CycleGAN without identity loss

## Identifying problems
1. CycleGAN learning is not stable. The more epochs, the worse the result.
* If one of the sets of images is relatively small (only 300 Monet paintings), then overfitting is observed - the discriminator remembers them and accurately determines them. In the case of testing with other real paintings, Monet does not recognize them as the author, confidently considers them to be generated fakes.
* Objective function problem. The correspondence of the restored drawing to the original, as well as the preservation of colors in the generated drawing, distracts too much from the real goal - to generate a Monet painting (or photograph) similar to the real one
* Too many hyperparameters. Also, initial initialization affects the learning outcome.

## Fail: CycleGAN without identity loss

CycleGAN model problem. If we submit a photo of horses along with zebras at the entrance, waiting for zebras in the output (or photos of men and women waiting for photos of women), then the logical result is to take into account the identity, that is, if you submit photos of some zebras (some women) at the entrance, you should get the same photo at the exit (preferably without changes). Just as the basline CycleGAN notebook does. 

But if the changes are significant (for example, we make a topographic plan from satellite images or a cartoon from photographs), then submitting Monet's paintings (or a topographic plan) to the entrance instead of a photo (or satellite images) does not make sense.

My hypothesis: if we remove the identity check that the generator should produce practically unchanged input result if input already meets the requirements (that is, it is already a Monet painting or a topographic plan), then the learning process will look more natural. Why “torment” CycleGAN with an unnecessary component of the loss function?

Identity loss deleted from the model. 

It was hard to estimate result of identity loss absence when FID was growing up after some epoch on basic BCE loss.

After change to WGAN-GP and LSE loss it became clear that identity loss makes better results. Added lambda_id to control it.


## Fail: MSE cycle loss

Loss function problem. I tried to use the mean square error (MSE) for the recovered image (instead of MAE in baseline notebook). If the color difference between the point of the reconstructed drawing is insignificant, it is indistinguishable to the human eye. Squaring will penalize large deviations and reduce the significance of small deviations. At the same time, the disadvantage of MSE - intolerance to outliers in the training data - is completely inapplicable to our problem, since the training data does not contain the “ground truth” answer.

Fail: it is bad idea punish for cycled image deviation. Good transformation means some information lost that cant be reconstructed. Changed to  Huber loss. 



## Early stopping 

Fail: I spent two weeks looking for inception layers MIFID metric evaluated. I inspected pretrained inception/xception/resnet/vgg and their early layers (with max/average pooling) from tensorflow and tensorflow_hub. I didn't find anything related to the competition metric. 


## Additinal metric

In class CycleGan implemented function test_step method to evaluate metrics on validation set.

I use for training 6038 photos and 300 Monet paintings. Validation set consists of 1000 photos and 1367 Monet paintings from  monet-tfrecords-256x256 kaggle dataset

I observe parts of total_cycle_loss: cycle_loss_mpm and cycle_loss_pmp - deviation of the Monet-Photo-Monet and Photo-Monet-Photo transformations

Observation: when comparing the MAE deviation of the Monet-Photo-Monet and Photo-Monet-Photo transformations for the training and test data, I noticed the following:
* for the training data, MAE of the transformations are approximately equal  (cycle_loss_mpm ~= cycle_loss_pmp);
* for the test data, the MAE of the Monet-Photo-Monet transformation is almost twice as high as the MAE of the Photo-Monet-Photo (val_cycle_loss_mpm >> val_cycle_loss_pmp);


Fail: excluding 1000 photos from training set is a very bad idea for the competition metric. I removed model evaluation on the test set


## Fail: count Dicriminator and Generator loss for Real_Monet <-> Cycled_Monet and Real_Photo <-> Cycled_Photo

It cant help with overfitting Discriminator on Real_Monet<->Fake_Monet 

## Fail: Wasserstain loss WGAN-GP for discriminator and generator

I tried to use WGAN-GP loss (versions 56 - 78). 
Positive changes: FID metric does not grow into infinity after some epoch.
Negative: competition metric is higher then 44

## Current experiment: LSE loss for discriminator and generator

Competition score is 41 from the first version. FID metric is decreasing with each epoch. 



## Other minor changes

* fast generation of submission images (fast_photo_ds dataset with other batchsize than for training)


## Result

CycleGAN without identity loss is bad idea for this competition.  The reasons to use the identity loss: to preserve the color composition. 
Final version of the notebook uses identity loss with lambda_id



In [None]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_probability as tfp
import os, random, json, PIL, shutil, re, gc
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
from scipy import linalg
from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt
import numpy as np
def seed_everything(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'
SEED = 0
seed_everything(SEED)


%matplotlib inline
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

AUTOTUNE = tf.data.experimental.AUTOTUNE
tf.__version__


In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path("gan-getting-started")

# GCS_PATH_1367 = KaggleDatasets().get_gcs_path("monet-tfrecords-256x256")
# GCS_PATH_1193 = KaggleDatasets().get_gcs_path("tfrecords-monet-paintings-256x256")

MONET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))
PHOTO_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/photo_tfrec/*.tfrec'))
# MONET_FILENAMES_1193 = tf.io.gfile.glob(str(GCS_PATH_1193 + '/mon*.tfrec'))
# MONET_FILENAMES_1367 = tf.io.gfile.glob(str(GCS_PATH_1367 + '/mon*.tfrec'))


print('Monet TFRecord Files:', len(MONET_FILENAMES))
print('Monet TFRecord Files:', MONET_FILENAMES)

print('Photo TFRecord Files:', len(PHOTO_FILENAMES))
print('Photo Files:', PHOTO_FILENAMES)


# print('Monet 1193 TFRecord Files:', len(MONET_FILENAMES_1193))
# print('Monet 1193 TFRecord Files:', MONET_FILENAMES_1193)

# print('Monet 1367 TFRecord Files:', len(MONET_FILENAMES_1367))
# print('Monet 1367 TFRecord Files:', MONET_FILENAMES_1367)


In [None]:
IMAGE_SIZE = [256, 256]

def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_tfrecord(example):
    tfrecord_format = {
        "image": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

In [None]:
def data_augment_color(image):
    image = tf.image.random_flip_left_right(image)
    image = (image + 1) / 2
    image = tf.image.random_saturation(image, 0.7, 1.2)
    image = tf.clip_by_value(image, 0, 1) 
    image = (image - 0.5) * 2    
    return image

In [None]:
###### from pats notebook https://www.kaggle.com/swepat/cyclegan-to-generate-monet-style-images   #############
def data_augment(image):
    p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    
    # Apply jitter
    if p_crop > .5:
        image = tf.image.resize(image, [286, 286])
        image = tf.image.random_crop(image, size=[256, 256, 3])
        if p_crop > .9:
            image = tf.image.resize(image, [300, 300])
            image = tf.image.random_crop(image, size=[256, 256, 3])
    
    # Random rotation
    if p_rotate > .9:
        image = tf.image.rot90(image, k=3) # rotate 270º
    elif p_rotate > .7:
        image = tf.image.rot90(image, k=2) # rotate 180º
    elif p_rotate > .5:
        image = tf.image.rot90(image, k=1) # rotate 90º
    
    # Random mirroring
    if p_spatial > .6:
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_flip_up_down(image)
        if p_spatial > .9:
            image = tf.image.transpose(image)
    
    return image

In [None]:
BATCH_SIZE = 1
# EPOCHS = 5

def load_dataset(filenames):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    return dataset

photo_ds = load_dataset(PHOTO_FILENAMES)
monet_ds = load_dataset(MONET_FILENAMES)

# photo_ds_val=photo_ds.skip(6038)
# monet_ds_val = load_dataset(MONET_FILENAMES_1367)
# val_ds=tf.data.Dataset.zip((monet_ds_val, photo_ds_val)).batch(64)

# photo_ds=photo_ds.take(6038)

monet_ds = monet_ds.repeat()
photo_ds = photo_ds.repeat()

# monet_ds = monet_ds.map(data_augment_color, num_parallel_calls=AUTOTUNE)
# photo_ds = photo_ds.map(data_augment, num_parallel_calls=AUTOTUNE)

# photo_ds=photo_ds.batch(BATCH_SIZE)
# monet_ds = monet_ds.batch(BATCH_SIZE)
  
gan_ds = tf.data.Dataset.zip((monet_ds, photo_ds)).batch(BATCH_SIZE).prefetch(AUTOTUNE)

fast_photo_ds = load_dataset(PHOTO_FILENAMES).batch(32*strategy.num_replicas_in_sync).prefetch(32)

monet_ds_fid = load_dataset(MONET_FILENAMES).batch(32*strategy.num_replicas_in_sync).prefetch(32)


In [None]:
with strategy.scope():
#         inception_model = tf.keras.applications.InceptionV3(input_shape=(256,256,3),pooling="avg",include_top=False)
    inception_model = tf.keras.applications.InceptionV3(input_shape=(256,256,3),pooling="avg",include_top=False)

    mix3  = inception_model.get_layer("mixed3").output
    f0 = tf.keras.layers.GlobalMaxPooling2D()(mix3)

    inception_model = tf.keras.Model(inputs=inception_model.input, outputs=f0)
    inception_model.trainable = False
    
def calculate_activation_statistics_mod(images,fid_model):
        act=fid_model.predict(images)
        mu = np.mean(act, axis=0)
        sigma = np.cov(act, rowvar=False)
        return mu, sigma
myFID_mu2, myFID_sigma2 = calculate_activation_statistics_mod(monet_ds_fid,inception_model)

In [None]:
print(myFID_mu2.shape,myFID_sigma2.shape)


In [None]:
def calculate_frechet_distance(mu1,sigma1,mu2,sigma2):
        fid_epsilon = 1e-14
        mu1 = np.atleast_1d(mu1)
        mu2 = np.atleast_1d(mu2)
        sigma1 = np.atleast_2d(sigma1)
        sigma2 = np.atleast_2d(sigma2)

        assert mu1.shape == mu2.shape, 'Training and test mean vectors have different lengths'
        assert sigma1.shape == sigma2.shape, 'Training and test covariances have different dimensions'

        # product might be almost singular
        covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
        if not np.isfinite(covmean).all():
            msg = f'fid calculation produces singular product; adding {fid_epsilon} to diagonal of cov estimates'
            warnings.warn(msg)
            offset = np.eye(sigma1.shape[0]) * fid_epsilon
            covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
            
        # numerical error might give slight imaginary component
        if np.iscomplexobj(covmean):
            if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
                m = np.max(np.abs(covmean.imag))
                raise ValueError(f'Imaginary component {m}')
            covmean = covmean.real
        tr_covmean = np.trace(covmean)
        return (mu1 - mu2).dot(mu1 - mu2) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean


    
def FID(images,gen_model,inception_model=inception_model,myFID_mu2=myFID_mu2, myFID_sigma2=myFID_sigma2):
            with strategy.scope():
                inp = layers.Input(shape=[256, 256, 3], name='input_image')
                x  = gen_model(inp)
                x=inception_model(x)
                fid_model = tf.keras.Model(inputs=inp, outputs=x)
                
            mu1, sigma1 = calculate_activation_statistics_mod(images,fid_model)

            fid_value = calculate_frechet_distance(mu1, sigma1,myFID_mu2, myFID_sigma2)

            return fid_value

In [None]:

OUTPUT_CHANNELS = 3

def downsample(filters, size, apply_instancenorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

    if apply_instancenorm:
        result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    result.add(layers.LeakyReLU())

    return result

In [None]:
def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2DTranspose(filters, size, strides=2,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))

    result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    if apply_dropout:
        result.add(layers.Dropout(0.5))

    result.add(layers.ReLU())

    return result

In [None]:
def Generator():
    inputs = layers.Input(shape=[256,256,3])

    # bs = batch size
    down_stack = [
        downsample(64, 4, apply_instancenorm=False), # (bs, 128, 128, 64)
        downsample(128, 4), # (bs, 64, 64, 128)
        downsample(256, 4), # (bs, 32, 32, 256)
        downsample(512, 4), # (bs, 16, 16, 512)
        downsample(512, 4), # (bs, 8, 8, 512)
        downsample(512, 4), # (bs, 4, 4, 512)
        downsample(512, 4), # (bs, 2, 2, 512)
        downsample(512, 4), # (bs, 1, 1, 512)
    ]

    up_stack = [
        upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
        upsample(512, 4), # (bs, 16, 16, 1024)
        upsample(256, 4), # (bs, 32, 32, 512)
        upsample(128, 4), # (bs, 64, 64, 256)
        upsample(64, 4), # (bs, 128, 128, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                  strides=2,
                                  padding='same',
                                  kernel_initializer=initializer,
                                  activation='tanh') # (bs, 256, 256, 3)

    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip])

    x = last(x)

    return keras.Model(inputs=inputs, outputs=x)

In [None]:
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    inp = layers.Input(shape=[256, 256, 3], name='input_image')

    x = inp

    down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
    down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
    down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)

    zero_pad1 = layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
    conv = layers.Conv2D(512, 4, strides=1,
                         kernel_initializer=initializer,
                         use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

    norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(conv)

    leaky_relu = layers.LeakyReLU()(norm1)

    zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    last = layers.Conv2D(1, 4, strides=1,
                         kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)
    
    

    return tf.keras.Model(inputs=inp, outputs=last)



In [None]:
with strategy.scope():
    monet_generator = Generator() # transforms photos to Monet-esque paintings
    photo_generator = Generator() # transforms Monet paintings to be more like photos

    monet_discriminator = Discriminator() # differentiates real Monet paintings and generated Monet paintings
    photo_discriminator = Discriminator() # differentiates real photos and generated photos

In [None]:
class CycleGan(keras.Model):
    def __init__(
        self,
        monet_generator,
        photo_generator,
        monet_discriminator,
        photo_discriminator,
        lambda_cycle=1,
        lambda_id=0.3,
#         lambda_GP=10,        
    ):
        super(CycleGan, self).__init__()
        self.m_gen = monet_generator
        self.p_gen = photo_generator
        self.m_disc = monet_discriminator
        self.p_disc = photo_discriminator
        self.lambda_cycle = lambda_cycle
        self.lambda_id = lambda_id
#         self.lambda_GP = lambda_GP


        
    def compile(
        self,
        m_gen_optimizer,
        p_gen_optimizer,
        m_disc_optimizer,
        p_disc_optimizer,
        gen_loss_fn,
        disc_loss_fn,
        cycle_loss_fn,
        identity_loss_fn
    ):
        super(CycleGan, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
    
    def train_step(self, batch_data):
        real_monet, real_photo = batch_data
        
        with tf.GradientTape(persistent=True) as tape:
            # photo to monet back to photo
            
#             noise = tf.random.normal(shape = (256,256,3), mean = 0.0, stddev = 0.02, dtype = tf.float32) 
            
            fake_monet = self.m_gen(real_photo, training=True)
#             cycled_photo = self.p_gen(fake_monet+noise, training=True)
            cycled_photo = self.p_gen(fake_monet, training=True)

            # monet to photo back to monet
            fake_photo = self.p_gen(real_monet, training=True)
#             cycled_monet = self.m_gen(fake_photo+noise, training=True)
            cycled_monet = self.m_gen(fake_photo, training=True)

            # generating itself
#             same_monet = self.m_gen(real_monet, training=True)
#             same_photo = self.p_gen(real_photo, training=True)

            # discriminator used to check, inputing real images
            disc_real_monet = self.m_disc(real_monet, training=True)
            disc_real_photo = self.p_disc(real_photo, training=True)

            # discriminator used to check, inputing fake images
            disc_fake_monet = self.m_disc(fake_monet, training=True)
            disc_fake_photo = self.p_disc(fake_photo, training=True)

            # discriminator used to check, inputing cycled images
#             disc_cycled_monet = self.m_disc(cycled_monet, training=True)
#             disc_cycled_photo = self.p_disc(cycled_photo, training=True)

            
            # evaluates generator loss
            monet_gen_loss = self.gen_loss_fn(disc_fake_monet)
            photo_gen_loss = self.gen_loss_fn(disc_fake_photo)
#             monet_gen_loss = self.gen_loss_fn(disc_fake_monet)+self.gen_loss_fn(disc_cycled_monet)
#             photo_gen_loss = self.gen_loss_fn(disc_fake_photo)+self.gen_loss_fn(disc_cycled_photo)
#             monet_gen_loss = -disc_fake_monet-disc_cycled_monet # W
#             photo_gen_loss = -disc_fake_photo-disc_cycled_photo # W
#             monet_gen_loss = -disc_fake_monet # W
#             photo_gen_loss = -disc_fake_photo # W


            # evaluates total cycle consistency loss
            cycle_loss_mpm = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle)
            cycle_loss_pmp = self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle)
            total_cycle_loss = cycle_loss_mpm + cycle_loss_pmp

            # evaluates total generator loss
#             total_monet_gen_loss = monet_gen_loss + total_cycle_loss + self.identity_loss_fn(real_photo+noise, fake_monet, self.lambda_id)
#             total_photo_gen_loss = photo_gen_loss + total_cycle_loss + self.identity_loss_fn(real_monet+noise, fake_photo, self.lambda_id)
#             total_monet_gen_loss = monet_gen_loss + total_cycle_loss 
#             total_photo_gen_loss = photo_gen_loss + total_cycle_loss 
            total_monet_gen_loss = monet_gen_loss + total_cycle_loss + self.identity_loss_fn(real_monet, fake_photo, self.lambda_id)
            total_photo_gen_loss = photo_gen_loss + total_cycle_loss + self.identity_loss_fn(real_photo, fake_monet, self.lambda_id)

            # evaluates discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
#             monet_disc_loss2 = self.disc_loss_fn(disc_real_monet, disc_cycled_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)
#             photo_disc_loss2 = self.disc_loss_fn(disc_real_photo, disc_cycled_photo)
            
            # evaluates discriminator loss gradient punishment
#             alpha_m_f = tf.random.uniform(shape=[BATCH_SIZE,1,1,1], minval=0.,maxval=1.)
#             alpha_p_f = tf.random.uniform(shape=[BATCH_SIZE,1,1,1], minval=0.,maxval=1.)
#             alpha_m_c = tf.random.uniform(shape=[BATCH_SIZE,1,1,1], minval=0.,maxval=1.)
#             alpha_p_c = tf.random.uniform(shape=[BATCH_SIZE,1,1,1], minval=0.,maxval=1.)
            
#             m_f_hat = alpha_m_f * real_monet+ (1.0-alpha_m_f) * fake_monet
#             m_c_hat = alpha_m_c * real_monet+ (1.0-alpha_m_c) * cycled_monet
#             p_f_hat = alpha_p_f * real_photo+ (1.0-alpha_p_f) * fake_photo
#             p_c_hat = alpha_p_c * real_photo+ (1.0-alpha_p_c) * cycled_photo
            
#             d_m_f_hat = self.m_disc(m_f_hat, training=False) # changed from False
#             d_m_c_hat = self.m_disc(m_c_hat, training=False)
#             d_p_f_hat = self.p_disc(p_f_hat, training=False)
#             d_p_c_hat = self.p_disc(p_c_hat, training=False)

 
            # mayby maximum with 0?
#             GP_m_f = tf.reduce_mean((tf.sqrt(tf.reduce_sum(tf.gradients(d_m_f_hat,m_f_hat)[0]**2,axis=[1,2,3]))-1.0)**2)
#             GP_m_f = tf.reduce_mean(tf.maximum(tf.sqrt(tf.reduce_sum(tf.gradients(d_m_f_hat,m_f_hat)[0]**2,axis=[1,2,3]))-1.0,0)**2)

#             GP_m_c = tf.reduce_mean(tf.maximum(tf.sqrt(tf.reduce_sum(tf.gradients(d_m_c_hat,m_c_hat)[0]**2,axis=[1,2,3]))-1.0,0)**2)
#             GP_p_f = tf.reduce_mean((tf.sqrt(tf.reduce_sum(tf.gradients(d_p_f_hat,p_f_hat)[0]**2,axis=[1,2,3]))-1.0)**2)
#             GP_p_f = tf.reduce_mean(tf.maximum(tf.sqrt(tf.reduce_sum(tf.gradients(d_p_f_hat,p_f_hat)[0]**2,axis=[1,2,3]))-1.0,0)**2)

        #             GP_p_c = tf.reduce_mean(tf.maximum(tf.sqrt(tf.reduce_sum(tf.gradients(d_p_c_hat,p_c_hat)[0]**2,axis=[1,2,3]))-1.0,0)**2)

            # evaluates discriminator loss
#             monet_disc_loss = -2*disc_real_monet + disc_fake_monet + disc_cycled_monet + self.lambda_GP*GP_m_f + self.lambda_GP*GP_m_c # W
#             photo_disc_loss = -2*disc_real_photo + disc_fake_photo + disc_cycled_photo + self.lambda_GP*GP_p_f + self.lambda_GP*GP_p_c # W            

#             monet_disc_loss = -disc_real_monet + disc_fake_monet  + self.lambda_GP*GP_m_f # W
#             photo_disc_loss = -disc_real_photo + disc_fake_photo  + self.lambda_GP*GP_p_f # W            


            
        # Calculate the gradients for generator and discriminator
        monet_generator_gradients = tape.gradient(total_monet_gen_loss,
                                                  self.m_gen.trainable_variables)
        photo_generator_gradients = tape.gradient(total_photo_gen_loss,
                                                  self.p_gen.trainable_variables)

        monet_discriminator_gradients = tape.gradient(monet_disc_loss,
                                                      self.m_disc.trainable_variables)
        photo_discriminator_gradients = tape.gradient(photo_disc_loss,
                                                      self.p_disc.trainable_variables)

        # Apply the gradients to the optimizer
        self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients,
                                                 self.m_gen.trainable_variables))

        self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients,
                                                 self.p_gen.trainable_variables))

        self.m_disc_optimizer.apply_gradients(zip(monet_discriminator_gradients,
                                                  self.m_disc.trainable_variables))

        self.p_disc_optimizer.apply_gradients(zip(photo_discriminator_gradients,
                                                  self.p_disc.trainable_variables))
        
        return {
            "monet_disc_loss": monet_disc_loss,
#             "monet_disc_loss2": -disc_real_monet+disc_cycled_monet,
            "photo_disc_loss": photo_disc_loss,
#             "photo_disc_loss2": -disc_real_photo+disc_cycled_photo,
            "cycle_loss_mpm": cycle_loss_mpm,
            "cycle_loss_pmp": cycle_loss_pmp,
            "disc_real_monet": disc_real_monet,
            "disc_fake_monet": disc_fake_monet,            
            "disc_real_photo": disc_real_photo,            
            "disc_fake_photo": disc_fake_photo,            
        }
    

In [None]:
with strategy.scope():
    def discriminator_loss(real, generated):
        real_loss = tf.square(tf.ones_like(real) - real)

        generated_loss = tf.square(tf.zeros_like(generated) - generated)

        total_disc_loss = real_loss + generated_loss

        return total_disc_loss * 0.5


In [None]:
with strategy.scope():
    def generator_loss(generated):
        return tf.square(tf.ones_like(generated) - generated)

In [None]:
with strategy.scope():
    def calc_cycle_loss(real_image, cycled_image, LAMBDA):
        loss1 = tf.reduce_mean(tf.abs(inception_model(real_image) - inception_model(cycled_image)))
        return LAMBDA * loss1

In [None]:
with strategy.scope():
    def identity_loss(real_image, translated_image, LAMBDA):
        loss = tf.reduce_mean(tf.abs(tf.nn.avg_pool2d(real_image, ksize=16, strides=4, padding="VALID") - tf.nn.avg_pool2d(translated_image, ksize=16, strides=4, padding="VALID")))
        return LAMBDA *  loss

In [None]:
with strategy.scope():

    monet_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    monet_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
with strategy.scope():
    cycle_gan_model = CycleGan(
        monet_generator, photo_generator, monet_discriminator, photo_discriminator
    )

    cycle_gan_model.compile(
        m_gen_optimizer = monet_generator_optimizer,
        p_gen_optimizer = photo_generator_optimizer,
        m_disc_optimizer = monet_discriminator_optimizer,
        p_disc_optimizer = photo_discriminator_optimizer,
        gen_loss_fn = generator_loss,
        disc_loss_fn = discriminator_loss,
        cycle_loss_fn = calc_cycle_loss,
        identity_loss_fn = identity_loss
    )

In [None]:
# callbacks = [keras.callbacks.ModelCheckpoint(filepath='monet.h5',save_weights_only=True,save_best_only=True, monitor='val_total_cycle_loss', verbose=1)]
disc_m_loss1=[]
disc_p_loss1=[]
# cycle_gan_model.built = True
# cycle_gan_model.load_weights('../input/cyclegan-with-dg-pretraining/monet.h5')

In [None]:
%%time
fids=[]
best_fid=999999999
for epoch in range(1,32):

    print("Epoch = ",epoch)
    hist=cycle_gan_model.fit(gan_ds,steps_per_epoch=1500, epochs=1).history
    cur_fid=FID(fast_photo_ds,monet_generator)
#     disc_m_loss1.append(hist["monet_disc_loss1"][0][0][0])
#     disc_p_loss1.append(hist["photo_disc_loss1"][0][0][0])
    fids.append(cur_fid)
    print("After epoch #{} FID = {}\n".format(epoch,cur_fid))
    if cur_fid<best_fid:
            best_fid=cur_fid
            monet_generator.save('monet_generator.h5')


# hist=cycle_gan_model.fit(gan_ds,steps_per_epoch=30, epochs=EPOCHS).history
# hist=cycle_gan_model.fit(gan_ds,steps_per_epoch=30,validation_data=([1]), epochs=3).history


In [None]:
# plt.plot(disc_m_loss1, label='monet_disc_loss1')
# plt.plot(disc_p_loss1, label='photo_disc_loss1')
plt.plot(np.array(fids), label='FID')

plt.legend()
plt.show()

In [None]:
# !conda install -y gdown 
# import gdown 
# url = 'https://drive.google.com/uc?export=download&id=18UWaVxb_UHDMq4KzJqHqGSPizXXy4H7' 
# output = 'photo.jpg'
# gdown.download(url, output)

In [None]:
# cycle_gan_model.built = True
# cycle_gan_model.load_weights('monet.h5')
with strategy.scope():
    monet_generator = tf.keras.models.load_model('monet_generator.h5')

In [None]:
_, ax = plt.subplots(5, 3, figsize=(32, 32))
for i, img in enumerate(photo_ds.batch(1).take(5)):
    prediction = monet_generator(img, training=False)
    cycledphoto = photo_generator(prediction, training=False)
    prediction = (prediction * 127.5 + 127.5)[0].numpy().astype(np.uint8)
    cycledphoto = (cycledphoto * 127.5 + 127.5)[0].numpy().astype(np.uint8)

    img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

    ax[i, 0].imshow(img)
    ax[i, 1].imshow(prediction)
    ax[i, 2].imshow(cycledphoto)
    ax[i, 0].set_title("Input Photo")
    ax[i, 1].set_title("Monet-esque")
    ax[i, 2].set_title("Cycled Photo")
    ax[i, 0].axis("off")
    ax[i, 1].axis("off")
    ax[i, 2].axis("off")

plt.show()

In [None]:
_, ax = plt.subplots(5, 3, figsize=(32, 32))
for i, img in enumerate(monet_ds.batch(1).take(5)):
    prediction = photo_generator(img, training=False)
    cycledphoto = monet_generator(prediction, training=False)
    prediction = (prediction * 127.5 + 127.5)[0].numpy().astype(np.uint8)
    cycledphoto = (cycledphoto * 127.5 + 127.5)[0].numpy().astype(np.uint8)

    img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

    ax[i, 0].imshow(img)
    ax[i, 1].imshow(prediction)
    ax[i, 2].imshow(cycledphoto)
    ax[i, 0].set_title("Input Monet")
    ax[i, 1].set_title("Generated Photo")
    ax[i, 2].set_title("Cycled Monet")
    ax[i, 0].axis("off")
    ax[i, 1].axis("off")
    ax[i, 2].axis("off")

plt.show()

In [None]:
ds_iter = iter(photo_ds.batch(1))
for n_sample in range(8):
        example_sample = next(ds_iter)
        generated_sample = monet_generator(example_sample)
        
        f = plt.figure(figsize=(32, 32))
        
        plt.subplot(121)
        plt.title('Input image')
        plt.imshow(example_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        
        plt.subplot(122)
        plt.title('Generated image')
        plt.imshow(generated_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        plt.show()


In [None]:
ds_iter = iter(monet_ds.batch(1))
for n_sample in range(8):

        example_sample = next(ds_iter)
        generated_sample = photo_generator(example_sample)
        
        f = plt.figure(figsize=(24, 24))
        
        plt.subplot(121)
        plt.title('Input image')
        plt.imshow(example_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        
        plt.subplot(122)
        plt.title('Generated image')
        plt.imshow(generated_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        plt.show()

In [None]:
import PIL
! mkdir ../images

In [None]:
%%time
i = 1
for img in fast_photo_ds:
    prediction = monet_generator(img, training=False).numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    for pred in prediction:
        im = PIL.Image.fromarray(pred)
        im.save("../images/" + str(i) + ".jpg")
        i += 1
    
    


In [None]:
import shutil
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")