In [None]:
import os
import pathlib
import time
import datetime
import imageio
from glob import glob

import tensorflow as tf
import numpy as np 
import tensorflow.keras.backend as K
import tensorflow_addons as tfa

from matplotlib import pyplot as plt
from IPython import display
from termcolor import colored
from tqdm import tqdm
from IPython.display import Image
import PIL
from PIL import ImageDraw
from IPython import display

In [None]:
def color_print(print_str, print_color='green'):
    
    '''print in given  color (default green)'''
    print(colored(print_str,print_color))
    
def set_seed(seed):
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'
    print(f'setting seed to {seed}')

class CFG:
    
    IMG_WIDTH =  512
    IMG_HEIGHT = 512
    
    resize_height = 512
    resize_width = 512
    
    LAMBDA = 10

    BUFFER_SIZE = 100
    
    BATCH_SIZE = 2
    
    cache= 50
    
    learning_rate = 0.00025
    
    seed = 7 
    
set_seed(CFG.seed)

In [None]:
white_dir = '/kaggle/input/green-shirts-only/Green Shirts'
black_dir = '/kaggle/input/black-shirts-only/Black Shirts'

plt.figure(figsize=(16,8))

img = plt.imread(white_dir + '/' + os.listdir(white_dir)[1])
plt.imshow(img)
plt.axis('off')
plt.title('sample image')
print(f'Image dimensions {img.shape}')
plt.show()

In [None]:
def load_image(image_file):
    '''load a image file'''
    image = tf.io.read_file(image_file)
    image = tf.io.decode_jpeg(image)
    
    return image


def random_crop(image):
    '''randomly crop image into defined size '''
    cropped_image = tf.image.random_crop(image, size=[CFG.IMG_HEIGHT, CFG.IMG_WIDTH, 3])

    return cropped_image


def normalize(image):
    '''normalizing the images to [-1, 1]'''
    image = tf.cast(image, tf.float32)
    image = (image / 127.5) - 1
    return image


def de_normalize(image):
    '''De normalize the image to be in range (0,1)'''
    
    return (image * 0.5) + 0.5 

def image_augmentations(image):
    '''perform spatial augmentations (rotation and flips) on input image'''
    
    # --------------------rotations----------
    #rotation probabliity
    p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    
    if p_rotate > .8:
        image = tf.image.rot90(image, k=3) # rotate 270º
    elif p_rotate > .6:
        image = tf.image.rot90(image, k=2) # rotate 180º
    elif p_rotate > .4:
        image = tf.image.rot90(image, k=1) # rotate 90º
        
    
    # ----------------------Flips---------------------
    p_flip = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    
    if p_flip > 0.7:    
        image = tf.image.random_flip_left_right(image)
    elif p_flip < 0.3:
        image = tf.image.random_flip_up_down(image)
    
    return image

def random_jitter(image):
    '''resize and randommly crop the input image'''
    
#     # resizing image
    image = tf.image.resize(image, size=(CFG.resize_height, CFG.resize_width), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

#     randomly cropping to 512,512
    image = random_crop(image)
    
    return image

def preprocess_image_train(image):
    image = load_image(image)
    image = random_jitter(image)
    image= image_augmentations(image)
    image = normalize(image)
    return image


#same function, withou the augemntation
def preprocess_image_eval(image):
    image = load_image(image)
    image = random_jitter(image)
    image = normalize(image)
    return image

def create_img_dataset(directory,
                       image_preprocess_fn,
                       image_extension = 'jpg',         
                       repeat=True
                      ):
    '''create a tf dataset object from a directory of images'''
    img_list = glob(directory+f'/*{image_extension}')
    
    dataset = tf.data.Dataset.list_files(img_list)
    
    dataset = dataset.map(image_preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE)
    
    if repeat :
        dataset = dataset.repeat()
              
    dataset = dataset.shuffle(CFG.BUFFER_SIZE) 
    dataset = dataset.batch(CFG.BATCH_SIZE)
    return dataset

In [None]:
white_Dataset = create_img_dataset(directory = white_dir,image_preprocess_fn = preprocess_image_train)

#without augmentation
white_eval = create_img_dataset(directory = white_dir, image_preprocess_fn = preprocess_image_eval)

fig,ax = plt.subplots(figsize=(16,8))
    
inp_img = next(iter(white_Dataset))
plt.imshow(de_normalize(inp_img[0]))
plt.title('Sample white Shirt image')
print(f'Image dimensions {inp_img[0].shape}')
plt.axis('off')

plt.show()

In [None]:
black_Dataset = create_img_dataset(directory = black_dir,image_preprocess_fn = preprocess_image_train)
#without augmentation

black_eval = create_img_dataset(directory = black_dir, image_preprocess_fn = preprocess_image_eval)

fig,ax = plt.subplots(figsize=(16,8))
    
inp_img = next(iter(black_Dataset))
plt.imshow(de_normalize(inp_img[0]))
plt.title('Sample black image')
plt.axis('off')

plt.show()

In [None]:
Train_Dataset = tf.data.Dataset.zip((white_Dataset,black_Dataset))

In [None]:
conv_initializer = tf.random_normal_initializer(mean=0.0, stddev=0.02)


gamma_initializer = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
    
def downsample(input_layer,
               filters,
               name,
               size=3, 
               strides=2, 
               activation=tf.keras.layers.ReLU(), 
               ):
    
    '''perform a downsampling by applying a convolution,followed by instance norm and activation'''
    conv = tf.keras.layers.Conv2D(filters, 
                                  size, 
                                  strides=strides, 
                                  padding='same', 
                                  use_bias=False, 
                                  kernel_initializer=conv_initializer, 
                                  name=f'encoder_{name}')(input_layer)
 
    conv = tfa.layers.InstanceNormalization(axis=-1,gamma_initializer=gamma_initializer)(conv)
        
    conv = activation(conv)

    return conv

def upsample(input_layer,
             filters,
             name,
             size=3,
             strides=2,
             activation='relu'):
    
    res = tf.keras.layers.Conv2DTranspose(filters, size, 
                                          strides=strides, 
                                          padding='same', 
                                          use_bias=False, 
                                          kernel_initializer=conv_initializer, 
                                          name=f'decoder_{name}')(input_layer)

    res = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(res)

    res =  tf.keras.layers.Activation(activation)(res)
    
    return res

In [None]:
def residual_block(input_layer, size=3, strides=1, name='block_x'): 
    '''performs 2 convolutions followed by an added skip connection with the input'''
    
    filters = input_layer.shape[-1]
    block = tf.keras.layers.Conv2D(filters, 
                     size,
                     strides=strides,
                     padding='same',
                     use_bias=False, 
                     kernel_initializer=conv_initializer,
                     name=f'residual_{name}')(input_layer)
    
    block = tf.keras.layers.Activation('relu')(block)
    block = tf.keras.layers.Conv2D(filters, size, strides=strides, padding='same', use_bias=False, 
                     kernel_initializer=conv_initializer, name=f'transformer_{name}_2')(block)    
    
    res = tf.keras.layers.Add()([block, input_layer])

    return res

def concat_layer(layer_1,layer_2,name):
    '''concatenation of layers for skip connections'''
    return tf.keras.layers.Concatenate(name=name)([layer_1,layer_2])

In [None]:
def get_generator(num_residual_connections=6):
    
    input_ = tf.keras.layers.Input(shape=(CFG.IMG_WIDTH,CFG.IMG_HEIGHT,3),  name='input_layer')
    
    #-----------------------ENCODER-------------------------------
    enc1 = downsample(input_layer = input_, filters=64,  strides =  1, size=7, name='dwn_1')    
    enc2 = downsample(input_layer=enc1,filters= 128,size =  3, strides =  2, name='dwn_2')
    enc3 = downsample(input_layer=enc2, filters=256,size =  3, strides =2, name='dwn_3')        
    enc4 = downsample(input_layer=enc3, filters=256,size =  3, strides =2, name='dwn_4')        
    
    #-----------------------Residual connections-------------------------------
    x = enc4
    for n in range(num_residual_connections):
        x = residual_block(input_layer=x, name=f'res_block_{n+1}')     # (bs, 64, 64, 256)

    #-----------------------DECODER-------------------------------
    x_skip = concat_layer(layer_1=x,layer_2=enc4,name='skip_1')               
    dec1 = upsample(x_skip,filters=256 ,name='upsam_1')
    
    x_skip = concat_layer(layer_1=dec1,layer_2=enc3,name='skip_2')               
    dec_2 = upsample(x_skip, filters=128,name='upsam_2')
       
    x_skip = concat_layer(layer_1=dec_2,layer_2=enc2,name='skip_3')               
    dec_3 = upsample(x_skip, filters= 64,name='upsam_3')
    
    x_skip = concat_layer(layer_1=dec_3,
                          layer_2=enc1,
                          name='skip_final')

    output = tf.keras.layers.Conv2D(filters = 3,kernel_size = 7, strides=1, padding='same', 
                                  kernel_initializer=conv_initializer, use_bias=False, activation='tanh', 
                                  name='output_layer')(x_skip) 
    
    return tf.keras.models.Model(inputs=input_,outputs=output)

In [None]:
white2black_gen = get_generator()

In [None]:
black2white_gen = get_generator()

In [None]:
def PATCH_discriminator(leak_rate = 0.2):
    '''PATCH discriminator network'''
    leaky_relu = tf.keras.layers.LeakyReLU(leak_rate)

    
    input_ = tf.keras.layers.Input(shape=(CFG.IMG_WIDTH,CFG.IMG_HEIGHT,3),  name='input_layer')
    # Encoder    
    x = downsample(input_layer = input_, filters=64,  strides =  2, size=4, name='dwn_1',activation = leaky_relu)    #h,w =256
    x = downsample(input_layer = x, filters=128,  strides =  2, size=4, name='dwn_2',activation = leaky_relu)        #h,w =128
    x = downsample(input_layer = x, filters=256,  strides =  2, size=4, name='dwn_3',activation = leaky_relu)        #h,w = 64
    x = downsample(input_layer = x, filters=512,  strides =  2, size=4, name='dwn_4',activation = leaky_relu)        #h,w = 32
    x = downsample(input_layer = x, filters=512,  strides =  1, size=4, name='dwn_5',activation = leaky_relu)        #h,w = 32
    
    output = tf.keras.layers.Conv2D(1, 4, strides=1, padding='valid', kernel_initializer=conv_initializer)(x)         #(29, 29, 1)
    
    return tf.keras.models.Model(inputs=input_,outputs=output)

In [None]:
white2black_disc = PATCH_discriminator()  

In [None]:
black2white_disc = PATCH_discriminator() 

In [None]:
def generate_cycle(gen_1,gen_2,input_image):
    '''generate a full cycle of images using given generators'''
    gen_img_1 = gen_1(input_image,training=True)
    gen_img_2 = gen_2(gen_img_1,training=True)
    
    return gen_img_1,gen_img_2


def calc_and_apply_gradients(tape,
                             model,
                             loss,
                             optimizer):
    '''Apply gradients for a given model using given optimizer''' 
   
    gradients = tape.gradient(loss,model.trainable_variables)
    
    optimizer.apply_gradients(zip(gradients,model.trainable_variables))
    return 

def discriminator_loss(real, generated):
    '''discriminator Binary CrossEntropy loss'''
    real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)(tf.ones_like(real), real)

    generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)(tf.zeros_like(generated), generated)

    total_disc_loss = real_loss + generated_loss

    return total_disc_loss * 0.5

# Generator Adverserial loss
def generator_loss(generated):
    '''adverserial generator loss (BCE)'''
    return tf.keras.losses.BinaryCrossentropy(from_logits=True)(tf.ones_like(generated), generated)

    
# Cycle consistency loss 
    
def calc_cycle_loss(real_image, cycled_image, LAMBDA):
    '''pixel wise cycle loss between original image and cycled image'''
    mae_loss = tf.reduce_mean(tf.abs(real_image - cycled_image))

    return LAMBDA * mae_loss


# identity loss
def identity_loss(real_image, same_image, LAMBDA):    
    mae_loss = tf.reduce_mean(tf.abs(real_image - same_image))
    return LAMBDA * 0.5 * mae_loss

In [None]:
class CycleGAN(tf.keras.models.Model):
    def __init__(self,
                 lambda_cycle=10):
        super(CycleGAN, self).__init__()
        self.gen_w2b = white2black_gen
        self.gen_b2w = black2white_gen 
        self.disc_w2b = white2black_disc 
        self.disc_b2w = black2white_disc 
        self.lambda_cycle = lambda_cycle 
    
    
    def compile(self,
                gen_loss_fn,
                disc_loss_fn,
                cycle_loss_fn,
                identity_loss_fn,
                common_opt = tf.keras.optimizers.legacy.Adam(learning_rate = CFG.learning_rate,beta_1 = 0.5)):
        
        super(CycleGAN, self).compile()
        
        # -------optimizers ---------
        self.opt_gen_w2b = common_opt
        self.opt_gen_b2w = common_opt
        self.opt_disc_w2b = common_opt
        self.opt_disc_b2w = common_opt
        
        
        # -------losses ---------
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
        
    def train_step(self, batch_data):
        white_image, black_image = batch_data
        
        with tf.GradientTape(persistent=True) as tape:
            
            fake_black,cycled_white = generate_cycle(self.gen_w2b,
                                                     self.gen_b2w,
                                                     white_image) 

            fake_white,cycled_black = generate_cycle(self.gen_b2w,
                                                   self.gen_w2b,
                                                   black_image)
            
            #---------- generating itself (for identity loss)
            iden_white = self.gen_w2b(black_image, training=True)
            iden_black = self.gen_b2w(white_image, training=True)

            # -----------discriminator on real images
            disc_black = self.disc_w2b(black_image, training=True)
            disc_white = self.disc_b2w(white_image, training=True)

            # -----------discriminator on fake images-----------------
            disc_fake_black   = self.disc_w2b(fake_black, training=True)
            disc_fake_white = self.disc_b2w(fake_white, training=True)

            # -------------------------generator loss-------------
               #---1)adverserial loss
            black_gen_loss = self.gen_loss_fn(disc_fake_black) 
            white_gen_loss = self.gen_loss_fn(disc_fake_white)

                #---2)Cycle loss loss
            total_cycle_loss = self.cycle_loss_fn(black_image, cycled_black, self.lambda_cycle) + self.cycle_loss_fn(white_image, cycled_white, self.lambda_cycle)

                # +++++3) Total Gen loss (white gen and black gen)
            total_gen_w2b_loss = black_gen_loss + total_cycle_loss + self.identity_loss_fn(black_image, iden_black,self.lambda_cycle)
            total_gen_b2w_loss = white_gen_loss + total_cycle_loss + self.identity_loss_fn(white_image, iden_white, self.lambda_cycle)
            
            
            # -------------------------Discriminator loss-------------
            black_disc_loss = self.disc_loss_fn(disc_black, disc_fake_black)  # check classifying generated and real black
            white_disc_loss = self.disc_loss_fn(disc_white, disc_fake_white)        # check  classifying generated and real white

        ## ------------------------- Calculating and Updating gradients------------------
        
        # white->black gen gradeints
        _ = calc_and_apply_gradients(tape=tape,
                                     model= self.gen_w2b,
                                     loss = total_gen_w2b_loss,
                                     optimizer = self.opt_gen_w2b)
        
        # black - >white  gen gradeints
        _ = calc_and_apply_gradients(tape=tape,
                                     model= self.gen_b2w,
                                     loss = total_gen_b2w_loss,
                                     optimizer = self.opt_gen_b2w)
        
        #  discrim gradients (classifies black images)
        _ = calc_and_apply_gradients(tape=tape,
                                     model= self.disc_w2b,
                                     loss = black_disc_loss,
                                     optimizer = self.opt_disc_w2b)
        
        # white discrim gradients (classifies white images)
        _ = calc_and_apply_gradients(tape=tape,
                                     model= self.disc_b2w,
                                     loss = white_disc_loss,
                                     optimizer = self.opt_disc_b2w)
        
        
        return {'gen_w2b_loss': total_gen_w2b_loss,
                'gen_b2w_loss': total_gen_b2w_loss,
                'disc_white_loss': white_disc_loss,
                'disc_black_loss': black_disc_loss
               }
        
    
        

In [None]:
#creat a instance of Cycle gan 
gan = CycleGAN()


#complie with the losses 
gan.compile(gen_loss_fn = generator_loss, disc_loss_fn  = discriminator_loss, cycle_loss_fn = calc_cycle_loss, identity_loss_fn = identity_loss)

In [None]:
#learning rate schedule 

def scheduler(epoch, 
              lr,
              decay_rate = 0.05,
              warm_up_period = 10):
    
    if epoch < warm_up_period:
        return lr
    elif (epoch > warm_up_period and epoch<40):
        return lr * tf.math.exp(decay_rate)
    else:
        return lr * tf.math.exp(decay_rate*2)
        
lr_scheduler = tf.keras.callbacks.LearningRateScheduler(scheduler,
                                                        verbose = 0)

In [None]:
def display_samples(ds, n_samples):
    ds_iter = iter(ds)
    for n_sample in range(n_samples):
        example_sample = next(ds_iter)
        plt.subplot(121)
        plt.imshow(example_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        plt.show()
        
def display_generated_samples(ds, model, n_samples):
    ds_iter = iter(ds)
    for n_sample in range(n_samples):
        example_sample = next(ds_iter)
        generated_sample = model.predict(example_sample)
        
        f = plt.figure(figsize=(16,8))
        
        plt.subplot(121)
        plt.title('Input image')
        plt.imshow(example_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        
        plt.subplot(122)
        plt.title('Generated image')
        plt.imshow(generated_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        plt.show()
        
def evaluate_cycle(ds, generator_a, generator_b, n_samples=1):
    fig, axes = plt.subplots(n_samples, 3, figsize=(22, (n_samples*6)))
    axes = axes.flatten()
    
    ds_iter = iter(ds)
    for n_sample in range(n_samples):
        idx = n_sample*3
        example_sample = next(ds_iter)
        generated_a_sample = generator_a.predict(example_sample)
        generated_b_sample = generator_b.predict(generated_a_sample)
        
        axes[idx].set_title('Input image', fontsize=18)
        axes[idx].imshow(example_sample[0] * 0.5 + 0.5)
        axes[idx].axis('off')
        
        axes[idx+1].set_title('Generated image', fontsize=18)
        axes[idx+1].imshow(generated_a_sample[0] * 0.5 + 0.5)
        axes[idx+1].axis('off')
        
        axes[idx+2].set_title('Cycled image', fontsize=18)
        axes[idx+2].imshow(generated_b_sample[0] * 0.5 + 0.5)
        axes[idx+2].axis('off')
        
    plt.show()
        
def predict_and_save(input_ds, generator_model, output_path):
    i = 1
    for img in input_ds:
        prediction = generator_model(img, training=False)[0].numpy() # make predition
        prediction = (prediction * 127.5 + 127.5).astype(np.uint8)   # re-scale
        im = PIL.Image.fromarray(prediction)
        im.save(f'{output_path}{str(i)}.jpg')
        i += 1
        
        
def save_models(g_model_AtoB, g_model_BtoA):
    filename1 = 'g_model_AtoB.h5'
    g_model_AtoB.save(filename1)
    filename2 = 'g_model_BtoA.h5'
    g_model_BtoA.save(filename2)
    print(f'--->Saved: {filename1}, {filename2}')

# Callback
class GANMonitor(tf.keras.callbacks.Callback):
    """A callback to generate and save images after each epoch"""

    def __init__(self, 
                 num_img=1, 
                 white_paths='generated_white', 
                 black_paths='generated_black'):
        self.num_img = num_img
        self.white_paths = white_paths
        self.black_paths = black_paths
        
        # dir to save genereated white images
        if not os.path.exists(self.white_paths):
            os.makedirs(self.white_paths)
            
            
        # dir to save genereated black images
        if not os.path.exists(self.black_paths):
            os.makedirs(self.black_paths)
            
    def on_epoch_end(self, epoch, logs=None):
        #generated black 
        for i, img in enumerate(white_eval.take(self.num_img)):   
            prediction = white2black_gen(img, training=False)[0].numpy()
            prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
            prediction = PIL.Image.fromarray(prediction)
            prediction.save(f'{self.black_paths}/generated_{i}_{epoch+1}.png')
            
        # generated white images 
        for i, img in enumerate(black_eval.take(self.num_img)):
            prediction = black2white_gen(img, training=False)[0].numpy()
            prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
            prediction = PIL.Image.fromarray(prediction)
            prediction.save(f'{self.white_paths}/generated_{i}_{epoch+1}.png')
            
        save_models(white2black_gen, black2white_gen)

In [None]:
EPOCHS = 70
callbacks = [lr_scheduler,GANMonitor()]
steps_per_epoch = 20


history = gan.fit(Train_Dataset,
                epochs = EPOCHS,
                steps_per_epoch=steps_per_epoch,
                callbacks = callbacks)

In [None]:
display_generated_samples(black_eval.take(10), black2white_gen, 5)

In [None]:
display_generated_samples(white_eval.take(10), white2black_gen, 5)

In [None]:
from tensorflow.keras.models import load_model

model_AtoB = load_model('/kaggle/working/g_model_AtoB.h5')
model_BtoA = load_model('/kaggle/working/g_model_BtoA.h5')

In [None]:
display_generated_samples(white_eval.take(10), model_AtoB, 2)

In [None]:
display_generated_samples(black_eval.take(10), model_BtoA, 2)

In [None]:
import requests
from PIL import Image
from tensorflow.keras.preprocessing import image
import numpy as np
from io import BytesIO
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.inception_v3 import preprocess_input, decode_predictions
import numpy as np


# URL of the image
img_url = 'https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTZ2iDYguRA_8Yf_DdQ30FAtP9e4xLdh7HKkw&usqp=CAU'
# Download the image
response = requests.get(img_url)
img = Image.open(BytesIO(response.content))

# Resize the image to the desired target size
img = img.resize((512, 512))

img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array = preprocess_input(img_array)


predictions = model_AtoB.predict(img_array)

# print("Pred: ", predictions)
print("Ouput Image", output_image)

output_image = predictions[0] 
# image.save_img('output_image.jpg', output_image) 
plt.subplot(1, 2, 1)
plt.imshow(img)
plt.title('Input Image')
plt.axis('off')

# Display the generated output image
plt.subplot(1, 2, 2)
plt.imshow(output_image)
plt.title('Generated Output')
plt.axis('off')

# Show the plot
plt.show()

In [None]:
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.inception_v3 import preprocess_input, decode_predictions
import numpy as np

img_path = '/kaggle/input/black-white-shirts/BlackAndWhiteShirts/white/1034.jpg'
img = image.load_img(img_path, target_size=(512, 512))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array = preprocess_input(img_array)


predictions = model_AtoB.predict(img_array)


output_image = predictions[0] 
# image.save_img('output_image.jpg', output_image) 
plt.subplot(1, 2, 1)
plt.imshow(img)
plt.title('Input Image')
plt.axis('off')

# Display the generated output image
plt.subplot(1, 2, 2)
plt.imshow(output_image)
plt.title('Generated Output')
plt.axis('off')

# Show the plot
plt.show()
