# [](http://)CycleGAN with pretraining of generators and discriminators


## Contribution:
I suggest that discriminators and  generators for unpaired image-to-image translation should be pre-trained separately on the problems of classification (for the discriminator) and image scaling (for the generator). It will be advantageous to eliminate such hyperparameter as choosing an initial random initialization.


## Identifying problems

* Cycle consistency loss is the most important problem of CycleGAN.
* Small domain dataset. If one of the sets of images is relatively small (only 300 Monet paintings), then Monet discriminator overfits. The discriminator remembers all 300 of them and accurately determines them. In the case of testing with other real Monet paintings, the discriminator does not recognize them and confidently considers them to be generated fakes.
* Too many hyperparameters. Even initial initialization affects the learning outcome.


## 1. Cycle consistency loss

Why are we using Cycle consistency loss? We assume that translated image can be reconstructed. But this is not the case. Some data has been lost. If we convert zebras to horses, we lose stripes and the models cannot restore stripes in the correct places. If the generator turns the horses back into zebras, it could place stripes in completely wrong places. What does the cycle consistency loss say? A very bad zebra! What will the model do next time? It will encode stripes on the "horses" image to reconstruct later!

To demonstrate the possibilities of encoding, I have inverted the Monet generator's adversarial loss asking "make the least real Monet!". The result is gray squares. But from these gray squares the photo generator reconstructed the original photos!

But this is only one part of the Cycle consistency loss problem. As for image-to-image translation, we want the model to be accurate. A tree must be translated as a tree. All semantics must be preserved. But there is no motivation for this in Cycle consistency loss. The only motivation is to encode the original image into a translated image.

So Cycle consistency loss is useless. Then duplicating generators and discriminators is also useless. We should start with a regular GAN looking for good objective function.

## 2. Random initialization problem

The learning outcome depends on the type of initialization significantly. This is an essential hyperparameter, can we avoid brute-force? My hypothesis: if the generator and the discriminator are pre-trained separately from each other on supervised learning problems, such as classification? Let me give you an example from human activity. The competitive mode significantly improves the development of a person's skills. Human play is an element of learning. But do they start playing without any training? First, children learn the basic skills of walking, running, hitting the ball - and only then they compete in the game of football. Children under one year old do not play football! Then why is the neural network trained exclusively in competition? Maybe, first, the deep layers of the neural network should become suitable for use in similar tasks, in which it is possible to achieve good results, and then the competitive mode will perfect the skills of generating / recognizing fakes, surpassing humans?

## Related works

Unpaired image-to-image translation frameworks have been proposed in 2017 [1, 2, 3]. I found two different approaches: 

* UNIT [2] (Unsupervised Image-to-Image Translation Networks) combines variational autoencoders (VAEs) with Coupled Generative Adversarial Networks (CoGAN) [4], a GAN framework where two generators share weights to learn the joint distribution of images in cross domains; 
* CycleGAN [1] and DiscoGAN [3] preserve key attributes between the input and the translated image by utilizing a cycle consistency loss.

There are many papers about Cycle consistency loss problem. They are usually of two types:
* Improvements of CycleGAN:  [5] (StarGAN transfers images between multiple domains and uses Domain Classification Loss as part of full objective functions), [6] (adding noise for defence against a self-adversarial attack), [7] (training stability using Wasserstein loss), [8] (proposing multi-cycle translation with mutual information constraints - MCMI), [9] (HarmonicGAN adds additional smoothing to CycleGAN), [10] (Lipschitz Regularized CycleGAN for Improving Semantic Robustness in Unpaired Image-to-image Translation). 
* Without cycle consistency loss: DistanceGAN [11], geometry-consistent generative adversarial network GcGAN [12] and  Contrastive learning for unpaired image-to-image translation CUT [13] have been proposed.  [14] propose to learn common latent space with siamese network additional to adversarial network without cycle consistency loss. 

CerfGAN [15] only needs two networks (decoder and multi-class discriminator that works also as encoder) to solve image-to-image translation problems, but it uses reconstruction loss just as CycleGAN does. 

Random initialization problem was noted in [16]: “Training Stability: Domain adaptation approaches that rely on some form of adversarial training are sensitive to random initialization. To address this, we incorporate a task–specific loss trained on both source and generated images and a pixel similarity regularization that allows us to avoid mode collapse and stabilize training. By using these tools, we are able to reduce variance of performance for the same hyperparameters across different random initializations of our model “

Transfer learning for image-to-image translation was noted only in [17]. They introduce DeepI2I model, but it obtains inferior results on limited datasets. To mitigate this problem they propose to initialize deepI2I from a pre-trained GAN. So they are the first to study transfer learning for I2I networks. Another important deepI2I feature is that it doesn't use pixelwise cycle consistency loss. Instead it compares outputs of hidden layers of discriminator. 


## Notebook


In this notebook, I did my best to make CycleGAN work. One of the most important things is LSE loss instead of BCE loss. 
I print the FID metric after every epoch. And I also implemented  generators/discriminators pretraining.

References

[1] J.-Y. Zhu, T. Park, P. Isola, and A. A. Efros. Unpaired image-to-image translation using cycle-consistent adversarial networks. In Proceedings of the IEEE International Conference on Computer Vision (ICCV), 2017. 

[2] M.-Y. Liu, T. Breuel, and J. Kautz. Unsupervised image-to-image translation networks. arXiv preprint arXiv:1703.00848, 2017. 

[3] T. Kim, M. Cha, H. Kim, J. K. Lee, and J. Kim. Learning to discover cross-domain relations with generative adversarial networks. In Proceedings of the 34th International Conference on Machine Learning (ICML), pages 1857–1865, 2017

[4] M.-Y. Liu and O. Tuzel. Coupled generative adversarial networks. In Advances in Neural Information Processing Systems (NIPS), pages 469–477, 2016.

[5] Choi, Y., Choi, M., Kim, M., Ha, J.W., Kim, S., Choo, J.: Stargan: Unified generative adversarial networks for multi-domain image-to-image translation. In: Conference on Computer Vision and Pattern Recognition. pp. 8789–8797 (2018)

[[6] D. Bashkirova, B. Usman, and K. Saenko. Adversarial self-defense for cycle-consistent gans. In Advances in Neural Information Processing Systems, pages 635–645, 2019](https://arxiv.org/abs/1908.01517)

[7] Arjovsky, Martín et al. “Wasserstein GAN.” ArXiv abs/1701.07875 (2017)

[[8] Xu, X. et al. “MCMI: Multi-Cycle Image Translation with Mutual Information Constraints.” ArXiv abs/2007.02919 (2020)](https://arxiv.org/pdf/2007.02919.pdf)

[[9] Rui Zhang, Tomas Pfister, and Jia Li. Harmonic unpaired image-to-image translation. arXiv preprint arXiv:1902.09727, 2019](https://arxiv.org/abs/1902.09727)

[[10] Jia, Z., Yuan, B., Wang, K., Wu, H., Clifford, D., Yuan, Z., & Su, H. (2020). Lipschitz Regularized CycleGAN for Improving Semantic Robustness in Unpaired Image-to-image Translation. ArXiv, abs/2012.04932](https://arxiv.org/abs/2012.04932)

[[11] S. Benaim and L. Wolf. One-sided unsupervised domain mapping. In NIPS, 2017 ](https://arxiv.org/abs/1706.00826)

[[12] Huan Fu, Mingming Gong, Chaohui Wang, Kayhan Batmanghelich, Kun Zhang, and Dacheng Tao. Geometry-consistent generative adversarial networks for one-sided unsupervised domain mapping. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 2427–2436, 2019](https://arxiv.org/abs/1809.05852)

[[13] Taesung Park, Alexei A Efros, Richard Zhang, and JunYan Zhu. Contrastive learning for unpaired image-to-image translation. In European Conference on Computer Vision, pages 319–345. Springer, 2020](https://arxiv.org/abs/2007.15651)

[[14] Amodio, M., Krishnaswamy, S.: Travelgan: Image-to-image translation by transformation vector learning. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. pp. 8983–8992 (2019)](https://arxiv.org/abs/1902.09631)

[[15] X. Liu, S. Zhang, H. Liu, X. Liu, and R. Ji. Cerfgan: A compact, effective, robust, and fast model for unsupervised multi-domain image-to-image translation. arXiv preprint arXiv:1805.10871v2, 2018](https://arxiv.org/pdf/1805.10871.pdf)

[[16] Konstantinos Bousmalis, Nathan Silberman, David Dohan, Dumitru Erhan, and Dilip Krishnan. Unsupervised pixel-level domain adaptation with generative adversarial networks. In The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), July 2017b.](https://arxiv.org/pdf/1612.05424)

[[17] Wang, Y., Yu, L., & Weijer, J.V. (2020). DeepI2I: Enabling Deep Hierarchical Image-to-Image Translation by Transferring from GANs. ArXiv, abs/2011.05867.](https://arxiv.org/pdf/2011.05867.pdf)



In [None]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_probability as tfp
import os, random, json, PIL, shutil, re, gc
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
from scipy import linalg
from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt
import numpy as np
def seed_everything(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'
SEED = 0
seed_everything(SEED)


%matplotlib inline
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

AUTOTUNE = tf.data.experimental.AUTOTUNE
tf.__version__


In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path("gan-getting-started")
GCS_PATH_1367 = KaggleDatasets().get_gcs_path("monet-tfrecords-256x256")

MONET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))
PHOTO_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/photo_tfrec/*.tfrec'))

MONET_FILENAMES_1367 = tf.io.gfile.glob(str(GCS_PATH_1367 + '/mon*.tfrec'))

m=MONET_FILENAMES_1367
p=PHOTO_FILENAMES
PRETRAIN_FILENAMES=[m[0],p[0],m[1],p[1],m[2],p[2],m[3],p[3],m[0],p[4],m[1],p[5],m[2],p[6],m[3],p[7],m[0],p[8],m[1],p[9],m[2],p[10],m[3],p[11],m[0],p[12],m[1],p[13],m[2],p[14],m[3],p[15],m[0],p[16],m[1],p[17],m[2],p[18],m[4],p[19]]
PRETRAIN_LABELS=[1]*274+[0]*352+[1]*274+[0]*352+[1]*274+[0]*352+[1]*274+[0]*352+[1]*274+[0]*352+[1]*274+[0]*352+[1]*274+[0]*352+[1]*274+[0]*352+[1]*274+[0]*352+[1]*274+[0]*352+[1]*274+[0]*352+[1]*274+[0]*352+[1]*274+[0]*352+[1]*274+[0]*352+[1]*274+[0]*352+[1]*274+[0]*352+[1]*274+[0]*352+[1]*274+[0]*352+[1]*274+[0]*352+[1]*271+[0]*350


In [None]:
IMAGE_SIZE = [256, 256]

def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_tfrecord(example):
    tfrecord_format = {
        "image": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

In [None]:
def data_augment_color(image):
    image = tf.image.random_flip_left_right(image)
    image = (image + 1) / 2
    image = tf.image.random_saturation(image, 0.7, 1.2)
    image = tf.clip_by_value(image, 0, 1) 
    image = (image - 0.5) * 2    
    return image


def pretrain_data_augment(image):

        image = tf.image.resize(image, [286, 286])
        image = tf.image.random_crop(image, size=[256, 256, 3])

   
        return image

def to3030(label):
    return np.ones((30,30,1),dtype="float32")*label

In [None]:
###### from pats notebook https://www.kaggle.com/swepat/cyclegan-to-generate-monet-style-images   #############
def data_augment(image):
    p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    
    # Apply jitter
    if p_crop > .5:
        image = tf.image.resize(image, [286, 286])
        image = tf.image.random_crop(image, size=[256, 256, 3])
        if p_crop > .9:
            image = tf.image.resize(image, [300, 300])
            image = tf.image.random_crop(image, size=[256, 256, 3])
    
    # Random rotation
    if p_rotate > .9:
        image = tf.image.rot90(image, k=3) # rotate 270º
    elif p_rotate > .7:
        image = tf.image.rot90(image, k=2) # rotate 180º
    elif p_rotate > .5:
        image = tf.image.rot90(image, k=1) # rotate 90º
    
    # Random mirroring
    if p_spatial > .6:
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_flip_up_down(image)
        if p_spatial > .9:
            image = tf.image.transpose(image)
    
    return image

In [None]:
BATCH_SIZE = 1
# EPOCHS = 5

def load_dataset(filenames):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    return dataset

photo_ds = load_dataset(PHOTO_FILENAMES)
monet_ds = load_dataset(MONET_FILENAMES)

# photo_ds_val=photo_ds.skip(6038)
# monet_ds_val = load_dataset(MONET_FILENAMES_1367)
# val_ds=tf.data.Dataset.zip((monet_ds_val, photo_ds_val)).batch(64)

# photo_ds=photo_ds.take(6038)

monet_ds = monet_ds.repeat()
photo_ds = photo_ds.repeat()

# monet_ds = monet_ds.map(data_augment, num_parallel_calls=AUTOTUNE)
# photo_ds = photo_ds.map(data_augment, num_parallel_calls=AUTOTUNE)

# photo_ds=photo_ds.batch(BATCH_SIZE)
# monet_ds = monet_ds.batch(BATCH_SIZE)
  
gan_ds = tf.data.Dataset.zip((monet_ds, photo_ds)).batch(BATCH_SIZE).prefetch(AUTOTUNE)

fast_photo_ds = load_dataset(PHOTO_FILENAMES).batch(32*strategy.num_replicas_in_sync).prefetch(32)

monet_ds_fid = load_dataset(MONET_FILENAMES).batch(32*strategy.num_replicas_in_sync).prefetch(32)


pretrain_ds = load_dataset(PRETRAIN_FILENAMES)

pretrain_labels=tf.data.Dataset.from_tensor_slices(np.array(PRETRAIN_LABELS,dtype="float32"))
pretrain_labels=pretrain_labels.map(to3030,num_parallel_calls=AUTOTUNE)
pretrain_trans=pretrain_ds.map(pretrain_data_augment, num_parallel_calls=AUTOTUNE)
pretrain_ds=tf.data.Dataset.zip((pretrain_ds, pretrain_trans,pretrain_labels))
val_pretrain_ds=pretrain_ds.skip(11894).batch(32, drop_remainder=True).prefetch(2)
pretrain_ds=pretrain_ds.take(11894).shuffle(400).batch(32, drop_remainder=True).prefetch(2)

In [None]:
with strategy.scope():
#         inception_model = tf.keras.applications.InceptionV3(input_shape=(256,256,3),pooling="avg",include_top=False)
    inception_model = tf.keras.applications.InceptionV3(input_shape=(256,256,3),pooling="avg",include_top=False)

#     mix3  = inception_model.get_layer("mixed4").output
#     f0 = tf.keras.layers.GlobalMaxPooling2D()(mix3)

#     inception_model = tf.keras.Model(inputs=inception_model.input, outputs=f0)
    inception_model.trainable = False

    
    
def calculate_activation_statistics_mod(images,fid_model):
        act=fid_model.predict(images)
        mu = np.mean(act, axis=0)
        sigma = np.cov(act, rowvar=False)
        return mu, sigma
myFID_mu2, myFID_sigma2 = calculate_activation_statistics_mod(monet_ds_fid,inception_model)

In [None]:
print(myFID_mu2.shape,myFID_sigma2.shape)


In [None]:
def calculate_frechet_distance(mu1,sigma1,mu2,sigma2):
        fid_epsilon = 1e-14
        mu1 = np.atleast_1d(mu1)
        mu2 = np.atleast_1d(mu2)
        sigma1 = np.atleast_2d(sigma1)
        sigma2 = np.atleast_2d(sigma2)

        assert mu1.shape == mu2.shape, 'Training and test mean vectors have different lengths'
        assert sigma1.shape == sigma2.shape, 'Training and test covariances have different dimensions'

        # product might be almost singular
        covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
        if not np.isfinite(covmean).all():
            msg = f'fid calculation produces singular product; adding {fid_epsilon} to diagonal of cov estimates'
            warnings.warn(msg)
            offset = np.eye(sigma1.shape[0]) * fid_epsilon
            covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
            
        # numerical error might give slight imaginary component
        if np.iscomplexobj(covmean):
            if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
                m = np.max(np.abs(covmean.imag))
                raise ValueError(f'Imaginary component {m}')
            covmean = covmean.real
        tr_covmean = np.trace(covmean)
        return (mu1 - mu2).dot(mu1 - mu2) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean


    
def FID(images,gen_model,inception_model=inception_model,myFID_mu2=myFID_mu2, myFID_sigma2=myFID_sigma2):
            with strategy.scope():
                inp = layers.Input(shape=[256, 256, 3], name='input_image')
                x  = gen_model(inp)
                x=inception_model(x)
                fid_model = tf.keras.Model(inputs=inp, outputs=x)
                
            mu1, sigma1 = calculate_activation_statistics_mod(images,fid_model)

            fid_value = calculate_frechet_distance(mu1, sigma1,myFID_mu2, myFID_sigma2)

            return fid_value

In [None]:

OUTPUT_CHANNELS = 3

def downsample(filters, size, apply_instancenorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

    if apply_instancenorm:
        result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    result.add(layers.LeakyReLU())

    return result

In [None]:
def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2DTranspose(filters, size, strides=2,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))

    result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    if apply_dropout:
        result.add(layers.Dropout(0.5))

    result.add(layers.ReLU())

    return result

In [None]:
def Generator():
    inputs = layers.Input(shape=[256,256,3])

    # bs = batch size
    down_stack = [
        downsample(64, 4, apply_instancenorm=False), # (bs, 128, 128, 64)
        downsample(128, 4), # (bs, 64, 64, 128)
        downsample(256, 4), # (bs, 32, 32, 256)
        downsample(512, 4), # (bs, 16, 16, 512)
        downsample(512, 4), # (bs, 8, 8, 512)
        downsample(512, 4), # (bs, 4, 4, 512)
        downsample(512, 4), # (bs, 2, 2, 512)
        downsample(512, 4), # (bs, 1, 1, 512)
    ]

    up_stack = [
        upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
        upsample(512, 4), # (bs, 16, 16, 1024)
        upsample(256, 4), # (bs, 32, 32, 512)
        upsample(128, 4), # (bs, 64, 64, 256)
        upsample(64, 4), # (bs, 128, 128, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                  strides=2,
                                  padding='same',
                                  kernel_initializer=initializer,
                                  activation='tanh') # (bs, 256, 256, 3)

    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip])

    x = last(x)

    return keras.Model(inputs=inputs, outputs=x)

In [None]:
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    inp = layers.Input(shape=[256, 256, 3], name='input_image')

    x = inp

    down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
    down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
    down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)

    zero_pad1 = layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
    conv = layers.Conv2D(512, 4, strides=1,
                         kernel_initializer=initializer,
                         use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

    norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(conv)

    leaky_relu = layers.LeakyReLU()(norm1)

    zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    last = layers.Conv2D(1, 4, strides=1,
                         kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)
    
    

    return tf.keras.Model(inputs=inp, outputs=last)



In [None]:
with strategy.scope():
    monet_generator = Generator() # transforms photos to Monet-esque paintings
    photo_generator = Generator() # transforms Monet paintings to be more like photos

    monet_discriminator = Discriminator() # differentiates real Monet paintings and generated Monet paintings
    photo_discriminator = Discriminator() # differentiates real photos and generated photos

In [None]:
class CycleGan(keras.Model):
    def __init__(
        self,
        monet_generator,
        photo_generator,
        monet_discriminator,
        photo_discriminator,
        monet_expert,
        lambda_cycle=10,
        lambda_id=3,
#         lambda_GP=10,        
    ):
        super(CycleGan, self).__init__()
        self.m_gen = monet_generator
        self.p_gen = photo_generator
        self.m_disc = monet_discriminator
        self.p_disc = photo_discriminator
        self.monet_expert = monet_expert
        self.lambda_cycle = lambda_cycle
        self.lambda_id = lambda_id
#         self.lambda_GP = lambda_GP


        
    def compile(
        self,
        m_gen_optimizer,
        p_gen_optimizer,
        m_disc_optimizer,
        p_disc_optimizer,
        gen_loss_fn,
        disc_loss_fn,
        cycle_loss_fn,
        identity_loss_fn,
        expert_loss_fn,
        
    ):
        super(CycleGan, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        self.expert_loss_fn = expert_loss
        
    
    def train_step(self, batch_data):
        real_monet, real_photo = batch_data
      
        with tf.GradientTape(persistent=True) as tape:
            # photo to monet 
            
            fake_monet = self.m_gen(real_photo, training=True)
            fake_photo = self.p_gen(real_monet, training=True)
            
            same_monet = self.m_gen(real_monet, training=True)
            same_photo = self.p_gen(real_photo, training=True)

            # discriminator used to check, inputing real images
            disc_real_monet = self.m_disc(real_monet, training=True)
            disc_real_photo = self.p_disc(real_photo, training=True)
            # discriminator used to check, inputing fake images
            disc_fake_monet = self.m_disc(fake_monet, training=True)
            disc_fake_photo = self.p_disc(fake_photo, training=True)
            
            # evaluates discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)

            monet_expert_loss=self.expert_loss_fn(self.monet_expert(fake_monet, training=False),tf.ones_like(fake_monet))
            photo_expert_loss=self.expert_loss_fn(self.monet_expert(fake_photo, training=False),tf.zeros_like(fake_photo))


                # back to monet 
            cycled_monet = self.m_gen(fake_photo, training=True)
            cycled_photo = self.p_gen(fake_monet, training=True)


                # evaluates generator loss
            monet_gen_loss = self.gen_loss_fn(disc_fake_monet)
            photo_gen_loss = self.gen_loss_fn(disc_fake_photo)


                # evaluates total cycle consistency loss
            cycle_loss_mpm = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle)
            cycle_loss_pmp = self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle)
            total_cycle_loss = cycle_loss_mpm + cycle_loss_pmp

                # evaluates total generator loss
            total_monet_gen_loss = monet_gen_loss + total_cycle_loss + self.identity_loss_fn(real_monet, same_monet, self.lambda_id) + monet_expert_loss
            total_photo_gen_loss = photo_gen_loss + total_cycle_loss + self.identity_loss_fn(real_photo, same_photo, self.lambda_id) + photo_expert_loss
#             total_monet_gen_loss*=tf.dtypes.cast(tf.random.uniform(shape=[1])[0]<0.50, tf.float32)             
#             total_photo_gen_loss*=tf.dtypes.cast(tf.random.uniform(shape=[1])[0]<0.50, tf.float32)

            
        # Calculate the gradients for generator and discriminator
        monet_discriminator_gradients = tape.gradient(monet_disc_loss,
                                                      self.m_disc.trainable_variables)
        photo_discriminator_gradients = tape.gradient(photo_disc_loss,
                                                      self.p_disc.trainable_variables)
        monet_generator_gradients = tape.gradient(total_monet_gen_loss,
                                                      self.m_gen.trainable_variables)
        photo_generator_gradients = tape.gradient(total_photo_gen_loss,
                                                      self.p_gen.trainable_variables)

        self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients,
                                                     self.m_gen.trainable_variables))

        self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients,
                                                     self.p_gen.trainable_variables))


        # Apply the gradients to the optimizer

        self.m_disc_optimizer.apply_gradients(zip(monet_discriminator_gradients,
                                                  self.m_disc.trainable_variables))

        self.p_disc_optimizer.apply_gradients(zip(photo_discriminator_gradients,
                                                  self.p_disc.trainable_variables))
        
        return {
            "monet_disc_loss": monet_disc_loss,
            "photo_disc_loss": photo_disc_loss,
            "disc_real_monet": disc_real_monet,
            "disc_fake_monet": disc_fake_monet,            
            "disc_real_photo": disc_real_photo,            
            "disc_fake_photo": disc_fake_photo,            
            "monet_expert_loss": monet_expert_loss,            
            "photo_expert_loss": photo_expert_loss,            

        }
    

In [None]:
class Pretrain_CycleGan(keras.Model):
    def __init__(
        self,
        monet_generator,
        photo_generator,
        monet_discriminator,
        photo_discriminator,
        monet_expert,
        lambda_cycle=25,
    ):
        super(Pretrain_CycleGan, self).__init__()
        self.m_gen = monet_generator
        self.p_gen = photo_generator
        self.m_disc = monet_discriminator
        self.p_disc = photo_discriminator
        self.m_expert = monet_expert
        self.lambda_cycle = lambda_cycle
        
    def compile(
        self,
        m_gen_optimizer,
        p_gen_optimizer,
        m_disc_optimizer,
        p_disc_optimizer,
        m_expert_optimizer,
        gen_loss_fn,
        disc_loss_fn,
        expert_loss_fn,
    ):
        super(Pretrain_CycleGan, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.m_expert_optimizer = m_expert_optimizer
        self.expert_loss_fn = expert_loss
        

        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn

        
    def train_step(self, batch_data):
        real_image, transformed_image, label_is_monet = batch_data
        
        with tf.GradientTape(persistent=True) as tape:

            gen_monet = self.m_gen(real_image, training=True)
            gen_photo = self.p_gen(real_image, training=True)

            disc_real_monet = self.m_disc(real_image, training=True)
            disc_real_photo = self.p_disc(real_image, training=True)
#             expert_monet = self.m_expert(real_image, training=True)
#             expert_monet = tf.reduce_mean(self.m_expert(real_image, training=True),axis=[1,2,3])
            monet_exp_loss=self.expert_loss_fn(self.m_expert(real_image, training=True),label_is_monet)
         

            # evaluates generator loss
            monet_gen_loss = self.gen_loss_fn(gen_monet,transformed_image,1)
            photo_gen_loss = self.gen_loss_fn(gen_photo,transformed_image,1)

            # evaluates discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, tf.ones_like(disc_real_monet)*label_is_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, tf.ones_like(disc_real_photo)*(1-label_is_monet))
#             monet_exp_loss = self.disc_loss_fn(expert_monet, tf.ones_like(expert_monet)*label_is_monet)


        # Calculate the gradients for generator and discriminator
        monet_generator_gradients = tape.gradient(monet_gen_loss,
                                                  self.m_gen.trainable_variables)
        photo_generator_gradients = tape.gradient(photo_gen_loss,
                                                  self.p_gen.trainable_variables)

        monet_discriminator_gradients = tape.gradient(monet_disc_loss,
                                                      self.m_disc.trainable_variables)
        photo_discriminator_gradients = tape.gradient(photo_disc_loss,
                                                      self.p_disc.trainable_variables)

        monet_expert_gradients = tape.gradient(monet_exp_loss,
                                                      self.m_expert.trainable_variables)

        # Apply the gradients to the optimizer
        self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients,
                                                 self.m_gen.trainable_variables))

        self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients,
                                                 self.p_gen.trainable_variables))

        self.m_disc_optimizer.apply_gradients(zip(monet_discriminator_gradients,
                                                  self.m_disc.trainable_variables))

        self.p_disc_optimizer.apply_gradients(zip(photo_discriminator_gradients,
                                                  self.p_disc.trainable_variables))
        
        self.m_expert_optimizer.apply_gradients(zip(monet_expert_gradients,
                                                  self.m_expert.trainable_variables))

        
        return {
                "monet_gen_loss": monet_gen_loss,
                "photo_gen_loss": photo_gen_loss,
                "monet_disc_loss": monet_disc_loss,
                "photo_disc_loss": photo_disc_loss,
                "monet_exp_loss" : monet_exp_loss
        }


    def test_step(self, batch_data):
            real_image, transformed_image, label_is_monet = batch_data
        
            gen_monet = self.m_gen(real_image, training=True)
            gen_photo = self.p_gen(real_image, training=True)

            disc_real_monet = self.m_disc(real_image, training=True)
            disc_real_photo = self.p_disc(real_image, training=True)
#             expert_monet = self.m_expert(real_image, training=True)
#             disc_real_monet = tf.reduce_mean(self.m_disc(real_image, training=True),axis=[1,2,3])
#             disc_real_photo = tf.reduce_mean(self.p_disc(real_image, training=True),axis=[1,2,3])
            monet_exp_loss=self.expert_loss_fn(self.m_expert(real_image, training=True),label_is_monet)

            # evaluates generator loss
            monet_gen_loss = self.gen_loss_fn(gen_monet,transformed_image,1)
            photo_gen_loss = self.gen_loss_fn(gen_photo,transformed_image,1)

            # evaluates discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, label_is_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, 1-label_is_monet)
  

            return {
                "monet_gen_loss": monet_gen_loss,
                "photo_gen_loss": photo_gen_loss,
                "monet_disc_loss": monet_disc_loss,
                "photo_disc_loss": photo_disc_loss,
                "monet_exp_loss" : monet_exp_loss
            }

In [None]:
with strategy.scope():
    def discriminator_loss(real, generated):
        real_loss = tf.square(tf.ones_like(real) - real)

        generated_loss = tf.square(generated)

        total_disc_loss = real_loss + generated_loss

        return total_disc_loss * 0.5


In [None]:
with strategy.scope():
    def generator_loss(generated):
        return tf.square(tf.ones_like(generated) - generated)

In [None]:
with strategy.scope():
    def calc_cycle_loss(real_image, cycled_image, LAMBDA):
#         loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
#         loss1 = tf.reduce_mean(tf.keras.losses.Huber(0.5,reduction=tf.keras.losses.Reduction.NONE)(real_image, cycled_image))
        loss1 = tf.reduce_mean(tf.keras.losses.Huber(0.5,reduction=tf.keras.losses.Reduction.NONE)(inception_model(real_image), inception_model(cycled_image)))

        return LAMBDA * loss1

In [None]:
with strategy.scope():
    def identity_loss(real_image, translated_image, LAMBDA):
#         loss = tf.reduce_mean(tf.keras.losses.Huber(0.5,reduction=tf.keras.losses.Reduction.NONE)(real_image, translated_image))
#         loss = tf.reduce_mean(tf.abs(real_image - translated_image))
        loss = tf.reduce_mean(tf.keras.losses.Huber(0.5,reduction=tf.keras.losses.Reduction.NONE)(tf.nn.avg_pool2d(real_image, ksize=32, strides=16, padding="VALID"), tf.nn.avg_pool2d(translated_image, ksize=32, strides=16, padding="VALID")))
        return LAMBDA *  loss

In [None]:
with strategy.scope():
    def pretrain_identity_loss(real_image, translated_image, LAMBDA):
        loss = tf.reduce_mean(tf.abs(real_image - translated_image))
        return LAMBDA *  loss

In [None]:
with strategy.scope():
    def pretrain_disc_loss(image,label):
        return tf.square(label - image)

    def expert_loss(generated,label):
        return tf.keras.losses.BinaryCrossentropy(from_logits=True,reduction=tf.keras.losses.Reduction.SUM)(tf.reduce_mean(label,axis=[1,2,3]),tf.reduce_mean(generated,axis=[1,2,3]))

In [None]:
with strategy.scope():

    monet_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    monet_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    monet_expert_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    
    pt_monet_generator = Generator() # transforms photos to Monet-esque paintings
    pt_photo_generator = Generator() # transforms Monet paintings to be more like photos

    pt_monet_discriminator = Discriminator() # differentiates real Monet paintings and generated Monet paintings
    pt_photo_discriminator = Discriminator() # differentiates real photos and generated photos
    pt_monet_expert = Discriminator() # differentiates Monet paintings from photos
    monet_expert = Discriminator() # differentiates Monet paintings from photos

In [None]:
with strategy.scope():
    cycle_gan_model = CycleGan(
        monet_generator, photo_generator, monet_discriminator, photo_discriminator,monet_expert
    )

    cycle_gan_model.compile(
        m_gen_optimizer = monet_generator_optimizer,
        p_gen_optimizer = photo_generator_optimizer,
        m_disc_optimizer = monet_discriminator_optimizer,
        p_disc_optimizer = photo_discriminator_optimizer,
        gen_loss_fn = generator_loss,
        disc_loss_fn = discriminator_loss,
        cycle_loss_fn = calc_cycle_loss,
        identity_loss_fn = identity_loss,
        expert_loss_fn = expert_loss,
    )

In [None]:
with strategy.scope():
    pretrain_cycle_gan_model = Pretrain_CycleGan(pt_monet_generator, pt_photo_generator, pt_monet_discriminator, pt_photo_discriminator,pt_monet_expert)

    pretrain_cycle_gan_model.compile(
        m_gen_optimizer = monet_generator_optimizer,
        p_gen_optimizer = photo_generator_optimizer,
        m_disc_optimizer = monet_discriminator_optimizer,
        p_disc_optimizer = photo_discriminator_optimizer,
        m_expert_optimizer = monet_expert_optimizer,
        gen_loss_fn = pretrain_identity_loss,
        disc_loss_fn = pretrain_disc_loss,
        expert_loss_fn = expert_loss,
    )

In [None]:
hist=pretrain_cycle_gan_model.fit(pretrain_ds,validation_data=val_pretrain_ds, epochs=33).history

In [None]:
pretrain_cycle_gan_model.save_weights('premonet.h5')

In [None]:
cycle_gan_model.built = True
cycle_gan_model.load_weights('premonet.h5')

In [None]:
! rm premonet.h5
# callbacks = [keras.callbacks.ModelCheckpoint(filepath='monet.h5',save_weights_only=True,save_best_only=True, monitor='val_total_cycle_loss', verbose=1)]
disc_m_loss1=[]
disc_p_loss1=[]

In [None]:
%%time
fids=[]
best_fid=999999999
for epoch in range(1,48):

    print("Epoch = ",epoch)
    hist=cycle_gan_model.fit(gan_ds,steps_per_epoch=1500, epochs=1).history
#     disc_m_loss1.append(hist["monet_disc_loss1"][0][0][0])
#     disc_p_loss1.append(hist["photo_disc_loss1"][0][0][0])
    if epoch>35:
        cur_fid=FID(fast_photo_ds,monet_generator)
        fids.append(cur_fid)
        print("After epoch #{} FID = {}\n".format(epoch,cur_fid))
    if epoch>42:
            best_fid=cur_fid
            monet_generator.save('monet_generator_'+str(epoch)+'.h5')


# hist=cycle_gan_model.fit(gan_ds,steps_per_epoch=30, epochs=EPOCHS).history
# hist=cycle_gan_model.fit(gan_ds,steps_per_epoch=30,validation_data=([1]), epochs=3).history


In [None]:
# plt.plot(disc_m_loss1, label='monet_disc_loss1')
# plt.plot(disc_p_loss1, label='photo_disc_loss1')
plt.plot(np.array(fids), label='FID')

plt.legend()
plt.show()

In [None]:
# !conda install -y gdown 
# import gdown 
# url = 'https://drive.google.com/uc?export=download&id=18UWaVxb_UHDMq4KzJqHqGSPizXXy4H7' 
# output = 'photo.jpg'
# gdown.download(url, output)

In [None]:
# cycle_gan_model.built = True
# cycle_gan_model.load_weights('monet.h5')
# with strategy.scope():
#     monet_generator = tf.keras.models.load_model('monet_generator.h5')
#     cycle_gan_model.save_weights('monet.h5')

In [None]:
_, ax = plt.subplots(5, 3, figsize=(32, 32))
for i, img in enumerate(photo_ds.batch(1).take(5)):
    prediction = monet_generator(img, training=False)
    cycledphoto = photo_generator(prediction, training=False)
    prediction = (prediction * 127.5 + 127.5)[0].numpy().astype(np.uint8)
    cycledphoto = (cycledphoto * 127.5 + 127.5)[0].numpy().astype(np.uint8)

    img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

    ax[i, 0].imshow(img)
    ax[i, 1].imshow(prediction)
    ax[i, 2].imshow(cycledphoto)
    ax[i, 0].set_title("Input Photo")
    ax[i, 1].set_title("Monet-esque")
    ax[i, 2].set_title("Cycled Photo")
    ax[i, 0].axis("off")
    ax[i, 1].axis("off")
    ax[i, 2].axis("off")

plt.show()

In [None]:
_, ax = plt.subplots(5, 3, figsize=(32, 32))
for i, img in enumerate(monet_ds.batch(1).take(5)):
    prediction = photo_generator(img, training=False)
    cycledphoto = monet_generator(prediction, training=False)
    prediction = (prediction * 127.5 + 127.5)[0].numpy().astype(np.uint8)
    cycledphoto = (cycledphoto * 127.5 + 127.5)[0].numpy().astype(np.uint8)

    img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

    ax[i, 0].imshow(img)
    ax[i, 1].imshow(prediction)
    ax[i, 2].imshow(cycledphoto)
    ax[i, 0].set_title("Input Monet")
    ax[i, 1].set_title("Generated Photo")
    ax[i, 2].set_title("Cycled Monet")
    ax[i, 0].axis("off")
    ax[i, 1].axis("off")
    ax[i, 2].axis("off")

plt.show()

In [None]:
ds_iter = iter(photo_ds.batch(1))
for n_sample in range(8):
        example_sample = next(ds_iter)
        generated_sample = monet_generator(example_sample)
        
        f = plt.figure(figsize=(32, 32))
        
        plt.subplot(121)
        plt.title('Input image')
        plt.imshow(example_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        
        plt.subplot(122)
        plt.title('Generated image')
        plt.imshow(generated_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        plt.show()


In [None]:
ds_iter = iter(monet_ds.batch(1))
for n_sample in range(8):

        example_sample = next(ds_iter)
        generated_sample = photo_generator(example_sample)
        
        f = plt.figure(figsize=(24, 24))
        
        plt.subplot(121)
        plt.title('Input image')
        plt.imshow(example_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        
        plt.subplot(122)
        plt.title('Generated image')
        plt.imshow(generated_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        plt.show()

In [None]:
import PIL
! mkdir ../images

In [None]:
%%time
i = 1
for img in fast_photo_ds:
    prediction = monet_generator(img, training=False).numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    for pred in prediction:
        im = PIL.Image.fromarray(pred)
        im.save("../images/" + str(i) + ".jpg")
        i += 1
    
    


In [None]:
import shutil
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")