# Contrastive Unpaired Translation
This Notebook implements the contrastive unpaired translation (CUT) of photos into Monet-style paintings.
More information about CUT can be found [here](https://taesung.me/ContrastiveUnpairedTranslation/) (videos, GitHub, link to paper, etc.).
The code is this notebook is taken from [this GitHub repository](https://github.com/cryu854/CUT), which contains a TensorFlow 2 implementation of CUT.
The last two cells to store the images is based on https://www.kaggle.com/tpothjuan/cyclegan-with-gpu-with-explanations (Maybe remove this last sentence if the 2nd link has exactly the same way of generating and storing the final images).

**Notes:**
* Run this notebook on a GPU.
* Current speeds: (TBA)

**Plans:**
* Early stopping (but this is maybe not logical because the losses do not seem to stabalise between epochs)
* Adapting the number of resnet layers
* Choosing the best learning rates
* Data-augmentation: more images, maybe zooming in.

**Implemented:**
* Data-augmentation: horizontally flipped images, removed two circular images.
* Anti-aliasing: part of the code from the TensorFlow 2 implementation that this notebook is based on. Increased the score slightly (by 1 point), but also increased the running time by 30%. Did not remove the checkerboard pattern, although that is why we implemented it.
* Validation data: assessing the performance of the model on unseen data shows how well it generalises.

In [None]:
# Maximum line-length (79) including the sharp and space for PEP-8:
# 34567890123456789012345678901234567890123456789012345678901234567890123456789
# Remember that the maximum line length for comments is 72.

#### Customisation parameters #### Set these to what you prefer!

# True for 'monet-jpg-improved', False for 'gan-getting-started':
augmented_input = True
# True to enable anti-alias, False to disable anti-alias:
anti_alias = True
# Max number of epochs:
max_epochs = 15
# Threshold of run-time in seconds above which no new epochs start:
max_train_time = 14400


In [None]:
#### Imports ####
# Kaggle:
from kaggle_datasets import KaggleDatasets

# Basic:
import time
import matplotlib.pyplot as plt
import numpy as np

# Tensorflow:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (Input, Dense, Lambda, Layer, Conv2D,
    Conv2DTranspose, BatchNormalization, Activation)
# I think these can be removed:
# from tensorflow import keras
# from tensorflow.keras import layers

# Outputting the generated images:
import os
import shutil
from PIL import Image

# Constants:
AUTOTUNE = tf.data.experimental.AUTOTUNE


# Loading the data

We use the JPGs for training. This has to be used in combination with the GPU, unless the code is changed a lot further.


In [None]:
@tf.function
def load_image(image_file, image_size=None, data_augmentation=False):
    """ Load the image file."""
    image = tf.io.read_file(image_file)
    image = tf.image.decode_png(image)
    image = (tf.cast(image, tf.float32) / 127.5) - 1.0

    if data_augmentation:
        image = tf.image.random_flip_left_right(image)
    if image_size is not None:
        image = tf.image.resize(image, size=(image_size[0], image_size[1]))
    if tf.shape(image)[-1] == 1:
        image = tf.tile(image, [1,1,3])

    return  image

def create_dataset(src_folder, tar_folder, batch_size):
    """ Create tf.data.Dataset.
    Input:
    - src_folder, tar_folder: (String) folder path of training data and
      their targets
    - batch_size: the batch size for the generated dataset
    
    Output:
    - train_dataset: (tf.data.Dataset) zipped dataset of 80 percent of
      the contents of src_folder and tar_folder
    - val_dataset: (tf.data.Dataset) zipped dataset of 20 percent of the
      contents of src_ and tar_folder, as validation data
    """
    # Create the shuffled datasets for the source and target:
    src_dataset = tf.data.Dataset.list_files(src_folder+'/*.jpg', shuffle=False)
    tar_dataset = tf.data.Dataset.list_files(tar_folder+'/*.jpg', shuffle=False)
    
    # Calculate size of training set for src and tar:
    src_dataset_size = tf.data.Dataset.cardinality(src_dataset).numpy()
    tar_dataset_size = tf.data.Dataset.cardinality(tar_dataset).numpy()
    
    train_src_size = int(0.8 * src_dataset_size)
    train_tar_size = int(0.8 * tar_dataset_size)

    # Shuffle with reshuffle=False, so order stays the same each time:
    src_dataset = src_dataset.shuffle(src_dataset_size,
                                      reshuffle_each_iteration=False)
    tar_dataset = tar_dataset.shuffle(tar_dataset_size,
                                      reshuffle_each_iteration=False)
    
    # Separate the training set and validation set:
    train_src_dataset = src_dataset.take(train_src_size)
    train_tar_dataset = tar_dataset.take(train_tar_size)
    train_src_dataset = train_src_dataset.shuffle(train_src_size,
        reshuffle_each_iteration=True)
    train_tar_dataset = train_tar_dataset.shuffle(train_tar_size,
        reshuffle_each_iteration=True)
    
    val_src_dataset = src_dataset.skip(train_src_size)
    val_tar_dataset = tar_dataset.skip(train_tar_size)
    
    train_src_dataset = (
        train_src_dataset.map(load_image, num_parallel_calls=AUTOTUNE)
        .batch(batch_size, drop_remainder=True)
        .prefetch(AUTOTUNE)
    )
    
    train_tar_dataset = (
        train_tar_dataset.map(load_image, num_parallel_calls=AUTOTUNE)
        .batch(batch_size, drop_remainder=True)
        .prefetch(AUTOTUNE)
    )
    
    val_src_dataset = (
        val_src_dataset.map(load_image, num_parallel_calls=AUTOTUNE)
        .batch(batch_size, drop_remainder=True)
        .prefetch(AUTOTUNE)
    )
    
    val_tar_dataset = (
        val_tar_dataset.map(load_image, num_parallel_calls=AUTOTUNE)
        .batch(batch_size, drop_remainder=True)
        .prefetch(AUTOTUNE)
    )
    
    train_dataset = tf.data.Dataset.zip((train_src_dataset, train_tar_dataset))
    val_dataset = tf.data.Dataset.zip((val_src_dataset, val_tar_dataset))

    return train_dataset, val_dataset


In [None]:
# Set input folder paths:
GCS_PATH_KAGGLE = KaggleDatasets().get_gcs_path('gan-getting-started')
train_src_folder = GCS_PATH_KAGGLE + '/photo_jpg'

if augmented_input:
    GCS_PATH_AUGMENTED = KaggleDatasets().get_gcs_path('monet-jpg-improved')
    train_tar_folder = GCS_PATH_AUGMENTED + '/monet_jpg_improved'
else:
    train_tar_folder = GCS_PATH_KAGGLE + '/monet_jpg'

# Create the datasets:
batch_size = 1
train_dataset, val_dataset = create_dataset(train_src_folder,
                                            train_tar_folder,
                                            batch_size)

# Get the image shapes (ignore batch size):
source_image, target_image = next(iter(train_dataset))
source_shape = source_image.shape[1:]
target_shape = target_image.shape[1:]


# CUT Model
**Loss Functions**

In [None]:
class GANLoss:
    def __init__(self, gan_mode):
        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.loss = tf.keras.losses.MeanSquaredError()
        else:
            raise NotImplementedError(f'gan mode {gan_mode} not implemented.')

    def __call__(self, prediction, target_is_real):
        if self.gan_mode == 'lsgan':
            if target_is_real:
                loss = self.loss(tf.ones_like(prediction), prediction)
            else:
                loss = self.loss(tf.zeros_like(prediction), prediction)
        return loss


class PatchNCELoss:
    def __init__(self, nce_temp, nce_lambda):
        # Potential: only supports for batch_size=1 now.
        self.nce_temp = nce_temp
        self.nce_lambda = nce_lambda
        self.cross_entropy_loss = tf.keras.losses.CategoricalCrossentropy(
            reduction=tf.keras.losses.Reduction.NONE, from_logits=True)

    def __call__(self, source, target, netE, netF):
        feat_source = netE(source, training=True)
        feat_target = netE(target, training=True)

        feat_source_pool, sample_ids = netF(feat_source,
                                            patch_ids=None,
                                            training=True)
        feat_target_pool, _ = netF(feat_target,
                                   patch_ids=sample_ids,
                                   training=True)
        
        total_nce_loss = 0.0
        for feat_s, feat_t in zip(feat_source_pool, feat_target_pool):
            n_patches, dim = feat_s.shape

            logit = tf.matmul(feat_s, tf.transpose(feat_t)) / self.nce_temp

            # Diagonal entries are pos logits, the others are neg logits
            diagonal = tf.eye(n_patches, dtype=tf.bool)
            target = tf.where(diagonal, 1.0, 0.0)

            loss = self.cross_entropy_loss(target, logit) * self.nce_lambda
            total_nce_loss += tf.reduce_mean(loss)

        return total_nce_loss / len(feat_source_pool)


**Layers**

Up- and Downsampling 

In [None]:
def _setup_kernel(k):
    k = np.asarray(k, dtype=np.float32)
    if k.ndim == 1:
        k = np.outer(k, k)
    k /= np.sum(k)
    assert k.ndim == 2
    assert k.shape[0] == k.shape[1]
    return k

def _shape(tf_expr, dim_idx):
    if tf_expr.shape.rank is not None:
        dim = tf_expr.shape[dim_idx]
        if dim is not None:
            return dim
    return tf.shape(tf_expr)[dim_idx]


In [None]:
def _upfirdn_2d_ref(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
    """Slow reference implementation of `upfirdn_2d()` using standard
    TensorFlow ops.
    """

    x = tf.convert_to_tensor(x)
    k = np.asarray(k, dtype=np.float32)
    assert x.shape.rank == 4
    inH = x.shape[1]
    inW = x.shape[2]
    minorDim = _shape(x, 3)
    kernelH, kernelW = k.shape
    assert inW >= 1 and inH >= 1
    assert kernelW >= 1 and kernelH >= 1
    assert isinstance(upx, int) and isinstance(upy, int)
    assert isinstance(downx, int) and isinstance(downy, int)
    assert isinstance(padx0, int) and isinstance(padx1, int)
    assert isinstance(pady0, int) and isinstance(pady1, int)

    # Upsample (insert zeros).
    x = tf.reshape(x, [-1, inH, 1, inW, 1, minorDim])
    x = tf.pad(x, [[0, 0], [0, 0], [0, upy - 1], [0, 0], [0, upx - 1], [0, 0]])
    x = tf.reshape(x, [-1, inH * upy, inW * upx, minorDim])

    # Pad (crop if negative).
    x = tf.pad(x, [[0, 0],
                   [max(pady0, 0), max(pady1, 0)],
                   [max(padx0, 0), max(padx1, 0)],
                   [0, 0]]
              )
    x = x[:,
          max(-pady0, 0) : x.shape[1] - max(-pady1, 0),
          max(-padx0, 0) : x.shape[2] - max(-padx1, 0),
          :
         ]

    # Convolve with filter.
    x = tf.transpose(x, [0, 3, 1, 2])
    x = tf.reshape(x, [-1,
                       1,
                       inH * upy + pady0 + pady1,
                       inW * upx + padx0 + padx1
                      ]
                  )
    w = tf.constant(k[::-1, ::-1, np.newaxis, np.newaxis], dtype=x.dtype)
    x = tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='VALID',
                     data_format='NCHW')
    x = tf.reshape(x, [-1,
                       minorDim,
                       inH * upy + pady0 + pady1 - kernelH + 1,
                       inW * upx + padx0 + padx1 - kernelW + 1
                      ]
                  )
    x = tf.transpose(x, [0, 2, 3, 1])

    # Downsample (throw away pixels).
    return x[:, ::downy, ::downx, :]


In [None]:
def _simple_upfirdn_2d(x, k, up=1, down=1, pad0=0, pad1=0, data_format='NCHW',
                       impl='ref'):
    assert data_format in ['NCHW', 'NHWC']
    assert x.shape.rank == 4
    y = x
    if data_format == 'NCHW':
        y = tf.reshape(y, [-1, _shape(y, 2), _shape(y, 3), 1])
        #######MODIFIED#########
    if impl== "ref":
        y = _upfirdn_2d_ref(y, k, upx=up, upy=up, downx=down, downy=down,
                            padx0=pad0, padx1=pad1, pady0=pad0, pady1=pad1)
    else: 
         raise NotImplementedError(f'implentation {impl} not implemented.')
    if data_format == 'NCHW':
        y = tf.reshape(y, [-1, _shape(x, 1), _shape(y, 1), _shape(y, 2)])
    return y


In [None]:
def upsample_2d(x, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
    r"""Upsample a batch of 2D images with the given filter.
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or
    `[N, H, W, C]` and upsamples each image with the given filter.
    The filter is normalized so that if the input pixels are constant,
    they will be scaled by the specified `gain`.
    Pixels outside the image are assumed to be zero, and the filter is
    padded with zeros so that its shape is a multiple of the upsampling
    factor.
    Args:
        x: Input tensor of the shape `[N, C, H, W]` or`[N, H, W, C]`.
        k: FIR filter of the shape `[firH, firW]` or `[firN]`
           (separable). The default is `[1] * factor`, which
           corresponds to nearest-neighbor upsampling.
        factor: Integer upsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
        data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
        impl: Name of the implementation to use. Can ONLY be `"ref"`.
    Returns:
        Tensor of the shape `[N, C, H * factor, W * factor]` or
        `[N, H * factor, W * factor, C]`, and same datatype as `x`.
    """

    assert isinstance(factor, int) and factor >= 1
    if k is None:
        k = [1] * factor
    k = _setup_kernel(k) * (gain * (factor ** 2))
    p = k.shape[0] - factor
    return _simple_upfirdn_2d(x, k, up=factor, pad0=(p+1)//2+factor-1,
                              pad1=p//2, data_format=data_format, impl=impl)


In [None]:
def downsample_2d(x, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
    r"""Downsample a batch of 2D images with the given filter.
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or
    `[N, H, W, C]` and downsamples each image with the given filter.
    The filter is normalized so that if the input pixels are constant,
    they will be scaled by the specified `gain`.
    Pixels outside the image are assumed to be zero, and the filter
    is padded with zeros so that its shape is a multiple of the
    downsampling factor.
    Args:
        x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
        k: FIR filter of the shape `[firH, firW]` or `[firN]`
           (separable). The default is `[1] * factor`, which corresponds
           to average pooling.
        factor: Integer downsampling factor (default: 2).
        gain: Scaling factor for signal magnitude (default: 1.0).
        data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
        impl: Name of the implementation to use. Can ONLY be `"ref"`.
    Returns:
        Tensor of the shape `[N, C, H // factor, W // factor]` or
        `[N, H // factor, W // factor, C]`, and same datatype as `x`.
    """

    assert isinstance(factor, int) and factor >= 1
    if k is None:
        k = [1] * factor
    k = _setup_kernel(k) * gain
    p = k.shape[0] - factor
    return _simple_upfirdn_2d(x, k, down=factor, pad0=(p+1)//2, pad1=p//2,
                              data_format=data_format, impl=impl)


Other Layers

In [None]:
class Padding2D(Layer):
    """ 2D padding layer."""
    def __init__(self, padding=(1, 1), pad_type='constant', **kwargs):
        assert pad_type in ['constant', 'reflect', 'symmetric']
        super(Padding2D, self).__init__(**kwargs)
        if type(padding) is int:
            self.padding = (padding, padding)
        else:
            self.padding = tuple(padding)
        self.pad_type = pad_type

    def call(self, inputs, training=None):
        padding_width, padding_height = self.padding
        padding_tensor = [[0, 0],
                          [padding_height, padding_height],
                          [padding_width, padding_width],
                          [0, 0],
                         ]

        return tf.pad(inputs, padding_tensor, mode=self.pad_type)


class InstanceNorm(Layer):
    """ Instance Normalization layer
    (https://arxiv.org/abs/1607.08022).
    """
    def __init__(self, epsilon=1e-5, affine=False, **kwargs):
        super(InstanceNorm, self).__init__(**kwargs)
        self.epsilon = epsilon
        self.affine = affine
 
    def build(self, input_shape):
        if self.affine:
            self.gamma = self.add_weight(name='gamma',
                shape=(input_shape[-1],),
                initializer=tf.random_normal_initializer(0, 0.02),
                trainable=True)
            self.beta = self.add_weight(name='beta',
                                        shape=(input_shape[-1],),
                                        initializer=tf.zeros_initializer(),
                                        trainable=True)

    def call(self, inputs, training=None):
        mean, var = tf.nn.moments(inputs, axes=[1, 2], keepdims=True)
        x = tf.divide(tf.subtract(inputs, mean),
                      tf.math.sqrt(tf.add(var, self.epsilon)))

        if self.affine:
            return self.gamma * x + self.beta
        return x

    
class AntialiasSampling(Layer):
    """ Down/Up sampling layer with blur-kernel.
    """
    def __init__(self,
                 kernel_size,
                 mode,
                 impl, 
                 **kwargs):
        super(AntialiasSampling, self).__init__(**kwargs)
        if(kernel_size == 1):
            self.k = np.array([1., ])
        elif(kernel_size == 2):
            self.k = np.array([1., 1.])
        elif(kernel_size == 3):
            self.k = np.array([1., 2., 1.])
        elif(kernel_size == 4):
            self.k = np.array([1., 3., 3., 1.])
        elif(kernel_size == 5):
            self.k = np.array([1., 4., 6., 4., 1.])
        elif(kernel_size == 6):
            self.k = np.array([1., 5., 10., 10., 5., 1.])
        elif(kernel_size == 7):
            self.k = np.array([1., 6., 15., 20., 15., 6., 1.])
        self.mode = mode
        self.impl = impl

    def call(self, inputs, training=None):
        if self.mode == 'up':
            x = upsample_2d(inputs, k=self.k, data_format='NHWC',
                            impl=self.impl)
        elif self.mode == 'down':
            x = downsample_2d(inputs, k=self.k, data_format='NHWC',
                              impl=self.impl)
        else:
            raise ValueError(f'Unsupported sampling mode: {self.mode}')

        return x
    
class ConvBlock(Layer):
    """ConBlock layer consists of Conv2D + Normalization +
    Activation.
    """
    def __init__(self, filters, kernel_size, strides=(1,1), padding='valid',
                 use_bias=True, norm_layer=None, activation='linear',
                 **kwargs):
        super(ConvBlock, self).__init__(**kwargs)
        initializer = tf.random_normal_initializer(0., 0.02)
        self.conv2d = Conv2D(filters,
                             kernel_size,
                             strides,
                             padding,
                             use_bias=use_bias,
                             kernel_initializer=initializer)
        self.activation = Activation(activation)
        if norm_layer == 'batch':
            self.normalization = BatchNormalization()
        elif norm_layer == 'instance':
            self.normalization = InstanceNorm(affine=False)
        else:
            self.normalization = tf.identity

    def call(self, inputs, training=None):
        x = self.conv2d(inputs)
        x = self.normalization(x)
        x = self.activation(x)
        return x


class ConvTransposeBlock(Layer):
    """ ConvTransposeBlock layer consists of Conv2DTranspose + 
    Normalization + Activation.
    """
    def __init__(self, filters, kernel_size, strides=(1,1), padding='valid',
                 use_bias=True, norm_layer=None, activation='linear',
                 **kwargs):
        super(ConvTransposeBlock, self).__init__(**kwargs)
        initializer = tf.random_normal_initializer(0., 0.02)
        self.convT2d = Conv2DTranspose(filters,
                                       kernel_size,
                                       strides,
                                       padding,
                                       use_bias=use_bias,
                                       kernel_initializer=initializer)
        self.activation = Activation(activation)
        if norm_layer == 'batch':
            self.normalization = BatchNormalization()
        elif norm_layer == 'instance':
            self.normalization = InstanceNorm(affine=False)
        else:
            self.normalization = tf.identity

    def call(self, inputs, training=None):
        x = self.convT2d(inputs)
        x = self.normalization(x)
        x = self.activation(x)
        return x


class ResBlock(Layer):
    """ ResBlock is a ConvBlock with skip connections.
    Original Resnet paper (https://arxiv.org/pdf/1512.03385.pdf).
    """
    def __init__(self, filters, kernel_size, use_bias, norm_layer, **kwargs):
        super(ResBlock, self).__init__(**kwargs)
        self.reflect_pad1 = Padding2D(1, pad_type='reflect')
        self.conv_block1 = ConvBlock(filters,
                                     kernel_size,
                                     padding='valid',
                                     use_bias=use_bias,
                                     norm_layer=norm_layer,
                                     activation='relu')
        self.reflect_pad2 = Padding2D(1, pad_type='reflect')
        self.conv_block2 = ConvBlock(filters,
                                     kernel_size,
                                     padding='valid',
                                     use_bias=use_bias,
                                     norm_layer=norm_layer)

    def call(self, inputs, training=None):
        x = self.reflect_pad1(inputs)
        x = self.conv_block1(x)
        x = self.reflect_pad2(x)
        x = self.conv_block2(x)
        return inputs + x


**The Model Itself**


In [None]:
def Generator(input_shape, output_shape, norm_layer, use_antialias,
              resnet_blocks, impl):
    """ Create a Resnet-based generator.
    Adapt from Justin Johnson's neural style transfer project
    (https://github.com/jcjohnson/fast-neural-style).
    For BatchNorm, we use learnable affine parameters and track running
    statistics (mean/stddev). For InstanceNorm, we do not use learnable
    affine parameters. We do not track running statistics. 
    """
    use_bias = (norm_layer == 'instance')
    inputs = Input(shape=input_shape)
    
    x = Padding2D(3, pad_type='reflect')(inputs)
    x = ConvBlock(64, 7, padding='valid', use_bias=use_bias,
                  norm_layer=norm_layer, activation='relu')(x)
   
    if use_antialias:
        x = ConvBlock(128, 3, padding='same', use_bias=use_bias,
                      norm_layer=norm_layer, activation='relu')(x)
        x = AntialiasSampling(4, mode='down', impl=impl)(x)
        x = ConvBlock(256, 3, padding='same', use_bias=use_bias,
                      norm_layer=norm_layer, activation='relu')(x)
        x = AntialiasSampling(4, mode='down', impl=impl)(x)
    else:
        x = ConvBlock(128, 3, strides=2, padding='same', use_bias=use_bias,
                      norm_layer=norm_layer, activation='relu')(x)
        x = ConvBlock(256, 3, strides=2, padding='same', use_bias=use_bias,
                      norm_layer=norm_layer, activation='relu')(x)

    for _ in range(resnet_blocks):
        x = ResBlock(256, 3, use_bias, norm_layer)(x)
      
    if use_antialias:
        x = AntialiasSampling(4, mode='up', impl=impl)(x)
        x = ConvBlock(128, 3, padding='same', use_bias=use_bias,
                      norm_layer=norm_layer, activation='relu')(x)
        x = AntialiasSampling(4, mode='up', impl=impl)(x)
        x = ConvBlock(64, 3, padding='same', use_bias=use_bias,
                      norm_layer=norm_layer, activation='relu')(x)
    else:
        x = ConvTransposeBlock(128, 3, strides=2, padding='same',
                               use_bias=use_bias, norm_layer=norm_layer,
                               activation='relu')(x)
        x = ConvTransposeBlock(64, 3, strides=2, padding='same',
                               use_bias=use_bias, norm_layer=norm_layer,
                               activation='relu')(x)

    x = Padding2D(3, pad_type='reflect')(x)
    outputs = ConvBlock(output_shape[-1], 7, padding='valid',
                        activation='tanh')(x)

    return Model(inputs=inputs, outputs=outputs, name='generator')

def Discriminator(input_shape, norm_layer, use_antialias, impl):
    """ Create a PatchGAN discriminator.
    PatchGAN classifier described in the original pix2pix paper
    (https://arxiv.org/abs/1611.07004).
    Such a patch-level discriminator architecture has fewer parameters
    than a full-image discriminator and can work on arbitrarily-sized
    images in a fully convolutional fashion.
    """
    use_bias = (norm_layer == 'instance')
    inputs = Input(shape=input_shape)
    
    if use_antialias:
        x = ConvBlock(64, 4, padding='same',
                      activation=tf.nn.leaky_relu)(inputs)
        x = AntialiasSampling(4, mode='down', impl=impl)(x)
        x = ConvBlock(128, 4, padding='same', use_bias=use_bias,
                      norm_layer=norm_layer, activation=tf.nn.leaky_relu)(x)
        x = AntialiasSampling(4, mode='down', impl=impl)(x)
        x = ConvBlock(256, 4, padding='same', use_bias=use_bias,
                      norm_layer=norm_layer, activation=tf.nn.leaky_relu)(x)
        x = AntialiasSampling(4, mode='down', impl=impl)(x)
    else:
        x = ConvBlock(64, 4, strides=2, padding='same',
                      activation=tf.nn.leaky_relu)(inputs)
        x = ConvBlock(128, 4, strides=2, padding='same', use_bias=use_bias,
                      norm_layer=norm_layer, activation=tf.nn.leaky_relu)(x)
        x = ConvBlock(256, 4, strides=2, padding='same', use_bias=use_bias,
                      norm_layer=norm_layer, activation=tf.nn.leaky_relu)(x)

    x = Padding2D(1, pad_type='constant')(x)
    x = ConvBlock(512, 4, padding='valid', use_bias=use_bias,
                  norm_layer=norm_layer, activation=tf.nn.leaky_relu)(x)
    x = Padding2D(1, pad_type='constant')(x)
    outputs = ConvBlock(1, 4, padding='valid')(x)

    return Model(inputs=inputs, outputs=outputs, name='discriminator')

def Encoder(generator, nce_layers):
    """ Create an Encoder that shares weights with the generator."""
    assert max(nce_layers) <= len(generator.layers) and min(nce_layers) >= 0
    outputs = [generator.get_layer(index=idx).output for idx in nce_layers]
    return Model(inputs=generator.input, outputs=outputs, name='encoder')


class PatchSampleMLP(Model):
    """ Create a PatchSampleMLP.
    Adapt from official CUT implementation
    (https://github.com/taesungp/contrastive-unpaired-translation).
    PatchSampler samples patches from pixel/feature-space. Two-layer
    MLP projects both the input and output patches to a shared
    embedding space.
    """
    def __init__(self, units, num_patches, **kwargs):
        super(PatchSampleMLP, self).__init__(**kwargs)
        self.units = units
        self.num_patches = num_patches
        self.l2_norm = Lambda(lambda x: x * tf.math.rsqrt(tf.reduce_sum(
            tf.square(x), axis=-1, keepdims=True) + 10-10))

    def build(self, input_shape):
        initializer = tf.random_normal_initializer(0., 0.02)
        feats_shape = input_shape
        for feat_id in range(len(feats_shape)):
            mlp = tf.keras.models.Sequential([
                    Dense(self.units, activation="relu",
                          kernel_initializer=initializer),
                    Dense(self.units, kernel_initializer=initializer),
                ])
            setattr(self, f'mlp_{feat_id}', mlp)

    def call(self, inputs, patch_ids=None, training=None):
        feats = inputs
        samples = []
        ids = []
        for feat_id, feat in enumerate(feats):
            feat.set_shape([1, feat.shape[1], feat.shape[2], feat.shape[3]])
            B, H, W, C = feat.shape
            feat_reshape = tf.reshape(feat, [B, -1, C])

            if patch_ids is not None:
                patch_id = patch_ids[feat_id]
            else:
                patch_id = tf.random.shuffle(
                    tf.range(H * W))[:min(self.num_patches, H * W)]

            x_sample = tf.reshape(
                tf.gather(feat_reshape, patch_id, axis=1), [-1, C])
            mlp = getattr(self, f'mlp_{feat_id}')
            x_sample = mlp(x_sample)
            x_sample = self.l2_norm(x_sample)
            samples.append(x_sample)
            ids.append(patch_id)
        return samples, ids


class CUT_model(Model):
    """ Create a CUT/FastCUT model, described in the paper Contrastive
    Learning for Unpaired Image-to-Image Translation. Taesung Park,
    Alexei A. Efros, Richard Zhang, Jun-Yan Zhu. ECCV, 2020
    (https://arxiv.org/abs/2007.15651).
    """
    def __init__(self, source_shape, target_shape, cut_mode='cut',
                 gan_mode='lsgan', use_antialias=True, norm_layer='instance',
                 resnet_blocks=9, netF_units=256, netF_num_patches=256,
                 nce_temp=0.07, nce_layers=[0,3,5,7,11], impl='ref', **kwargs):
        assert cut_mode in ['cut']
        assert gan_mode in ['lsgan']
        assert norm_layer in [None, 'batch', 'instance']
        assert netF_units > 0
        assert netF_num_patches > 0
        super(CUT_model, self).__init__(self, **kwargs)

        self.gan_mode = gan_mode
        self.nce_temp = nce_temp
        self.nce_layers = nce_layers
        self.netG = Generator(source_shape, target_shape, norm_layer,
                              use_antialias, resnet_blocks, impl)
        self.netD = Discriminator(target_shape, norm_layer, use_antialias,
                                  impl)
        self.netE = Encoder(self.netG, self.nce_layers)
        self.netF = PatchSampleMLP(netF_units, netF_num_patches)

        if cut_mode == 'cut':
            self.nce_lambda = 1.0
            self.use_nce_identity = True
        else:
            raise ValueError(cut_mode)

    def compile(self, G_optimizer, F_optimizer, D_optimizer,):
        super(CUT_model, self).compile()
        self.G_optimizer = G_optimizer
        self.F_optimizer = F_optimizer
        self.D_optimizer = D_optimizer
        self.gan_loss_func = GANLoss(self.gan_mode)
        self.nce_loss_func = PatchNCELoss(self.nce_temp, self.nce_lambda)

    @tf.function
    def train_step(self, batch_data):
        # A is source and B is target
        real_A, real_B = batch_data
        if self.use_nce_identity:
            real = tf.concat([real_A, real_B], axis=0)
        else:
            real = real_A

        with tf.GradientTape(persistent=True) as tape:
            fake = self.netG(real, training=True)
            fake_B = fake[:real_A.shape[0]]
            if self.use_nce_identity:
                idt_B = fake[real_A.shape[0]:]

            # Calculate GAN loss for the discriminator
            fake_score = self.netD(fake_B, training=True)
            D_fake_loss = tf.reduce_mean(self.gan_loss_func(fake_score, False))
            real_score = self.netD(real_B, training=True)
            D_real_loss = tf.reduce_mean(self.gan_loss_func(real_score, True))
            D_loss = (D_fake_loss + D_real_loss) * 0.5

            # Calculate GAN loss and NCE loss for the generator
            G_loss = tf.reduce_mean(self.gan_loss_func(fake_score, True))
            NCE_loss = self.nce_loss_func(real_A, fake_B, self.netE, self.netF)
            if self.use_nce_identity:
                NCE_B_loss = self.nce_loss_func(real_B, idt_B, self.netE,
                                                self.netF)
                NCE_loss = (NCE_loss + NCE_B_loss) * 0.5
            G_loss += NCE_loss

        D_loss_grads = tape.gradient(D_loss, self.netD.trainable_variables)
        self.D_optimizer.apply_gradients(
            zip(D_loss_grads, self.netD.trainable_variables))
        
        G_loss_grads = tape.gradient(G_loss, self.netG.trainable_variables)
        self.G_optimizer.apply_gradients(
            zip(G_loss_grads, self.netG.trainable_variables))

        F_loss_grads = tape.gradient(NCE_loss, self.netF.trainable_variables)
        self.F_optimizer.apply_gradients(
            zip(F_loss_grads, self.netF.trainable_variables))

        del tape
        return {'D_loss': D_loss, 'G_loss': G_loss, 'NCE_loss': NCE_loss}

    @tf.function
    def test_step(self, batch_data):
        # A is source and B is target
        real_A, real_B = batch_data
        if self.use_nce_identity:
            real = tf.concat([real_A, real_B], axis=0)
        else:
            real = real_A
      
        fake = self.netG(real, training=False)
        fake_B = fake[:real_A.shape[0]]
        if self.use_nce_identity:
            idt_B = fake[real_A.shape[0]:]

        # Calculate GAN loss for the discriminator
        fake_score = self.netD(fake_B, training=False)
        D_fake_loss = tf.reduce_mean(self.gan_loss_func(fake_score, False))
        real_score = self.netD(real_B, training=False)
        D_real_loss = tf.reduce_mean(self.gan_loss_func(real_score, True))
        D_loss = (D_fake_loss + D_real_loss) * 0.5

        # Calculate GAN loss and NCE loss for the generator
        G_loss = tf.reduce_mean(self.gan_loss_func(fake_score, True))
        NCE_loss = self.nce_loss_func(real_A, fake_B, self.netE, self.netF)
        if self.use_nce_identity:
            NCE_B_loss = self.nce_loss_func(real_B, idt_B, self.netE,
                                            self.netF)
            NCE_loss = (NCE_loss + NCE_B_loss) * 0.5
        G_loss += NCE_loss

        return {'D_loss': D_loss, 'G_loss': G_loss, 'NCE_loss': NCE_loss}

**Instantiate and train**

In [None]:
mode = "cut"
batch_size = 1
beta_1 = 0.5
beta_2 = 0.999
lr = 0.0002
lr_decay_rate = 0.9
lr_decay_step = 100000

# Create model
cut = CUT_model(source_shape, target_shape, cut_mode=mode,
                use_antialias=anti_alias)

# Define learning rate schedule
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=lr, decay_steps=lr_decay_step,
    decay_rate=lr_decay_rate, staircase=True)

# Compile model
cut.compile(G_optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule,
                                                 beta_1=beta_1, beta_2=beta_2),
            F_optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule,
                                                 beta_1=beta_1, beta_2=beta_2),
            D_optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule,
                                                 beta_1=beta_1, beta_2=beta_2),
           )


In [None]:
class GANMonitor(tf.keras.callbacks.Callback):
    """A callback to generate and save images after each epoch"""
    def __init__(self, generator, val_dataset, out_dir, num_img=2):
        self.num_img = num_img
        self.generator = generator
        self.val_dataset = val_dataset
        self.out_dir = out_dir

    def on_epoch_end(self, epoch, logs=None):
        _, ax = plt.subplots(self.num_img, 4, figsize=(20, 10))
        [ax[0, i].set_title(title) for i, title in enumerate(
            ['Source', "Translated", "Target", "Identity"])]
        for i, (source, target) in enumerate(
                self.val_dataset.take(self.num_img)):
            translated = self.generator(source, training=False)[0].numpy()
            translated = (translated * 127.5 + 127.5).astype(np.uint8)
            source = (source[0] * 127.5 + 127.5).numpy().astype(np.uint8)

            idt = self.generator(target, training=False)[0].numpy()
            idt = (idt * 127.5 + 127.5).astype(np.uint8)
            target = (target[0] * 127.5 + 127.5).numpy().astype(np.uint8)

            [ax[i, j].imshow(img) for j, img in enumerate(
                [source, translated, target, idt])]
            [ax[i, j].axis("off") for j in range(4)]
        plt.show()
        plt.savefig(f'{self.out_dir}/epoch={epoch + 1}.png')
        plt.close()


# Create validating callback to generate output image every epoch
out_dir = 'callbacks'
if not(os.path.exists(out_dir)):
    print("Creating folder...")
    os.makedirs(out_dir)
plotter_callback = GANMonitor(cut.netG, val_dataset, out_dir)

# Train the model for max. nr of epochs or until time-out:
epoch = 0
time_start = time.time()
history_val = {'D_loss': [], 'G_loss': [], 'NCE_loss': []}
history_train = {'D_loss': [], 'G_loss': [], 'NCE_loss': []}
while time.time() - time_start < max_train_time and epoch < max_epochs:
    print("Epoch:", epoch)
    train_losses = cut.fit(train_dataset, epochs=1,
                           callbacks=[plotter_callback], verbose=1)
    history_train['D_loss'].append(train_losses.history['D_loss'][0])
    history_train['G_loss'].append(train_losses.history['G_loss'][0])
    history_train['NCE_loss'].append(train_losses.history['NCE_loss'][0])
    val_losses = (cut.evaluate(val_dataset))
    history_val['D_loss'].append(val_losses[0])
    history_val['G_loss'].append(val_losses[1])
    history_val['NCE_loss'].append(val_losses[2])
    
    print("Train:", train_losses.history)
    print("Val:", val_losses)
    epoch += 1


In [None]:
plt.plot(history_train['D_loss'], label="Training")
plt.plot(history_val['D_loss'], label="Validation")
plt.title('Discriminator Loss')
plt.legend()
plt.show()

plt.plot(history_train['G_loss'], label="Training")
plt.plot(history_val['G_loss'], label="Validation")
plt.title('Generator Loss')
plt.legend()
plt.show()

plt.plot(history_train['NCE_loss'], label="Training")
plt.plot(history_val['NCE_loss'], label="Validation")
plt.title('Patch NCE Loss')
plt.legend()
plt.show()


**Generate images and store them.**

In [None]:
# Create exactly the same train data set as before, but separate from
# the target, in order to use it as test data.
train_src_dataset = tf.data.Dataset.list_files(train_src_folder+'/*.jpg',
                                               shuffle=False)
train_src_dataset = (
    train_src_dataset.map(load_image, num_parallel_calls=AUTOTUNE)
    .batch(batch_size, drop_remainder=True)
    .prefetch(AUTOTUNE)
)

if not(os.path.exists('images')):
    print("Creating folder...")
    os.makedirs('images') # Create folder to save generated images

def predict_and_save(input_ds, generator_model):
    i = 1
    for img in input_ds:
        prediction = generator_model(img, training=False)
        prediction = tf.cast((prediction * 127.5 + 127.5), tf.uint8)
        im = tf.squeeze(prediction, 0)
        im = im.numpy()
        im = Image.fromarray(im)
        im.save("images/" + str(i) + '.jpg')
        i += 1

predict_and_save(train_src_dataset, cut.netG)


In [None]:
shutil.make_archive('/kaggle/working/images/', 'zip', 'images')
shutil.rmtree("./images")
