<center><img src='https://raw.githubusercontent.com/dimitreOliveira/MachineLearning/master/Kaggle/Cassava%20Leaf%20Disease%20Classification/banner.png' height=350></center>
<p>
<h1><center> Cassava Leaf Disease - CycleGAN data augmentation </center></h1>


TODO: Explain the goal of the noteobok and give some description


#### This work is based on my previous work [Improving CycleGAN - Monet paintings](https://www.kaggle.com/dimitreoliveira/improving-cyclegan-monet-paintings) from the [I’m Something of a Painter Myself](https://www.kaggle.com/c/gan-getting-started) competition.
- Dataset source [here](https://www.kaggle.com/dimitreoliveira/cassava-leaf-disease-tfrecords-classes-512x512)
- Dataset source [discussion thread](https://www.kaggle.com/c/cassava-leaf-disease-classification/discussion/198744)

## Dependencies

In [None]:
import os, random, json, PIL, shutil, re, imageio, glob
import numpy as np
import pandas as pd
import seaborn as sns
from PIL import ImageDraw
import matplotlib.pyplot as plt
from kaggle_datasets import KaggleDatasets
import tensorflow as tf
import tensorflow.keras.layers as L
import tensorflow.keras.backend as K
import tensorflow_addons as tfa
from tensorflow.keras import Model, losses, optimizers, applications
from tensorflow.keras.callbacks import Callback


def seed_everything(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'
    
SEED = 0
seed_everything(SEED)

## Hardware configuration

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print(f'Running on TPU {tpu.master()}')
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()


REPLICAS = strategy.num_replicas_in_sync
AUTO = tf.data.experimental.AUTOTUNE
print(f'REPLICAS: {REPLICAS}')

# Model parameters

In [None]:
HEIGHT_DS = 512
WIDTH_DS = 512
HEIGHT = 256
WIDTH = 256
HEIGHT_RESIZE = 256
WIDTH_RESIZE = 256
CHANNELS = 3
BATCH_SIZE = 16
EPOCHS = 30
TRANSFORMER_BLOCKS = 4
GENERATOR_LR = 2e-4
DISCRIMINATOR_LR = 2e-4

# Load data

In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path(f'cassava-leaf-disease-tfrecords-classes-{HEIGHT_DS}x{WIDTH_DS}')

CBB_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/CBB*.tfrec')
CBSD_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/CBSD*.tfrec')
CGM_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/CGM*.tfrec')
CMD_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/CMD*.tfrec')
HEALTHY_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/Healthy*.tfrec')


def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

n_healthy_samples = count_data_items(HEALTHY_FILENAMES)


print(f'CBB {len(CBB_FILENAMES)} TFRecord files with {count_data_items(CBB_FILENAMES)} total samples')
print(f'CBSD {len(CBSD_FILENAMES)} TFRecord files with {count_data_items(CBSD_FILENAMES)} total samples')
print(f'CGM {len(CGM_FILENAMES)} TFRecord files with {count_data_items(CGM_FILENAMES)} total samples')
print(f'CMD {len(CMD_FILENAMES)} TFRecord files with {count_data_items(CMD_FILENAMES)} total samples')
print(f'Healthy {len(HEALTHY_FILENAMES)} TFRecord files with {n_healthy_samples} total samples')

# Augmentations

Data augmentation for GANs should be done very carefully, especially for tasks similar to style transfer, if we apply transformations that can change too much the style of the data (e.g. brightness, contrast, saturation) it can cause the generator to do not efficiently learn the base style, so in this case, we are using only spatial transformations like, flips, rotates and crops.

In [None]:
def data_augment(image):
    p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_1 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_2 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_3 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
#     p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)

    
#     # Random jitter
#     image = tf.image.resize(image, [560, 560], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    # Flips
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    if p_spatial > .75:
        image = tf.image.transpose(image)
    # Rotates
    if p_rotate > .75:
        image = tf.image.rot90(image, k=3) # rotate 270º
    elif p_rotate > .5:
        image = tf.image.rot90(image, k=2) # rotate 180º
    elif p_rotate > .25:
        image = tf.image.rot90(image, k=1) # rotate 90º
    # Pixel-level transforms
    if p_pixel_1 >= .4:
        image = tf.image.random_saturation(image, lower=.7, upper=1.3)
    if p_pixel_2 >= .4:
        image = tf.image.random_contrast(image, lower=.8, upper=1.2)
    if p_pixel_3 >= .4:
        image = tf.image.random_brightness(image, max_delta=.1)
        
#     # Crops
#     if p_crop > .6: # random crop
#         crop_size = tf.random.uniform([], int(HEIGHT*.7), HEIGHT, dtype=tf.int32)
#         image = tf.image.random_crop(image, size=[crop_size, crop_size, CHANNELS])
#     elif p_crop > .2: # central crop
#         if p_crop > .5:
#             image = tf.image.central_crop(image, central_fraction=.7)
#         elif p_crop > .35:
#             image = tf.image.central_crop(image, central_fraction=.8)
#         else:
#             image = tf.image.central_crop(image, central_fraction=.9)
            
#     # Train on crops
#     image = tf.image.random_crop(image, size=[HEIGHT_RESIZE, WIDTH_RESIZE, CHANNELS])
    
    image = tf.image.resize(image, [HEIGHT_RESIZE, WIDTH_RESIZE])
    
    return image

## Auxiliar functions

In [None]:
def normalize_img(img):
    img = tf.cast(img, dtype=tf.float32)
    # Map values in the range [-1, 1]
    return (img / 127.5) - 1.0

def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=CHANNELS)
    image = tf.image.resize(image, [HEIGHT, WIDTH])
    image = tf.reshape(image, [HEIGHT, WIDTH, CHANNELS])
    return image

def read_tfrecord(example):
    tfrecord_format = {
        'image':      tf.io.FixedLenFeature([], tf.string), 
        'target':     tf.io.FixedLenFeature([], tf.int64), 
        'image_name': tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

def load_dataset(filenames):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTO)
    return dataset

def get_dataset(filenames, augment=None, repeat=True, shuffle=True, batch_size=1):
    dataset = load_dataset(filenames)

    if augment:
        dataset = dataset.map(augment, num_parallel_calls=AUTO)
    dataset = dataset.map(normalize_img, num_parallel_calls=AUTO)
    if repeat:
        dataset = dataset.repeat()
    if shuffle:
        dataset = dataset.shuffle(512)
        
    dataset = dataset.batch(batch_size)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO)
    
    return dataset

def display_samples(ds, n_samples):
    ds_iter = iter(ds)
    for n_sample in range(n_samples):
        example_sample = next(ds_iter)
        plt.subplot(121)
        plt.imshow(example_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        plt.show()
        
def display_generated_samples(ds, model, n_samples):
    ds_iter = iter(ds)
    for n_sample in range(n_samples):
        example_sample = next(ds_iter)
        generated_sample = model.predict(example_sample)
        
        f = plt.figure(figsize=(12, 12))
        
        plt.subplot(121)
        plt.title('Input image')
        plt.imshow(example_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        
        plt.subplot(122)
        plt.title('Generated image')
        plt.imshow(generated_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        plt.show()
        
def evaluate_cycle(ds, generator_a, generator_b, n_samples=1):
    fig, axes = plt.subplots(n_samples, 3, figsize=(22, (n_samples*6)))
    axes = axes.flatten()
    
    ds_iter = iter(ds)
    for n_sample in range(n_samples):
        idx = n_sample*3
        example_sample = next(ds_iter)
        generated_a_sample = generator_a.predict(example_sample)
        generated_b_sample = generator_b.predict(generated_a_sample)
        
        axes[idx].set_title('Input image', fontsize=18)
        axes[idx].imshow(example_sample[0] * 0.5 + 0.5)
        axes[idx].axis('off')
        
        axes[idx+1].set_title('Generated image', fontsize=18)
        axes[idx+1].imshow(generated_a_sample[0] * 0.5 + 0.5)
        axes[idx+1].axis('off')
        
        axes[idx+2].set_title('Cycled image', fontsize=18)
        axes[idx+2].imshow(generated_b_sample[0] * 0.5 + 0.5)
        axes[idx+2].axis('off')
        
    plt.show()

def create_gif(images_path, gif_path):
    images = []
    filenames = glob.glob(images_path)
    filenames.sort(key=lambda x: int(''.join(filter(str.isdigit, x))))
    for epoch, filename in enumerate(filenames):
        img = PIL.ImageDraw.Image.open(filename)
        ImageDraw.Draw(img).text((0, 0),  # Coordinates
                                 f'Epoch {epoch+1}')
        images.append(img)
    imageio.mimsave(gif_path, images, fps=2) # Save gif
        
def predict_and_save(input_ds, generator_model, output_path):
    i = 1
    for img in input_ds:
        prediction = generator_model(img, training=False)[0].numpy() # make predition
        prediction = (prediction * 127.5 + 127.5).astype(np.uint8)   # re-scale
        im = PIL.Image.fromarray(prediction)
        im.save(f'{output_path}{str(i)}.jpg')
        i += 1

## Auxiliar functions (model)

Here we the building blocks of our models:
- Encoder block: Apply convolutional filters while also reducing data resolution and increasing features.
- Decoder block: Apply convolutional filters while also increasing data resolution and decreasing features.
- Transformer block: Apply convolutional filters to find relevant data patterns and keeps features constant.

In [None]:
conv_initializer = tf.random_normal_initializer(mean=0.0, stddev=0.02)
gamma_initializer = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
    
def encoder_block(input_layer, filters, size=3, strides=2, apply_instancenorm=True, activation=L.ReLU(), name='block_x'):
    block = L.Conv2D(filters, size, 
                     strides=strides, 
                     padding='same', 
                     use_bias=False, 
                     kernel_initializer=conv_initializer, 
                     name=f'encoder_{name}')(input_layer)

    if apply_instancenorm:
        block = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(block)
        
    block = activation(block)

    return block

def transformer_block(input_layer, size=3, strides=1, name='block_x'):
    filters = input_layer.shape[-1]
    
    block = L.Conv2D(filters, size, strides=strides, padding='same', use_bias=False, 
                     kernel_initializer=conv_initializer, name=f'transformer_{name}_1')(input_layer)
#     block = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(block)
    block = L.ReLU()(block)
    
    block = L.Conv2D(filters, size, strides=strides, padding='same', use_bias=False, 
                     kernel_initializer=conv_initializer, name=f'transformer_{name}_2')(block)
#     block = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(block)
    
    block = L.Add()([block, input_layer])

    return block

def decoder_block(input_layer, filters, size=3, strides=2, apply_instancenorm=True, name='block_x'):
    block = L.Conv2DTranspose(filters, size, 
                              strides=strides, 
                              padding='same', 
                              use_bias=False, 
                              kernel_initializer=conv_initializer, 
                              name=f'decoder_{name}')(input_layer)

    if apply_instancenorm:
        block = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(block)

    block = L.ReLU()(block)
    
    return block

# Resized convolution
def decoder_rc_block(input_layer, filters, size=3, strides=1, apply_instancenorm=True, name='block_x'):
    block = tf.image.resize(images=input_layer, method='bilinear', 
                            size=(input_layer.shape[1]*2, input_layer.shape[2]*2))
    
#     block = tf.pad(block, [[0, 0], [1, 1], [1, 1], [0, 0]], "SYMMETRIC") # Works only with GPU
#     block = L.Conv2D(filters, size, strides=strides, padding='valid', use_bias=False, # Works only with GPU
    block = L.Conv2D(filters, size, 
                     strides=strides, 
                     padding='same', 
                     use_bias=False, 
                     kernel_initializer=conv_initializer, 
                     name=f'decoder_{name}')(block)

    if apply_instancenorm:
        block = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(block)

    block = L.ReLU()(block)
    
    return block

# Generator model

The `generator` is responsible for generating images from a specific domain. `CycleGAN` architecture has two generators, in this context for example we can take one `generator` that will take `Healthy` images and generate `CBB` images, and the other `generator` will take `CBB` images and generate `Healthy` images.

Bellow, we have the architecture of the original `CycleGAN` `generator`, ours have some changes to improve performance on this task.

<center><img src='https://github.com/dimitreOliveira/MachineLearning/blob/master/Kaggle/I%E2%80%99m%20Something%20of%20a%20Painter%20Myself/generator_architecture.png?raw=true' height=250></center>

In [None]:
def generator_fn(height=HEIGHT, width=WIDTH, channels=CHANNELS, transformer_blocks=TRANSFORMER_BLOCKS):
    OUTPUT_CHANNELS = 3
    inputs = L.Input(shape=[height, width, channels], name='input_image')

    # Encoder
    enc_1 = encoder_block(inputs, 64,  7, 1, apply_instancenorm=False, activation=L.ReLU(), name='block_1') # (bs, 256, 256, 64)
    enc_2 = encoder_block(enc_1, 128, 3, 2, apply_instancenorm=True, activation=L.ReLU(), name='block_2')   # (bs, 128, 128, 128)
    enc_3 = encoder_block(enc_2, 256, 3, 2, apply_instancenorm=True, activation=L.ReLU(), name='block_3')   # (bs, 64, 64, 256)
    
    # Transformer
    x = enc_3
    for n in range(transformer_blocks):
        x = transformer_block(x, 3, 1, name=f'block_{n+1}') # (bs, 64, 64, 256)

    # Decoder
    x_skip = L.Concatenate(name='enc_dec_skip_1')([x, enc_3]) # encoder - decoder skip connection
    
    dec_1 = decoder_block(x_skip, 128, 3, 2, apply_instancenorm=True, name='block_1') # (bs, 128, 128, 128)
    x_skip = L.Concatenate(name='enc_dec_skip_2')([dec_1, enc_2]) # encoder - decoder skip connection
    
    dec_2 = decoder_block(x_skip, 64,  3, 2, apply_instancenorm=True, name='block_2') # (bs, 256, 256, 64)
    x_skip = L.Concatenate(name='enc_dec_skip_3')([dec_2, enc_1]) # encoder - decoder skip connection

    outputs = last = L.Conv2D(OUTPUT_CHANNELS, 7, 
                              strides=1, padding='same', 
                              kernel_initializer=conv_initializer, 
                              use_bias=False, 
                              activation='tanh', 
                              name='decoder_output_block')(x_skip) # (bs, 256, 256, 3)

    generator = Model(inputs, outputs)
    
    return generator

sample_generator = generator_fn()
sample_generator.summary()

# Discriminator model


The `discriminator` is responsible for differentiating real images from images that have been generated by a `generator` model.

Bellow, we have the architecture of the original `CycleGAN` `discriminator`, again, ours have some changes to improve performance on this task.

<center><img src='https://github.com/dimitreOliveira/MachineLearning/blob/master/Kaggle/I%E2%80%99m%20Something%20of%20a%20Painter%20Myself/discriminator_architecture.png?raw=true' height=550, width=550></center>

In [None]:
def discriminator_fn(height=HEIGHT, width=WIDTH, channels=CHANNELS):
    inputs = L.Input(shape=[height, width, channels], name='input_image')
    #inputs_patch = L.experimental.preprocessing.RandomCrop(height=70, width=70, name='input_image_patch')(inputs) # Works only with GPU

#     # Encoder    
#     x = encoder_block(inputs, 64,  4, 2, apply_instancenorm=False, activation=L.LeakyReLU(0.2), name='block_1') # (bs, 128, 128, 64)
#     x = encoder_block(x, 128, 4, 2, apply_instancenorm=True, activation=L.LeakyReLU(0.2), name='block_2')       # (bs, 64, 64, 128)
#     x = encoder_block(x, 256, 4, 2, apply_instancenorm=True, activation=L.LeakyReLU(0.2), name='block_3')       # (bs, 32, 32, 256)
#     x = encoder_block(x, 512, 4, 1, apply_instancenorm=True, activation=L.LeakyReLU(0.2), name='block_4')       # (bs, 32, 32, 512)
    
    # Using pre-trained model
    base_model = applications.MobileNetV2(weights='imagenet', include_top=False)
    x = base_model(inputs)
    
    outputs = L.Conv2D(1, 4, strides=1, padding='valid', kernel_initializer=conv_initializer)(x)                # (bs, 29, 29, 1)
    
    discriminator = Model(inputs, outputs)
    
    return discriminator


sample_discriminator = discriminator_fn()
sample_discriminator.summary()

# Build model (CycleGAN)

In [None]:
class CycleGan(Model):
    def __init__(
        self,
        first_domain_generator,
        second_domain_generator,
        first_domain_discriminator,
        second_domain_discriminator,
        first_domain_name='first_domain',
        second_domain_name='second_domain',
        lambda_cycle=10,
    ):
        super(CycleGan, self).__init__()
        self.f_gen = first_domain_generator
        self.s_gen = second_domain_generator
        self.f_disc = first_domain_discriminator
        self.s_disc = second_domain_discriminator
        self.first_domain_name = first_domain_name
        self.second_domain_name = second_domain_name
        self.lambda_cycle = lambda_cycle
        
    def compile(
        self,
        f_gen_optimizer,
        s_gen_optimizer,
        f_disc_optimizer,
        s_disc_optimizer,
        gen_loss_fn,
        disc_loss_fn,
        cycle_loss_fn,
        identity_loss_fn
    ):
        super(CycleGan, self).compile()
        self.f_gen_optimizer = f_gen_optimizer
        self.s_gen_optimizer = s_gen_optimizer
        self.f_disc_optimizer = f_disc_optimizer
        self.s_disc_optimizer = s_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
    def train_step(self, batch_data):
        real_first_domain, real_second_domain = batch_data
        
        with tf.GradientTape(persistent=True) as tape:
            # second_domain to first_domain back to second_domain
            fake_first_domain = self.f_gen(real_second_domain, training=True)
            cycled_second_domain = self.s_gen(fake_first_domain, training=True)

            # first_domain to second_domain back to first_domain
            fake_second_domain = self.s_gen(real_first_domain, training=True)
            cycled_first_domain = self.f_gen(fake_second_domain, training=True)

            # generating itself
            same_first_domain = self.f_gen(real_first_domain, training=True)
            same_second_domain = self.s_gen(real_second_domain, training=True)

            # discriminator used to check, inputing real images
            disc_real_first_domain = self.f_disc(real_first_domain, training=True)
            disc_real_second_domain = self.s_disc(real_second_domain, training=True)

            # discriminator used to check, inputing fake images
            disc_fake_first_domain = self.f_disc(fake_first_domain, training=True)
            disc_fake_second_domain = self.s_disc(fake_second_domain, training=True)

            # evaluates generator loss
            first_domain_gen_loss = self.gen_loss_fn(disc_fake_first_domain)
            second_domain_gen_loss = self.gen_loss_fn(disc_fake_second_domain)

            # evaluates total cycle consistency loss
            total_cycle_loss = self.cycle_loss_fn(real_first_domain, cycled_first_domain, self.lambda_cycle) + self.cycle_loss_fn(real_second_domain, cycled_second_domain, self.lambda_cycle)

            # evaluates total generator loss
            total_first_domain_gen_loss = first_domain_gen_loss + total_cycle_loss + self.identity_loss_fn(real_first_domain, same_first_domain, self.lambda_cycle)
            total_second_domain_gen_loss = second_domain_gen_loss + total_cycle_loss + self.identity_loss_fn(real_second_domain, same_second_domain, self.lambda_cycle)

            # evaluates discriminator loss
            first_domain_disc_loss = self.disc_loss_fn(disc_real_first_domain, disc_fake_first_domain)
            second_domain_disc_loss = self.disc_loss_fn(disc_real_second_domain, disc_fake_second_domain)

        # Calculate the gradients for generator and discriminator
        first_domain_generator_gradients = tape.gradient(total_first_domain_gen_loss,
                                                  self.f_gen.trainable_variables)
        second_domain_generator_gradients = tape.gradient(total_second_domain_gen_loss,
                                                  self.s_gen.trainable_variables)

        first_domain_discriminator_gradients = tape.gradient(first_domain_disc_loss,
                                                      self.f_disc.trainable_variables)
        second_domain_discriminator_gradients = tape.gradient(second_domain_disc_loss,
                                                      self.s_disc.trainable_variables)

        # Apply the gradients to the optimizer
        self.f_gen_optimizer.apply_gradients(zip(first_domain_generator_gradients,
                                                 self.f_gen.trainable_variables))

        self.s_gen_optimizer.apply_gradients(zip(second_domain_generator_gradients,
                                                 self.s_gen.trainable_variables))

        self.f_disc_optimizer.apply_gradients(zip(first_domain_discriminator_gradients,
                                                  self.f_disc.trainable_variables))

        self.s_disc_optimizer.apply_gradients(zip(second_domain_discriminator_gradients,
                                                  self.s_disc.trainable_variables))
        
        return {f'{self.first_domain_name}_gen_loss': total_first_domain_gen_loss,
                f'{self.second_domain_name}_gen_loss': total_second_domain_gen_loss,
                f'{self.first_domain_name}_disc_loss': first_domain_disc_loss,
                f'{self.second_domain_name}_disc_loss': second_domain_disc_loss
               }

# Loss functions

In [None]:
with strategy.scope():
    # Discriminator loss {0: fake, 1: real} (The discriminator loss outputs the average of the real and generated loss)
    def discriminator_loss(real, generated):
        real_loss = losses.BinaryCrossentropy(from_logits=True, reduction=losses.Reduction.NONE)(tf.ones_like(real), real)
        generated_loss = losses.BinaryCrossentropy(from_logits=True, reduction=losses.Reduction.NONE)(tf.zeros_like(generated), generated)
        total_disc_loss = real_loss + generated_loss
        return total_disc_loss * 0.5
    
    # Generator loss
    def generator_loss(generated):
        return losses.BinaryCrossentropy(from_logits=True, reduction=losses.Reduction.NONE)(tf.ones_like(generated), generated)
    
    # Cycle consistency loss (measures if original image and the twice transformed image to be similar to one another)
    with strategy.scope():
        def calc_cycle_loss(real_image, cycled_image, LAMBDA):
            loss = tf.reduce_mean(tf.abs(real_image - cycled_image))
            return loss
#             return LAMBDA * loss

    # Identity loss (compares the image with its generator (i.e. Healthy with CBB generator))
    with strategy.scope():
        def identity_loss(real_image, same_image, LAMBDA):
            loss = tf.reduce_mean(tf.abs(real_image - same_image))
            return loss
#             return LAMBDA * 0.5 * loss

# Create datasets

In [None]:
# Create dataset
## Single class datsets
healthy_ds = get_dataset(HEALTHY_FILENAMES, augment=data_augment, batch_size=BATCH_SIZE)
cbb_ds = get_dataset(CBB_FILENAMES, augment=data_augment, batch_size=BATCH_SIZE)
cbsd_ds = get_dataset(CBSD_FILENAMES, augment=data_augment, batch_size=BATCH_SIZE)
cgm_ds = get_dataset(CGM_FILENAMES, augment=data_augment, batch_size=BATCH_SIZE)
cmd_ds = get_dataset(CMD_FILENAMES, augment=data_augment, batch_size=BATCH_SIZE)

## Joint datasets
cbb_healthy_ds = tf.data.Dataset.zip((cbb_ds, healthy_ds))
cbsd_healthy_ds = tf.data.Dataset.zip((cbsd_ds, healthy_ds))
cgm_healthy_ds = tf.data.Dataset.zip((cgm_ds, healthy_ds))
cmd_healthy_ds = tf.data.Dataset.zip((cmd_ds, healthy_ds))

## Eval datasets
healthy_ds_eval = get_dataset(HEALTHY_FILENAMES, repeat=False, shuffle=False, batch_size=1)
cbb_ds_eval = get_dataset(CBB_FILENAMES, repeat=False, shuffle=False, batch_size=1)
cbsd_ds_eval = get_dataset(CBSD_FILENAMES, repeat=False, shuffle=False, batch_size=1)
cgm_ds_eval = get_dataset(CGM_FILENAMES, repeat=False, shuffle=False, batch_size=1)
cmd_ds_eval = get_dataset(CMD_FILENAMES, repeat=False, shuffle=False, batch_size=1)

# Callbacks
class GANMonitor(Callback):
    """A callback to generate and save images after each epoch"""

    def __init__(self, generator, output_path, input_ds=healthy_ds_eval, num_img=1):
        self.num_img = num_img
        self.input_ds = input_ds
        self.generator = generator
        self.output_path = output_path
        # Create directories to save the generate images
        if not os.path.exists(self.output_path):
            os.makedirs(self.output_path)

    def on_epoch_end(self, epoch, logs=None):
        for i, img in enumerate(self.input_ds.take(self.num_img)):
            prediction = self.generator(img, training=False)[0].numpy()
            prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
            prediction = PIL.Image.fromarray(prediction)
            prediction.save(f'{self.output_path}/generated_{i}_{epoch+1}.png')

# CBB generator training

In [None]:
if tpu: tf.tpu.experimental.initialize_tpu_system(tpu)
K.clear_session()

with strategy.scope():
    # Create generators
    healthy_cbb_generator = generator_fn(height=HEIGHT_RESIZE, width=WIDTH_RESIZE, transformer_blocks=TRANSFORMER_BLOCKS)
    cbb_generator = generator_fn(height=HEIGHT_RESIZE, width=WIDTH_RESIZE, transformer_blocks=TRANSFORMER_BLOCKS)

    healthy_generator_optimizer = optimizers.Adam(learning_rate=GENERATOR_LR, beta_1=0.5)
    cbb_generator_optimizer = optimizers.Adam(learning_rate=GENERATOR_LR, beta_1=0.5)

    # Create discriminators
    healthy_discriminator = discriminator_fn(height=HEIGHT_RESIZE, width=WIDTH_RESIZE)
    cbb_discriminator = discriminator_fn(height=HEIGHT_RESIZE, width=WIDTH_RESIZE)

    healthy_discriminator_optimizer = optimizers.Adam(learning_rate=DISCRIMINATOR_LR, beta_1=0.5)
    cbb_discriminator_optimizer = optimizers.Adam(learning_rate=DISCRIMINATOR_LR, beta_1=0.5)

    # Create GAN
    gan_model = CycleGan(cbb_generator, healthy_cbb_generator, 
                         cbb_discriminator, healthy_discriminator, 
                         'cbb', 'healthy')
    

    gan_model.compile(f_gen_optimizer=cbb_generator_optimizer,
                      s_gen_optimizer=healthy_generator_optimizer,
                      f_disc_optimizer=cbb_discriminator_optimizer,
                      s_disc_optimizer=healthy_discriminator_optimizer,
                      gen_loss_fn=generator_loss,
                      disc_loss_fn=discriminator_loss,
                      cycle_loss_fn=calc_cycle_loss,
                      identity_loss_fn=identity_loss)


history = gan_model.fit(cbb_healthy_ds, 
                        epochs=EPOCHS, 
                        batch_size=BATCH_SIZE,
                        callbacks=[GANMonitor(cbb_generator, 'cbb')], 
                        steps_per_epoch=(n_healthy_samples//BATCH_SIZE), 
                        verbose=2).history

# Output models
healthy_cbb_generator.save('healthy_cbb_generator.h5')
cbb_generator.save('cbb_generator.h5')

# CBSD generator training

In [None]:
if tpu: tf.tpu.experimental.initialize_tpu_system(tpu)
K.clear_session()

with strategy.scope():
    # Create generators
    healthy_cbsd_generator = generator_fn(height=HEIGHT_RESIZE, width=WIDTH_RESIZE, transformer_blocks=TRANSFORMER_BLOCKS)
    cbsd_generator = generator_fn(height=HEIGHT_RESIZE, width=WIDTH_RESIZE, transformer_blocks=TRANSFORMER_BLOCKS)

    healthy_generator_optimizer = optimizers.Adam(learning_rate=GENERATOR_LR, beta_1=0.5)
    cbsd_generator_optimizer = optimizers.Adam(learning_rate=GENERATOR_LR, beta_1=0.5)

    # Create discriminators
    healthy_discriminator = discriminator_fn(height=HEIGHT_RESIZE, width=WIDTH_RESIZE)
    cbsd_discriminator = discriminator_fn(height=HEIGHT_RESIZE, width=WIDTH_RESIZE)

    healthy_discriminator_optimizer = optimizers.Adam(learning_rate=DISCRIMINATOR_LR, beta_1=0.5)
    cbsd_discriminator_optimizer = optimizers.Adam(learning_rate=DISCRIMINATOR_LR, beta_1=0.5)

    # Create GAN
    gan_model = CycleGan(cbsd_generator, healthy_cbsd_generator, 
                         cbsd_discriminator, healthy_discriminator, 
                         'cbsd', 'healthy')
    

    gan_model.compile(f_gen_optimizer=cbsd_generator_optimizer,
                      s_gen_optimizer=healthy_generator_optimizer,
                      f_disc_optimizer=cbsd_discriminator_optimizer,
                      s_disc_optimizer=healthy_discriminator_optimizer,
                      gen_loss_fn=generator_loss,
                      disc_loss_fn=discriminator_loss,
                      cycle_loss_fn=calc_cycle_loss,
                      identity_loss_fn=identity_loss)
    
history = gan_model.fit(cbsd_healthy_ds, 
                        epochs=EPOCHS, 
                        batch_size=BATCH_SIZE,
                        callbacks=[GANMonitor(cbsd_generator, 'cbsd')], 
                        steps_per_epoch=(n_healthy_samples//BATCH_SIZE), 
                        verbose=2).history

# Output models
healthy_cbsd_generator.save('healthy_cbsd_generator.h5')
cbsd_generator.save('cbsd_generator.h5')

# CGM generator training

In [None]:
if tpu: tf.tpu.experimental.initialize_tpu_system(tpu)
K.clear_session()

with strategy.scope():
    # Create generators
    healthy_cgm_generator = generator_fn(height=HEIGHT_RESIZE, width=WIDTH_RESIZE, transformer_blocks=TRANSFORMER_BLOCKS)
    cgm_generator = generator_fn(height=HEIGHT_RESIZE, width=WIDTH_RESIZE, transformer_blocks=TRANSFORMER_BLOCKS)

    healthy_generator_optimizer = optimizers.Adam(learning_rate=GENERATOR_LR, beta_1=0.5)
    cgm_generator_optimizer = optimizers.Adam(learning_rate=GENERATOR_LR, beta_1=0.5)

    # Create discriminators
    healthy_discriminator = discriminator_fn(height=HEIGHT_RESIZE, width=WIDTH_RESIZE)
    cgm_discriminator = discriminator_fn(height=HEIGHT_RESIZE, width=WIDTH_RESIZE)

    healthy_discriminator_optimizer = optimizers.Adam(learning_rate=DISCRIMINATOR_LR, beta_1=0.5)
    cgm_discriminator_optimizer = optimizers.Adam(learning_rate=DISCRIMINATOR_LR, beta_1=0.5)

    # Create GAN
    gan_model = CycleGan(cgm_generator, healthy_cgm_generator, 
                         cgm_discriminator, healthy_discriminator, 
                         'cgm', 'healthy')
    

    gan_model.compile(f_gen_optimizer=cgm_generator_optimizer,
                      s_gen_optimizer=healthy_generator_optimizer,
                      f_disc_optimizer=cgm_discriminator_optimizer,
                      s_disc_optimizer=healthy_discriminator_optimizer,
                      gen_loss_fn=generator_loss,
                      disc_loss_fn=discriminator_loss,
                      cycle_loss_fn=calc_cycle_loss,
                      identity_loss_fn=identity_loss)
    
history = gan_model.fit(cgm_healthy_ds, 
                        epochs=EPOCHS, 
                        batch_size=BATCH_SIZE,
                        callbacks=[GANMonitor(cgm_generator, 'cgm')], 
                        steps_per_epoch=(n_healthy_samples//BATCH_SIZE), 
                        verbose=2).history

# Output models
healthy_cgm_generator.save('healthy_cgm_generator.h5')
cgm_generator.save('cgm_generator.h5')

# CMD generator training

In [None]:
if tpu: tf.tpu.experimental.initialize_tpu_system(tpu)
K.clear_session()

with strategy.scope():
    # Create generators
    healthy_cmd_generator = generator_fn(height=HEIGHT_RESIZE, width=WIDTH_RESIZE, transformer_blocks=TRANSFORMER_BLOCKS)
    cmd_generator = generator_fn(height=HEIGHT_RESIZE, width=WIDTH_RESIZE, transformer_blocks=TRANSFORMER_BLOCKS)

    healthy_generator_optimizer = optimizers.Adam(learning_rate=GENERATOR_LR, beta_1=0.5)
    cmd_generator_optimizer = optimizers.Adam(learning_rate=GENERATOR_LR, beta_1=0.5)

    # Create discriminators
    healthy_discriminator = discriminator_fn(height=HEIGHT_RESIZE, width=WIDTH_RESIZE)
    cmd_discriminator = discriminator_fn(height=HEIGHT_RESIZE, width=WIDTH_RESIZE)

    healthy_discriminator_optimizer = optimizers.Adam(learning_rate=DISCRIMINATOR_LR, beta_1=0.5)
    cmd_discriminator_optimizer = optimizers.Adam(learning_rate=DISCRIMINATOR_LR, beta_1=0.5)

    # Create GAN
    gan_model = CycleGan(cmd_generator, healthy_cmd_generator, 
                         cmd_discriminator, healthy_discriminator, 
                         'cmd', 'healthy')
    

    gan_model.compile(f_gen_optimizer=cmd_generator_optimizer,
                      s_gen_optimizer=healthy_generator_optimizer,
                      f_disc_optimizer=cmd_discriminator_optimizer,
                      s_disc_optimizer=healthy_discriminator_optimizer,
                      gen_loss_fn=generator_loss,
                      disc_loss_fn=discriminator_loss,
                      cycle_loss_fn=calc_cycle_loss,
                      identity_loss_fn=identity_loss)
    
history = gan_model.fit(cmd_healthy_ds, 
                        epochs=EPOCHS, 
                        batch_size=BATCH_SIZE,
                        callbacks=[GANMonitor(cmd_generator, 'cmd')], 
                        steps_per_epoch=(n_healthy_samples//BATCH_SIZE), 
                        verbose=2).history

# Output models
healthy_cmd_generator.save('healthy_cmd_generator.h5')
cmd_generator.save('cmd_generator.h5')

In [None]:
# Load models
## Discriminators
cbb_generator.load_weights('cbb_generator.h5')
cbsd_generator.load_weights('cbsd_generator.h5')
cgm_generator.load_weights('cgm_generator.h5')
cmd_generator.load_weights('cmd_generator.h5')
# Generators
healthy_cbb_generator.load_weights('healthy_cbb_generator.h5')
healthy_cbsd_generator.load_weights('healthy_cbsd_generator.h5')
healthy_cgm_generator.load_weights('healthy_cgm_generator.h5')
healthy_cmd_generator.load_weights('healthy_cmd_generator.h5')

We can see the generators progress at each epoch by creating a `gif` that is a generated image at each epoch.

## Generated image GIFs

In [None]:
# Create GIFs
# create_gif('/kaggle/working/healthy/*.png', 'healthy.gif') # Create healthy images gif
create_gif('/kaggle/working/cbb/*.png', 'cbb.gif') # Create cbb images gif
create_gif('/kaggle/working/cbsd/*.png', 'cbsd.gif') # Create cbsd images gif
create_gif('/kaggle/working/cgm/*.png', 'cgm.gif') # Create cgm images gif
create_gif('/kaggle/working/cmd/*.png', 'cmd.gif') # Create cmd images gif

<figure class="half">
  <figcaption>Left "Healthy" to "CBB" right "Healthy" to "CBSD" .</figcaption>
  <table>
    <tr>
      <td>
        <img style="width:400px;" src="cbb.gif">
      </td>
      <td>
        <img style="width:400px;" src="cbsd.gif">
      </td>
    </tr>
  </table>
</figure>

<figure class="half">
  <figcaption>Left "Healthy" to "CGM" right "Healthy" to "CMD" .</figcaption>
  <table>
    <tr>
      <td>
        <img style="width:400px;" src="cgm.gif">
      </td>
      <td>
        <img style="width:400px;" src="cmd.gif">
      </td>
    </tr>
  </table>
</figure>

# Evaluating generator models

Here we are going to evaluate the generator models including how good is the generator cycle, this means that we will get an image from a domain to generate an image to another domain, then use the generated image to generate back the original image domain.

## Healthy (input) -> CBB (generated) -> Healthy (generated)

In [None]:
evaluate_cycle(healthy_ds_eval.take(2), cbb_generator, healthy_cbb_generator, n_samples=2)

## Healthy (input) -> CBSD (generated) -> Healthy (generated)

In [None]:
evaluate_cycle(healthy_ds_eval.take(2), cbsd_generator, healthy_cbsd_generator, n_samples=2)

## Healthy (input) -> CGM (generated) -> Healthy (generated)

In [None]:
evaluate_cycle(healthy_ds_eval.take(2), cgm_generator, healthy_cgm_generator, n_samples=2)

## Healthy (input) -> CMD (generated) -> Healthy (generated)

In [None]:
evaluate_cycle(healthy_ds_eval.take(2), cmd_generator, healthy_cmd_generator, n_samples=2)

# Visualize predictions

A common issue with images generated by GANs is that the often show some undisered artifacts, a very common on is known as "[checkerboard artifacts](https://distill.pub/2016/deconv-checkerboard/)", a good practice is to inspect some of the images to see its quality and if some of these undisered artifacts are present.

## CBB (generated)

In [None]:
display_generated_samples(healthy_ds_eval.take(4), cbb_generator, 4)

## CBSD (generated)

In [None]:
display_generated_samples(healthy_ds_eval.take(4), cbsd_generator, 4)

## CGM (generated)

In [None]:
display_generated_samples(healthy_ds_eval.take(4), cgm_generator, 4)

## CMD (generated)

In [None]:
display_generated_samples(healthy_ds_eval.take(4), cmd_generator, 4)

## Healthy (generated)

### From CBB

In [None]:
display_generated_samples(cbb_ds_eval.take(3), healthy_cbb_generator, 3)

### From CBSD

In [None]:
display_generated_samples(cbsd_ds_eval.take(3), healthy_cbsd_generator, 3)

### From CGM

In [None]:
display_generated_samples(cgm_ds_eval.take(3), healthy_cgm_generator, 3)

### From CMD

In [None]:
display_generated_samples(cmd_ds_eval.take(3), healthy_cmd_generator, 3)

## Generating images predictions

#### Not duing here because takes too much time using TPU

### CBB

In [None]:
# %%time

# # Create folders
# os.makedirs('../cbb_generated/') # Create folder to save generated images
# # Generate images
# predict_and_save(healthy_ds_eval, cbb_generator, '../cbb_generated/')
# # Zip folders
# shutil.make_archive('/kaggle/working/cbb_generated/', 'zip', '../cbb_generated')
# # Count images
# print(f"Generated CBB samples: {len([name for name in os.listdir('../cbb_generated/') if os.path.isfile(os.path.join('../cbb_generated/', name))])}")

### CBSD

In [None]:
# %%time

# Create folders
# os.makedirs('../cbsd_generated/') # Create folder to save generated images
# Generate images
# predict_and_save(healthy_ds_eval, cbsd_generator, '../cbsd_generated/')
# Zip folders
# shutil.make_archive('/kaggle/working/cbsd_generated/', 'zip', '../cbsd_generated')
# Count images
# print(f"Generated CBSD samples: {len([name for name in os.listdir('../cbsd_generated/') if os.path.isfile(os.path.join('../cbsd_generated/', name))])}")

### CGM

In [None]:
# %%time

# Create folders
# os.makedirs('../cgm_generated/') # Create folder to save generated images
# Generate images
# predict_and_save(healthy_ds_eval, cgm_generator, '../cgm_generated/')
# Zip folders
# shutil.make_archive('/kaggle/working/cgm_generated/', 'zip', '../cgm_generated')
# Count images
# print(f"Generated CGM samples: {len([name for name in os.listdir('../cgm_generated/') if os.path.isfile(os.path.join('../cgm_generated/', name))])}")

### CMD

In [None]:
# %%time

# Create folders
# os.makedirs('../cmd_generated/') # Create folder to save generated images
# Generate images
# predict_and_save(healthy_ds_eval, cmd_generator, '../cmd_generated/')
# Zip folders
# shutil.make_archive('/kaggle/working/cmd_generated/', 'zip', '../cmd_generated')
# Count images
# print(f"Generated CMD samples: {len([name for name in os.listdir('../cmd_generated/') if os.path.isfile(os.path.join('../cmd_generated/', name))])}")