# Introduction and Setup

Most of the code here is similar to the [Monet CycleGAN Tutorial] (https://www.kaggle.com/amyjang/monet-cyclegan-tutorial)

---

## What I've added here:
1. Residual Blocks in the Generator.
2. Weights and Biases support to log images to the cloud.

---
## References:
1. [Improving CycleGAN - Monet paintings](https://www.kaggle.com/dimitreoliveira/improving-cyclegan-monet-paintings/)
2. [CycleGAN Tutorial from Scratch: Monet-to-Photo](https://www.kaggle.com/songseungwon/cyclegan-tutorial-from-scratch-monet-to-photo)


# Installing Dependencies

In [None]:
!pip install wandb

In [None]:
import re
import wandb
import tensorflow as tf
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import gridspec
from kaggle_datasets import KaggleDatasets

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

AUTOTUNE = tf.data.experimental.AUTOTUNE
REPLICAS = strategy.num_replicas_in_sync
print(tf.__version__)

Let us define few constants.

In [None]:
IMAGE_SIZE = [256, 256]
OUTPUT_CHANNELS = 3
BATCH_SIZE = 16
EPOCHS = 25

Now we will initialize our WandB project

In [None]:
from kaggle_secrets import UserSecretsClient
wandb_key = UserSecretsClient().get_secret("wandb_key")

wandb.login(key=wandb_key)
wandb.init(project="monet-cyclegan-kaggle")

# Load in the data

We want to keep our photo dataset and our Monet dataset separate. First, load in the filenames of the TFRecords.

In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path()

In [None]:
MONET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))
print('Monet TFRecord Files:', len(MONET_FILENAMES))

PHOTO_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/photo_tfrec/*.tfrec'))
print('Photo TFRecord Files:', len(PHOTO_FILENAMES))

In [None]:
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

n_monet_samples = count_data_items(MONET_FILENAMES)
n_photo_samples = count_data_items(PHOTO_FILENAMES)

print(f"Total Monet Samples: {n_monet_samples}")
print(f"Total Photo Samples: {n_photo_samples}")

# Auxilary Functions

## Data Loading

All the images for the competition are already sized to 256x256. As these images are RGB images, set the channel to 3. Additionally, we need to scale the images to a [-1, 1] scale. Because we are building a generative model, we don't need the labels or the image id so we'll only return the image from the TFRecord.

Before we return the image from the TFRecord, we will apply a few augmentations. We need to ensure that we apply augmentations carefully as the generator might learn these augmentations.

In [None]:
def augment(image):
    
    # Resize image to 286x286
    image = tf.image.resize(image, [286, 286], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    # Randomly Crop to 256x256
    image = tf.image.random_crop(image, [*IMAGE_SIZE, 3])
    # Randomly Mirror
    image = tf.image.random_flip_left_right(image)
    
    return image

In [None]:
def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_tfrecord(example):
    tfrecord_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

Define the function to extract the image from the files.

In [None]:
def load_dataset(filenames, labeled=True, ordered=False):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    return dataset

Now we will write a function that will prepare our datasets for training and testing.

In [None]:
def prepare_dataset(filenames, augment=None, repeat=True, shuffle=True, batch_size=1):
    dataset = load_dataset(filenames)
    
    if augment:
        dataset = dataset.map(augment, num_parallel_calls=AUTOTUNE)
    if repeat:
        dataset = dataset.repeat()
    if shuffle:
        dataset = dataset.shuffle(512)
    
    dataset = dataset.batch(batch_size)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTOTUNE)
    
    return dataset

Let's load in our datasets.

In [None]:
# Create dataset
monet_ds = prepare_dataset(MONET_FILENAMES, augment=augment, batch_size=BATCH_SIZE)
photo_ds = prepare_dataset(PHOTO_FILENAMES, augment=augment, batch_size=BATCH_SIZE)
gan_ds = tf.data.Dataset.zip((monet_ds, photo_ds))

# For evaluation
photo_ds_eval = prepare_dataset(PHOTO_FILENAMES, repeat=False, shuffle=False, batch_size=1)
monet_ds_eval = prepare_dataset(MONET_FILENAMES, repeat=False, shuffle=False, batch_size=1)

## Data Visualization

Let's  visualize a photo example and a Monet example.

In [None]:
# Reference: https://stackoverflow.com/questions/41071947/how-to-remove-the-space-between-subplots-in-matplotlib-pyplot
def prepare_grid(n_rows, n_cols, wspace=0.0, hspace=0.0):
    fig = plt.figure(figsize=(n_cols+1, n_rows+1)) 

    gs = gridspec.GridSpec(n_rows, n_cols,
             wspace=wspace, hspace=hspace, 
             top=1.-0.5/(n_rows+1), bottom=0.5/(n_rows+1), 
             left=0.5/(n_cols+1), right=1-0.5/(n_cols+1)
        )
    return fig, gs

In [None]:
def display_samples(dataset, n_samples, n_rows=2):
    dataset_iter = iter(dataset)
    
    n_cols = n_samples // n_rows + n_samples % n_rows
    
    fig, grid = prepare_grid(n_rows, n_cols)
    
    for i in range(n_rows):
        for j in range(n_cols):
            example = next(dataset_iter)
            ax = plt.subplot(grid[i,j])
            ax.imshow(example[0] * 0.5 + 0.5)
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.axis("off")

    plt.show()

In [None]:
display_samples(monet_ds, 5, n_rows=1)
display_samples(photo_ds, 5, n_rows=1)

In [None]:
def display_generated_samples(dataset, model, n_samples):
    dataset_iter = iter(dataset)
    
    n_cols = 2
    n_rows = n_samples
    
    fig, grid = prepare_grid(n_rows, n_cols, wspace=0.1)
    
    for i in range(n_rows):
        for j in range(n_cols):
            if j % 2 == 0:
                image = next(dataset_iter)
            else:
                image = model.predict(image)
            
            ax = plt.subplot(grid[i,j])
            ax.imshow(image[0] * 0.5 + 0.5)
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.axis("off")
            
            if i == 0 and j % 2 == 0:
                ax.set_title("Real")
            elif i == 0 and j % 2 == 1:
                ax.set_title("Generated")

    plt.show()

In [None]:
def evaluate_cycle(dataset, gen_1, gen_2, n_samples=1):

    dataset_iter = iter(dataset)
    
    n_cols = 3
    n_rows = n_samples
    
    fig, grid = prepare_grid(n_rows, n_cols, wspace=0.1)
    
    for i in range(n_rows):
        for j in range(n_cols):
            if j % 3 == 0:
                image = next(dataset_iter)
            elif j % 3 == 1:
                image = gen_1.predict(image)
            else:
                image = gen_2.predict(image)
            
            ax = plt.subplot(grid[i,j])
            ax.imshow(image[0] * 0.5 + 0.5)
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.axis("off")
            
            if i == 0 and j % 3 == 0:
                ax.set_title("Real")
            elif i == 0 and j % 3 == 1:
                ax.set_title("Generated")
            elif i == 0 and j % 3 == 2:
                ax.set_title("Cycled")

    plt.show()

# Build the Generator

We'll be using a UNET architecture for our CycleGAN. To build our generator, let's first define our `downsample` and `upsample` methods.

The `downsample`, as the name suggests, reduces the 2D dimensions, the width and height, of the image by the stride. The stride is the length of the step the filter takes. Since the stride is 2, the filter is applied to every other pixel, hence reducing the weight and height by 2.

We'll be using an instance normalization instead of batch normalization. As the instance normalization is not standard in the TensorFlow API, we'll use the layer from TensorFlow Add-ons.

In [None]:
def downsample(filters, size, apply_instancenorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

    if apply_instancenorm:
        result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    result.add(tf.keras.layers.LeakyReLU())

    return result

`Upsample` does the opposite of downsample and increases the dimensions of the of the image. `Conv2DTranspose` does basically the opposite of a `Conv2D` layer.

In [None]:
def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))

    result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))

    result.add(tf.keras.layers.ReLU())

    return result

A `Residual Block` is a stack of layers set in such a way that the output of a layer is taken and added to another layer deeper in the block.

In [None]:
def residual_block(input_layer, size=3):
    filters = input_layer.shape[-1]
    
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
    
    block = tf.keras.layers.Conv2D(filters, size, padding='same', use_bias=False, 
                     kernel_initializer=initializer)(input_layer)

    block = tf.keras.layers.ReLU()(block)
    
    block = tf.keras.layers.Conv2D(filters, size, padding='same', use_bias=False, 
                     kernel_initializer=initializer)(block)
    
    block = tf.keras.layers.Add()([block, input_layer])
        
    return block

Let's build our generator!

The generator first downsamples the input image and then upsample while establishing long skip connections. Skip connections are a way to help bypass the vanishing gradient problem by concatenating the output of a layer to multiple layers instead of only one. Here we concatenate the output of the downsample layer to the upsample layer in a symmetrical fashion.

In [None]:
def Generator():
    inputs = tf.keras.layers.Input(shape=[256,256,3])

    # bs = batch size
    down_stack = [
        downsample(64, 4, apply_instancenorm=False), # (bs, 128, 128, 64)
        downsample(128, 4), # (bs, 64, 64, 128)
        downsample(256, 4), # (bs, 32, 32, 256)
        downsample(512, 4), # (bs, 16, 16, 512)
        downsample(512, 4), # (bs, 8, 8, 512)
        downsample(512, 4), # (bs, 4, 4, 512)
        downsample(512, 4), # (bs, 2, 2, 512)
        downsample(512, 4), # (bs, 1, 1, 512)
    ]

    up_stack = [
        upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
        upsample(512, 4), # (bs, 16, 16, 1024)
        upsample(256, 4), # (bs, 32, 32, 512)
        upsample(128, 4), # (bs, 64, 64, 256)
        upsample(64, 4), # (bs, 128, 128, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                  strides=2,
                                  padding='same',
                                  kernel_initializer=initializer,
                                  activation='tanh') # (bs, 256, 256, 3)

    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])
    
    for _ in range(6):
        x = residual_block(x)

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = tf.keras.layers.Concatenate()([x, skip])

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

# Build the Discriminator

The discriminator takes in the input image and classifies it as real or fake (generated). Instead of outputing a single node, the discriminator outputs a smaller 2D image with higher pixel values indicating a real classification and lower values indicating a fake classification.

In [None]:
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')

    x = inp

    down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
    down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
    down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)

    zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
    conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                         kernel_initializer=initializer,
                         use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

    norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(conv)

    leaky_relu = tf.keras.layers.LeakyReLU()(norm1)

    zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    last = tf.keras.layers.Conv2D(1, 4, strides=1,
                         kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

    return tf.keras.Model(inputs=inp, outputs=last)

In [None]:
with strategy.scope():
    monet_generator = Generator() # transforms photos to Monet-esque paintings
    photo_generator = Generator() # transforms Monet paintings to be more like photos

    monet_discriminator = Discriminator() # differentiates real Monet paintings and generated Monet paintings
    photo_discriminator = Discriminator() # differentiates real photos and generated photos

Since our generators are not trained yet, the generated Monet-esque photo does not show what is expected at this point.

In [None]:
display_generated_samples(photo_ds, monet_generator, 2)

# Build the CycleGAN model

We will subclass a `tf.keras.Model` so that we can run `fit()` later to train our model. During the training step, the model transforms a photo to a Monet painting and then back to a photo. The difference between the original photo and the twice-transformed photo is the cycle-consistency loss. We want the original photo and the twice-transformed photo to be similar to one another.

The losses are defined in the next section.

In [None]:
class CycleGan(tf.keras.Model):
    def __init__(
        self,
        monet_generator,
        photo_generator,
        monet_discriminator,
        photo_discriminator,
        lambda_cycle=10,
    ):
        super(CycleGan, self).__init__()
        self.m_gen = monet_generator
        self.p_gen = photo_generator
        self.m_disc = monet_discriminator
        self.p_disc = photo_discriminator
        self.lambda_cycle = lambda_cycle
        
    def compile(
        self,
        m_gen_optimizer,
        p_gen_optimizer,
        m_disc_optimizer,
        p_disc_optimizer,
        gen_loss_fn,
        disc_loss_fn,
        cycle_loss_fn,
        identity_loss_fn
    ):
        super(CycleGan, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
    def train_step(self, batch_data):
        real_monet, real_photo = batch_data
        
        with tf.GradientTape(persistent=True) as tape:
            # photo to monet back to photo
            fake_monet = self.m_gen(real_photo, training=True)
            cycled_photo = self.p_gen(fake_monet, training=True)

            # monet to photo back to monet
            fake_photo = self.p_gen(real_monet, training=True)
            cycled_monet = self.m_gen(fake_photo, training=True)

            # generating itself
            same_monet = self.m_gen(real_monet, training=True)
            same_photo = self.p_gen(real_photo, training=True)

            # discriminator used to check, inputing real images
            disc_real_monet = self.m_disc(real_monet, training=True)
            disc_real_photo = self.p_disc(real_photo, training=True)

            # discriminator used to check, inputing fake images
            disc_fake_monet = self.m_disc(fake_monet, training=True)
            disc_fake_photo = self.p_disc(fake_photo, training=True)

            # evaluates generator loss
            monet_gen_loss = self.gen_loss_fn(disc_fake_monet)
            photo_gen_loss = self.gen_loss_fn(disc_fake_photo)

            # evaluates total cycle consistency loss
            total_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle) + self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle)

            # evaluates total generator loss
            total_monet_gen_loss = monet_gen_loss + total_cycle_loss + self.identity_loss_fn(real_monet, same_monet, self.lambda_cycle)
            total_photo_gen_loss = photo_gen_loss + total_cycle_loss + self.identity_loss_fn(real_photo, same_photo, self.lambda_cycle)

            # evaluates discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)

        # Calculate the gradients for generator and discriminator
        monet_generator_gradients = tape.gradient(total_monet_gen_loss,
                                                  self.m_gen.trainable_variables)
        photo_generator_gradients = tape.gradient(total_photo_gen_loss,
                                                  self.p_gen.trainable_variables)

        monet_discriminator_gradients = tape.gradient(monet_disc_loss,
                                                      self.m_disc.trainable_variables)
        photo_discriminator_gradients = tape.gradient(photo_disc_loss,
                                                      self.p_disc.trainable_variables)

        # Apply the gradients to the optimizer
        self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients,
                                                 self.m_gen.trainable_variables))

        self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients,
                                                 self.p_gen.trainable_variables))

        self.m_disc_optimizer.apply_gradients(zip(monet_discriminator_gradients,
                                                  self.m_disc.trainable_variables))

        self.p_disc_optimizer.apply_gradients(zip(photo_discriminator_gradients,
                                                  self.p_disc.trainable_variables))
        
        return {
            "monet_gen_loss": total_monet_gen_loss,
            "photo_gen_loss": total_photo_gen_loss,
            "monet_disc_loss": monet_disc_loss,
            "photo_disc_loss": photo_disc_loss
        }

# Define loss functions

The discriminator loss function below compares real images to a matrix of 1s and fake images to a matrix of 0s. The perfect discriminator will output all 1s for real images and all 0s for fake images. The discriminator loss outputs the average of the real and generated loss.

In [None]:
with strategy.scope():
    def discriminator_loss(real, generated):
        real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)

        generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)

        total_disc_loss = real_loss + generated_loss

        return total_disc_loss * 0.5

The generator wants to fool the discriminator into thinking the generated image is real. The perfect generator will have the discriminator output only 1s. Thus, it compares the generated image to a matrix of 1s to find the loss.

In [None]:
with strategy.scope():
    def generator_loss(generated):
        return tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(generated), generated)

We want our original photo and the twice transformed photo to be similar to one another. Thus, we can calculate the cycle consistency loss be finding the average of their difference.

In [None]:
with strategy.scope():
    def calc_cycle_loss(real_image, cycled_image, LAMBDA):
        loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

        return LAMBDA * loss1

The identity loss compares the image with its generator (i.e. photo with photo generator). If given a photo as input, we want it to generate the same image as the image was originally a photo. The identity loss compares the input with the output of the generator.

In [None]:
with strategy.scope():
    def identity_loss(real_image, same_image, LAMBDA):
        loss = tf.reduce_mean(tf.abs(real_image - same_image))
        return LAMBDA * 0.5 * loss

# Weights and Biases Callback

In [None]:
class WandBMonitor(tf.keras.callbacks.Callback):
    """
    A tensorflow and Weights and Biases callback to save images to WandB.
    """
    
    def __init__(self, n_images=1):
        self.n_images = n_images
        
    def get_wandb_image(self, images):
        images = (images * 127.5 + 127.5)
        images = wandb.Image(images)
        return images
        
    
    def on_epoch_end(self, epoch, logs):
        # Real to Monet Generated Images
        for i, image in enumerate(photo_ds_eval.take(self.n_images)):
            monet_pred = monet_generator.predict(image)
            monet_wandb = self.get_wandb_image(monet_pred)
        
        # Monet to Real Generated Images
        for i, image in enumerate(monet_ds_eval.take(self.n_images)):
            real_pred = photo_generator.predict(image)
            real_wandb = self.get_wandb_image(real_pred)
            
        # Log losses
        wandb.log({
            "real_to_monet": monet_wandb,
            "monet_to_real": real_wandb,
            "monet_gen_loss": tf.reduce_mean(logs["monet_gen_loss"]),
            "photo_gen_loss": tf.reduce_mean(logs["photo_gen_loss"]),
            "monet_disc_loss": tf.reduce_mean(logs["monet_disc_loss"]),
            "photo_disc_loss": tf.reduce_mean(logs["photo_disc_loss"]),
        }, step=epoch)
            

# Train the CycleGAN

Let's compile our model. Since we used `tf.keras.Model` to build our CycleGAN, we can just ude the `fit` function to train our model.

In [None]:
with strategy.scope():
    monet_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    monet_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
with strategy.scope():
    cycle_gan_model = CycleGan(
        monet_generator, photo_generator, monet_discriminator, photo_discriminator
    )

    cycle_gan_model.compile(
        m_gen_optimizer = monet_generator_optimizer,
        p_gen_optimizer = photo_generator_optimizer,
        m_disc_optimizer = monet_discriminator_optimizer,
        p_disc_optimizer = photo_discriminator_optimizer,
        gen_loss_fn = generator_loss,
        disc_loss_fn = discriminator_loss,
        cycle_loss_fn = calc_cycle_loss,
        identity_loss_fn = identity_loss
    )

In [None]:
cycle_gan_model.fit(
    tf.data.Dataset.zip((monet_ds, photo_ds)),
    epochs=EPOCHS,
    callbacks=[WandBMonitor()],
    steps_per_epoch=(max(n_monet_samples, n_photo_samples)//BATCH_SIZE)
)

# Visualize our Monet-esque photos

In [None]:
_, ax = plt.subplots(5, 2, figsize=(12, 12))
for i, img in enumerate(photo_ds.take(5)):
    prediction = monet_generator(img, training=False)[0].numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

    ax[i, 0].imshow(img)
    ax[i, 1].imshow(prediction)
    ax[i, 0].set_title("Input Photo")
    ax[i, 1].set_title("Monet-esque")
    ax[i, 0].axis("off")
    ax[i, 1].axis("off")
plt.show()

# Visualize Outputs

Photo (Real) -> Monet (Generated) - Photo (Generated)

In [None]:
evaluate_cycle(photo_ds_eval, monet_generator, photo_generator, n_samples=2)

Monet (Real) -> Photo (Generated) -> Monet (Generated)

In [None]:
evaluate_cycle(monet_ds_eval, photo_generator, monet_generator, n_samples=2)

# Create submission file

In [None]:
import PIL
! mkdir ../images

In [None]:
%%time

i = 1
for img in photo_ds_eval:
    prediction = monet_generator(img, training=False)[0].numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    im = PIL.Image.fromarray(prediction)
    im.save("../images/" + str(i) + ".jpg")
    i += 1

In [None]:
import shutil
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")

# Save Models

We will now save our models locally, then upload them to WandB

In [None]:
monet_generator.save("monet_generator.h5")
photo_generator.save("photo_generator.h5")
monet_discriminator.save("monet_discriminator.h5")
photo_discriminator.save("photo_discriminator.h5")

In [None]:
wandb.save("monet_generator.h5")
wandb.save("photo_generator.h5")
wandb.save("monet_discriminator.h5")
wandb.save("photo_discriminator.h5")