Resource:

* [CS109B Data Science 2(Harvard) Generative Adversarial Networks (GANs) Vincent Casser, Pavlos Protopapas](https://harvard-iacs.github.io/2019-CS109B/a-sections/a-section8/presentation/cs109b_asec8_slides_gan.pdf)
* [How to Develop a CycleGAN for Image-to-Image Translation with Keras](https://machinelearningmastery.com/cyclegan-tutorial-with-keras/)
* [Image2Art Translation Using CycleGAN](https://sahiltinky94.medium.com/image2art-translation-using-cyclegan-e1bc096b7315)
* [Transforming the World Into Paintings with CycleGAN](https://medium.com/analytics-vidhya/transforming-the-world-into-paintings-with-cyclegan-6748c0b85632)
* [Tensorflow CycleGAN](https://www.tensorflow.org/tutorials/generative/cyclegan)
* [Keras CycleGAN](https://keras.io/examples/generative/cyclegan/)  
* [CycleGAN-Tensorflow-2](https://github.com/LynnHo/CycleGAN-Tensorflow-2)  

In [None]:
import os, random, time, math, datetime, glob, sys, warnings
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output
import tensorflow as tf
from tensorflow.keras import backend as K
from kaggle_datasets import KaggleDatasets
from tensorflow.keras.optimizers import Adam, SGD, RMSprop
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, LeakyReLU, Concatenate, Input,  ReLU, Layer, Dropout, ZeroPadding2D
from tensorflow_addons.layers import InstanceNormalization
import tensorflow_addons as tfa
import tensorflow.keras as keras

AUTO = tf.data.experimental.AUTOTUNE 
warnings.filterwarnings("ignore")

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
except ValueError:
    tpu = None
    gpus = tf.config.experimental.list_logical_devices("GPU")
if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
    print('Running on TPU ', tpu.master())  
elif len(gpus) > 0:
    strategy = tf.distribute.MirroredStrategy(gpus)
    print('Running on ', len(gpus), ' GPU(s) ')
else:
    strategy = tf.distribute.get_strategy()
    print('Running on CPU')
print("Number of accelerators: ", strategy.num_replicas_in_sync)

In [None]:
GCS_DS_PATH = KaggleDatasets().get_gcs_path('gan-getting-started') 
MONET_FILENAMES = tf.io.gfile.glob(GCS_DS_PATH+'/monet_jpg/*.jpg') # monet_tfrec monet_jpg
PHOTO_FILENAMES = tf.io.gfile.glob(GCS_DS_PATH+'/photo_jpg/*.jpg') # photo_tfrec photo_jpg

random.shuffle(MONET_FILENAMES)
random.shuffle(PHOTO_FILENAMES)

PHOTO_FILENAMES_TFREC = tf.io.gfile.glob(GCS_DS_PATH+'/photo_tfrec/*.tfrec')
MONET_FILENAMES_TFREC = tf.io.gfile.glob(GCS_DS_PATH+'/monet_tfrec/*.tfrec')

random.shuffle(PHOTO_FILENAMES_TFREC)
random.shuffle(MONET_FILENAMES_TFREC)

BATCH_SIZE = 1 * strategy.num_replicas_in_sync
STEPS_PER_EPOCH = 200
EPOCHS = 150

BUFFER_SIZE = 300
IMAGE_HEIGHT = 256
IMAGE_WIDTH  = 256

DATA_SIZE = BATCH_SIZE*STEPS_PER_EPOCH
IMAGE_SIZE = [IMAGE_HEIGHT, IMAGE_WIDTH]
SAVE_FREQ = STEPS_PER_EPOCH * int(EPOCHS*0.2)
if K.image_data_format() == 'channels_first':
    SHAPE = (3,*IMAGE_SIZE)
else:
    SHAPE = (*IMAGE_SIZE, 3)
    
DIM = 64  # 64
SIZE = 4 # 4
EPSILON = 1e-12 # 1e-5
START_LEARNING_RATE=5e-4
START_LEARNING_RATE_DISC=3e-4

In [None]:
def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    return tf.reshape(image, SHAPE)

@tf.function
def _map_fn(img):  # preprocessing
    img = tf.image.resize(img, [286, 286])  # or img = tf.image.resize(img, [load_size, load_size]); img = tl.center_crop(img, crop_size)
    img = tf.clip_by_value(img, 0, 255) / 255.0  # or img = tl.minmax_norm(img)
    img = img * 2 - 1
    return img

@tf.function
def augmentation(image):
    aug_size = IMAGE_WIDTH+int(IMAGE_WIDTH*0.15)
    image = tf.image.resize(image, [aug_size, aug_size], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    image = tf.image.random_crop(image, size=[IMAGE_WIDTH, IMAGE_HEIGHT, 3])
    image = tf.image.random_flip_up_down(image)
    image = tf.image.random_flip_left_right(image)
    return image

def read_tfrecord(example):
    tfrecord_format = { "image": tf.io.FixedLenFeature([], tf.string) }
    example = tf.io.parse_single_example(example, tfrecord_format)
    return decode_image(example['image'])
     
def load_dataset(filenames):
    dataset = tf.data.TFRecordDataset(filenames)
    return dataset.map(read_tfrecord, num_parallel_calls=AUTO)


fast_photo_ds = load_dataset(PHOTO_FILENAMES_TFREC).shuffle(BUFFER_SIZE).batch(32*strategy.num_replicas_in_sync, drop_remainder=True).prefetch(AUTO)
fid_photo_ds = load_dataset(PHOTO_FILENAMES_TFREC).shuffle(BUFFER_SIZE).take(1024).batch(32*strategy.num_replicas_in_sync, drop_remainder=True).prefetch(AUTO)
fid_monet_ds = load_dataset(MONET_FILENAMES_TFREC).shuffle(BUFFER_SIZE).batch(32*strategy.num_replicas_in_sync, drop_remainder=True).prefetch(AUTO)

def get_gan_dataset():
    monet_ds = load_dataset(MONET_FILENAMES_TFREC)
    photo_ds = load_dataset(PHOTO_FILENAMES_TFREC)
    monet_ds = monet_ds.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True).cache().shuffle(BUFFER_SIZE).repeat().prefetch(AUTO)
    photo_ds = photo_ds.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True).cache().shuffle(BUFFER_SIZE).repeat().prefetch(AUTO)
    return tf.data.Dataset.zip((monet_ds, photo_ds))

In [None]:
import multiprocessing
def batch_dataset(dataset,
                  batch_size,
                  drop_remainder=True,
                  n_prefetch_batch=1,
                  filter_fn=None,
                  map_fn=None,
                  n_map_threads=None,
                  filter_after_map=False,
                  shuffle=True,
                  shuffle_buffer_size=None,
                  repeat=AUTO):
    # set defaults
    if n_map_threads is None:
        n_map_threads = multiprocessing.cpu_count()
    if shuffle and shuffle_buffer_size is None:
        shuffle_buffer_size = max(batch_size * 128, 2048)  # set the minimum buffer size as 2048

    # [*] it is efficient to conduct `shuffle` before `map`/`filter` because `map`/`filter` is sometimes costly
    if shuffle:
        dataset = dataset.shuffle(shuffle_buffer_size)

    if not filter_after_map:
        if filter_fn:
            dataset = dataset.filter(filter_fn)

        if map_fn:
            dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)

    else:  # [*] this is slower
        if map_fn:
            dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)

        if filter_fn:
            dataset = dataset.filter(filter_fn)

    dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)

    dataset = dataset.repeat().prefetch(n_prefetch_batch)

    return dataset


def memory_data_batch_dataset(memory_data,
                              batch_size,
                              drop_remainder=True,
                              n_prefetch_batch=1,
                              filter_fn=None,
                              map_fn=None,
                              n_map_threads=None,
                              filter_after_map=False,
                              shuffle=True,
                              shuffle_buffer_size=None,
                              repeat=AUTO):
    """Batch dataset of memory data.
    Parameters
    ----------
    memory_data : nested structure of tensors/ndarrays/lists
    """
    dataset = tf.data.Dataset.from_tensor_slices(memory_data)
    dataset = batch_dataset(dataset,
                            batch_size,
                            drop_remainder=drop_remainder,
                            n_prefetch_batch=n_prefetch_batch,
                            filter_fn=filter_fn,
                            map_fn=map_fn,
                            n_map_threads=n_map_threads,
                            filter_after_map=filter_after_map,
                            shuffle=shuffle,
                            shuffle_buffer_size=shuffle_buffer_size,
                            repeat=repeat)
    return dataset


def disk_image_batch_dataset(img_paths,
                             batch_size,
                             labels=None,
                             drop_remainder=True,
                             n_prefetch_batch=1,
                             filter_fn=None,
                             map_fn=None,
                             n_map_threads=None,
                             filter_after_map=False,
                             shuffle=True,
                             shuffle_buffer_size=None,
                             repeat=AUTO):
    """Batch dataset of disk image for PNG and JPEG.
    Parameters
    ----------
    img_paths : 1d-tensor/ndarray/list of str
    labels : nested structure of tensors/ndarrays/lists
    """
    if labels is None:
        memory_data = img_paths
    else:
        memory_data = (img_paths, labels)

    def parse_fn(path):
        img = tf.io.read_file(path)
        img = tf.image.decode_jpeg(img, 3)  # fix channels to 3
        img = tf.cast(img, tf.float32)
        return (img,)

    if map_fn:  # fuse `map_fn` and `parse_fn`
        def map_fn_(*args):
            return map_fn(*parse_fn(*args))
    else:
        map_fn_ = parse_fn

    dataset = memory_data_batch_dataset(memory_data,
                                        batch_size,
                                        drop_remainder=drop_remainder,
                                        n_prefetch_batch=n_prefetch_batch,
                                        filter_fn=filter_fn,
                                        map_fn=map_fn_,
                                        n_map_threads=n_map_threads,
                                        filter_after_map=filter_after_map,
                                        shuffle=shuffle,
                                        shuffle_buffer_size=shuffle_buffer_size,
                                        repeat=repeat)

    return dataset

def make_dataset(img_paths, batch_size, load_size, crop_size, training, drop_remainder=True, shuffle=True, repeat=AUTO):
    if training:
        @tf.function
        def _map_fn(img):  # preprocessing
            img = tf.image.random_flip_left_right(img)
            img = tf.image.resize(img, [load_size, load_size])
            img = tf.image.random_crop(img, [crop_size, crop_size, tf.shape(img)[-1]])
            img = tf.clip_by_value(img, 0, 255)
            img = img / 127.5 - 1
            return img
    else:
        @tf.function
        def _map_fn(img):  # preprocessing
            img = tf.image.resize(img, [crop_size, crop_size]) 
            img = tf.clip_by_value(img, 0, 255)
            img = img / 127.5 - 1
            return img

    return disk_image_batch_dataset(img_paths,
                                       batch_size,
                                       drop_remainder=drop_remainder,
                                       map_fn=_map_fn,
                                       shuffle=shuffle,
                                       repeat=repeat)


def make_zip_dataset(A_img_paths=MONET_FILENAMES, B_img_paths=PHOTO_FILENAMES, batch_size=BATCH_SIZE, load_size=IMAGE_WIDTH+int(IMAGE_WIDTH*0.15), crop_size=IMAGE_WIDTH, training=True, shuffle=True, repeat=AUTO):
    # zip two datasets aligned by the longer one
    if repeat:
        A_repeat = B_repeat = None  # cycle both
    else:
        if len(A_img_paths) >= len(B_img_paths):
            A_repeat = 1
            B_repeat = None  # cycle the shorter one
        else:
            A_repeat = None  # cycle the shorter one
            B_repeat = 1

    A_dataset = make_dataset(A_img_paths, batch_size, load_size, crop_size, training, drop_remainder=True, shuffle=shuffle, repeat=A_repeat)
    B_dataset = make_dataset(B_img_paths, batch_size, load_size, crop_size, training, drop_remainder=True, shuffle=shuffle, repeat=B_repeat)

    A_B_dataset = tf.data.Dataset.zip((A_dataset, B_dataset))
    len_dataset = max(len(A_img_paths), len(B_img_paths)) // batch_size

    return A_B_dataset, len_dataset


class ItemPool:
    def __init__(self, pool_size=50):
        self.pool_size = pool_size
        self.items = []

    def __call__(self, in_items):
        # `in_items` should be a batch tensor

        if self.pool_size == 0:
            return in_items

        out_items = []
        for in_item in in_items:
            if len(self.items) < self.pool_size:
                self.items.append(in_item)
                out_items.append(in_item)
            else:
                if np.random.rand() > 0.5:
                    idx = np.random.randint(0, len(self.items))
                    out_item, self.items[idx] = self.items[idx], in_item
                    out_items.append(out_item)
                else:
                    out_items.append(in_item)
        return tf.stack(out_items, axis=0)

In [None]:
dataset, len_dataset = make_zip_dataset()
test_dataset=get_gan_dataset()

In [None]:
# def plot_images(sample_monet1, sample_monet2):
#     plt.figure(figsize=(16, 816))
#     display_list = [sample_monet2[0], sample_monet1[0]]
#     title = [f'Input Image', f'Result']

#     for i in range(2):
#         plt.subplot(1, 2, i+1)
#         plt.title(title[i])
#         plt.imshow(display_list[i] * 0.5 + 0.5)
#         plt.axis('off')
#     plt.show()
    
# for monet, photo in dataset.take(2):
#     plot_images(monet, photo)

### Pix2Pix define model

In [None]:
initializer = tf.random_normal_initializer(0., 0.02)

def _get_norm_layer(norm):
    if norm == 'none':
        return lambda: lambda x: x
    elif norm == 'batch_norm':
        return keras.layers.BatchNormalization
    elif norm == 'instance_norm':
        return tfa.layers.InstanceNormalization
    elif norm == 'layer_norm':
        return keras.layers.LayerNormalization

def downsample(filters, size, apply_norm=True, norm='instance_norm', strides=4):
    result = Sequential()
    result.add(Conv2D(filters, size, strides=strides, padding='same',kernel_initializer=initializer, use_bias=False))
    if apply_norm:
        Norm = _get_norm_layer(norm)
        result.add(Norm())
    result.add(LeakyReLU(alpha=0.2))
    return result


def upsample(filters, size, apply_dropout=False, norm='instance_norm', strides=4):
    result = Sequential()
    result.add(Conv2DTranspose(filters, size, strides=strides, padding='same', kernel_initializer=initializer, use_bias=False))
    Norm = _get_norm_layer(norm)
    result.add(Norm())

    if apply_dropout:
        result.add(Dropout(0.5))

    result.add(ReLU())
    return result


def get_generator(output_channels=3, norm='instance_norm'):
    down_stack = [
      downsample(DIM, SIZE, apply_norm=False, norm=norm),  
      downsample(DIM*2, SIZE, norm=norm), 
      downsample(DIM*4, SIZE, norm=norm), 
#       downsample(DIM*8, SIZE, norm=norm),
#       downsample(DIM*8, SIZE, norm=norm), 
#       downsample(DIM*8, SIZE, norm=norm),
#       downsample(DIM*8, SIZE, norm=norm),
      downsample(DIM*8, SIZE, norm=norm),
    ]

    up_stack = [
#       upsample(DIM*8, SIZE, apply_dropout=True, norm=norm),
#       upsample(DIM*8, SIZE, apply_dropout=True, norm=norm), 
      upsample(DIM*8, SIZE, apply_dropout=True, norm=norm),
#       upsample(DIM*8, SIZE, norm=norm),
#       upsample(DIM*8, SIZE, norm=norm),
      upsample(DIM*4, SIZE, norm=norm),
      upsample(DIM*2, SIZE, norm=norm),
    ]

    last = Conv2DTranspose(output_channels, SIZE, strides=4, padding='same', kernel_initializer=initializer,activation='tanh')

    concat = Concatenate()

    inputs = Input(shape=[None, None, 3])
    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = concat([x, skip])

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

def get_discriminator(norm='instance_norm'):
    inp = Input(shape=[None, None, 3], name='input_image')
    x = inp
    down1 = downsample(DIM, SIZE, apply_norm=False, norm=norm)(x)
    down2 = downsample(DIM*2, SIZE, norm=norm)(down1)
    down3 = downsample(DIM*4, SIZE, norm=norm)(down2)
    zero_pad1 = ZeroPadding2D()(down3) 
    conv = Conv2D(DIM*8, SIZE, strides=1, kernel_initializer=initializer,use_bias=False)(zero_pad1) 
    Norm = _get_norm_layer(norm)
    norm1 = Norm()(conv)
    leaky_relu = LeakyReLU(alpha=0.2)(norm1)
    zero_pad2 = ZeroPadding2D()(leaky_relu)
    last = Conv2D(1, SIZE, strides=1, kernel_initializer=initializer)(zero_pad2) 
    return Model(inputs=inp, outputs=last)

### Plot lr

In [None]:
# Matplotlib config
plt.ioff()
plt.rc('image', cmap='gray_r')
plt.rc('grid', linewidth=1)
plt.rc('xtick', top=False, bottom=False, labelsize='large')
plt.rc('ytick', left=False, right=False, labelsize='large')
plt.rc('axes', facecolor='F8F8F8', titlesize="large", edgecolor='white')
plt.rc('text', color='a8151a')
plt.rc('figure', facecolor='F0F0F0', figsize=(16,9))
# Matplotlib fonts
MATPLOTLIB_FONT_DIR = os.path.join(os.path.dirname(plt.__file__), "mpl-data/fonts/ttf")

def display_epoch_predict(photo2monet, photo, monet, monet2photo):
    plt.figure(figsize=(16, 816))
    display_list = [photo[0], photo2monet[0]]
    title = [f'photo', f'photo2monet']
    for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()
    
    plt.figure(figsize=(16, 816))
    display_list = [monet[0], monet2photo[0]]
    title = ['monet', 'monet2photo']
    for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()
         
class PlotTraining(tf.keras.callbacks.Callback):
    def __init__(self, sample_rate=10, zoom=16):
        self.sample_rate = sample_rate
        self.step = 0
        self.zoom = zoom
        self.steps_per_epoch = STEPS_PER_EPOCH*BATCH_SIZE
        
    def on_train_begin(self, logs={}):
        self.batch_history = {}
        self.batch_step = []
        self.epoch_history = {}
        self.epoch_step = []
        self.fig, self.axes = plt.subplots(1, 2, figsize=(16, 7))
        self.fig.subplots_adjust(wspace=0.12, hspace=0.12)
        plt.ioff()
       
    def on_batch_end(self, batch, logs={}):
        if (batch % self.sample_rate) == 0:
            self.batch_step.append(self.step)
            for k,v in logs.items():
                # do not log "batch" and "size" metrics that do not change
                # do not log training accuracy "acc"
                if k=='batch' or k=='size' or k.endswith('gen_loss'): # or k=='acc':
                    continue
                self.batch_history.setdefault(k, []).append(v[-1][-1][-1])
                self.step += 1
 
    def on_epoch_end(self, epoch, logs={}):
        plt.close(self.fig)
        self.axes[0].cla()
        self.axes[1].cla()
    
        self.axes[0].set_ylim(0, 2)
        self.axes[1].set_ylim(0, 10)

        self.epoch_step.append(self.step)
     
        for k,v in logs.items():
          # only log validation metrics
            if k.endswith('disc_loss'):
                  continue
            self.epoch_history.setdefault(k, []).append(v[-1][-1][-1])
        
        for k,v in self.batch_history.items():
            self.axes[0 if k.endswith('disc_loss') else 1].plot(np.array(self.batch_step) / self.steps_per_epoch, v, label=k)

        for k,v in self.epoch_history.items():
            self.axes[0 if k.endswith('disc_loss') else 1].plot(np.array(self.epoch_step) / self.steps_per_epoch, v, label=k, linewidth=3)

        self.axes[0].legend()
        self.axes[1].legend()
        self.axes[0].set_xlabel('epochs')
        self.axes[1].set_xlabel('epochs')
        self.axes[0].minorticks_on()
        self.axes[0].grid(True, which='major', axis='both', linestyle='-', linewidth=1)
        self.axes[0].grid(True, which='minor', axis='both', linestyle=':', linewidth=0.5)
        self.axes[1].minorticks_on()
        self.axes[1].grid(True, which='major', axis='both', linestyle='-', linewidth=1)
        self.axes[1].grid(True, which='minor', axis='both', linestyle=':', linewidth=0.5)
        display(self.fig)

In [None]:
with strategy.scope():
    LAMBDA = 10
    loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)

    def discriminator_loss(real, generated):
        real_loss = loss_obj(tf.ones_like(real), real)
        generated_loss = loss_obj(tf.zeros_like(generated), generated)
        total_disc_loss = real_loss + generated_loss
        return total_disc_loss * 0.5

    def generator_loss(generated):
        return loss_obj(tf.ones_like(generated), generated)

    def calc_cycle_loss(real_image, cycled_image):
        loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
        return LAMBDA * loss1

    def identity_loss(real_image, same_image):
        loss = tf.reduce_mean(tf.abs(real_image - same_image))
        return LAMBDA * 0.5 * loss


### Define optimizers

In [None]:
class LinearDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
    # if `step` < `step_decay`: use fixed learning rate
    # else: linearly decay the learning rate to zero

    def __init__(self, initial_learning_rate, total_steps, step_decay):
        super(LinearDecay, self).__init__()
        self._initial_learning_rate = initial_learning_rate
        self._steps = total_steps
        self._step_decay = step_decay
        self.current_learning_rate = tf.Variable(initial_value=initial_learning_rate, trainable=False, dtype=tf.float32)

    def __call__(self, step):
        self.current_learning_rate.assign(tf.cond(
            step >= self._step_decay,
            true_fn=lambda: self._initial_learning_rate * (1 - 1 / (self._steps - self._step_decay) * (step - self._step_decay)),
            false_fn=lambda: self._initial_learning_rate
        ))
        return self.current_learning_rate
    
with strategy.scope():
    lr_schedule_gen = LinearDecay(START_LEARNING_RATE, EPOCHS * STEPS_PER_EPOCH, STEPS_PER_EPOCH)
    lr_schedule_disc = LinearDecay(START_LEARNING_RATE_DISC, EPOCHS * STEPS_PER_EPOCH, STEPS_PER_EPOCH)
    
    monet_generator_optimizer = Adam(learning_rate=lr_schedule_gen, beta_1=0.5)
    photo_generator_optimizer = Adam(learning_rate=lr_schedule_gen, beta_1=0.5)
    monet_discriminator_optimizer = Adam(learning_rate=lr_schedule_disc, beta_1=0.5)
    photo_discriminator_optimizer = Adam(learning_rate=lr_schedule_disc, beta_1=0.5)

    monet_generator = get_generator(norm='instance_norm') # layer_norm batch_norm instance_norm
    photo_generator = get_generator(norm='instance_norm') # layer_norm batch_norm instance_norm
    monet_discriminator = get_discriminator(norm='instance_norm') # layer_norm batch_norm instance_norm
    photo_discriminator = get_discriminator(norm='instance_norm') # layer_norm batch_norm instance_norm

In [None]:
class CycleGan(tf.keras.Model):
    def __init__(self, m_gen, p_gen, m_disc, p_disc):
        super(CycleGan, self).__init__()
        self.m_gen = m_gen
        self.p_gen = p_gen
        self.m_disc = m_disc
        self.p_disc = p_disc
      
    def compile(
            self, 
            m_gen_optimizer, 
            p_gen_optimizer,
            m_disc_optimizer, 
            p_disc_optimizer, 
            disc_loss_fn, 
            gen_loss_fn, 
            calc_cycle_loss_fn, 
            identity_loss_fn,
            steps_per_execution
        ):
        super(CycleGan, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.disc_loss_fn = disc_loss_fn
        self.gen_loss_fn = gen_loss_fn
        self.calc_cycle_loss = calc_cycle_loss_fn
        self.identity_loss = identity_loss_fn
        self.steps_per_execution = steps_per_execution
        
    def call(self, inputs, training=False):
        return self.m_gen(inputs)
    
    def train_step(self, input_batch):
        real_monet, real_photo = input_batch
        with tf.GradientTape(persistent=True) as tape:
            # photo to monet back to photo
            fake_monet = self.m_gen(real_photo, training=True)
            cycled_photo = self.p_gen(fake_monet, training=True)

            # monet to photo back to monet
            fake_photo = self.p_gen(real_monet, training=True)
            cycled_monet = self.m_gen(fake_photo, training=True)

            # generating itself
            same_monet = self.m_gen(real_monet, training=True)
            same_photo = self.p_gen(real_photo, training=True)

            # discriminator used to check, inputing real images
            disc_real_monet = self.m_disc(real_monet, training=True)
            disc_real_photo = self.p_disc(real_photo, training=True)

            # discriminator used to check, inputing fake images
            disc_fake_monet = self.m_disc(fake_monet, training=True)
            disc_fake_photo = self.p_disc(fake_photo, training=True)

            # evaluates generator loss
            monet_gen_loss = self.gen_loss_fn(disc_fake_monet)
            photo_gen_loss = self.gen_loss_fn(disc_fake_photo)

            # evaluates total cycle consistency loss
            total_cycle_loss =  self.calc_cycle_loss(real_monet, cycled_monet) + self.calc_cycle_loss(real_photo, cycled_photo)

            # evaluates total generator loss 
            total_monet_gen_loss = monet_gen_loss + total_cycle_loss + self.identity_loss(real_monet, same_monet)
            total_photo_gen_loss = photo_gen_loss + total_cycle_loss + self.identity_loss(real_photo, same_photo)

            # evaluates discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)

        # Calculate the gradients for generator and discriminator
        monet_generator_gradients = tape.gradient(total_monet_gen_loss, self.m_gen.trainable_variables)
        photo_generator_gradients = tape.gradient(total_photo_gen_loss, self.p_gen.trainable_variables)

        monet_discriminator_gradients = tape.gradient(monet_disc_loss, self.m_disc.trainable_variables)
        photo_discriminator_gradients = tape.gradient(photo_disc_loss, self.p_disc.trainable_variables)

        # Apply the gradients to the optimizer
        self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients, self.m_gen.trainable_variables))
        self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients, self.p_gen.trainable_variables))

        self.m_disc_optimizer.apply_gradients(zip(monet_discriminator_gradients, self.m_disc.trainable_variables))
        self.p_disc_optimizer.apply_gradients(zip(photo_discriminator_gradients, self.p_disc.trainable_variables))
        
        return {
            "monet_gen_loss": total_monet_gen_loss,
            "photo_gen_loss": total_photo_gen_loss,
            "monet_disc_loss": monet_disc_loss,
            "photo_disc_loss": photo_disc_loss
        }

In [None]:
with strategy.scope():    
    model = CycleGan(
        m_gen=monet_generator,
        p_gen=photo_generator,
        m_disc=monet_discriminator,
        p_disc=photo_discriminator
    )
    model.compile(
        m_gen_optimizer=monet_generator_optimizer,
        p_gen_optimizer=photo_generator_optimizer,
        m_disc_optimizer=monet_discriminator_optimizer,
        p_disc_optimizer=photo_discriminator_optimizer,
        disc_loss_fn = discriminator_loss, 
        gen_loss_fn = generator_loss, 
        calc_cycle_loss_fn = calc_cycle_loss, 
        identity_loss_fn = identity_loss,
        steps_per_execution = STEPS_PER_EPOCH//strategy.num_replicas_in_sync
    )

In [None]:
checkpoint_path = "/kaggle/working/cp-last-tpu.h5"
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path, 
    monitor='monet_gen_loss', 
    verbose=1, 
    mode='min', 
    save_weights_only=True,
    save_freq=SAVE_FREQ
)


plot_training = PlotTraining(sample_rate=8, zoom=12)

class DisplayCallback(tf.keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
        self._supports_tf_logs = True
    def on_epoch_end(self, epoch, logs=None):
        clear_output(wait=True)
        monet, photo = next(iter(test_dataset))
        display_epoch_predict(monet_generator.predict(photo), photo, monet, photo_generator.predict(monet))
        print(f'gen lr: {self.model.m_gen_optimizer._decayed_lr(tf.float32)}')
        print(f'disc lr: {self.model.m_disc_optimizer._decayed_lr(tf.float32)}')

class SaveModelsWeightsCallback(tf.keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
    def on_epoch_end(self, epoch, logs=None):
        monet_generator.save_weights('./monet_generator-tpu.h5')
        photo_generator.save_weights('./photo_generator-tpu.h5')

callbacks_list = [DisplayCallback(), cp_callback, plot_training, SaveModelsWeightsCallback()] # lr_decay_callback, cp_callback plot_training, , plot_training

In [None]:
# model.built = True
# model.predict(next(iter(dataset))[0])
# model.load_weights('../input/checkpoints/cp-last-tpu.h5')

In [None]:
model.fit(dataset, epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, callbacks=callbacks_list)

In [None]:
model.save_weights('./cp-last-tpu.h5')

In [None]:
from tensorflow.keras import layers

with strategy.scope():
    inception_model = tf.keras.applications.InceptionV3(input_shape=(256,256,3),pooling="avg",include_top=False)
    mix3  = inception_model.get_layer("mixed9").output
    f0 = tf.keras.layers.GlobalAveragePooling2D()(mix3)
    inception_model = tf.keras.Model(inputs=inception_model.input, outputs=f0)
    inception_model.trainable = False
    
    def calculate_activation_statistics_mod(images,fid_model):
            act=tf.cast(fid_model.predict(images), tf.float32)
            mu = tf.reduce_mean(act, axis=0)
            mean_x = tf.reduce_mean(act, axis=0, keepdims=True)
            mx = tf.matmul(tf.transpose(mean_x), mean_x)
            vx = tf.matmul(tf.transpose(act), act)/tf.cast(tf.shape(act)[0], tf.float32)
            sigma = vx - mx
            return mu, sigma
    myFID_mu2, myFID_sigma2 = calculate_activation_statistics_mod(fid_monet_ds,inception_model)        
    fids=[]
with strategy.scope():
    def calculate_frechet_distance(mu1,sigma1,mu2,sigma2):
        fid_epsilon = 1e-14
        covmean = tf.linalg.sqrtm(tf.cast(tf.matmul(sigma1,sigma2),tf.complex64))
        covmean = tf.cast(tf.math.real(covmean),tf.float32)
        tr_covmean = tf.linalg.trace(covmean)
        return tf.matmul(tf.expand_dims(mu1 - mu2, axis=0),tf.expand_dims(mu1 - mu2, axis=1)) + tf.linalg.trace(sigma1) + tf.linalg.trace(sigma2) - 2 * tr_covmean
    
    def FID(images,gen_model,inception_model=inception_model,myFID_mu2=myFID_mu2, myFID_sigma2=myFID_sigma2):
                inp = layers.Input(shape=[256, 256, 3], name='input_image')
                x  = gen_model(inp)
                x=inception_model(x)
                fid_model = tf.keras.Model(inputs=inp, outputs=x)
                mu1, sigma1= calculate_activation_statistics_mod(images,fid_model)
                fid_value = calculate_frechet_distance(mu1, sigma1,myFID_mu2, myFID_sigma2)
                return fid_value

In [None]:
FID(test_dataset.take(300), monet_generator) 
# 35 Score for FID 9.726887

In [None]:
# model.built = True
# model.call(next(iter(test_photo)))
# model.predict(next(iter(test_photo)))
# model.load_weights('./cp-138-loss1.11.h5')

for inp in test_photo.take(15):
    plot_images(model.predict(inp), inp)