In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import os
import PIL
import shutil
import numpy as np
import keras
import cv2

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

# **Utils**

In [None]:
class TfrecCreator:
    @staticmethod
    def _bytes_feature(value):
        """Returns a bytes_list from a string / byte."""
        if isinstance(value, type(tf.constant(0))):
            value = value.numpy()  # BytesList won't unpack a string from an EagerTensor.
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    @staticmethod
    def _float_feature(value):
        """Returns a float_list from a float / double."""
        return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

    @staticmethod
    def _int64_feature(value):
        """Returns an int64_list from a bool / enum / int / uint."""
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

    @staticmethod
    def serialize_example(feature0, feature1, feature2):
        feature = {
            'image_name': TfrecCreator._bytes_feature(feature0),
            'image': TfrecCreator._bytes_feature(feature1),
            'target': TfrecCreator._bytes_feature(feature2)
        }
        example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
        return example_proto.SerializeToString()

    @staticmethod
    def create_tfrec(jpeg_path, tfrecords_filepath):
        # PATHS TO IMAGES
        PATH = jpeg_path
        IMGS = os.listdir(PATH);
        print(f'There are %i train images and {len(IMGS)} test images')

        SIZE = 30
        CT = len(IMGS) // SIZE + int(len(IMGS) % SIZE != 0)
        for j in range(CT):
            print();
            print('Writing TFRecord %i of %i...' % (j, CT))
            CT2 = min(SIZE, len(IMGS) - j * SIZE)
            with tf.io.TFRecordWriter(tfrecords_filepath + ('%.2i-%i.tfrec' % (j, CT2))) as writer:
                for k in range(CT2):
                    img = cv2.imread(os.path.join(PATH,IMGS[SIZE * j + k]))
                    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)  # Fix incorrect colors
                    img = cv2.imencode('.jpg', img, (cv2.IMWRITE_JPEG_QUALITY, 94))[1].tostring()
                    name = IMGS[SIZE * j + k].split('.')[0]
                    example = TfrecCreator.serialize_example(
                        str.encode(name),
                        img,
                        str.encode('monet'))
                    writer.write(example)
                    if k % 100 == 0: print(k, ', ', end='')


def create_submission_file(photo_ds, monet_generator):

    if not os.path.exists('original_images'):
        os.mkdir('original_images')

    if not os.path.exists('images'):
        os.mkdir('images')
    i = 1
    for img in photo_ds:

        original_img = img.numpy()
        original_img = (original_img * 127.5 + 127.5).astype(np.uint8)
        original_img = PIL.Image.fromarray(original_img.squeeze(0))
        original_img.save("./original_images/" + str(i) + ".jpg")

        prediction = monet_generator(img, training=False)[0].numpy()
        prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
        im = PIL.Image.fromarray(prediction)
        im.save("./images/" + str(i) + ".jpg")
        i += 1

    shutil.make_archive("./images", 'zip', "./images")
    

def visualize_our_monet_photos(photo_ds, monet_generator):
    def close_event():
        plt.close()  # timer calls this function after 10 seconds and closes the window

    fig, ax = plt.subplots(5, 2, figsize=(12, 12))
    # creating a timer object and setting an interval of 10000 milliseconds
    timer = fig.canvas.new_timer(interval=10000)
    timer.add_callback(close_event)
    timer.start()

    for i, img in enumerate(photo_ds.take(5)):
        prediction = monet_generator(img, training=False)[0].numpy()
        prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
        img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

        ax[i, 0].imshow(img)
        ax[i, 1].imshow(prediction)
        ax[i, 0].set_title("Input Photo")
        ax[i, 1].set_title("Monet-esque")
        ax[i, 0].axis("off")
        ax[i, 1].axis("off")
    plt.show()


class CustomCallback(keras.callbacks.Callback):

    def __init__(self, photo_ds, monet_generator):
        self.photo_ds = photo_ds
        self.monet_generator = monet_generator

    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print(f"End epoch {epoch} of training; got log keys: {keys}")
        visualize_our_monet_photos(photo_ds=self.photo_ds, monet_generator=self.monet_generator)


In [None]:
# Get all monet file names -----------------------------------------------------------------------------------------
source_dir_path = '../input/gan-getting-started/photo_jpg'
all_files = os.listdir(source_dir_path)
images = []
for file in all_files:
    if file.endswith('.jpg'):
        images.append(file)

# **Choose 30 images**

In [None]:
# Put first 30 in seperate directory -------------------------------------------------------------------------------
if len(images) > 30:
    images = images[:30]
chosen30dir = './chosen30'
destination_dir_name = 'first_30'
destination_full_path = os.path.join(chosen30dir, destination_dir_name)
if not os.path.isdir(chosen30dir):
    os.mkdir(chosen30dir)
if not os.path.isdir(destination_full_path):
    os.mkdir(destination_full_path)
for img_name in images:
    shutil.copy2(os.path.join(source_dir_path, img_name), os.path.join(destination_full_path, img_name))
    

# **Create TFrec file**

In [None]:
# Create TFrec file ------------------------------------------------------------------------------------------------
if not os.path.isdir('./tfrec'):
    os.mkdir('./tfrec')

monet_tfrecords_dir_path = './tfrec/monet_tfrec_' + destination_dir_name
if not os.path.isdir(monet_tfrecords_dir_path):
    os.mkdir(monet_tfrecords_dir_path)

TfrecCreator.create_tfrec(jpeg_path=destination_full_path,
                          tfrecords_filepath=monet_tfrecords_dir_path + '/monet')

# **DataLoader**

In [None]:
class TFRecordsDataLoader:

    def __init__(self, directory_path, monet_dir_path='monet_tfrec'):
        self.AUTOTUNE = tf.data.experimental.AUTOTUNE

        # Get file names of all dataset
        self.MONET_FILENAMES = tf.io.gfile.glob(os.path.join(monet_dir_path, '*.tfrec'))
        print(f'Monet TFRecord Files: {len(self.MONET_FILENAMES)}')

        self.PHOTO_FILENAMES = tf.io.gfile.glob(os.path.join(directory_path, 'photo_tfrec/*.tfrec'))
        print(f'Photo TFRecord Files: {len(self.PHOTO_FILENAMES)}')

        self.tfrecord_format = {
            "image_name": tf.io.FixedLenFeature([], tf.string),
            "image": tf.io.FixedLenFeature([], tf.string),
            "target": tf.io.FixedLenFeature([], tf.string)
        }
  
        self.monet_splits = {
            "monet00-30": 1,
        }

        self.photo_splits = {
            "photo00-352": 0.05,
            "photo01-352": 0.05,
            "photo02-352": 0.05,
            "photo03-352": 0.05,
            "photo04-352": 0.05,
            "photo05-352": 0.05,
            "photo06-352": 0.05,
            "photo07-352": 0.05,
            "photo08-352": 0.05,
            "photo09-352": 0.05,
            "photo10-352": 0.05,
            "photo11-352": 0.05,
            "photo12-352": 0.05,
            "photo13-352": 0.05,
            "photo14-352": 0.05,
            "photo15-352": 0.05,
            "photo16-352": 0.05,
            "photo17-352": 0.05,
            "photo18-352": 0.05,
            "photo19-350": 0.05,
        }

        self.IMAGE_SIZE = [256, 256]

    def decode_image(self, image):
        # Decode a JPEG-encoded image to a uint8 tensor
        image = tf.image.decode_jpeg(image, channels=3)
        # Scale the images to a [-1, 1] scale.
        image = (tf.cast(image, tf.float32) / 127.5) - 1
        # Reshape to [255, 255, 3]
        image = tf.reshape(image, [*self.IMAGE_SIZE, 3])
        return image

    def read_tfrecord(self, example):
        example = tf.io.parse_single_example(serialized=example, features=self.tfrecord_format)
        image = self.decode_image(image=example['image'])
        return image

    def load_dataset(self, filenames, labeled=True, ordered=False):
        """
        Define the function to extract the image from the files.
        :param filenames:
        :param labeled:
        :param ordered:
        :return:
        """
        dataset = tf.data.TFRecordDataset(filenames)
        dataset = dataset.map(self.read_tfrecord, num_parallel_calls=self.AUTOTUNE)
        return dataset

# **Nets**

In [None]:
OUTPUT_CHANNELS = 3


def downsample(filters, size, apply_instancenorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

    if apply_instancenorm:
        result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    result.add(layers.LeakyReLU())

    return result


def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2DTranspose(filters, size, strides=2,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))

    result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    if apply_dropout:
        result.add(layers.Dropout(0.5))

    result.add(layers.ReLU())

    return result


def Generator():
    """
    The generator first downsamples the input image and then upsample while establishing long skip connections.
    Skip connections are a way to help bypass the vanishing gradient problem by concatenating the output of a layer
    to multiple layers instead of only one.
    Here we concatenate the output of the downsample layer to the upsample layer in a symmetrical fashion.
    :return:
    """
    inputs = layers.Input(shape=[256,256,3])

    # bs = batch size
    down_stack = [
        downsample(64, 4, apply_instancenorm=False), # (bs, 128, 128, 64)
        downsample(128, 4), # (bs, 64, 64, 128)
        downsample(256, 4), # (bs, 32, 32, 256)
        downsample(512, 4), # (bs, 16, 16, 512)
        downsample(512, 4), # (bs, 8, 8, 512)
        downsample(512, 4), # (bs, 4, 4, 512)
        downsample(512, 4), # (bs, 2, 2, 512)
        downsample(512, 4), # (bs, 1, 1, 512)
    ]

    up_stack = [
        upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
        upsample(512, 4), # (bs, 16, 16, 1024)
        upsample(256, 4), # (bs, 32, 32, 512)
        upsample(128, 4), # (bs, 64, 64, 256)
        upsample(64, 4), # (bs, 128, 128, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                  strides=2,
                                  padding='same',
                                  kernel_initializer=initializer,
                                  activation='tanh') # (bs, 256, 256, 3)

    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip])

    x = last(x)

    return keras.Model(inputs=inputs, outputs=x)


def Discriminator():
    """
    The discriminator takes in the input image and classifies it as real or fake (generated).
    Instead of outputing a single node, the discriminator outputs a smaller 2D image with higher pixel values
    indicating a real classification and lower values indicating a fake classification.
    :return:
    """
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    inp = layers.Input(shape=[256, 256, 3], name='input_image')

    x = inp

    down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
    down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
    down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)

    zero_pad1 = layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
    conv = layers.Conv2D(512, 4, strides=1,
                         kernel_initializer=initializer,
                         use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

    norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(conv)

    leaky_relu = layers.LeakyReLU()(norm1)

    zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    last = layers.Conv2D(1, 4, strides=1,
                         kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

    return tf.keras.Model(inputs=inp, outputs=last)

# **CycleGan Model**

In [None]:
class CycleGan(keras.Model):
    """
    subclass a tf.keras.Model so that we can run fit() later to train our model.
    During the training step, the model transforms a photo to a Monet painting and then back to a photo.
    The difference between the original photo and the twice-transformed photo is the cycle-consistency loss.
    We want the original photo and the twice-transformed photo to be similar to one another.
    """
    def __init__(
            self,
            monet_generator,
            photo_generator,
            monet_discriminator,
            photo_discriminator,
            lambda_cycle=10,
    ):
        super(CycleGan, self).__init__()
        self.m_gen = monet_generator
        self.p_gen = photo_generator
        self.m_disc = monet_discriminator
        self.p_disc = photo_discriminator
        self.lambda_cycle = lambda_cycle

    def compile(
            self,
            m_gen_optimizer,
            p_gen_optimizer,
            m_disc_optimizer,
            p_disc_optimizer,
            gen_loss_fn,
            disc_loss_fn,
            cycle_loss_fn,
            identity_loss_fn
    ):
        super(CycleGan, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn

    def train_step(self, batch_data):
        real_monet, real_photo = batch_data

        with tf.GradientTape(persistent=True) as tape:
            # photo to monet back to photo
            fake_monet = self.m_gen(real_photo, training=True)
            cycled_photo = self.p_gen(fake_monet, training=True)

            # monet to photo back to monet
            fake_photo = self.p_gen(real_monet, training=True)
            cycled_monet = self.m_gen(fake_photo, training=True)

            # generating itself
            same_monet = self.m_gen(real_monet, training=True)
            same_photo = self.p_gen(real_photo, training=True)

            # discriminator used to check, inputing real images
            disc_real_monet = self.m_disc(real_monet, training=True)
            disc_real_photo = self.p_disc(real_photo, training=True)

            # discriminator used to check, inputing fake images
            disc_fake_monet = self.m_disc(fake_monet, training=True)
            disc_fake_photo = self.p_disc(fake_photo, training=True)

            # evaluates generator loss
            monet_gen_loss = self.gen_loss_fn(disc_fake_monet)
            photo_gen_loss = self.gen_loss_fn(disc_fake_photo)

            # evaluates total cycle consistency loss
            total_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle) + self.cycle_loss_fn \
                (real_photo, cycled_photo, self.lambda_cycle)

            # evaluates total generator loss
            total_monet_gen_loss = monet_gen_loss + total_cycle_loss + self.identity_loss_fn(real_monet, same_monet, self.lambda_cycle)
            total_photo_gen_loss = photo_gen_loss + total_cycle_loss + self.identity_loss_fn(real_photo, same_photo, self.lambda_cycle)

            # evaluates discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)

        # Calculate the gradients for generator and discriminator
        monet_generator_gradients = tape.gradient(total_monet_gen_loss,
                                                  self.m_gen.trainable_variables)
        photo_generator_gradients = tape.gradient(total_photo_gen_loss,
                                                  self.p_gen.trainable_variables)

        monet_discriminator_gradients = tape.gradient(monet_disc_loss,
                                                      self.m_disc.trainable_variables)
        photo_discriminator_gradients = tape.gradient(photo_disc_loss,
                                                      self.p_disc.trainable_variables)

        # Apply the gradients to the optimizer
        self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients,
                                                 self.m_gen.trainable_variables))

        self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients,
                                                 self.p_gen.trainable_variables))

        self.m_disc_optimizer.apply_gradients(zip(monet_discriminator_gradients,
                                                  self.m_disc.trainable_variables))

        self.p_disc_optimizer.apply_gradients(zip(photo_discriminator_gradients,
                                                  self.p_disc.trainable_variables))

        return {
            "monet_gen_loss": total_monet_gen_loss,
            "photo_gen_loss": total_photo_gen_loss,
            "monet_disc_loss": monet_disc_loss,
            "photo_disc_loss": photo_disc_loss
        }


# **Loss Functions**

In [None]:
with strategy.scope():
    def discriminator_loss(real, generated):
        """
        The discriminator loss function below compares real images to a matrix of 1s and fake images to a matrix of 0s.
        The perfect discriminator will output all 1s for real images and all 0s for fake images.
        The discriminator loss outputs the average of the real and generated loss.
        :param real: real image
        :param generated: generated image
        :return: (real_loss + generated_loss) / 2. i.e real and generated loss avg
        """
        real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True,
                                                       reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real),
                                                                                                 real)
        generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True,
                                                            reduction=tf.keras.losses.Reduction.NONE)(
            tf.zeros_like(generated), generated)
        total_disc_loss = real_loss + generated_loss
        return total_disc_loss * 0.5

    def generator_loss(generated):
        """
        The generator wants to fool the discriminator into thinking the generated image is real.
        The perfect generator will have the discriminator output only 1s.
        Thus, it compares the generated image to a matrix of 1s to find the loss.
        :param generated: generated image
        :return: generator loss
        """
        return tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(
            tf.ones_like(generated), generated)

    def calc_cycle_loss(real_image, cycled_image, LAMBDA):
        """
        We want our original photo and the twice transformed photo to be similar to one another.
        Thus, we can calculate the cycle consistency loss be finding the average of their difference.
        :param real_image:
        :param cycled_image:
        :param LAMBDA: the cycle loss factor
        :return: cycle loss
        """
        loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
        return LAMBDA * loss1

    def identity_loss(real_image, same_image, LAMBDA):
        """
        The identity loss compares the image with its generator (i.e. photo with photo generator).
        If given a photo as input, we want it to generate the same image as the image was originally a photo.
        The identity loss compares the input with the output of the generator.
        :param real_image:
        :param same_image:
        :param LAMBDA: loss factor
        :return:
        """
        loss = tf.reduce_mean(tf.abs(real_image - same_image))
        return LAMBDA * 0.5 * loss

In [None]:
dataset_path = '../input/gan-getting-started'
batch = 1
epochs = 25
    
with strategy.scope():
    # Create Generators and Discriminators -------------------------------------------------------------------------
    # Create Generators
    monet_generator = Generator()  # transforms photos to Monet-esque paintings
    photo_generator = Generator()  # transforms Monet paintings to be more like photos

    # Create Discriminators
    monet_discriminator = Discriminator()  # differentiates real Monet paintings and generated Monet paintings
    photo_discriminator = Discriminator()  # differentiates real photos and generated photos

    # Data loader --------------------------------------------------------------------------------------------------
    # Create our data-loader
    print(f'monet tfrec path: {monet_tfrecords_dir_path}')
    data_loader = TFRecordsDataLoader(directory_path=dataset_path, monet_dir_path=monet_tfrecords_dir_path)

    # Load datasets (monet and photo) To tensorflow tensor
    monet_ds = data_loader.load_dataset(data_loader.MONET_FILENAMES, labeled=True).batch(batch)
    photo_ds = data_loader.load_dataset(data_loader.PHOTO_FILENAMES, labeled=True).batch(batch)

    # # DEBUG
    # # Iterate on datasets (sample batch) -------------------------------------------------------------------------
    # """
    # Since our generators are not trained yet, the generated Monet-esque photo does not show what is expected at
    # this point.
    # """
    # example_monet = next(iter(monet_ds))
    # example_photo = next(iter(photo_ds))
    #
    # to_monet = monet_generator(example_photo)
    #
    # plt.subplot(1, 2, 1)
    # plt.title("Original Photo")
    # plt.imshow(example_photo[0] * 0.5 + 0.5)
    #
    # plt.subplot(1, 2, 2)
    # plt.title("Monet-esque Photo")
    # plt.imshow(to_monet[0] * 0.5 + 0.5)
    # plt.show()

    # Train the CycleGAN -------------------------------------------------------------------------------------------
    # Create Optimizers
    monet_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    monet_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    # Create CycleGAN model
    cycle_gan_model = CycleGan(
        monet_generator=monet_generator,
        photo_generator=photo_generator,
        monet_discriminator=monet_discriminator,
        photo_discriminator=photo_discriminator,
        lambda_cycle=10
        )

    # Compile. i.e Configures the model for training
    cycle_gan_model.compile(
        m_gen_optimizer=monet_generator_optimizer,
        p_gen_optimizer=photo_generator_optimizer,
        m_disc_optimizer=monet_discriminator_optimizer,
        p_disc_optimizer=photo_discriminator_optimizer,
        gen_loss_fn=generator_loss,
        disc_loss_fn=discriminator_loss,
        cycle_loss_fn=calc_cycle_loss,
        identity_loss_fn=identity_loss
        )

    # # Create a callback to visualize generated monet images after every epoch
    # customCallback = CustomCallback(photo_ds=photo_ds, monet_generator=monet_generator)
    #
    # # Visualize our monet generated photos before training is started
    # visualize_our_monet_photos(photo_ds=photo_ds, monet_generator=monet_generator)

# **Train CycleGAN model**

In [None]:
with strategy.scope():
    cycle_gan_model.fit(
        tf.data.Dataset.zip((monet_ds, photo_ds)),
        # callbacks=[customCallback],
        epochs=epochs
    )

# **Create a submit zip file**

In [None]:
with strategy.scope():
    create_submission_file(photo_ds=photo_ds, monet_generator=monet_generator)