In [None]:
import numpy as np
import keras
from keras import layers
import tensorflow as tf
import tensorflow_addons as tfa
from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as pp
import os
import random
import PIL
import shutil
import math

from keras.optimizers import Adam
from keras.initializers import RandomNormal
from keras.models import Model
from keras.models import Input
from keras.layers import Conv2D
from keras.layers import LeakyReLU
from keras.layers import Activation
from keras.layers import Concatenate
from keras.layers import BatchNormalization
from keras.layers import Conv2DTranspose
from tensorflow_addons.layers import InstanceNormalization
from keras.losses import MeanSquaredError
from keras.losses import MeanAbsoluteError
from keras.initializers import RandomNormal
from keras.losses import Reduction
from keras.callbacks import LearningRateScheduler

AUTOTUNE = tf.data.experimental.AUTOTUNE

img_size=(256,256,3)
MSE=MeanSquaredError()
MAE=MeanAbsoluteError()

def img2tensorm(image):
    image = tf.image.decode_jpeg(image, channels=3)    
    image = tf.image.random_flip_left_right(image) #mirror        
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*[256,256], 3])   
    return image

def get_tfrecm(img):
    tfrec_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.string)
    }
    img = tf.io.parse_single_example(img, tfrec_format)
    image = img2tensorm(img['image'])
    return image

def img2tensorp(image):
    image = tf.image.decode_jpeg(image, channels=3)        
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*[256,256], 3])   
    return image

def get_tfrecp(img):
    tfrec_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.string)
    }
    img = tf.io.parse_single_example(img, tfrec_format)
    image = img2tensorp(img['image'])
    return image

mfiles = tf.io.gfile.glob('/kaggle/input/gan-getting-started/monet_tfrec/*.tfrec')
mds = tf.data.TFRecordDataset(mfiles)
mds = mds.shuffle(1000)
mds = mds.take(30)
mds = mds.repeat(count=10)
mds = mds.map(get_tfrecm, num_parallel_calls=AUTOTUNE)
mds = mds.batch(1)
mds = mds.cache()
mds = mds.prefetch(AUTOTUNE)

pfiles = tf.io.gfile.glob('/kaggle/input/gan-getting-started/photo_tfrec/*.tfrec')
pds = tf.data.TFRecordDataset(pfiles)
pds = pds.map(get_tfrecp, num_parallel_calls=AUTOTUNE)
pds = pds.batch(1)
pds = pds.cache()
pds = pds.prefetch(AUTOTUNE)

In [None]:
def discriminator():
    init = tf.random_normal_initializer(0, 0.02)
    g_init = RandomNormal(mean=0, stddev=0.02)
    img = Input(shape=img_size)
    
    out = Conv2D(64, 4, 2, padding='same', kernel_initializer=init)(img)
    out = LeakyReLU()(out) 
    out = Conv2D(128, 4, 2, padding='same', kernel_initializer=init,use_bias=False)(out)
    out = InstanceNormalization(axis=-1,gamma_initializer=g_init)(out)
    out = LeakyReLU()(out)
    out = Conv2D(256, 4, 2, padding='same', kernel_initializer=init,use_bias=False)(out)
    out = InstanceNormalization(axis=-1,gamma_initializer=g_init)(out)
    out = LeakyReLU()(out)
    out = Conv2D(512, 4, 2, padding='same', kernel_initializer=init,use_bias=False)(out)
    out = InstanceNormalization(axis=-1,gamma_initializer=g_init)(out)
    out = LeakyReLU()(out)
    out = Conv2D(512, 4, 1, padding='same', kernel_initializer=init,use_bias=False)(out)
    out = InstanceNormalization(axis=-1,gamma_initializer=g_init)(out)
    out = LeakyReLU()(out)
    
    out = Conv2D(1, 4, 1, padding='same', kernel_initializer=init)(out)
    return Model(img, out)
 
disc_m = discriminator()
disc_p = discriminator()

In [None]:
def rblock(filters, out0):
    init = tf.random_normal_initializer(0, 0.02)
    g_init = RandomNormal(mean=0, stddev=0.02)
    out = tf.pad(out0, [[0,0], [1,1], [1,1], [0,0]], mode="REFLECT")    
    
    out = Conv2D(filters, 3, 1, padding='valid', kernel_initializer=init,use_bias=False)(out)
    out = InstanceNormalization(axis=-1,gamma_initializer=g_init)(out)
    out = Activation('relu')(out)
    
    out = tf.pad(out, [[0,0], [1,1], [1,1], [0,0]], mode="REFLECT")
    
    out = Conv2D(filters, 3, 1, padding='valid', kernel_initializer=init,use_bias=False)(out)
    out = InstanceNormalization(axis=-1,gamma_initializer=g_init)(out)
    out = Concatenate()([out, out0])
    return out

In [None]:
def generator():
    init = tf.random_normal_initializer(0, 0.02)
    g_init = RandomNormal(mean=0, stddev=0.02)
    img = Input(shape=img_size)
    
    out = tf.pad(img, [[0,0], [3,3], [3,3], [0,0]], mode="REFLECT")
    
    out = Conv2D(64, 7, kernel_initializer=init,use_bias=False)(out)
    out = InstanceNormalization(axis=-1,gamma_initializer=g_init)(out)
    out = Activation('relu')(out)
    out = Conv2D(128, 3, 2, padding='same', kernel_initializer=init,use_bias=False)(out)
    out = InstanceNormalization(axis=-1,gamma_initializer=g_init)(out)
    out = Activation('relu')(out)  
    out = Conv2D(256, 3, 2, padding='same', kernel_initializer=init,use_bias=False)(out)
    out = InstanceNormalization(axis=-1,gamma_initializer=g_init)(out)
    out = Activation('relu')(out)
    out = Conv2D(512, 3, 2, padding='same', kernel_initializer=init,use_bias=False)(out)
    out = InstanceNormalization(axis=-1,gamma_initializer=g_init)(out)
    out = Activation('relu')(out)

    for ii in range(9):
        out = rblock(512, out)
        
    out = Conv2DTranspose(256, 3, 2, padding='same', kernel_initializer=init,use_bias=False)(out)
    out = InstanceNormalization(axis=-1,gamma_initializer=g_init)(out)
    out = Activation('relu')(out)
    out = Conv2DTranspose(128, 3, 2, padding='same', kernel_initializer=init,use_bias=False)(out)
    out = InstanceNormalization(axis=-1,gamma_initializer=g_init)(out)
    out = Activation('relu')(out)
    out = Conv2DTranspose(64, 3, 2, padding='same', kernel_initializer=init,use_bias=False)(out)
    out = InstanceNormalization(axis=-1,gamma_initializer=g_init)(out)
    out = Activation('relu')(out)
    
    out = tf.pad(out, [[0,0], [3,3], [3,3], [0,0]], mode="REFLECT")
       
    out = Conv2D(3, 7, padding='valid')(out)
    out_img = Activation('tanh')(out)
    return Model(img, out_img)

gen_m2p=generator()
gen_p2m=generator()

In [None]:
def getfromPool(pool, img):
    if(len(pool) < 50):
        pool.append(img)
        return img
    elif(random.choice([0, 1])):
        if(img in pool):
            pool.remove(img)
    else:
        i = random.randint(0, 49)
        pool[i] = img
    i = random.randint(0, 49)   
    return pool[i]

In [None]:
class FullCycleGanModel(keras.Model):
    def __init__(self,gen_m2p,gen_p2m,disc_m,disc_p):
        super(FullCycleGanModel, self).__init__()
        self.gen_m2p = gen_m2p
        self.gen_p2m = gen_p2m
        self.disc_m = disc_m
        self.disc_p = disc_p
        self.pool_generated = list()
        
    def compile(self,gen_opt_m,gen_opt_p,disc_opt_m,disc_opt_p,gen_loss,disc_loss,c_loss,id_loss):
        super(FullCycleGanModel, self).compile()
        self.gen_opt_m = gen_opt_m
        self.gen_opt_p = gen_opt_p
        self.disc_opt_m = disc_opt_m
        self.disc_opt_p = disc_opt_p
        self.gen_loss = gen_loss
        self.disc_loss = disc_loss
        self.c_loss = c_loss
        self.id_loss = id_loss
        
        
    def train_step(self, batch):
        real_p, real_m = batch
        with tf.GradientTape(persistent=True) as tape:
            generated_m = self.gen_p2m(real_p, training=True)
            generated_p = self.gen_m2p(real_m, training=True)
            generated_real_p = self.gen_m2p(generated_m, training=True)
            generated_real_m = self.gen_p2m(generated_p, training=True)
            id_p = self.gen_m2p(real_p, training=True)
            id_m = self.gen_p2m(real_m, training=True)
            
            generated_m = getfromPool(self.pool_generated, generated_m)

            disc_real_m = self.disc_m(real_m, training=True)
            disc_real_p = self.disc_p(real_p, training=True)
            disc_generated_m = self.disc_m(generated_m, training=True)
            disc_generated_p = self.disc_p(generated_p, training=True)

            gen_loss_m = self.gen_loss(disc_generated_m)
            gen_loss_p = self.gen_loss(disc_generated_p)
            gen_real_loss_m = self.c_loss(real_m, generated_real_m) * 10
            gen_real_loss_p = self.c_loss(real_p, generated_real_p) * 10
            id_loss_m = (self.id_loss(real_m, id_m)* 10 * 0.5)
            id_loss_p = (self.id_loss(real_p, id_p)* 10 * 0.5)
            
            
            loss_m = gen_loss_m + gen_real_loss_m + id_loss_m
            loss_p = gen_loss_p + gen_real_loss_p + id_loss_p
            disc_loss_m = self.disc_loss(disc_real_m, disc_generated_m) * 0.5
            disc_loss_p = self.disc_loss(disc_real_p, disc_generated_p) * 0.5


        grads_m = tape.gradient(loss_m, self.gen_p2m.trainable_variables)
        grads_p = tape.gradient(loss_p, self.gen_m2p.trainable_variables)
        disc_grads_m = tape.gradient(disc_loss_m, self.disc_m.trainable_variables)
        disc_grads_p = tape.gradient(disc_loss_p, self.disc_p.trainable_variables)

        self.gen_opt_m.apply_gradients(zip(grads_m, self.gen_p2m.trainable_variables))
        self.gen_opt_p.apply_gradients(zip(grads_p, self.gen_m2p.trainable_variables))
        self.disc_opt_m.apply_gradients(zip(disc_grads_m, self.disc_m.trainable_variables))
        self.disc_opt_p.apply_gradients(zip(disc_grads_p, self.disc_p.trainable_variables))

        return {"loss_m": loss_m,"loss_p": loss_p,"disc_loss_m": disc_loss_m,"disc_loss_p": disc_loss_p}

In [None]:
def disc_loss(real, generated):
    real_loss = MSE(tf.ones_like(real), real)
    gen_loss = MSE(tf.zeros_like(generated), generated)
    return (real_loss + gen_loss)

def gen_loss(generated):
    gen_loss = MSE(tf.ones_like(generated), generated)
    return gen_loss

def c_loss(real,generated):
    cl=MAE(real,generated)
    return cl

def id_loss(real,generated):
    il=MAE(real,generated)
    return il

In [None]:
def decayLR(epoch, lr):
    if (epoch >= 25):
        lr = 0.0002
    return lr

sched = LearningRateScheduler(decayLR)
model = FullCycleGanModel(gen_m2p, gen_p2m, disc_m, disc_p)
model.compile(Adam(learning_rate=0.002, beta_1=0.5),Adam(learning_rate=0.002, beta_1=0.5),
            Adam(learning_rate=0.002, beta_1=0.5),Adam(learning_rate=0.002, beta_1=0.5),gen_loss,disc_loss,c_loss,id_loss)

md = model.fit(tf.data.Dataset.zip((pds, mds)),epochs=50,callbacks=[sched])
hist = md.history
model.save_weights("/kaggle/working/cgan_w.h5")

In [None]:
! mkdir ../images

i = 1
for img in pds:
    pred = model.gen_p2m(img, training=False)[0].numpy()
    pred = (pred * 127.5 + 127.5).astype(np.uint8)            
    image = PIL.Image.fromarray(pred)
    image.save("../images/" + str(i) + ".jpg")
    i += 1
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")