Resource:

* [CS109B Data Science 2(Harvard) Generative Adversarial Networks (GANs) Vincent Casser, Pavlos Protopapas](https://harvard-iacs.github.io/2019-CS109B/a-sections/a-section8/presentation/cs109b_asec8_slides_gan.pdf)
* [How to Develop a CycleGAN for Image-to-Image Translation with Keras](https://machinelearningmastery.com/cyclegan-tutorial-with-keras/)
* [Image2Art Translation Using CycleGAN](https://sahiltinky94.medium.com/image2art-translation-using-cyclegan-e1bc096b7315)
* [Transforming the World Into Paintings with CycleGAN](https://medium.com/analytics-vidhya/transforming-the-world-into-paintings-with-cyclegan-6748c0b85632)
* [Tensorflow CycleGAN](https://www.tensorflow.org/tutorials/generative/cyclegan)
* [Keras CycleGAN](https://keras.io/examples/generative/cyclegan/)  
* [CycleGAN-Tensorflow-2 ResNet](https://github.com/LynnHo/CycleGAN-Tensorflow-2) 
* [CycleGAN-Tensorflow-2 U-net](https://github.com/LynnHo/AttGAN-Tensorflow) 


In [None]:
import os, random, time, math, datetime, glob, sys, warnings, tqdm, functools, gc
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output
import tensorflow as tf
from tensorflow.keras import backend as K
from kaggle_datasets import KaggleDatasets
from tensorflow.keras.optimizers import Adam, SGD, RMSprop
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, LeakyReLU, Concatenate, Input,  ReLU, Layer, Dropout, ZeroPadding2D
import tensorflow_addons as tfa
import tensorflow.keras as keras
AUTO = tf.data.experimental.AUTOTUNE 

warnings.filterwarnings("ignore")

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
except ValueError:
    tpu = None
    gpus = tf.config.experimental.list_logical_devices("GPU")
if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
    print('Running on TPU ', tpu.master())  
elif len(gpus) > 0:
    strategy = tf.distribute.MirroredStrategy(gpus)
    print('Running on ', len(gpus), ' GPU(s) ')
else:
    strategy = tf.distribute.get_strategy()
    print('Running on CPU')
print("Number of accelerators: ", strategy.num_replicas_in_sync)

In [None]:
# GCS_DS_PATH = KaggleDatasets().get_gcs_path('gan-getting-started') 
GCS_DS_PATH = '../input/gan-getting-started'
MONET_FILENAMES = tf.io.gfile.glob(GCS_DS_PATH+'/monet_jpg/*.jpg') # monet_tfrec monet_jpg
PHOTO_FILENAMES = tf.io.gfile.glob(GCS_DS_PATH+'/photo_jpg/*.jpg') # photo_tfrec photo_jpg

random.shuffle(MONET_FILENAMES)
random.shuffle(PHOTO_FILENAMES)

PHOTO_FILENAMES_TFREC = tf.io.gfile.glob(GCS_DS_PATH+'/photo_tfrec/*.tfrec')
MONET_FILENAMES_TFREC = tf.io.gfile.glob(GCS_DS_PATH+'/monet_tfrec/*.tfrec')

random.shuffle(PHOTO_FILENAMES_TFREC)
random.shuffle(MONET_FILENAMES_TFREC)

BATCH_SIZE = 1 * strategy.num_replicas_in_sync
STEPS_PER_EPOCH = 400
EPOCHS = 200

BUFFER_SIZE = 300
IMAGE_HEIGHT = 256
IMAGE_WIDTH  = 256

DATA_SIZE = BATCH_SIZE*STEPS_PER_EPOCH
IMAGE_SIZE = [IMAGE_HEIGHT, IMAGE_WIDTH]
SAVE_FREQ = STEPS_PER_EPOCH * int(EPOCHS*0.2)
if K.image_data_format() == 'channels_first':
    SHAPE = (3,*IMAGE_SIZE)
else:
    SHAPE = (*IMAGE_SIZE, 3)

DIM = 64  # 64
SIZE = 4 # 4
START_LEARNING_RATE=3e-5
START_LEARNING_RATE_DISC=1e-5 
N_DOWNSAMPLINGS = 2 # 2
N_DOWNSAMPLINGS_DESC = 3 # 3
N_BLOCKS = 9 # 9

In [None]:
def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    return tf.reshape(image, SHAPE)

@tf.function
def _map_fn(img):  # preprocessing
    img = tf.image.resize(img, [286, 286])  # or img = tf.image.resize(img, [load_size, load_size]); img = tl.center_crop(img, crop_size)
    img = tf.clip_by_value(img, 0, 255) / 255.0  # or img = tl.minmax_norm(img)
    img = img * 2 - 1
    return img

@tf.function
def augmentation(image):
    aug_size = IMAGE_WIDTH+int(IMAGE_WIDTH*0.15)
    image = tf.image.resize(image, [aug_size, aug_size], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    image = tf.image.random_crop(image, size=[IMAGE_WIDTH, IMAGE_HEIGHT, 3])
    image = tf.image.random_flip_up_down(image)
    image = tf.image.random_flip_left_right(image)
    return image

def read_tfrecord(example):
    tfrecord_format = { "image": tf.io.FixedLenFeature([], tf.string) }
    example = tf.io.parse_single_example(example, tfrecord_format)
    return decode_image(example['image'])
     
def load_dataset(filenames):
    dataset = tf.data.TFRecordDataset(filenames)
    return dataset.map(read_tfrecord, num_parallel_calls=AUTO)


fast_photo_ds = load_dataset(PHOTO_FILENAMES_TFREC).shuffle(BUFFER_SIZE).batch(32*strategy.num_replicas_in_sync, drop_remainder=True).prefetch(AUTO)
fid_photo_ds = load_dataset(PHOTO_FILENAMES_TFREC).shuffle(BUFFER_SIZE).take(1024).batch(32*strategy.num_replicas_in_sync, drop_remainder=True).prefetch(AUTO)
fid_monet_ds = load_dataset(MONET_FILENAMES_TFREC).shuffle(BUFFER_SIZE).batch(32*strategy.num_replicas_in_sync, drop_remainder=True).prefetch(AUTO)

def get_gan_dataset():
    monet_ds = load_dataset(MONET_FILENAMES_TFREC)
    photo_ds = load_dataset(PHOTO_FILENAMES_TFREC)
    monet_ds = monet_ds.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True).cache().shuffle(BUFFER_SIZE).repeat().prefetch(AUTO)
    photo_ds = photo_ds.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True).cache().shuffle(BUFFER_SIZE).repeat().prefetch(AUTO)
    return tf.data.Dataset.zip((monet_ds, photo_ds))

In [None]:
import multiprocessing
def batch_dataset(dataset,
                  batch_size,
                  drop_remainder=True,
                  n_prefetch_batch=1,
                  filter_fn=None,
                  map_fn=None,
                  n_map_threads=None,
                  filter_after_map=False,
                  shuffle=True,
                  shuffle_buffer_size=None,
                  repeat=AUTO):
    # set defaults
    if n_map_threads is None:
        n_map_threads = multiprocessing.cpu_count()
    if shuffle and shuffle_buffer_size is None:
        shuffle_buffer_size = max(batch_size * 128, 2048)  # set the minimum buffer size as 2048

    # [*] it is efficient to conduct `shuffle` before `map`/`filter` because `map`/`filter` is sometimes costly
    if shuffle:
        dataset = dataset.shuffle(shuffle_buffer_size)

    if not filter_after_map:
        if filter_fn:
            dataset = dataset.filter(filter_fn)

        if map_fn:
            dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)

    else:  # [*] this is slower
        if map_fn:
            dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)

        if filter_fn:
            dataset = dataset.filter(filter_fn)

    dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)

    dataset = dataset.repeat().prefetch(n_prefetch_batch)

    return dataset


def memory_data_batch_dataset(memory_data,
                              batch_size,
                              drop_remainder=True,
                              n_prefetch_batch=1,
                              filter_fn=None,
                              map_fn=None,
                              n_map_threads=None,
                              filter_after_map=False,
                              shuffle=True,
                              shuffle_buffer_size=None,
                              repeat=AUTO):
    """Batch dataset of memory data.
    Parameters
    ----------
    memory_data : nested structure of tensors/ndarrays/lists
    """
    dataset = tf.data.Dataset.from_tensor_slices(memory_data)
    dataset = batch_dataset(dataset,
                            batch_size,
                            drop_remainder=drop_remainder,
                            n_prefetch_batch=n_prefetch_batch,
                            filter_fn=filter_fn,
                            map_fn=map_fn,
                            n_map_threads=n_map_threads,
                            filter_after_map=filter_after_map,
                            shuffle=shuffle,
                            shuffle_buffer_size=shuffle_buffer_size,
                            repeat=repeat)
    return dataset


def disk_image_batch_dataset(img_paths,
                             batch_size,
                             labels=None,
                             drop_remainder=True,
                             n_prefetch_batch=1,
                             filter_fn=None,
                             map_fn=None,
                             n_map_threads=None,
                             filter_after_map=False,
                             shuffle=True,
                             shuffle_buffer_size=None,
                             repeat=AUTO):
    """Batch dataset of disk image for PNG and JPEG.
    Parameters
    ----------
    img_paths : 1d-tensor/ndarray/list of str
    labels : nested structure of tensors/ndarrays/lists
    """
    if labels is None:
        memory_data = img_paths
    else:
        memory_data = (img_paths, labels)

    def parse_fn(path):
        img = tf.io.read_file(path)
        img = tf.image.decode_jpeg(img, 3)  # fix channels to 3
        img = tf.cast(img, tf.float32)
        return (img,)

    if map_fn:  # fuse `map_fn` and `parse_fn`
        def map_fn_(*args):
            return map_fn(*parse_fn(*args))
    else:
        map_fn_ = parse_fn

    dataset = memory_data_batch_dataset(memory_data,
                                        batch_size,
                                        drop_remainder=drop_remainder,
                                        n_prefetch_batch=n_prefetch_batch,
                                        filter_fn=filter_fn,
                                        map_fn=map_fn_,
                                        n_map_threads=n_map_threads,
                                        filter_after_map=filter_after_map,
                                        shuffle=shuffle,
                                        shuffle_buffer_size=shuffle_buffer_size,
                                        repeat=repeat)

    return dataset

def make_dataset(img_paths, batch_size, load_size, crop_size, training, drop_remainder=True, shuffle=True, repeat=AUTO):
    if training:
        @tf.function
        def _map_fn(img):  # preprocessing
            img = tf.image.random_flip_left_right(img)
            img = tf.image.resize(img, [load_size, load_size])
            img = tf.image.random_crop(img, [crop_size, crop_size, tf.shape(img)[-1]])
            img = tf.clip_by_value(img, 0, 255) / 255.0  # or img = tl.minmax_norm(img)
            img = img * 2 - 1
            return img
    else:
        @tf.function
        def _map_fn(img):  # preprocessing
            img = tf.image.resize(img, [crop_size, crop_size])  # or img = tf.image.resize(img, [load_size, load_size]); img = tl.center_crop(img, crop_size)
            img = tf.clip_by_value(img, 0, 255) / 255.0  # or img = tl.minmax_norm(img)
            img = img * 2 - 1
            return img

    return disk_image_batch_dataset(img_paths,
                                       batch_size,
                                       drop_remainder=drop_remainder,
                                       map_fn=_map_fn,
                                       shuffle=shuffle,
                                       repeat=repeat)


def make_zip_dataset(A_img_paths=MONET_FILENAMES, B_img_paths=PHOTO_FILENAMES, batch_size=BATCH_SIZE, load_size=IMAGE_WIDTH+int(IMAGE_WIDTH*0.15), crop_size=IMAGE_WIDTH, training=True, shuffle=True, repeat=AUTO):
    # zip two datasets aligned by the longer one
    if repeat:
        A_repeat = B_repeat = None  # cycle both
    else:
        if len(A_img_paths) >= len(B_img_paths):
            A_repeat = 1
            B_repeat = None  # cycle the shorter one
        else:
            A_repeat = None  # cycle the shorter one
            B_repeat = 1

    A_dataset = make_dataset(A_img_paths, batch_size, load_size, crop_size, training, drop_remainder=True, shuffle=shuffle, repeat=A_repeat)
    B_dataset = make_dataset(B_img_paths, batch_size, load_size, crop_size, training, drop_remainder=True, shuffle=shuffle, repeat=B_repeat)

    A_B_dataset = tf.data.Dataset.zip((A_dataset, B_dataset))
    len_dataset = max(len(A_img_paths), len(B_img_paths)) // batch_size

    return A_B_dataset, len_dataset


class ItemPool:
    def __init__(self, pool_size=50):
        self.pool_size = pool_size
        self.items = []

    def __call__(self, in_items):
        # `in_items` should be a batch tensor

        if self.pool_size == 0:
            return in_items

        out_items = []
        for in_item in in_items:
            if len(self.items) < self.pool_size:
                self.items.append(in_item)
                out_items.append(in_item)
            else:
                if np.random.rand() > 0.5:
                    idx = np.random.randint(0, len(self.items))
                    out_item, self.items[idx] = self.items[idx], in_item
                    out_items.append(out_item)
                else:
                    out_items.append(in_item)
        return tf.stack(out_items, axis=0)

In [None]:
dataset, len_dataset = make_zip_dataset()
test_dataset=get_gan_dataset()

In [None]:
def plot_images(sample_monet1, sample_monet2):
    plt.figure(figsize=(16, 816))
    display_list = [sample_monet2[0], sample_monet1[0]]
    title = ['photo', 'monet']

    for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()
    
for monet, photo in dataset.take(2):
    plot_images(monet, photo)

### Losses fn

In [None]:
def get_gan_losses_fn():
    bce = tf.losses.BinaryCrossentropy(from_logits=True)

    def d_loss_fn(r_logit, f_logit):
        r_loss = bce(tf.ones_like(r_logit), r_logit)
        f_loss = bce(tf.zeros_like(f_logit), f_logit)
        return r_loss, f_loss

    def g_loss_fn(f_logit):
        f_loss = bce(tf.ones_like(f_logit), f_logit)
        return f_loss

    return d_loss_fn, g_loss_fn


def get_hinge_v1_losses_fn():
    def d_loss_fn(r_logit, f_logit):
        r_loss = tf.reduce_mean(tf.maximum(1 - r_logit, 0))
        f_loss = tf.reduce_mean(tf.maximum(1 + f_logit, 0))
        return r_loss, f_loss

    def g_loss_fn(f_logit):
        f_loss = tf.reduce_mean(tf.maximum(1 - f_logit, 0))
        return f_loss

    return d_loss_fn, g_loss_fn


def get_hinge_v2_losses_fn():
    def d_loss_fn(r_logit, f_logit):
        r_loss = tf.reduce_mean(tf.maximum(1 - r_logit, 0))
        f_loss = tf.reduce_mean(tf.maximum(1 + f_logit, 0))
        return r_loss, f_loss

    def g_loss_fn(f_logit):
        f_loss = tf.reduce_mean(- f_logit)
        return f_loss

    return d_loss_fn, g_loss_fn


def get_lsgan_losses_fn():
    mse = tf.losses.MeanSquaredError()

    def d_loss_fn(r_logit, f_logit):
        r_loss = mse(tf.ones_like(r_logit), r_logit)
        f_loss = mse(tf.zeros_like(f_logit), f_logit)
        return r_loss, f_loss

    def g_loss_fn(f_logit):
        f_loss = mse(tf.ones_like(f_logit), f_logit)
        return f_loss

    return d_loss_fn, g_loss_fn


def get_wgan_losses_fn():
    def d_loss_fn(r_logit, f_logit):
        r_loss = - tf.reduce_mean(r_logit)
        f_loss = tf.reduce_mean(f_logit)
        return r_loss, f_loss

    def g_loss_fn(f_logit):
        f_loss = - tf.reduce_mean(f_logit)
        return f_loss

    return d_loss_fn, g_loss_fn


def get_adversarial_losses_fn(mode):
    if mode == 'gan':
        return get_gan_losses_fn()
    elif mode == 'hinge_v1':
        return get_hinge_v1_losses_fn()
    elif mode == 'hinge_v2':
        return get_hinge_v2_losses_fn()
    elif mode == 'lsgan':
        return get_lsgan_losses_fn()
    elif mode == 'wgan':
        return get_wgan_losses_fn()
@tf.function
def minmax_norm(x, epsilon=1e-12):
    x = tf.cast(x, tf.float32)
    min_val = tf.reduce_min(x)
    max_val = tf.reduce_max(x)
    norm_x = (x - min_val) / tf.maximum((max_val - min_val), epsilon)
    return norm_x


@tf.function
def reshape(x, shape):
    x = tf.convert_to_tensor(x)
    shape = [x.shape[i] if shape[i] == 0 else shape[i] for i in range(len(shape))]  # TODO(Lynn): is it slow here?
    shape = [tf.shape(x)[i] if shape[i] is None else shape[i] for i in range(len(shape))]
    return tf.reshape(x, shape)

def gradient_penalty(f, real, fake, mode):
    def _gradient_penalty(f, real, fake=None):
        def _interpolate(a, b=None):
            if b is None:   # interpolation in DRAGAN
                beta = tf.random.uniform(shape=tf.shape(a), minval=0., maxval=1.)
                b = a + 0.5 * tf.math.reduce_std(a) * beta
            shape = [tf.shape(a)[0]] + [1] * (a.shape.ndims - 1)
            alpha = tf.random.uniform(shape=shape, minval=0., maxval=1.)
            inter = a + alpha * (b - a)
            inter.set_shape(a.shape)
            return inter

        x = _interpolate(real, fake)
        with tf.GradientTape() as t:
            t.watch(x)
            pred = f(x)
        grad = t.gradient(pred, x)
        norm = tf.norm(tf.reshape(grad, [tf.shape(grad)[0], -1]), axis=1)
        gp = tf.reduce_mean((norm - 1.)**2)

        return gp

    if mode == 'none':
        gp = tf.constant(0, dtype=real.dtype)
    elif mode == 'dragan':
        gp = _gradient_penalty(f, real)
    elif mode == 'wgan-gp':
        gp = _gradient_penalty(f, real, fake)

    return gp

### make dataset @tf.function

In [None]:
# import multiprocessing
# def batch_dataset(dataset,
#                   batch_size,
#                   drop_remainder=True,
#                   n_prefetch_batch=1,
#                   filter_fn=None,
#                   map_fn=None,
#                   n_map_threads=None,
#                   filter_after_map=False,
#                   shuffle=True,
#                   shuffle_buffer_size=None,
#                   repeat=None):
#     # set defaults
#     if n_map_threads is None:
#         n_map_threads = multiprocessing.cpu_count()
#     if shuffle and shuffle_buffer_size is None:
#         shuffle_buffer_size = max(batch_size * 128, 2048)  # set the minimum buffer size as 2048

#     # [*] it is efficient to conduct `shuffle` before `map`/`filter` because `map`/`filter` is sometimes costly
#     if shuffle:
#         dataset = dataset.shuffle(shuffle_buffer_size)

#     if not filter_after_map:
#         if filter_fn:
#             dataset = dataset.filter(filter_fn)

#         if map_fn:
#             dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)

#     else:  # [*] this is slower
#         if map_fn:
#             dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)

#         if filter_fn:
#             dataset = dataset.filter(filter_fn)

#     dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)

#     dataset = dataset.repeat(repeat).prefetch(n_prefetch_batch)

#     return dataset


# def memory_data_batch_dataset(memory_data,
#                               batch_size,
#                               drop_remainder=True,
#                               n_prefetch_batch=1,
#                               filter_fn=None,
#                               map_fn=None,
#                               n_map_threads=None,
#                               filter_after_map=False,
#                               shuffle=True,
#                               shuffle_buffer_size=None,
#                               repeat=None):
#     """Batch dataset of memory data.
#     Parameters
#     ----------
#     memory_data : nested structure of tensors/ndarrays/lists
#     """
#     dataset = tf.data.Dataset.from_tensor_slices(memory_data)
#     dataset = batch_dataset(dataset,
#                             batch_size,
#                             drop_remainder=drop_remainder,
#                             n_prefetch_batch=n_prefetch_batch,
#                             filter_fn=filter_fn,
#                             map_fn=map_fn,
#                             n_map_threads=n_map_threads,
#                             filter_after_map=filter_after_map,
#                             shuffle=shuffle,
#                             shuffle_buffer_size=shuffle_buffer_size,
#                             repeat=repeat)
#     return dataset


# def disk_image_batch_dataset(img_paths,
#                              batch_size,
#                              labels=None,
#                              drop_remainder=True,
#                              n_prefetch_batch=1,
#                              filter_fn=None,
#                              map_fn=None,
#                              n_map_threads=None,
#                              filter_after_map=False,
#                              shuffle=True,
#                              shuffle_buffer_size=None,
#                              repeat=None):
#     dataset = load_dataset(img_paths)
#     dataset = dataset.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True).repeat().prefetch(AUTO)

#     return dataset

# def make_dataset(img_paths, batch_size, load_size, crop_size, training, drop_remainder=True, shuffle=True, repeat=1):
#     if training:
#         @tf.function
#         def _map_fn(img):  # preprocessing
#             img = tf.image.random_flip_left_right(img)
#             img = tf.image.resize(img, [load_size, load_size])
#             img = tf.image.random_crop(img, [crop_size, crop_size, tf.shape(img)[-1]])
#             img = tf.clip_by_value(img, 0, 255) / 255.0  # or img = tl.minmax_norm(img)
#             img = img * 2 - 1
#             return img
#     else:
#         @tf.function
#         def _map_fn(img):  # preprocessing
#             img = tf.image.resize(img, [crop_size, crop_size])  # or img = tf.image.resize(img, [load_size, load_size]); img = tl.center_crop(img, crop_size)
#             img = tf.clip_by_value(img, 0, 255) / 255.0  # or img = tl.minmax_norm(img)
#             img = img * 2 - 1
#             return img

#     return disk_image_batch_dataset(img_paths,
#                                        batch_size,
#                                        drop_remainder=drop_remainder,
#                                        map_fn=_map_fn,
#                                        shuffle=shuffle,
#                                        repeat=repeat)


# def make_zip_dataset(A_img_paths, B_img_paths, batch_size, load_size, crop_size, training, shuffle=True, repeat=False):
#     # zip two datasets aligned by the longer one
#     if repeat:
#         A_repeat = B_repeat = None  # cycle both
#     else:
#         if len(A_img_paths) >= len(B_img_paths):
#             A_repeat = 1
#             B_repeat = None  # cycle the shorter one
#         else:
#             A_repeat = None  # cycle the shorter one
#             B_repeat = 1

#     A_dataset = make_dataset(A_img_paths, batch_size, load_size, crop_size, training, drop_remainder=True, shuffle=shuffle, repeat=A_repeat)
#     B_dataset = make_dataset(B_img_paths, batch_size, load_size, crop_size, training, drop_remainder=True, shuffle=shuffle, repeat=B_repeat)

#     A_B_dataset = tf.data.Dataset.zip((A_dataset, B_dataset))
#     len_dataset = max(len(A_img_paths), len(B_img_paths)) // batch_size

#     return A_B_dataset, len_dataset


# class ItemPool:
#     def __init__(self, pool_size=50):
#         self.pool_size = pool_size
#         self.items = []

#     def __call__(self, in_items):
#         # `in_items` should be a batch tensor

#         if self.pool_size == 0:
#             return in_items

#         out_items = []
#         for in_item in in_items:
#             if len(self.items) < self.pool_size:
#                 self.items.append(in_item)
#                 out_items.append(in_item)
#             else:
#                 if np.random.rand() > 0.5:
#                     idx = np.random.randint(0, len(self.items))
#                     out_item, self.items[idx] = self.items[idx], in_item
#                     out_items.append(out_item)
#                 else:
#                     out_items.append(in_item)
#         return tf.stack(out_items, axis=0)

### ResnetGenerator ConvDiscriminator LinearDecay

In [None]:
def _get_norm_layer(norm):
    if norm == 'none':
        return lambda: lambda x: x
    elif norm == 'batch_norm':
        return keras.layers.BatchNormalization
    elif norm == 'instance_norm':
        return tfa.layers.InstanceNormalization
    elif norm == 'layer_norm':
        return keras.layers.LayerNormalization

def ResnetGenerator(input_shape=SHAPE,
                    output_channels=3,
                    dim=DIM,
                    n_downsamplings=N_DOWNSAMPLINGS,
                    n_blocks=N_BLOCKS, # 9
                    norm='instance_norm'):
    Norm = _get_norm_layer(norm)

    def _residual_block(x):
        dim = x.shape[-1]
        h = x

        h = tf.pad(h, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')
        h = keras.layers.Conv2D(dim, 3, padding='valid', use_bias=False)(h)
        h = Norm()(h)
        h = tf.nn.relu(h)

        h = tf.pad(h, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')
        h = keras.layers.Conv2D(dim, 3, padding='valid', use_bias=False)(h)
        h = Norm()(h)

        return keras.layers.add([x, h])

    # 0
    h = inputs = keras.Input(shape=input_shape)

    # 1
    h = tf.pad(h, [[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')
    h = keras.layers.Conv2D(dim, 7, padding='valid', use_bias=False)(h)
    h = Norm()(h)
    h = tf.nn.relu(h)

    # 2
    for _ in range(n_downsamplings):
        dim *= 2
        h = keras.layers.Conv2D(dim, 3, strides=2, padding='same', use_bias=False)(h)
        h = Norm()(h)
        h = tf.nn.relu(h)

    # 3
    for _ in range(n_blocks):
        h = _residual_block(h)

    # 4
    for _ in range(n_downsamplings):
        dim //= 2
        h = keras.layers.Conv2DTranspose(dim, 3, strides=2, padding='same', use_bias=False)(h)
        h = Norm()(h)
        h = tf.nn.relu(h)

    # 5
    h = tf.pad(h, [[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')
    h = keras.layers.Conv2D(output_channels, 7, padding='valid')(h)
    h = tf.tanh(h)

    return keras.Model(inputs=inputs, outputs=h)

def ConvDiscriminator(input_shape=SHAPE,
                      dim=DIM,
                      n_downsamplings=N_DOWNSAMPLINGS_DESC,
                      norm='instance_norm'):
    dim_ = dim
    Norm = _get_norm_layer(norm)

    # 0
    h = inputs = keras.Input(shape=input_shape)

    # 1
    h = keras.layers.Conv2D(dim, 4, strides=2, padding='same')(h)
    h = tf.nn.leaky_relu(h, alpha=0.2)

    for _ in range(n_downsamplings - 1):
        dim = min(dim * 2, dim_ * 8)
        h = keras.layers.Conv2D(dim, 4, strides=2, padding='same', use_bias=False)(h)
        h = Norm()(h)
        h = tf.nn.leaky_relu(h, alpha=0.2)

    # 2
    dim = min(dim * 2, dim_ * 8)
    h = keras.layers.Conv2D(dim, 4, strides=1, padding='same', use_bias=False)(h)
    h = Norm()(h)
    h = tf.nn.leaky_relu(h, alpha=0.2)

    # 3
    h = keras.layers.Conv2D(1, 4, strides=1, padding='same')(h)

    return keras.Model(inputs=inputs, outputs=h)

class LinearDecay(keras.optimizers.schedules.LearningRateSchedule):
    # if `step` < `step_decay`: use fixed learning rate
    # else: linearly decay the learning rate to zero

    def __init__(self, initial_learning_rate, total_steps, step_decay):
        super(LinearDecay, self).__init__()
        self._initial_learning_rate = initial_learning_rate
        self._steps = total_steps
        self._step_decay = step_decay
        self.current_learning_rate = tf.Variable(initial_value=initial_learning_rate, trainable=False, dtype=tf.float32)

    def __call__(self, step):
        self.current_learning_rate.assign(tf.cond(
            step >= self._step_decay,
            true_fn=lambda: self._initial_learning_rate * (1 - 1 / (self._steps - self._step_decay) * (step - self._step_decay)),
            false_fn=lambda: self._initial_learning_rate
        ))
        return self.current_learning_rate

### Plot lr

In [None]:
# Matplotlib config
plt.ioff()
plt.rc('image', cmap='gray_r')
plt.rc('grid', linewidth=1)
plt.rc('xtick', top=False, bottom=False, labelsize='large')
plt.rc('ytick', left=False, right=False, labelsize='large')
plt.rc('axes', facecolor='F8F8F8', titlesize="large", edgecolor='white')
plt.rc('text', color='a8151a')
plt.rc('figure', facecolor='F0F0F0', figsize=(16,9))
# Matplotlib fonts
MATPLOTLIB_FONT_DIR = os.path.join(os.path.dirname(plt.__file__), "mpl-data/fonts/ttf")

def plot_learning_rate(lr_func, epochs):
    xx = np.arange(epochs+1, dtype=np.float)
    y = [lr_decay(x) for x in xx]
    fig, ax = plt.subplots(figsize=(9, 6))
    ax.set_xlabel('epochs')
    ax.set_title('Learning rate\ndecays from {:0.3g} to {:0.3g}'.format(y[0], y[-2]))
    ax.minorticks_on()
    ax.grid(True, which='major', axis='both', linestyle='-', linewidth=1)
    ax.grid(True, which='minor', axis='both', linestyle=':', linewidth=0.5)
    ax.step(xx,y, linewidth=3, where='post')
    display(fig)

def display_epoch_predict(photo2monet, photo, monet, monet2photo):
    plt.figure(figsize=(16, 816))
    display_list = [photo[0], photo2monet[0]]
    title = [f'photo', f'photo2monet']
    for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()
    
    plt.figure(figsize=(16, 816))
    display_list = [monet[0], monet2photo[0]]
    title = ['monet', 'monet2photo']
    for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()
         
class PlotTraining(tf.keras.callbacks.Callback):
    def __init__(self, sample_rate=10, zoom=16):
        self.sample_rate = sample_rate
        self.step = 0
        self.zoom = zoom
        self.steps_per_epoch = STEPS_PER_EPOCH*BATCH_SIZE
        
    def on_train_begin(self, logs={}):
        self.batch_history = {}
        self.batch_step = []
        self.epoch_history = {}
        self.epoch_step = []
        self.fig, self.axes = plt.subplots(1, 2, figsize=(16, 7))
        self.fig.subplots_adjust(wspace=0.12, hspace=0.12)
        plt.ioff()
       
    def on_batch_end(self, batch, logs={}):
        if (batch % self.sample_rate) == 0:
            self.batch_step.append(self.step)
            for k,v in logs.items():
                # do not log "batch" and "size" metrics that do not change
                # do not log training accuracy "acc"
                if k=='batch' or k=='size' or k.endswith('gen_loss'): # or k=='acc':
                    continue
                self.batch_history.setdefault(k, []).append(v)
#                 self.batch_history.setdefault(k, []).append(v[-1][-1][-1])
                self.step += 1
 
    def on_epoch_end(self, epoch, logs={}):
        plt.close(self.fig)
        self.axes[0].cla()
        self.axes[1].cla()
    
        self.axes[0].set_ylim(0, 2)
        self.axes[1].set_ylim(0, 10)

        self.epoch_step.append(self.step)
     
        for k,v in logs.items():
          # only log validation metrics
            if k.endswith('disc_loss'):
                  continue
            self.epoch_history.setdefault(k, []).append(v)
#             self.epoch_history.setdefault(k, []).append(v[-1][-1][-1])
        
        for k,v in self.batch_history.items():
            self.axes[0 if k.endswith('disc_loss') else 1].plot(np.array(self.batch_step) / self.steps_per_epoch, v, label=k)

        for k,v in self.epoch_history.items():
            self.axes[0 if k.endswith('disc_loss') else 1].plot(np.array(self.epoch_step) / self.steps_per_epoch, v, label=k, linewidth=3)

        self.axes[0].legend()
        self.axes[1].legend()
        self.axes[0].set_xlabel('epochs')
        self.axes[1].set_xlabel('epochs')
        self.axes[0].minorticks_on()
        self.axes[0].grid(True, which='major', axis='both', linestyle='-', linewidth=1)
        self.axes[0].grid(True, which='minor', axis='both', linestyle=':', linewidth=0.5)
        self.axes[1].minorticks_on()
        self.axes[1].grid(True, which='major', axis='both', linestyle='-', linewidth=1)
        self.axes[1].grid(True, which='minor', axis='both', linestyle=':', linewidth=0.5)
        display(self.fig)

### Define CycleGan model

In [None]:
class CycleGan(tf.keras.Model):
    def __init__(self, G_A2B, G_B2A, D_B, D_A):
        super(CycleGan, self).__init__()
        self.G_A2B = G_A2B
        self.G_B2A = G_B2A
        self.D_B = D_B
        self.D_A = D_A
      
    def compile(
            self, 
            G_optimizer, 
            D_optimizer,
            disc_loss_fn, 
            gen_loss_fn, 
            cycle_loss_fn, 
            identity_loss_fn
        ):
        super(CycleGan, self).compile()
        self.G_optimizer = G_optimizer
        self.D_optimizer = D_optimizer
        self.d_loss_fn = disc_loss_fn
        self.g_loss_fn = gen_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
    def call(self, inputs):
        return self.G_B2A(inputs)
    
    def train_step(self, input_batch):
        A, B = input_batch
        with tf.GradientTape(persistent=True) as tape:
            A2B = self.G_A2B(A, training=True)
            B2A = self.G_B2A(B, training=True)
            
            A2B2A = self.G_B2A(A2B, training=True)
            B2A2B = self.G_A2B(B2A, training=True)
            
            A2A = self.G_B2A(A, training=True)
            B2B = self.G_A2B(B, training=True)

            A2B_d_logits = self.D_B(A2B, training=True)
            B2A_d_logits = self.D_A(B2A, training=True)

            A2B_g_loss = self.g_loss_fn(A2B_d_logits)
            B2A_g_loss = self.g_loss_fn(B2A_d_logits)
            
            A2B2A_cycle_loss = self.cycle_loss_fn(A, A2B2A)
            B2A2B_cycle_loss = self.cycle_loss_fn(B, B2A2B)
            
            A2A_id_loss = self.identity_loss_fn(A, A2A)
            B2B_id_loss = self.identity_loss_fn(B, B2B)
            
            G_loss = (A2B_g_loss + B2A_g_loss) + (A2B2A_cycle_loss + B2A2B_cycle_loss) * 10.0 + (A2A_id_loss + B2B_id_loss) * 0.0
           
#             A2B = A2B_pool(A2B)  # or A2B = A2B_pool(A2B.numpy()), but it is much slower
#             B2A = B2A_pool(B2A)  # because of the communication between CPU and GPU
            
            A_d_logits = self.D_A(A, training=True)
            B2A_d_logits = self.D_A(B2A, training=True)
            
            B_d_logits = self.D_B(B, training=True)
            A2B_d_logits = self.D_B(A2B, training=True)

            A_d_loss, B2A_d_loss = self.d_loss_fn(A_d_logits, B2A_d_logits)
            B_d_loss, A2B_d_loss = self.d_loss_fn(B_d_logits, A2B_d_logits)
            
            D_A_gp = gradient_penalty(functools.partial(self.D_A, training=True), A, B2A, mode='none')
            D_B_gp = gradient_penalty(functools.partial(self.D_B, training=True), B, A2B, mode='none')

            D_loss = (A_d_loss + B2A_d_loss) + (B_d_loss + A2B_d_loss) + (D_A_gp + D_B_gp) * 10.0

        G_grad = tape.gradient(G_loss, self.G_A2B.trainable_variables + self.G_B2A.trainable_variables)
        self.G_optimizer.apply_gradients(zip(G_grad,self. G_A2B.trainable_variables + self.G_B2A.trainable_variables))
        
        D_grad = tape.gradient(D_loss, self.D_A.trainable_variables + self.D_B.trainable_variables)
        self.D_optimizer.apply_gradients(zip(D_grad, self.D_A.trainable_variables + self.D_B.trainable_variables))

        return {
            "monet_gen_loss": B2A_g_loss,
            "photo_gen_loss": A2B_g_loss,
            "G_gen_loss": G_loss,
            "monet_disc_loss": A_d_loss,
            "photo_disc_loss": B_d_loss,
            "D_disc_loss": D_loss
        }  

### Create and compile models, define losses and optimizers

In [None]:
# with strategy.scope(): 
G_A2B=ResnetGenerator(input_shape=SHAPE, norm='batch_norm') # layer_norm batch_norm instance_norm
G_B2A=ResnetGenerator(input_shape=SHAPE, norm='batch_norm') # layer_norm batch_norm instance_norm
D_B=ConvDiscriminator(input_shape=SHAPE, norm='batch_norm') # layer_norm batch_norm instance_norm
D_A=ConvDiscriminator(input_shape=SHAPE, norm='batch_norm') # layer_norm batch_norm instance_norm

model = CycleGan(G_A2B=G_A2B,G_B2A=G_B2A,D_B=D_B,D_A=D_A)

G_lr_scheduler = LinearDecay(START_LEARNING_RATE, EPOCHS * STEPS_PER_EPOCH, STEPS_PER_EPOCH)
D_lr_scheduler = LinearDecay(START_LEARNING_RATE_DISC, EPOCHS * STEPS_PER_EPOCH, STEPS_PER_EPOCH)

mse = tf.losses.MeanSquaredError()

def discriminator_loss(r_logit, f_logit):
    r_loss = mse(tf.ones_like(r_logit), r_logit)
    f_loss = mse(tf.zeros_like(f_logit), f_logit)
    return r_loss, f_loss

def generator_loss(f_logit):
    f_loss = mse(tf.ones_like(f_logit), f_logit)
    return f_loss

# bce = tf.losses.BinaryCrossentropy(from_logits=True)

# def discriminator_loss(r_logit, f_logit):
#     r_loss = bce(tf.ones_like(r_logit), r_logit)
#     f_loss = bce(tf.zeros_like(f_logit), f_logit)
#     return r_loss, f_loss

# def generator_loss(f_logit):
#     f_loss = bce(tf.ones_like(f_logit), f_logit)
#     return f_loss

model.compile(
    G_optimizer=Adam(learning_rate=G_lr_scheduler, beta_1=0.5),
    D_optimizer=Adam(learning_rate=D_lr_scheduler, beta_1=0.5),
    disc_loss_fn=discriminator_loss,
    gen_loss_fn=generator_loss,
    cycle_loss_fn=tf.losses.MeanAbsoluteError(),
    identity_loss_fn=tf.losses.MeanAbsoluteError()
)

In [None]:
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath="/kaggle/working/cp-last-gpu.h5", 
    monitor='photo_gen_loss', 
    verbose=1, 
    mode='min', 
    save_weights_only=True,
    save_freq=SAVE_FREQ
)
                                         
class DisplayCallback(tf.keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
        self._supports_tf_logs = True
    def on_epoch_end(self, epoch, logs=None):
        clear_output(wait=True)
        monet, photo = next(iter(test_dataset))
        display_epoch_predict(G_B2A.predict(photo), photo, monet, G_A2B.predict(monet))
        print(f'm_gen lr: {self.model.G_optimizer._decayed_lr(tf.float32).numpy()}')
        print(f'm_disc lr: {self.model.D_optimizer._decayed_lr(tf.float32).numpy()}')

class SaveModelsWeightsCallback(tf.keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
    def on_epoch_end(self, epoch, logs=None):
        G_A2B.save_weights('./G_A2B-tpu.h5')
        G_B2A.save_weights('./G_B2A-tpu.h5')

plot_training = PlotTraining(sample_rate=8, zoom=12)                                         
callbacks_list = [DisplayCallback(), plot_training, cp_callback, SaveModelsWeightsCallback()]

In [None]:
model.built = True
model.predict(next(iter(dataset))[0])
model.load_weights('../input/checkpoints/cp-last-gpu.h5')

In [None]:
model.fit(dataset, epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, callbacks=callbacks_list)

In [None]:
model.save_weights('./cp-last-gpu.h5')
G_B2A.save_weights('./G_B2A.h5')

# FIT NATIVE EXAMPLE

In [None]:
# with strategy.scope():
#     A2B_pool = ItemPool(50)
#     B2A_pool = ItemPool(50)

# class CycleGan(tf.keras.Model):
#     def __init__(self, G_A2B, G_B2A, D_B, D_A):
#         super(CycleGan, self).__init__()
#         self.G_A2B = G_A2B
#         self.G_B2A = G_B2A
#         self.D_B = D_B
#         self.D_A = D_A
      
#     def compile(
#             self, 
#             G_optimizer, 
#             D_optimizer,
#             disc_loss_fn, 
#             gen_loss_fn, 
#             cycle_loss_fn, 
#             identity_loss_fn
#         ):
#         super(CycleGan, self).compile()
#         self.G_optimizer = G_optimizer
#         self.D_optimizer = D_optimizer
#         self.d_loss_fn = disc_loss_fn
#         self.g_loss_fn = gen_loss_fn
#         self.cycle_loss_fn = cycle_loss_fn
#         self.identity_loss_fn = identity_loss_fn
        
#     def call(self, inputs):
#         return self.G_B2A(inputs)
    
#     def train_step(self, input_batch):
#         A, B = input_batch
#         with tf.GradientTape(persistent=True) as tape:
#             A2B = self.G_A2B(A, training=True)
#             B2A = self.G_B2A(B, training=True)
            
#             A2B2A = self.G_B2A(A2B, training=True)
#             B2A2B = self.G_A2B(B2A, training=True)
            
#             A2A = self.G_B2A(A, training=True)
#             B2B = self.G_A2B(B, training=True)

#             A2B_d_logits = self.D_B(A2B, training=True)
#             B2A_d_logits = self.D_A(B2A, training=True)

#             A2B_g_loss = self.g_loss_fn(A2B_d_logits)
#             B2A_g_loss = self.g_loss_fn(B2A_d_logits)
            
#             A2B2A_cycle_loss = self.cycle_loss_fn(A, A2B2A)
#             B2A2B_cycle_loss = self.cycle_loss_fn(B, B2A2B)
            
#             A2A_id_loss = self.identity_loss_fn(A, A2A)
#             B2B_id_loss = self.identity_loss_fn(B, B2B)
            
#             G_loss = (A2B_g_loss + B2A_g_loss) + (A2B2A_cycle_loss + B2A2B_cycle_loss) * 10.0 + (A2A_id_loss + B2B_id_loss) * 0.0
           
# #             A2B = A2B_pool(A2B)  # or A2B = A2B_pool(A2B.numpy()), but it is much slower
# #             B2A = B2A_pool(B2A)  # because of the communication between CPU and GPU
            
#             A_d_logits = self.D_A(A, training=True)
#             B2A_d_logits = self.D_A(B2A, training=True)
            
#             B_d_logits = self.D_B(B, training=True)
#             A2B_d_logits = self.D_B(A2B, training=True)

#             A_d_loss, B2A_d_loss = self.d_loss_fn(A_d_logits, B2A_d_logits)
#             B_d_loss, A2B_d_loss = self.d_loss_fn(B_d_logits, A2B_d_logits)
            
#             D_A_gp = gradient_penalty(functools.partial(self.D_A, training=True), A, B2A, mode='none')
#             D_B_gp = gradient_penalty(functools.partial(self.D_B, training=True), B, A2B, mode='none')

#             D_loss = (A_d_loss + B2A_d_loss) + (B_d_loss + A2B_d_loss) + (D_A_gp + D_B_gp) * 10.0

#         G_grad = tape.gradient(G_loss, self.G_A2B.trainable_variables + self.G_B2A.trainable_variables)
#         self.G_optimizer.apply_gradients(zip(G_grad,self. G_A2B.trainable_variables + self.G_B2A.trainable_variables))
        
#         D_grad = tape.gradient(D_loss, self.D_A.trainable_variables + self.D_B.trainable_variables)
#         self.D_optimizer.apply_gradients(zip(D_grad, self.D_A.trainable_variables + self.D_B.trainable_variables))

#         return {
#             "monet_gen_loss": B2A_g_loss,
#             "photo_gen_loss": A2B_g_loss,
#             "G_gen_loss": G_loss,
#             "monet_disc_loss": A_d_loss,
#             "photo_disc_loss": B_d_loss,
#             "D_disc_loss": D_loss
#         }
                                     
# # with strategy.scope():    
# model = CycleGan(
#     G_A2B=ResnetGenerator(input_shape=SHAPE),
#     G_B2A=ResnetGenerator(input_shape=SHAPE),
#     D_B=ConvDiscriminator(input_shape=SHAPE),
#     D_A=ConvDiscriminator(input_shape=SHAPE)
# )

# G_lr_scheduler = LinearDecay(0.0002, EPOCHS * STEPS_PER_EPOCH, STEPS_PER_EPOCH)
# D_lr_scheduler = LinearDecay(0.0002, EPOCHS * STEPS_PER_EPOCH, STEPS_PER_EPOCH)

# mse = tf.losses.MeanSquaredError()

# def discriminator_loss(r_logit, f_logit):
#     r_loss = mse(tf.ones_like(r_logit), r_logit)
#     f_loss = mse(tf.zeros_like(f_logit), f_logit)
#     return r_loss, f_loss

# def generator_loss(f_logit):
#     f_loss = mse(tf.ones_like(f_logit), f_logit)
#     return f_loss

# model.compile(
#     G_optimizer=Adam(learning_rate=G_lr_scheduler, beta_1=0.5),
#     D_optimizer=Adam(learning_rate=D_lr_scheduler, beta_1=0.5),
#     disc_loss_fn=discriminator_loss,
#     gen_loss_fn=generator_loss,
#     cycle_loss_fn=tf.losses.MeanAbsoluteError(),
#     identity_loss_fn=tf.losses.MeanAbsoluteError()
# )

# cp_callback = tf.keras.callbacks.ModelCheckpoint(
#     filepath="/kaggle/working/cp-last.h5", 
#     monitor='monet_gen_loss', 
#     verbose=1, 
#     mode='min', 
#     save_weights_only=True,
#     save_freq=SAVE_FREQ
# )
                                         
# class DisplayCallback(tf.keras.callbacks.Callback):
#     def __init__(self):
#         super().__init__()
#         self._supports_tf_logs = True
#     def on_epoch_end(self, epoch, logs=None):
#         clear_output(wait=True)
#         simple_image = next(iter(test_photo))
#         plot_images(self.model.predict(simple_image), simple_image)
#         print(f'm_gen lr: {self.model.G_optimizer._decayed_lr(tf.float32).numpy()}')
#         print(f'm_disc lr: {self.model.D_optimizer._decayed_lr(tf.float32).numpy()}')

# plot_training = PlotTraining(sample_rate=8, zoom=12)                                         
# callbacks_list = [DisplayCallback(), plot_training, cp_callback] # cp_callback, plot_training 
                                         
# model.fit(dataset, epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, callbacks=callbacks_list)

# CUSTOM NATIVE EXAMPLE

In [None]:
# args = {
#     'dataset': 'horse2zebra',
#     'datasets_dir': '../input/gan-getting-started',
#     'load_size': 286,
#     'crop_size': 256,
#     'batch_size': 1,
#     'epochs':EPOCHS,
#     'epoch_decay':EPOCHS,
#     'lr':0.0002,
#     'beta_1':0.5,
#     'adversarial_loss_mode':'lsgan',
#     'gradient_penalty_mode':'none',
#     'gradient_penalty_weight':10.0,
#     'cycle_loss_weight':10.0,
#     'identity_loss_weight':0.0,
#     'pool_size':50,
# }
# # output_dir
# output_dir = 'output_dir/'
# !mkdir output_dir

# A_img_paths = tf.io.gfile.glob(GCS_DS_PATH+'monet_jpg/*.jpg')
# B_img_paths = tf.io.gfile.glob(GCS_DS_PATH+'photo_jpg/*.jpg')
# A_B_dataset, len_dataset = make_zip_dataset(A_img_paths, B_img_paths, BATCH_SIZE, 286, 256, training=True, repeat=False)
# A_B_dataset = dataset
# A2B_pool = ItemPool(50)
# B2A_pool = ItemPool(50)

# G_A2B = ResnetGenerator(input_shape=SHAPE)
# G_B2A = ResnetGenerator(input_shape=SHAPE)

# D_A = ConvDiscriminator(input_shape=SHAPE)
# D_B = ConvDiscriminator(input_shape=SHAPE)

# d_loss_fn, g_loss_fn = get_adversarial_losses_fn('lsgan')
# cycle_loss_fn = tf.losses.MeanAbsoluteError()
# identity_loss_fn = tf.losses.MeanAbsoluteError()
# # lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
# #     0.0004, decay_steps=STEPS_PER_EPOCH, decay_rate=0.85, staircase=False, name=None
# # )
# G_lr_scheduler = LinearDecay(0.0002, EPOCHS * STEPS_PER_EPOCH, STEPS_PER_EPOCH)
# D_lr_scheduler = LinearDecay(0.0002, EPOCHS * STEPS_PER_EPOCH, STEPS_PER_EPOCH)
# G_optimizer = Adam(learning_rate=G_lr_scheduler, beta_1=0.5)
# D_optimizer = Adam(learning_rate=D_lr_scheduler, beta_1=0.5)

# @tf.function
# def train_G(A, B):
#     with tf.GradientTape() as t:
#         A2B = G_A2B(A, training=True)
#         B2A = G_B2A(B, training=True)
#         A2B2A = G_B2A(A2B, training=True)
#         B2A2B = G_A2B(B2A, training=True)
#         A2A = G_B2A(A, training=True)
#         B2B = G_A2B(B, training=True)

#         A2B_d_logits = D_B(A2B, training=True)
#         B2A_d_logits = D_A(B2A, training=True)

#         A2B_g_loss = g_loss_fn(A2B_d_logits)
#         B2A_g_loss = g_loss_fn(B2A_d_logits)
#         A2B2A_cycle_loss = cycle_loss_fn(A, A2B2A)
#         B2A2B_cycle_loss = cycle_loss_fn(B, B2A2B)
#         A2A_id_loss = identity_loss_fn(A, A2A)
#         B2B_id_loss = identity_loss_fn(B, B2B)

#         G_loss = (A2B_g_loss + B2A_g_loss) + (A2B2A_cycle_loss + B2A2B_cycle_loss) * 10.0 + (A2A_id_loss + B2B_id_loss) * 0.0

#     G_grad = t.gradient(G_loss, G_A2B.trainable_variables + G_B2A.trainable_variables)
#     G_optimizer.apply_gradients(zip(G_grad, G_A2B.trainable_variables + G_B2A.trainable_variables))

#     return A2B, B2A, {'A2B_g_loss': A2B_g_loss,
#                       'B2A_g_loss': B2A_g_loss,
#                       'A2B2A_cycle_loss': A2B2A_cycle_loss,
#                       'B2A2B_cycle_loss': B2A2B_cycle_loss,
#                       'A2A_id_loss': A2A_id_loss,
#                       'B2B_id_loss': B2B_id_loss}


# @tf.function
# def train_D(A, B, A2B, B2A):
#     with tf.GradientTape() as t:
#         A_d_logits = D_A(A, training=True)
#         B2A_d_logits = D_A(B2A, training=True)
#         B_d_logits = D_B(B, training=True)
#         A2B_d_logits = D_B(A2B, training=True)

#         A_d_loss, B2A_d_loss = d_loss_fn(A_d_logits, B2A_d_logits)
#         B_d_loss, A2B_d_loss = d_loss_fn(B_d_logits, A2B_d_logits)
#         D_A_gp = gradient_penalty(functools.partial(D_A, training=True), A, B2A, mode='none')
#         D_B_gp = gradient_penalty(functools.partial(D_B, training=True), B, A2B, mode='none')

#         D_loss = (A_d_loss + B2A_d_loss) + (B_d_loss + A2B_d_loss) + (D_A_gp + D_B_gp) * 10.0

#     D_grad = t.gradient(D_loss, D_A.trainable_variables + D_B.trainable_variables)
#     D_optimizer.apply_gradients(zip(D_grad, D_A.trainable_variables + D_B.trainable_variables))

#     return {'A_d_loss': A_d_loss + B2A_d_loss,
#             'B_d_loss': B_d_loss + A2B_d_loss,
#             'D_A_gp': D_A_gp,
#             'D_B_gp': D_B_gp}


# def train_step(A, B):
#     A2B, B2A, G_loss_dict = train_G(A, B)

#     # cannot autograph `A2B_pool`
#     A2B = A2B_pool(A2B)  # or A2B = A2B_pool(A2B.numpy()), but it is much slower
#     B2A = B2A_pool(B2A)  # because of the communication between CPU and GPU

#     D_loss_dict = train_D(A, B, A2B, B2A)

#     return G_loss_dict, D_loss_dict


# @tf.function
# def sample(A, B):
#     A2B = G_A2B(A, training=False)
#     B2A = G_B2A(B, training=False)
#     A2B2A = G_B2A(A2B, training=False)
#     B2A2B = G_A2B(B2A, training=False)
#     return A2B, B2A, A2B2A, B2A2B

# ep_cnt = tf.Variable(initial_value=0, trainable=False, dtype=tf.int64)

# # checkpoint
# # summary
# train_summary_writer = tf.summary.create_file_writer(output_dir+ 'summaries'+'/train')

# # sample
# sample_dir = output_dir + 'samples_training/'
# !mkdir sample_dir

# data_iter = iter(A_B_dataset)
# # main loop
# with train_summary_writer.as_default():
#     for ep in range(EPOCHS):
#         start_epoch_time = time.time()
#         if ep < ep_cnt:
#             continue

#         # update epoch counter
#         ep_cnt.assign_add(1)

#         # train for an epoch
#         for step in range(STEPS_PER_EPOCH):
#             A, B = next(data_iter)
#             G_loss_dict, D_loss_dict = train_step(A, B)
#             del A, B
#             gc.collect()
#             if (step+1) % STEPS_PER_EPOCH == 0:
#                 clear_output(wait=True)
#                 print ('Time taken for Epoch {}/{} is {} \n'.format(ep+1, EPOCHS, str(datetime.timedelta(seconds=int(time.time()-start_epoch_time)))))
#                 print(f'm_gen lr: {G_optimizer._decayed_lr(tf.float32).numpy()}')
#                 print(f'm_disc lr: {D_optimizer._decayed_lr(tf.float32).numpy()} \n')
#                 print('A2B_g_loss : ' , G_loss_dict['A2B_g_loss'].numpy()) 
#                 print('B2A_g_loss : ' , G_loss_dict['B2A_g_loss'].numpy()) 
#                 print('A2B2A_cycle_loss : ' , G_loss_dict['A2B2A_cycle_loss'].numpy()) 
#                 print('B2A2B_cycle_loss : ' ,  G_loss_dict['B2A2B_cycle_loss'].numpy()) 
#                 print('A2A_id_loss : ' , G_loss_dict['A2A_id_loss'].numpy()) 
#                 print('B2B_id_loss : ' , G_loss_dict['B2B_id_loss'].numpy())
#                 print('A_d_loss : ' , D_loss_dict['A_d_loss'].numpy())
#                 print('B_d_loss : ' , D_loss_dict['B_d_loss'].numpy())
#                 print('D_A_gp : ' , D_loss_dict['D_A_gp'].numpy())
#                 print('D_B_gp : ' , D_loss_dict['D_B_gp'].numpy())
#                 A, B = next(iter(dataset))
#                 A2B, B2A, A2B2A, B2A2B = sample(A, B)
#                 plot_images(B2A, B)
#                 del A
#                 del B
#                 del A2B, B2A, A2B2A, B2A2B
#                 gc.collect()

In [None]:
for inp in test_photo.take(10):
    plot_images(model.predict(inp), inp)