## In this notebook we are going to implement the new model designed by the creators of CycleGAN, [CUT](https://github.com/taesungp/contrastive-unpaired-translation/) (Contrastive Unpaired Translation), using ResNet as backbone. This model is suposed to be faster to train than CycleGAN getting even better results. As you will be able to see this is not the case for my code so there must be some error somewhere I can't find. If anyone spots it would be nice! You can also check the full code with TF training for easier debug on my github [repository](https://github.com/Brechard/Simple-CUT-TF)

## We also take some function from this [notebook](https://www.kaggle.com/amyjang/monet-cyclegan-tutorial) to load the data.


In [None]:
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt
import numpy as np
import datetime
import PIL
! mkdir ../images

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

AUTOTUNE = tf.data.experimental.AUTOTUNE
    
print(tf.__version__)

In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path()

MONET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))
print('Monet TFRecord Files:', len(MONET_FILENAMES))

PHOTO_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/photo_tfrec/*.tfrec'))
print('Photo TFRecord Files:', len(PHOTO_FILENAMES))

## Define the functions used to load each photo and apply some data augmentation

In [None]:
IMAGE_SIZE = (256, 256, 3)
BATCH_SIZE = 1


def decode_image(image, use_augmentation):
    image = tf.image.decode_jpeg(image, channels=3)
    if use_augmentation:
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_flip_up_down(image)
        image = tf.image.random_contrast(image, 0.9, 1.1)
        image = tf.image.random_brightness(image, 0.1)
        image = tf.image.random_crop(image, size=IMAGE_SIZE)
    # Normalize the pixel values in the range [-1, 1]
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, IMAGE_SIZE)
    return image


def read_tfrecord(example, use_augmentation):
    tfrecord_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'], use_augmentation)
    return image


def load_dataset(filenames, use_augmentation):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(lambda x: read_tfrecord(x, use_augmentation), num_parallel_calls=AUTOTUNE)
    dataset = dataset.cache().shuffle(2020).batch(BATCH_SIZE, drop_remainder=True)
    return dataset.prefetch(tf.data.experimental.AUTOTUNE)

def display_samples(ds, n_row=4, n_col=7):
    plt.figure(figsize=(20, int(20 * n_row / n_col)))
    for i, photo in enumerate(ds.take(n_row * n_col)):
        plt.subplot(n_row, n_col, i + 1)
        plt.imshow(photo[0] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()

monet_ds = load_dataset(MONET_FILENAMES, True)
photo_ds = load_dataset(PHOTO_FILENAMES, True)

In [None]:
display_samples(photo_ds)

In [None]:
display_samples(monet_ds)

## We create a common function to plot the history of the trained models

In [None]:
def plot_history_losses(history, suptitle):
    dict_history = history.history
    n_cols = 4
    n_rows = len(dict_history) // n_cols
    if len(dict_history) % n_cols != 0:
        n_rows += 1
    fig = plt.figure(figsize=(20, 10))
    i = 1
    for key, items in dict_history.items():
        ax = fig.add_subplot(n_rows, n_cols, i)
        ax.plot(items)
        ax.set_title(key)
        i += 1
    plt.suptitle(suptitle)
    plt.show()


## The following functions are used to create the ResNet structure. This functions are simple modifications of the code for CycleGAN by [keras] (https://keras.io/examples/generative/cyclegan/).

## The main difference is that we use antialiased Convolutions. Using a modification of he code found [here](https://github.com/adobe/antialiased-cnns/issues/10) .

In [None]:

KERNEL_INIT = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
GAMMA_INIT = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)


def get_filter(filt_size):
    filter_ = np.array([1., ])
    if filt_size == 2:
        filter_ = np.array([1., 1.])
    elif filt_size == 3:
        filter_ = np.array([1., 2., 1.])
    elif filt_size == 4:
        filter_ = np.array([1., 3., 3., 1.])
    elif filt_size == 5:
        filter_ = np.array([1., 4., 6., 4., 1.])
    elif filt_size == 6:
        filter_ = np.array([1., 5., 10., 10., 5., 1.])
    elif filt_size == 7:
        filter_ = np.array([1., 6., 15., 20., 15., 6., 1.])
    filter_ = filter_[:, None] * filter_[None, :]
    filter_ = filter_ / np.sum(filter_)
    return filter_


class BlurPool(tf.keras.layers.Layer):
    def __init__(self, filt_size=3, stride=2):
        super(BlurPool, self).__init__()
        self.strides = (stride, stride)
        self.filt_size = filt_size

        self.filter = get_filter(filt_size)
        self.pad_layer = ReflectionPadding2D()

    def compute_output_shape(self, input_shape):
        height = input_shape[1] // self.strides[0]
        width = input_shape[2] // self.strides[1]
        channels = input_shape[3]
        return input_shape[0], height, width, channels

    def call(self, x):
        filter_ = self.filter
        filter_ = np.tile(filter_[:, :, None, None], (1, 1, tf.keras.backend.int_shape(x)[-1], 1))
        filter_ = tf.keras.backend.constant(filter_, dtype=tf.keras.backend.floatx())
        x = self.pad_layer(x)
        x = tf.keras.backend.depthwise_conv2d(x, filter_, strides=self.strides, padding='valid')
        return x
    

class Upsample(tf.keras.layers.Layer):
    def __init__(self, filt_size=4, stride=2):
        super(Upsample, self).__init__()
        self.filt_size = filt_size
        self.filt_odd = np.mod(filt_size, 2) == 1
        self.pad_size = int((filt_size - 1) / 2)
        self.strides = (stride, stride)
        self.off = int((stride - 1) / 2.)

        self.filter = get_filter(filt_size=self.filt_size) * (stride ** 2)

    def compute_output_shape(self, input_shape):
        height = input_shape[1] * self.strides[0]
        width = input_shape[2] * self.strides[1]
        channels = input_shape[3]
        return BATCH_SIZE, height, width, channels

    def call(self, x):
        filter_ = self.filter
        filter_ = np.tile(filter_[:, :, None, None], (1, 1, 1, tf.keras.backend.int_shape(x)[-1]))
        filter_ = tf.keras.backend.constant(filter_, dtype=tf.keras.backend.floatx())
        ret_val = tf.nn.conv2d_transpose(x, filter_,
                                         output_shape=self.compute_output_shape(tf.keras.backend.int_shape(x)),
                                         strides=self.strides)
        return ret_val


class ReflectionPadding2D(tf.keras.layers.Layer):
    """Implements Reflection Padding as a layer.

    Args:
        padding(tuple): Amount of padding for the
        spatial dimensions.

    Returns:
        A padded tensor with the same type as the input tensor.
    """

    def __init__(self, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def call(self, input_tensor, mask=None):
        padding_width, padding_height = self.padding
        padding_tensor = [
            [0, 0],
            [padding_height, padding_height],
            [padding_width, padding_width],
            [0, 0],
        ]
        return tf.pad(input_tensor, padding_tensor, mode="REFLECT")


def residual_block(
        input_tensor,
        activation,
        kernel_initializer=KERNEL_INIT,
        kernel_size=(3, 3),
        strides=(1, 1),
        padding="valid",
        gamma_initializer=GAMMA_INIT,
        use_bias=False,
        res_block_n=None
):
    dim = input_tensor.shape[-1]
    input_tensor = layers.Input(input_tensor.shape[1:])

    x = ReflectionPadding2D()(input_tensor)
    x = layers.Conv2D(
        dim,
        kernel_size,
        strides=strides,
        kernel_initializer=kernel_initializer,
        padding=padding,
        use_bias=use_bias,
    )(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
    x = activation(x)

    x = ReflectionPadding2D()(x)
    x = layers.Conv2D(
        dim,
        kernel_size,
        strides=strides,
        kernel_initializer=kernel_initializer,
        padding=padding,
        use_bias=use_bias,
    )(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
    x = layers.add([input_tensor, x])
    return tf.keras.models.Model(input_tensor, x, name=f'residual_block_{res_block_n}')


def downsample(
        input_tensor,
        filters,
        activation,
        kernel_initializer=KERNEL_INIT,
        kernel_size=(3, 3),
        padding="same",
        gamma_initializer=GAMMA_INIT,
        use_bias=True,
        norm=True        
):
#     TODO: Add no antialias option
    x = layers.Conv2D(
        filters,
        kernel_size,
        strides=(1, 1),
        kernel_initializer=kernel_initializer,
        padding=padding,
        use_bias=use_bias,
    )(input_tensor)
    if norm:
        x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
    if activation:
        x = activation(x)
    x = BlurPool()(x)
    return x


def upsample(
        input_tensor,
        filters,
        activation,
        kernel_size=(3, 3),
        strides=(1, 1),
        padding="same",
        kernel_initializer=KERNEL_INIT,
        gamma_initializer=GAMMA_INIT,
        use_bias=True
):
    x = Upsample()(input_tensor)
    x = layers.Conv2DTranspose(
        filters,
        kernel_size,
        strides=strides,
        padding=padding,
        kernel_initializer=kernel_initializer,
        use_bias=use_bias,
    )(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
    if activation:
        x = activation(x)
    return x

## The generator is composed of 2 (+1, input layer) layers. Encoder and Decoder. The difference is important because we only calculate features for the NCE loss from the Encoder.

In [None]:
def get_resnet_encoder(filters=64, num_downsampling_blocks=2, num_residual_blocks=4,
                       gamma_initializer=GAMMA_INIT, name='Encoder'):
    img_input = layers.Input(shape=IMAGE_SIZE, name=name + "_img_input")
    x = ReflectionPadding2D(padding=(3, 3))(img_input)
    x = layers.Conv2D(filters, (7, 7), kernel_initializer=KERNEL_INIT, use_bias=False)(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
    x = layers.Activation("relu")(x)

    # Downsampling
    for i in range(num_downsampling_blocks):
        filters *= 2
        x = downsample(x, filters=filters, activation=layers.Activation("relu"))

    # Residual blocks
    for i in range(num_residual_blocks):
        x = residual_block(x, activation=layers.Activation("relu"), res_block_n=i)(x)

    return tf.keras.models.Model(img_input, x, name=name)


def get_resnet_decoder(
        input_shape,
        filters=64,
        num_upsample_blocks=2,
        name='Decoder',
):
    img_input = layers.Input(shape=input_shape, name=name + "_img_input")
    x = img_input
    # Upsampling
    for i in range(num_upsample_blocks):
        filters //= 2
        x = upsample(x, filters, activation=layers.Activation("relu"))

    # Final block
    x = ReflectionPadding2D(padding=(3, 3))(x)
    x = layers.Conv2D(3, (7, 7), padding="valid")(x)
    x = layers.Activation("tanh")(x)

    return tf.keras.models.Model(img_input, x, name=name)


def get_generator(num_downsampling_blocks=2, num_residual_blocks=4, num_upsample_blocks=2):
    encoder = get_resnet_encoder(num_downsampling_blocks=num_downsampling_blocks, num_residual_blocks=num_residual_blocks)
    decoder = get_resnet_decoder(encoder.output.shape[1:], num_upsample_blocks=num_upsample_blocks)

    img_input = layers.Input(shape=IMAGE_SIZE, name="generator_img_input")
    x = img_input
    x = encoder(x)
    x = decoder(x)
    return tf.keras.models.Model(img_input, x, name='Generator')

## The discrimator is the same as the CycleGan in the [keras repository](https://keras.io/examples/generative/cyclegan/).

In [None]:
def get_discriminator(filters=64, kernel_initializer=KERNEL_INIT, num_downsampling=3):
    img_input = layers.Input(shape=IMAGE_SIZE, name="discriminator_img_input")
    x = downsample(
        img_input,
        filters=filters,
        activation=layers.LeakyReLU(0.2),
        kernel_size=(4, 4),
        norm=False
    )

    num_filters = filters
    for num_downsample_block in range(1, num_downsampling):
        num_filters *= 2
        x = downsample(
            x,
            filters=num_filters,
            activation=layers.LeakyReLU(0.2),
            kernel_size=(4, 4),
        )

    x = layers.Conv2D(num_filters * 2, (4, 4), strides=(1, 1), kernel_initializer=KERNEL_INIT, padding='same', use_bias=True)(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=GAMMA_INIT)(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(1, (4, 4), strides=(1, 1), padding="same", kernel_initializer=kernel_initializer)(x)

    model = tf.keras.models.Model(inputs=img_input, outputs=x, name='Discriminator')
    return model

## To extract the features of the encoder model we pass the image block by block and save the output in a list that will be returned. The feature layers correspond to the output of the first padding layer, the two 2D conv layers and the first and fifth ResNet blocks.

In [None]:
feature_layers = [0, 4, 8, 12, 16]
def get_features(encoder, x):
    """ Extract the features generated by each block of the encoder model """
    x = encoder.layers[0](x)  # This is just the input layer
    features = []
    for i, layer in enumerate(encoder.layers[1:]):
        x = layer(x)
        if i in feature_layers:
            features.append(x)
    return features


def normalize(x):
    norm = tf.pow(tf.reduce_sum(tf.math.pow(x, 2), axis=1, keepdims=True), 1 / 2)
    out = tf.divide(x, norm + 1e-7)
    return out


## There is one MLP created for each feature that we will extract.

In [None]:
class MLP:
    def __init__(self, dimension=256, num_patches=64):
        self.dim = dimension
        self.num_patches = num_patches
        self.n_mlps = 0

    def create_mlp(self, feats):
        for mlp_id, feat in enumerate(feats):
            mlp = tf.keras.Sequential([tf.keras.layers.Dense(self.dim, input_dim=feat.shape[-1]),
                                       tf.keras.layers.ReLU(),
                                       tf.keras.layers.Dense(self.dim)])
            setattr(self, f'mlp_{mlp_id}', mlp)
        self.n_mlps = mlp_id + 1

    def forward(self, features, patch_ids=None, use_mlp=True):
        """
        Forward the features through their corresponding Multi Layer Perceptron.
        If the Patch IDs are not provided it means that it is the first time being used with this Batch. What we do
        then is randomly select "num_patches" of patches to process and return the IDs so that in the second run we
        execute the same patches.
        Args:
            features: Features extracted from the encoder. Shape must be [Batch size, Heigh, Width, Channels]
            patch_ids: IDs of the paths to execute. If None, they will be randomly chosen.

        Returns:
            Processed features and the patch ids.
        """
        if self.n_mlps == 0 and use_mlp:
            self.create_mlp(features)
        return_ids, return_feats = [], []
        for feat_id, feat in enumerate(features):
            B, H, W, C = feat.shape
            feat_reshape = tf.reshape(feat, (B, -1, C))

            if patch_ids is None:
                patch_id = tf.random.shuffle(tf.range(H * W))[:self.num_patches]
            else:
                patch_id = patch_ids[feat_id]

            x_sample = tf.reshape(tf.gather(feat_reshape, patch_id, axis=1), (-1, C))  # reshape(-1, x.shape[1])
            mlp = getattr(self, f'mlp_{feat_id}')
            if use_mlp:
                x_sample = mlp(x_sample)
            x_sample = normalize(x_sample)
            return_ids.append(patch_id)
            return_feats.append(x_sample)

        return return_feats, return_ids

## The losses. Simple translation to TF from the original [CUT](https://github.com/taesungp/contrastive-unpaired-translation/blob/master/models/patchnce.py) code.

In [None]:
def patch_nce_loss(feat_src, feat_tgt, nce_T):
    n_patches, size = feat_src.shape

    l_pos = tf.matmul(tf.reshape(feat_src, (n_patches, 1, -1)), tf.reshape(feat_tgt, (n_patches, -1, 1)))
    l_pos = tf.squeeze(l_pos, 1)

    # reshape features to batch size

    l_neg = tf.matmul(feat_src, tf.transpose(feat_tgt))

    # diagonal entries are similarity between same features, and hence meaningless.
    # Since there is no masked_fill method in tensorflow we will multiply by a unitary matrix with almost zero values in the diagonal
    diagonal = tf.ones((n_patches, n_patches)) - tf.eye(n_patches) * 0.9999999
    l_neg = l_neg * diagonal

    out = tf.concat((l_pos, l_neg), axis=1) / nce_T
    target = [[1] + [0.0] * l_neg.shape[1] for i in range(out.shape[0])]
#     tf.print(target)
    loss = tf.keras.losses.CategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE, from_logits=True)(target, out)
#     tf.print('entropy loss:', loss)

    return loss


def gan_loss(y, target_is_real, target_real_label=1.0, target_fake_label=0.0):
    if target_is_real:
        return tf.keras.losses.MSE(y, target_real_label)

    return tf.keras.losses.MSE(y, target_fake_label)


## Definition of the CUT/FastCUT model.

In [None]:
class CUT(tf.keras.Model):
    def __init__(
            self,
            generator,
            discriminator,
            mlp,
            use_mlp=True,
            lambda_NCE=1.0,
            lambda_GAN=1.0,
            fast=False,
            nce_T=0.07,
            use_defaults=True
    ):
        super(CUT, self).__init__()
        self.generator = generator
        self.discriminator = discriminator
        self.input_layer, self.encoder = generator.layers[:2]
        self.mlp = mlp
        self.fast = fast
        self.use_mlp = True if use_defaults else use_mlp
        self.lambda_NCE = (10.0 if fast else 1.0) if use_defaults else lambda_NCE
        self.lambda_GAN = 1.0 if use_defaults else lambda_GAN
        self.nce_T = 0.07 if use_defaults else nce_T

    def calculate_NCE_loss(self, src, tgt):
        encoded_features_src = get_features(self.encoder, self.input_layer(src))
        encoded_features_tgt = get_features(self.encoder, self.input_layer(tgt))

        mlp_features_src, feat_ids = self.mlp.forward(encoded_features_src)
        mlp_features_tgt, _ = self.mlp.forward(encoded_features_tgt, feat_ids, self.use_mlp)

        total_nce_loss = 0
        losses = []
        for feat_src, feat_tgt in zip(mlp_features_src, mlp_features_tgt):
            nce_loss = self.patch_nce_loss_fn(feat_src, feat_tgt, self.nce_T) * self.lambda_NCE
            losses.append(tf.reduce_mean(nce_loss))
            total_nce_loss += tf.reduce_mean(nce_loss)
        return total_nce_loss / len(encoded_features_tgt), losses

    def compile(
            self,
            generator_optimizer,
            discriminator_optimizer,
            mlp_optimizer,
            discriminator_loss_fn,
            patch_nce_loss_fn
    ):
        super(CUT, self).compile()
        self.generator_optimizer = generator_optimizer
        self.discriminator_optimizer = discriminator_optimizer
        self.mlp_optimizer = mlp_optimizer
        self.discriminator_loss_fn = discriminator_loss_fn
        self.patch_nce_loss_fn = patch_nce_loss_fn
        
    def data_dependent_initialize(self, source_ds, target_ds):
        """
        The feature network MLP is defined in terms of the shape of the intermediate, extracted
        features of the encoder portion of the Generator. Because of this, the weights of MLP are
        initialized at the first feedforward pass with some input images.
        """
        with tf.GradientTape(persistent=True) as tape:
            for source, target in zip(source_ds.take(1), target_ds.take(1)):
                fake_target = self.generator(source, training=True)

                pred_fake = self.discriminator(fake_target, training=True)
                pred_real = self.discriminator(target, training=True)

                loss_G_GAN = tf.reduce_mean(
                    self.discriminator_loss_fn(pred_fake, True)) * self.lambda_GAN if self.lambda_GAN > 0 else 0

                total_loss_nce, losses = loss_nce, losses = self.calculate_NCE_loss(source, fake_target) if self.lambda_NCE > 0 else 0
                loss_nce_identity, losses_identity = 0, [0 for i in losses]
                if not self.fast and self.lambda_NCE > 0:
                    fake_identity = self.generator(target, training=True)
                    loss_nce_identity, losses_identity = self.calculate_NCE_loss(target, fake_identity)
                    total_loss_nce = (total_loss_nce + loss_nce_identity) / 2

                loss_G = loss_G_GAN + total_loss_nce

                loss_D_fake = tf.reduce_mean(self.discriminator_loss_fn(pred_fake, False))
                loss_D_real = tf.reduce_mean(self.discriminator_loss_fn(pred_real, True))
                loss_D = (loss_D_fake + loss_D_real) / 2

                for feat_id, loss in enumerate(losses):
                    mlp = getattr(self.mlp, f'mlp_{feat_id}')
                    mlp_gradients = tape.gradient(loss, mlp.trainable_variables)
                    self.mlp_optimizer.apply_gradients(zip(mlp_gradients, mlp.trainable_variables))


    def train_step(self, batch_data):
        # source is PHOTO and target is MONET
        source, target = batch_data

        # For FastCUT, we need to calculate different
        # kinds of losses for the generators and discriminators.
        # We will perform the following steps here:
        #
        # 1. Pass source image through the generator to calculate the fake target image.
        # 2. Call the discriminator with the fake target and real target.
        # 3. Calculate the generator loss (adversarial + NCE).
        # 4. Calculate the discriminator loss.
        # 5. Update the weights of the generators
        # 6. Update the weights of the discriminators
        # 7. Return the losses in a dictionary

        with tf.GradientTape(persistent=True) as tape:
            # Photo to fake Monet
            fake_target = self.generator(source, training=True)

            pred_fake = self.discriminator(fake_target, training=True)
            pred_real = self.discriminator(target, training=True)

            ###########################
            ##### TRAIN GENERATOR #####
            ###########################

            # First, G(photo) should fake the discriminator
            loss_G_GAN = tf.reduce_mean(
                self.discriminator_loss_fn(pred_fake, True)) * self.lambda_GAN if self.lambda_GAN > 0 else 0

            total_loss_nce, losses = loss_nce, losses = self.calculate_NCE_loss(source, fake_target) if self.lambda_NCE > 0 else 0
            loss_nce_identity, losses_identity = 0, [0 for i in losses]
            if not self.fast and self.lambda_NCE > 0:
                fake_identity = self.generator(target, training=True)
                loss_nce_identity, losses_identity = self.calculate_NCE_loss(target, fake_identity)
                total_loss_nce = (total_loss_nce + loss_nce_identity) / 2

            loss_G = loss_G_GAN + total_loss_nce

            ###########################
            ### TRAIN DISCRIMINATOR ###
            ###########################

            loss_D_fake = tf.reduce_mean(self.discriminator_loss_fn(pred_fake, False))
            loss_D_real = tf.reduce_mean(self.discriminator_loss_fn(pred_real, True))
            loss_D = (loss_D_fake + loss_D_real) / 2

        # Get the gradients for the generators
        generator_gradients = tape.gradient(loss_G, self.generator.trainable_variables)
        discriminator_gradients = tape.gradient(loss_D, self.discriminator.trainable_variables)
        if self.use_mlp:
            mlp_gradients = []
            for feat_id, loss in enumerate(losses):
                mlp = getattr(self.mlp, f'mlp_{feat_id}')
                mlp_gradients = tape.gradient(loss, mlp.trainable_variables)
                self.mlp_optimizer.apply_gradients(zip(mlp_gradients, mlp.trainable_variables))

        # Update the weights
        self.generator_optimizer.apply_gradients(zip(generator_gradients, self.generator.trainable_variables))
        self.discriminator_optimizer.apply_gradients(zip(discriminator_gradients, self.discriminator.trainable_variables))

        return {
            "G_loss": loss_G,
            "D_loss": loss_D,
            "loss_G_GAN": loss_G_GAN,
            "total_loss_nce": total_loss_nce,
            "loss_nce": loss_nce,
            "loss_nce_idt": loss_nce_identity,
            'loss_D_fake': loss_D_fake,
            'loss_D_real': loss_D_real
        }

    def call(self, source):
        return self.generator(source, training=False)


## Callback to print images every X amount of epochs

In [None]:
class GANMonitor(tf.keras.callbacks.Callback):
    """A callback to generate and save images after each epoch"""

    def __init__(self, num_img=2, every_n_epoch=5):
        self.num_img = num_img
        self.every_n_epoch = every_n_epoch

    def on_epoch_end(self, epoch, logs=None):
        if epoch % self.every_n_epoch != 0:
            return
        _, ax = plt.subplots(self.num_img, 4, figsize=(20, 10))
        [ax[0, i].set_title(title) for i, title in enumerate(["Source", "Fake target", "Target", "Identity target"])]
        for i, (source, target) in enumerate(zip(photo_ds.take(self.num_img), monet_ds.take(self.num_img))):
            prediction = cut_mlp.generator(source)[0].numpy()
            prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
            source = (source[0] * 127.5 + 127.5).numpy().astype(np.uint8)

            idt_target = cut_mlp.generator(target)[0].numpy()
            idt_target = (idt_target * 127.5 + 127.5).astype(np.uint8)
            target = (target[0] * 127.5 + 127.5).numpy().astype(np.uint8)

            ax[i, 0].imshow(source)
            ax[i, 1].imshow(prediction)
            ax[i, 2].imshow(target)
            ax[i, 3].imshow(idt_target)

            [ax[i, j].axis("off") for j in range(4)]
            del prediction
            del source
            del target
            del idt_target


        plt.show()
        plt.close()

In [None]:
nce_T = 0.07
plotter = GANMonitor()

def get_model(num_residual_blocks, use_mlp, fast):
    generator = get_generator(num_residual_blocks=num_residual_blocks)
    discriminator = get_discriminator()
    mlp_network = MLP()

    # Create CUT model
    model = CUT(generator=generator, discriminator=discriminator, mlp=mlp_network, use_mlp=use_mlp, fast=fast)

    # Compile the model
    model.compile(
        generator_optimizer=tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
        discriminator_optimizer=tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
        mlp_optimizer=tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
        discriminator_loss_fn=gan_loss,
        patch_nce_loss_fn=patch_nce_loss
    )
    return model

In [None]:
cut_mlp = get_model(num_residual_blocks=6, use_mlp=True, fast=False)
cut_mlp.data_dependent_initialize(photo_ds, monet_ds)


In [None]:
for layer_id, layer in enumerate(cut_mlp.generator.layers[1].layers[1:]):
    if layer_id in feature_layers:
        print(f'Feature layer {layer}')

In [None]:
history_cut_mlp = cut_mlp.fit(
    tf.data.Dataset.zip((photo_ds, monet_ds)),
    epochs=50,
    callbacks=[plotter],
)

In [None]:
n_images = 2
_, ax = plt.subplots(n_images, 4, figsize=(20, 10))
[ax[0, i].set_title(title) for i, title in enumerate(["Source", "Fake target", "Target", "Identity target"])]
for i, (source, target) in enumerate(zip(photo_ds.take(n_images), monet_ds.take(n_images))):
    prediction_ = cut_mlp(source)[0].numpy()
    prediction = (prediction_ * 127.5 + 127.5).astype(np.uint8)
    source = (source[0] * 127.5 + 127.5).numpy().astype(np.uint8)

    idt_target = cut_mlp(target)[0].numpy()
    idt_target = (idt_target * 127.5 + 127.5).astype(np.uint8)
    target = (target[0] * 127.5 + 127.5).numpy().astype(np.uint8)

    ax[i, 0].imshow(source)
    ax[i, 1].imshow(prediction)
    ax[i, 2].imshow(target)
    ax[i, 3].imshow(idt_target)
    [ax[i, j].axis("off") for j in range(4)]

plt.show()
plt.close()

In [None]:
i = 1
for img in photo_ds:
    prediction = cut_mlp(img)[0].numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    im = PIL.Image.fromarray(prediction)
    im.save("../images/" + str(i) + ".jpg")
    i += 1

In [None]:
import shutil
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")
