In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np
import os 
import tensorflow_addons as tfa

# try:
#     tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
#     print('Device:', tpu.master())
#     tf.config.experimental_connect_to_cluster(tpu)
#     tf.tpu.experimental.initialize_tpu_system(tpu)
#     strategy = tf.distribute.experimental.TPUStrategy(tpu)
# except:
#     strategy = tf.distribute.get_strategy()
# print('Number of replicas:', strategy.num_replicas_in_sync)
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
MONET_FILENAMES = tf.io.gfile.glob('../input/gan-getting-started/monet_tfrec/*.tfrec')
print('Monet TFRecord Files:', len(MONET_FILENAMES))
PHOTO_FILENAMES = tf.io.gfile.glob(str('../input/gan-getting-started/photo_tfrec/*.tfrec'))
print('Photo TFRecord Files:', len(PHOTO_FILENAMES))

In [None]:
im_size=[256,256]
def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*im_size, 3])
    return image
def read_tfrecord(example):
    tfrecord_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

In [None]:
def load_dataset(filenames, labeled=True, ordered=False):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    return dataset

In [None]:
monet_ds = load_dataset(MONET_FILENAMES, labeled=True).batch(1)
photo_ds = load_dataset(PHOTO_FILENAMES, labeled=True).batch(1)

In [None]:
im = next(iter(monet_ds))
pim = next(iter(photo_ds))

In [None]:
im.numpy().min()

In [None]:
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.imshow(pim[0]*0.5+0.5)
plt.subplot(1,2,2)
plt.imshow(im[0]*0.5+0.5)

In [None]:
def conv_blk(filters,k_size,strd,inp,norm = True):
    x = layers.Conv2D(filters,k_size, strides=strd,padding='same',kernel_initializer = tf.random_normal_initializer(0., 0.02),use_bias=False)(inp)
    if norm:
        x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)
    return x

def rev_conv(chnl,k_size, inp):
    x = layers.Conv2DTranspose(chnl,k_size,strides = 2,padding='same',kernel_initializer = tf.random_normal_initializer(0., 0.02),use_bias=False)(inp)
    x = layers.LayerNormalization(axis=-1)(x)
    x = layers.ReLU(max_value=6)(x)
    return x

In [None]:
def Gen():
    im_input= layers.Input(shape=[256,256,3])
    extractor = [32,64,128,256,256,512,512,512]
    skips = []
    x = im_input
    first_ly = True
    for ly in extractor:
        if first_ly:
            x = conv_blk(ly,4,2,x,norm = False)
        else:
            x = conv_blk(ly,4,2,x)
        skips.append(x)
    skips = reversed(skips[:-1])
    
    constructor = [512,512,256,256,128,64,32]
    for ly,s in zip(constructor,skips):
        x = rev_conv(ly,4,x)
        x = layers.Concatenate()([x,s])
    out = layers.Conv2DTranspose(3,4,strides=2,padding='same',activation='tanh')(x)
    return tf.keras.Model(inputs=im_input, outputs=out, name="Generator")

def myGen():
    im_input= layers.Input(shape=[256,256,3])
    c = layers.Conv2D(32,1,strides = 1,padding = 'same',activation = 'relu',bias_initializer=tf.keras.initializers.constant(0.0))(im_input)
    x = layers.Conv2D(32,7, strides=1,padding='same',activation='relu',bias_initializer=tf.keras.initializers.constant(0.0))(im_input)
    x = layers.Conv2D(64,3, strides=2,padding='same',activation='relu',bias_initializer=tf.keras.initializers.constant(0.0))(x)
    x = layers.Conv2D(128,3, strides=2,padding='same',activation='relu',bias_initializer=tf.keras.initializers.constant(0.0))(x)
    v = x
    for ly in range(2):
        x = conv_blk(128,3,1,x)
        x = layers.Add()([x,v])
        v = x

    x = rev_conv(64,4,x)
    x = rev_conv(32,4,x)
    x = layers.Add()([x,c])
    out = layers.Conv2D(3,3,strides=1,padding='same',activation='tanh')(x)
    return tf.keras.Model(inputs=im_input, outputs=out, name="Generator")


def Judge():
    im_input= layers.Input(shape=[256,256,3])
    x = conv_blk(32,6,2,im_input,norm=False)
    v = x
    x = conv_blk(32,3,1,x)
    x = layers.Add()([x,v])
    x = conv_blk(64,3,2,x)
    v = x
    x = conv_blk(64,3,1,x)
    x = layers.Add()([x,v])
    x = conv_blk(128,3,2,x)
    x = conv_blk(256,3,2,x)
    out = layers.Conv2D(1,3,1,padding='same',kernel_initializer = tf.random_normal_initializer(0., 0.02),bias_initializer=tf.keras.initializers.constant(0.0))(x)
    return tf.keras.Model(inputs=im_input,outputs=out,name='Judge')




In [None]:
class CycleGan(keras.Model):
    def __init__(self,monet_gen,photo_gen,monet_judge,photo_judge,lambda_cycle=8):
        super(CycleGan, self).__init__()
        self.m_gen = monet_gen
        self.p_gen = photo_gen
        self.m_judge = monet_judge
        self.p_judge = photo_judge
        self.lambda_cycle = lambda_cycle
    
    def compile(self,m_gen_opt,p_gen_opt,m_jg_opt,p_jg_opt,gen_loss_fn,judge_loss_fn,cycle_loss_fn,identitty_loss_fn):
        super(CycleGan,self).compile()
        self.m_gen_optimizer = m_gen_opt
        self.p_gen_optimizer = p_gen_opt
        self.m_judge_optimizer = m_jg_opt
        self.p_judge_optimizer = p_jg_opt
        self.gen_loss_fn = gen_loss_fn
        self.judge_loss_fn = judge_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.id_loss_fn = identitty_loss_fn
    
    def train_step(self,batch_data):
        real_monet, real_photo = batch_data
        with tf.GradientTape(persistent=True) as tape:
            fake_monet = self.m_gen(real_photo, training=True)
            cycled_photo = self.p_gen(fake_monet, training=True)

            # monet to photo back to monet
            fake_photo = self.p_gen(real_monet, training=True)
            cycled_monet = self.m_gen(fake_photo, training=True)

            # generating itself
            same_monet = self.m_gen(real_monet, training=True)
            same_photo = self.p_gen(real_photo, training=True)

            # discriminator used to check, inputing real images
            disc_real_monet = self.m_judge(real_monet, training=True)
            disc_real_photo = self.p_judge(real_photo, training=True)

            # discriminator used to check, inputing fake images
            disc_fake_monet = self.m_judge(fake_monet, training=True)
            disc_fake_photo = self.p_judge(fake_photo, training=True)

            # evaluates generator loss
            monet_gen_loss = self.gen_loss_fn(disc_fake_monet)
            photo_gen_loss = self.gen_loss_fn(disc_fake_photo)

            # evaluates total cycle consistency loss
            total_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle) + self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle)

            # evaluates total generator loss
            total_monet_gen_loss = monet_gen_loss + total_cycle_loss + self.id_loss_fn(real_monet, same_monet, self.lambda_cycle)
            total_photo_gen_loss = photo_gen_loss + total_cycle_loss + self.id_loss_fn(real_photo, same_photo, self.lambda_cycle)

            # evaluates discriminator loss
            monet_disc_loss = self.judge_loss_fn(disc_real_monet, disc_fake_monet)
            photo_disc_loss = self.judge_loss_fn(disc_real_photo, disc_fake_photo)

        # Calculate the gradients for generator and discriminator
        monet_generator_gradients = tape.gradient(total_monet_gen_loss,
                                                  self.m_gen.trainable_variables)
        photo_generator_gradients = tape.gradient(total_photo_gen_loss,
                                                  self.p_gen.trainable_variables)

        monet_discriminator_gradients = tape.gradient(monet_disc_loss,
                                                      self.m_judge.trainable_variables)
        photo_discriminator_gradients = tape.gradient(photo_disc_loss,
                                                      self.p_judge.trainable_variables)

        # Apply the gradients to the optimizer
        self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients,
                                                 self.m_gen.trainable_variables))

        self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients,
                                                 self.p_gen.trainable_variables))

        self.m_judge_optimizer.apply_gradients(zip(monet_discriminator_gradients,
                                                  self.m_judge.trainable_variables))

        self.p_judge_optimizer.apply_gradients(zip(photo_discriminator_gradients,
                                                  self.p_judge.trainable_variables))
        
        return {
            "monet_gen_loss": total_monet_gen_loss,
            "photo_gen_loss": total_photo_gen_loss,
            "monet_disc_loss": monet_disc_loss,
            "photo_disc_loss": photo_disc_loss
        }


In [None]:

def discriminator_loss(real, generated):
    real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)

    generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)

    total_disc_loss = real_loss + generated_loss

    return total_disc_loss * 0.8

def generator_loss(generated):
    return tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(generated), generated)

def calc_cycle_loss(real_image, cycled_image, LAMBDA):
    loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
    return LAMBDA * loss1

def identity_loss(real_image, same_image, LAMBDA):
    loss = tf.reduce_mean(tf.abs(real_image - same_image))
    return LAMBDA * 0.5 * loss


In [None]:

monet_generator_optimizer = tf.keras.optimizers.Adam(0.001, beta_1=0.5)
photo_generator_optimizer = tf.keras.optimizers.Adam(0.001, beta_1=0.5)

monet_discriminator_optimizer = tf.keras.optimizers.Adam(0.01, beta_1=0.5)
photo_discriminator_optimizer = tf.keras.optimizers.Adam(0.01, beta_1=0.5)

In [None]:
#with strategy.scope():
#    monet_generator_optimizer = tf.keras.optimizers.RMSprop(2e-4,momentum=0.9)
#    photo_generator_optimizer = tf.keras.optimizers.RMSprop(2e-4,momentum=0.9)

#    monet_discriminator_optimizer = tf.keras.optimizers.RMSprop(2e-4,momentum=0.9)
#    photo_discriminator_optimizer = tf.keras.optimizers.RMSprop(2e-4,momentum=0.9)

In [None]:

#monet_generator = myGen() # transforms photos to Monet-esque paintings
#photo_generator = myGen() # transforms Monet paintings to be more like photos#

monet_generator = Gen() # transforms photos to Monet-esque paintings
photo_generator = Gen() # transforms Monet paintings to be more like photos

monet_discriminator = Judge() # differentiates real Monet paintings and generated Monet paintings
photo_discriminator = Judge() # differentiates real photos and generated photos

In [None]:
tf.keras.utils.plot_model(monet_generator, show_shapes=True, dpi=64)

In [None]:
def lr_sch(epoch,lr):
    if epoch < 10:
        if epoch % 3 == 0:
            lr = 0.01
        else:
            lr = 0.002
    else:
        if epoch % 3 == 0:
            lr = 0.0007
        else:
            lr = 0.0001
    #print(lr)
    return lr
callback = tf.keras.callbacks.LearningRateScheduler(lr_sch)

In [None]:

cycle_gan_model = CycleGan(
    monet_generator, photo_generator, monet_discriminator, photo_discriminator
)

In [None]:
cycle_gan_model.compile(
            m_gen_opt = monet_generator_optimizer,
            p_gen_opt = photo_generator_optimizer,
            m_jg_opt = monet_discriminator_optimizer,
            p_jg_opt = photo_discriminator_optimizer,
            gen_loss_fn = generator_loss,
            judge_loss_fn = discriminator_loss,
            cycle_loss_fn = calc_cycle_loss,
            identitty_loss_fn = identity_loss
        )

In [None]:
hist = cycle_gan_model.fit(
    tf.data.Dataset.zip((monet_ds, photo_ds)),
    epochs=30,batch_size = 32,callbacks = [callback]
)

In [None]:
monet_generator.save('monet_gen.h5')

In [None]:
#monet_generator.load_weights('../input/mygan-wh5/monet_my_gen.h5')

In [None]:
photo_it = iter(photo_ds)
painting_it = iter(monet_ds)


In [None]:
for i in range(10):
    im = photo_it.next() 
    m_y = monet_generator.predict(im)
    m_y = m_y*0.5 +0.5
    im = im *0.5 +0.5
    plt.figure(figsize=(16,8))
    plt.subplot(1,2,1)
    plt.title('real')
    plt.imshow(im[0])
    plt.subplot(1,2,2)
    plt.title('generated')
    plt.imshow(m_y[0])
    plt.show()

In [None]:
type(m_y[0][0][0][0])

In [None]:
for i in range(10):
    im = painting_it.next() *0.5 +0.5
    m_y = photo_generator.predict(im)
    m_y = m_y 
    plt.figure(figsize=(16,8))
    plt.subplot(1,2,1)
    plt.title('real')
    plt.imshow(im[0])
    plt.subplot(1,2,2)
    plt.title('generated')
    plt.imshow(m_y[0])
    plt.show()