# ***CYCLEGAN***

### CycleGAN uses a cycle consistency loss to enable training without the need for paired data. In other words, it can translate from one domain to another without a one-to-one mapping between the source and target domain.  This opens up the possibility to do a lot of interesting tasks like photo-enhancement, image colorization, style transfer, etc. All you need is the source and the target dataset (which is simply a directory of images).

### **I have taken help from the notebook provided by [Keras](https://keras.io/examples/generative/cyclegan/) and [Tensorflow](https://www.tensorflow.org/tutorials/generative/cyclegan) . Many many thaks to [Amy Jang](http://www.kaggle.com/amyjang) for her notebook [Monet CycleGAN Tutorial](https://www.kaggle.com/amyjang/monet-cyclegan-tutorial). As I am first time learner of GAN it was very helpful.**

In [None]:
!pip install tensorflow==2.2.0

In [None]:
!pip install  numba  
!pip install cudatoolkit

In [None]:
import io
from numba import jit, cuda
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
from kaggle_datasets import KaggleDatasets
from tensorflow import keras

import tqdm

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

AUTOTUNE = tf.data.experimental.AUTOTUNE
    
print(tf.__version__)

In [None]:
LAMBDA=10
BATCH_SIZE =  4
EPOCHS_NUM = 30
image_size = (256, 256)
buffer_size= 256

In [None]:
DATA_PATH = KaggleDatasets().get_gcs_path()

In [None]:
import re
MONET_FILENAMES = tf.io.gfile.glob(str(DATA_PATH + '/monet_tfrec/*.tfrec'))
print('Monet TFRecord Files:', len(MONET_FILENAMES))

PHOTO_FILENAMES = tf.io.gfile.glob(str(DATA_PATH + '/photo_tfrec/*.tfrec'))
print('Photo TFRecord Files:', len(PHOTO_FILENAMES))

def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

n_monet_samples = count_data_items(MONET_FILENAMES)
n_photo_samples = count_data_items(PHOTO_FILENAMES)



print(f'Monet TFRecord files: {len(MONET_FILENAMES)}')
print(f'Monet image files: {n_monet_samples}')
print(f'Photo TFRecord files: {len(PHOTO_FILENAMES)}')
print(f'Photo image files: {n_photo_samples}')
print(f"Batch_size: {BATCH_SIZE}")
print(f"Epochs number: {EPOCHS_NUM}")

In [None]:

def decode_image(image):
    #image = tf.image.resize(image, [286, 286,3],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*image_size, 3])
    image = tf.image.random_flip_left_right(image)
    return image

def read_tfrecord(example):
    tfrecord_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image



In [None]:
def data_augmentation(image):
    photo_rotate = tf.random.uniform([], 0, 1.0, dtype = tf.float32)
    photo_spatial = tf.random.uniform([], 0, 1.0, dtype= tf.float32)
    cropping = tf.random.uniform([], 0, 1.0, dtype= tf.float32)
    
    
    if cropping >0.5:
        image = tf.image.resize(image, [286, 286])
        image = tf.image.random_crop(image, size =[256, 256,3])
        
        if cropping >0.9:
            image = tf.image.resize(image, [286, 286])
            image = tf.image.random_crop(image, size=[256, 256, 3])
            
            
    if photo_rotate > 0.9:
        image = tf.image.rot90(image, k=3)
        
    elif photo_rotate > 0.7 and photo_rotate <=0.9:
        image = tf.image.rot90(image, k=2)
        
    elif photo_rotate > 0.5 and photo_rotate <=0.7:
        image = tf.image.rot90(image, k=1)
        
        
    if photo_spatial >0.6:
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_flip_up_down(image)
        
        
        if photo_spatial >0.9:
            image = tf.image.transpose(image)
            
    return image

In [None]:
def load_dataset(filename, aug):
    dataset = tf.data.TFRecordDataset(filename)
    dataset = dataset.map(read_tfrecord, num_parallel_calls= AUTOTUNE)
    if aug:
        dataset = dataset.map(aug, num_parallel_calls=  AUTOTUNE)
    return dataset

In [None]:
def get_dataset(filename, aug = True):
    dataset = load_dataset(filename, aug)
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset

In [None]:
monet_image = get_dataset(MONET_FILENAMES,  aug = data_augmentation)
photos = get_dataset(PHOTO_FILENAMES,  aug = data_augmentation)

In [None]:
example_monet = iter(monet_image)
example_photo = iter(photos)

_ = plt.figure(figsize=  (15, 30))

for i in range(10):
    plt.subplot(10, 2, i*2+1)
    plt.imshow(tf.cast(next(example_monet)[0] * 127.5 + 127.5, tf.uint8))
    plt.title("Monet image")

    plt.subplot(10,2,i*2 +2)
    plt.imshow(tf.cast(next(example_photo)[0] * 127.5 + 127.5, tf.uint8))
    plt.title("Photo image")



In [None]:
OUTPUT_CHANNELS = 3
# Weights initializer for the layers.
kernel_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
# Gamma initializer for instance normalization.
gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
input_img_size = (256, 256, 3)

In [None]:
class ReflectionPadding2D(layers.Layer):
    def __init__(self, padding=(1,1), **kwargs):
        self.padding = tuple(padding)
        super(ReflectionPadding2D, self).__init__(**kwargs)
        
    def call(self, input_tensor, mask= None):
        padding_width, padding_height= self.padding
        
        padding_tensor= [[0,0],[padding_height, padding_height], 
                         [padding_width, padding_width], [0,0],]
        
        return tf.pad(input_tensor, padding_tensor, mode= 'REFLECT')
    
    
    
    
    

def residual_block(x, activation, kernel_initializer = kernel_init,
                  kernel_size =(3,3), strides=(1,1), padding="valid",
                  gamma_initializer= gamma_init, use_bias= False):
    
    dim = x.shape[-1]
    input_tensor = x
    
    x= ReflectionPadding2D()(input_tensor)
    x= layers.Conv2D(dim, kernel_size, strides=strides,
                    kernel_initializer = kernel_initializer, padding= padding, use_bias=use_bias)(x)
    
    x= tfa.layers.InstanceNormalization(gamma_initializer = gamma_initializer)(x)
    x= activation(x)
    x = ReflectionPadding2D()(x)
    x = layers.Conv2D(dim, kernel_size= kernel_size, strides= strides, kernel_initializer= kernel_initializer,
                     padding= padding, use_bias = use_bias)(x)
    
    x= tfa.layers.InstanceNormalization(gamma_initializer = gamma_initializer)(x)
    x= layers.add([input_tensor, x])
    
    return x


In [None]:
def downsample(x, filters, activation,kernel_initializer= kernel_init, kernel_size= (3,3),
              strides= (2,2), padding= 'same', gamma_initializer = gamma_init, use_bias= False, apply_instancenorm=True):
    x= keras.Sequential()
    x.add(layers.Conv2D(filters, kernel_size, strides= strides, kernel_initializer= kernel_initializer, 
                        padding= padding, use_bias= use_bias))
    if apply_instancenorm:
        x.add(tfa.layers.InstanceNormalization(gamma_initializer= gamma_initializer))
        
    x.add(tf.keras.layers.BatchNormalization())
    
    x.add(layers.LeakyReLU())
    
    return x

In [None]:
def upsample(x, filters, activation, kernel_initializer= kernel_init, kernel_size =(3,3),
            strides= (2,2), padding="same", gamma_initializer= gamma_init, use_bias= False, apply_dropout=True):
    x= keras.Sequential()
    x.add(layers.Conv2DTranspose(filters, kernel_size, strides= strides, padding= padding,
                                    kernel_initializer= kernel_initializer, use_bias= use_bias))
    
    x.add(tfa.layers.InstanceNormalization(gamma_initializer = gamma_init))
    x.add(tf.keras.layers.BatchNormalization())
    if apply_dropout:
        x.add(layers.Dropout(0.5))
        
    
        
    x.add(layers.ReLU())
    
    return x
    
    

def Generator(filters= 64, num_downsampling_blocks=4, num_residual_blocks=9,  num_upsample_blocks=2,
             gamma_initializer= gamma_init, name =None):
    image_input= layers.Input(shape=[256,256,3], name=name)
    
    
    x= ReflectionPadding2D(padding=(3,3))(image_input)
    x = layers.Conv2D(filters, (7,7), kernel_initializer= kernel_init, use_bias= False)(x)
    
    x = tfa.layers.InstanceNormalization(gamma_initializer = gamma_initializer)(x)
    
    x =  layers.Activation("relu")(x)
    
    for _ in range(num_downsampling_blocks):
        filters *=2
        x=downsample(x, filters= filters, activation= layers.Activation("relu"))
        
        
    for _ in range(num_residual_blocks):
        
        
        x=residual_block(x, activation= layers.Activation("relu"))
        
        
    for _ in range(num_upsample_blocks):
        filters //=2
        
        
        x=upsample(x, filters, activation= layers.Activation("relu"))
        
    x = ReflectionPadding2D(padding=(3, 3))(x)
    x = layers.Conv2D(3, (7, 7), padding="valid")(x)
    x = layers.Activation("tanh")(x)
    
    model= keras.models.Model(inputs= image_input, outputs=x, name=name)
    
    return model
    
    

In [None]:
def Generator(kernel_initializer= kernel_init, num_residual_blocks=9):
    inputs = layers.Input(shape=[256,256,3])

    # bs = batch size
    down_stack = [
        downsample(64, 4, apply_instancenorm=False, activation= layers.LeakyReLU(0.2)), # (bs, 128, 128, 64)
        downsample(128, 4, activation= layers.LeakyReLU(0.2)), # (bs, 64, 64, 128)
        downsample(256, 4, activation= layers.LeakyReLU(0.2)), # (bs, 32, 32, 256)
        downsample(512, 4, activation= layers.LeakyReLU(0.2)), # (bs, 16, 16, 512)
        downsample(512, 4, activation= layers.LeakyReLU(0.2)), # (bs, 8, 8, 512)
        downsample(512, 4, activation= layers.LeakyReLU(0.2)), # (bs, 4, 4, 512)
        downsample(512, 4, activation= layers.LeakyReLU(0.2)), # (bs, 2, 2, 512)
        downsample(1024, 4, activation= layers.LeakyReLU(0.2))
    
        
    ]

    up_stack = [
        upsample(1024, 4, apply_dropout=True, activation= layers.Activation("relu")),
        upsample(512, 4, apply_dropout=True, activation= layers.Activation("relu")), # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True, activation= layers.Activation("relu")), # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True, activation= layers.Activation("relu")), # (bs, 8, 8, 1024)
        upsample(512, 4, activation= layers.Activation("relu")), # (bs, 16, 16, 1024)
        upsample(256, 4, activation= layers.Activation("relu")), # (bs, 32, 32, 512)
        upsample(128, 4, activation= layers.Activation("relu")), # (bs, 64, 64, 256)
        upsample(64, 4, activation= layers.Activation("relu")),
        
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                  strides=2,
                                  padding='same',
                                  kernel_initializer=kernel_initializer,
                                  activation='tanh') # (bs, 256, 256, 3)

    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])
    

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip])
        
    x = ReflectionPadding2D(padding=(3, 3))(x)
    x = layers.Conv2D(128, (7, 7), padding="valid")(x)
    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
with strategy.scope():
    monet_generator= Generator()
    photo_generator= Generator()

In [None]:
from keras.utils.vis_utils import plot_model
plot_model(monet_generator, to_file='Generator.png', show_shapes=True, show_layer_names=True)

In [None]:
def discriminator(kernel_initializer= kernel_init, gamma_initializer= gamma_init):
    
    input_image = layers.Input(shape=[256, 256, 3], name='input_image')

    x = input_image

    down1 = downsample(64, 4, apply_instancenorm = True , activation = layers.LeakyReLU(0.2))(x) # (bs, 128, 128, 64)
    down2 = downsample(128, 4, activation = layers.LeakyReLU(0.2))(down1) # (bs, 64, 64, 128)
    down3 = downsample(256, 4, activation = layers.LeakyReLU(0.2))(down2)
    #down4 = downsample(512, 4, activation = layers.LeakyReLU(0.2))(down3)
    #down5 = downsample(120, 4, activation = layers.LeakyReLU(0.2))(down4)# (bs, 32, 32, 256)

    zero_pad1 = layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
    conv = layers.Conv2D(512, (4,4), strides=(1,1),
                         kernel_initializer=kernel_initializer,
                         use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

    norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(conv)
    batchnorm1 = tf.keras.layers.BatchNormalization()(norm1)

    leaky_relu = layers.LeakyReLU()(batchnorm1)

    zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    last = layers.Conv2D(1, 4, strides=(1,1),padding="same",
                         kernel_initializer=kernel_initializer)(zero_pad2) # (bs, 30, 30, 1)

    return tf.keras.Model(inputs=input_image, outputs=last) 

In [None]:
with strategy.scope():
    monet_discriminator= discriminator()
    photo_discriminator = discriminator()

In [None]:
plot_model(photo_discriminator, to_file='Discriminator.png', show_shapes=True, show_layer_names=True)

to_monet = monet_generator(example_photo)

plt.subplot(1, 2, 1)
plt.title("Original Photo")
plt.imshow(example_photo[0] * 0.5 + 0.5)

plt.subplot(1, 2, 2)
plt.title("Monet-esque Photo")
plt.imshow(to_monet[0] * 0.5 + 0.5)
plt.show()

In [None]:
random_img = monet_generator(next(example_monet) *127.5 + 127.5)
random_img = random_img[0]
plt.imshow(random_img)
plt.axis('off')



In [None]:
random_img = photo_generator(next(example_photo) *127.5 + 127.5)
random_img = random_img[0]
plt.imshow(random_img)
plt.axis('off')

In [None]:
monet_example = monet_discriminator(next(example_monet))
photo_example = photo_discriminator(next(example_photo))
monet_example = monet_example[0]
photo_example = photo_example[0]

_ = plt.figure(figsize= (10,10))
plt.subplot(1,2,1)
plt.imshow(tf.squeeze(monet_example, 2))
plt.title("Monet image")

plt.subplot(1,2,2)
plt.imshow(tf.squeeze(photo_example, 2))
plt.title("Photo image")

In [None]:





class CycleGAN(keras.Model):
    def __init__(self, monet_generator, photo_generator, monet_discriminator, photo_discriminator,
                lambda_cycle=10.0, lambda_identity = 0.5):
        
        
        super(CycleGAN, self).__init__()
        self.monet_g = monet_generator
        self.photo_g = photo_generator
        
        self.monet_d = monet_discriminator
        self.photo_d = photo_discriminator
        
        self.lambda_cycle = lambda_cycle
        self.lambda_identity= lambda_identity
        
        
        
        
        
        
    def compile(self, monet_g_optimizer, photo_g_optimizer, monet_d_optimizer,
               photo_d_optimizer, gen_loss_fn, dis_loss_fn, cycle_loss_fn, identity_loss_fn):
        
        
        super(CycleGAN, self).compile()
        
        
        self.monet_g_optimizer = monet_g_optimizer
        self.photo_g_optimizer = photo_g_optimizer
        
        self.monet_d_optimizer = monet_d_optimizer
        self.photo_d_optimizer = photo_d_optimizer
        
        self.gen_loss_fn = gen_loss_fn
        self.dis_loss_fn = dis_loss_fn
        
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
      
    @tf.autograph.experimental.do_not_convert
    def train_step(self, batch_data):
        real_monet, real_photo= batch_data
        
        with tf.GradientTape(persistent= True) as tape:
            fake_photo = self.monet_g(real_monet, training= True)
            fake_monet = self.photo_g(real_photo, training = True)
            
            
            cycled_monet = self.photo_g(fake_photo, training =True)
            cycled_photo = self.monet_g(fake_monet, training= True)
            
            
            same_monet= self.photo_g(real_monet, training= True)
            same_photo = self.monet_g(real_photo, training= True)
            
            
            dis_real_monet = self.monet_d(real_monet, training = True)
            dis_fake_monet = self.monet_d(fake_monet, training= True)
            
            dis_real_photo = self.photo_d(real_photo, training = True)
            dis_fake_photo = self.photo_d(fake_photo, training = True)
            
            gen_monet_loss= self.gen_loss_fn(dis_fake_photo)
            gen_photo_loss= self.gen_loss_fn(dis_fake_monet)
            
            cycle_loss_monet = self.cycle_loss_fn(real_photo, cycled_photo, LAMBDA)* self.lambda_cycle
            cycle_loss_photo = self.cycle_loss_fn(real_monet, cycled_monet,LAMBDA)* self.lambda_cycle
            
            total_cycle_loss= cycle_loss_monet + cycle_loss_photo
            
            id_loss_monet = (self.identity_loss_fn(real_monet, same_monet,LAMBDA) * self.lambda_cycle * self.lambda_identity)
            id_loss_photo = (self.identity_loss_fn(real_photo, same_photo, LAMBDA) * self.lambda_cycle * self.lambda_identity)
            
            #total loss of generator
            total_loss_monet = gen_monet_loss+ cycle_loss_monet + id_loss_monet 
            total_loss_photo = gen_photo_loss + cycle_loss_photo + id_loss_photo
            
            
            #discrminator_loss
            dis_monet_loss = self.dis_loss_fn(dis_real_monet, dis_fake_monet)
            dis_photo_loss = self.dis_loss_fn(dis_real_photo, dis_fake_photo)
        
        #Generator gradient
        grad_monet_g = tape.gradient(total_loss_monet, self.monet_g.trainable_variables)
        grad_photo_g = tape.gradient(total_loss_photo, self.photo_g.trainable_variables)
        
        
        #Discriminator_gradient
        grad_monet_d = tape.gradient(dis_monet_loss, self.monet_d.trainable_variables)
        grad_photo_d = tape.gradient(dis_photo_loss, self.photo_d.trainable_variables)
        
        
        self.monet_g_optimizer.apply_gradients(zip(grad_monet_g, self.monet_g.trainable_variables))
        self.photo_g_optimizer.apply_gradients(zip(grad_photo_g, self.photo_g.trainable_variables))
        
        
        
        
        self.monet_d_optimizer.apply_gradients(zip(grad_monet_d, self.monet_d.trainable_variables))
        self.photo_d_optimizer.apply_gradients(zip(grad_photo_d, self.photo_d.trainable_variables))
        
        
        
        
        
        return {
            "Monet_generator_loss": total_loss_monet,
            "Photo_generator_loss": total_loss_photo,
            "Monet_discriminator_loss": dis_monet_loss,
            "Photo_discriminator_loss": dis_photo_loss
        }
    
    
   
            
            
        
        
        

In [None]:
LAMBDA=10

Cycle consistency means the result should be close to the original input. For example, if one translates a sentence from English to French, and then translates it back from French to English, then the resulting sentence should be the same as the  original sentence.

In cycle consistency loss, 

* Image $X$ is passed via generator $G$ that yields generated image $\hat{Y}$.
* Generated image $\hat{Y}$ is passed via generator $F$ that yields cycled image $\hat{X}$.
* Mean absolute error is calculated between $X$ and $\hat{X}$.

$$forward\ cycle\ consistency\ loss: X -> G(X) -> F(G(X)) \sim \hat{X}$$

$$backward\ cycle\ consistency\ loss: Y -> F(Y) -> G(F(Y)) \sim \hat{Y}$$


![Cycle loss](images/cycle_loss.png)

In [None]:
with strategy.scope():
    def discriminator_loss(real, generated):
        real_loss = tf.keras.losses.BinaryCrossentropy(from_logits = True, 
                                                       reduction= tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)
        generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits= True,
                                                           reduction= tf.keras.losses.Reduction.NONE)(tf.ones_like(generated), generated)
        
        total_disc_loss = real_loss + generated_loss
        
        return total_disc_loss * 0.5

In [None]:


with strategy.scope():
    def generator_loss(generated):
        return tf.keras.losses.BinaryCrossentropy(from_logits=True, 
                                                  reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(generated), generated)



In [None]:
with strategy.scope():
    def calc_cycle_loss(real_image, cycled_image, LAMBDA):
        loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

        return LAMBDA * loss1

In [None]:


with strategy.scope():
    def identity_loss(real_image, same_image, LAMBDA):
        loss = tf.reduce_mean(tf.abs(real_image - same_image))
        return LAMBDA * 0.5 * loss



In [None]:
with strategy.scope():
    monet_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    monet_d_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_d_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:


with strategy.scope():
    cycle_gan_model = CycleGAN(
        monet_generator, photo_generator, monet_discriminator, photo_discriminator)

    cycle_gan_model.compile(
        monet_g_optimizer = monet_g_optimizer,
        photo_g_optimizer = photo_g_optimizer,
        monet_d_optimizer = monet_d_optimizer,
        photo_d_optimizer = photo_d_optimizer,
        gen_loss_fn = generator_loss,
        dis_loss_fn = discriminator_loss,
        cycle_loss_fn = calc_cycle_loss,
        identity_loss_fn = identity_loss
    )



In [None]:
class GANMonitor(keras.callbacks.Callback):
    """A callback to generate and save images after each epoch"""

    def __init__(self, num_img=4):
        self.num_img = num_img

    def on_epoch_end(self, epoch, logs=None):
        _, ax = plt.subplots(4, 2, figsize=(12, 12))
        for i, img in enumerate(monet_image.take(self.num_img)):
            prediction = self.model.monet_generator(img)[0].numpy()
            prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
            img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

            ax[i, 0].imshow(img)
            ax[i, 1].imshow(prediction)
            ax[i, 0].set_title("Input image")
            ax[i, 1].set_title("Translated image")
            ax[i, 0].axis("off")
            ax[i, 1].axis("off")

            prediction = keras.preprocessing.image.array_to_img(prediction)
            prediction.save(
                "generated_img_{i}_{epoch}.png".format(i=i, epoch=epoch + 1)
            )
        plt.show()
        plt.close()


In [None]:
plotter = GANMonitor()
checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.{epoch:03d}"
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath
)

In [None]:
cycle_gan_model.fit(
    tf.data.Dataset.zip((monet_image, photos)),
    epochs=25,verbose=1
)

In [None]:
_, ax = plt.subplots(5, 2, figsize=(12, 12))
for i, img in enumerate(photos.take(5)):
    prediction = monet_generator(img, training=False)[0].numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

    ax[i, 0].imshow(img)
    ax[i, 1].imshow(prediction)
    ax[i, 0].set_title("Input Photo")
    ax[i, 1].set_title("Monet-esque")
    ax[i, 0].axis("off")
    ax[i, 1].axis("off")
plt.show()

In [None]:
import PIL
! mkdir ../images

In [None]:

i = 1
for img in photos:
    prediction = monet_generator(img, training=False)[0].numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    im = PIL.Image.fromarray(prediction)
    im.save("../images/" + str(i) + ".jpg")
    i += 1


In [None]:
import shutil
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")