## Import Libs

In [None]:
import numpy as np
import os, sys, json, scipy, shutil
import random, datetime, time
from PIL import Image, ImageDraw
from glob import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

# Import tensorflow and keras
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Input
import tensorflow.keras.layers as layers
from tensorflow.keras.layers import (Dense, Reshape, LeakyReLU, Conv2D, Conv2DTranspose,
                                     Flatten, Dropout, Concatenate, BatchNormalization,
                                    UpSampling2D)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import RMSprop, Adam

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 

# Define root path
cwd = os.getcwd()

In [None]:
# print(tf.__version__)
# print(keras.__version__)

## Set paths

In [None]:
IMAGE_PATH = '/kaggle/input/celebaalligned/celeb_id_aligned'
OUT_DIR = '/kaggle/working'
CHECKPOINT_DIR = os.path.join(OUT_DIR, 'model_checkpoints')
SAMPLE_DIR = os.path.join(OUT_DIR, 'sample')
CROPPED_IMG_PATH = os.path.join(OUT_DIR,'cropped')

if not os.path.exists(CHECKPOINT_DIR):
    os.makedirs(CHECKPOINT_DIR)
    
if not os.path.exists(SAMPLE_DIR):
    os.makedirs(SAMPLE_DIR)

if not os.path.exists(CROPPED_IMG_PATH):
    os.makedirs(CROPPED_IMG_PATH)


### Training params

In [None]:
NUM_EPOCHS = 20
BATCH_SIZE = 1
MAX_IMGS_TO_TRAIN = 2000

## Utils functions

In [None]:
def show_images_on_window(path, img_filenames):
    
    plt.figure(figsize=(12,12))
    for idx in range(9):
        img = Image.open(os.path.join(path, img_filenames[idx]))
        img_array = np.array(img)
        plt.subplot(3,3,(idx+1))
        plt.grid(False)
        plt.axis('off')
        plt.imshow(img_array)
        
def retrieve_eyes_coords(img_name, data):
    filename = img_name.rsplit('.')[0]
    person = filename.rsplit('-',1)[0]
    
    try:
        person_data_list = data[person]
    except:
        return None,None,None,None

    person_imgs = []
    for person_dict in person_data_list:
        person_imgs.append(person_dict['filename']) 

    person_imgs = sorted(person_imgs)
    
    try:
        img_dict = [img_dict for img_dict in person_data_list if img_dict['filename'] == img_name]
        img_dict = img_dict[0]
        eye_left, box_left, eye_right, box_right = img_dict['eye_left'], img_dict['box_left'], img_dict['eye_right'], img_dict['box_right']
        return eye_left,box_left,eye_right,box_right
    
    except:
        return None,None,None,None

def draw_eye_on_image(img, eye_coords, clr):
    x0,y0,x1,y1 = eye_coords
    
    img = img.convert('RGBA')
    overlay = Image.new('RGBA', img.size)
    draw = ImageDraw.Draw(overlay)
    draw.rectangle(((x0,y0),(x1,y1)), fill=clr)
    
    img = Image.alpha_composite(img, overlay)
    img = img.convert("RGB")
    
    return img

def paint_eyes_on_image(IMAGE_PATH, imgs_list, data, eye_left_clr, eye_right_clr, number_of_images=9):

    random.shuffle(imgs_list)

    for idx in range(number_of_images):
        img_path = os.path.join(IMAGE_PATH, imgs_list[idx])
        try:
            eye_l,box_l,eye_r,box_r = retrieve_eyes_coords(imgs_list[idx], data)

            x0 = eye_l['x'] - box_l['w'] // 2
            y0 = eye_l['y'] - box_l['w'] // 2
            x1 = x0 + box_l['w']
            y1 = y0 + box_l['h']
            eye_l_coords = (x0,y0,x1,y1)

            img = Image.open(img_path)
            img = draw_eye_on_image(img, eye_l_coords, eye_left_clr)
            x0 = eye_r['x'] - box_r['w'] // 2
            y0 = eye_r['y'] - box_r['w'] // 2
            x1 = x0 + box_r['w']
            y1 = y0 + box_r['h']
            eye_r_coords = (x0,y0,x1,y1)

            img = draw_eye_on_image(img, eye_r_coords, eye_right_clr)

            img_array = np.array(img)
            plt.subplot(3,3,(idx+1))
            plt.axis('off')
            plt.grid(False)
            plt.imshow(img_array)

        except:
            print('It was impossible to read the data for person: ', person)
            
def save_painted_eye_images(img_folder, imgs_list, data, dest_folder, eye_left_clr, eye_right_clr):
    
    input_imgs = []
    painted_imgs = []
    
    # If dest_folder exists, delete it first
    if os.path.exists(dest_folder) and os.path.isdir(dest_folder):
        shutil.rmtree(dest_folder)
    
    # After delete it, create folder again
    os.makedirs(dest_folder)
        
    # Shuffle imgs list
    random.shuffle(imgs_list)

    for filename in tqdm(imgs_list):
        img_path = os.path.join(img_folder, filename)

        try:
            img = Image.open(img_path)
            eye_l,box_l,eye_r,box_r = retrieve_eyes_coords(filename, data)
            
            if eye_l == None:
                continue

            x0 = eye_l['x'] - box_l['w'] // 2
            y0 = eye_l['y'] - box_l['w'] // 2
            x1 = x0 + box_l['w']
            y1 = y0 + box_l['h']
            
            if x0 == x1 or y0 == y1:
                continue
            else:
                eye_l_coords = (x0,y0,x1,y1)

            img = draw_eye_on_image(img, eye_l_coords, eye_left_clr)
            
            x0 = eye_r['x'] - box_r['w'] // 2
            y0 = eye_r['y'] - box_r['w'] // 2
            x1 = x0 + box_r['w']
            y1 = y0 + box_r['h']
            
            if x0 == x1 or y0 == y1:
                continue
            else:
                eye_r_coords = (x0,y0,x1,y1)

            img = draw_eye_on_image(img, eye_r_coords, eye_right_clr)
            
            dest_name = filename.rsplit('.')[0] + '_painted.' + filename.rsplit('.')[-1]
            dest_path = os.path.join(dest_folder, dest_name)
            img.save(dest_path)
            
            input_imgs.append(filename)
            painted_imgs.append(dest_name)
            
        except:
            print('It was impossible to read the data for person: ', person)
            
    return input_imgs, painted_imgs

## Examine dataset

In [None]:
# Alternative num_imgs = len(os.listdir(IMAGE_PATH))
os.chdir(IMAGE_PATH)
imgs_list = glob('*.jpg')

# Remove duplicates
imgs_list = list(dict.fromkeys(imgs_list))
sorted(imgs_list)

os.chdir(cwd)
print(len(imgs_list))

In [None]:
# show_images_on_window(IMAGE_PATH, imgs_list)

## Read Json 

In [None]:
#Check json is available
json_cat = IMAGE_PATH + "/data.json"
with open(json_cat, 'r') as f:
    data = json.load(f)
    
# # Explore data dict
# print(data)
# print(data.keys())

In [None]:
# filename = imgs_list[0].rsplit('.')[0]
# person = filename.rsplit('-',1)[0]
# person_data_list = data[person]

# person_imgs = []
# for person_dict in person_data_list:
#     person_imgs.append(person_dict['filename']) 

# person_imgs = sorted(person_imgs)

# img_dict = [img_dict for img_dict in person_data_list if img_dict['filename'] == imgs_list[0]]
# img_dict = img_dict[0]
# # print(type(img_dict))
# eye_left, box_left, eye_right, box_right = img_dict['eye_left'], img_dict['box_left'], img_dict['eye_right'], img_dict['box_right']
# print(eye_left)
# print(box_left)

## Show eyes boxes on images

In [None]:
# transparency = 0.0  # Degree of transparency, 0-100%
# opacity = int(255 * (1.0-transparency))
# eye_left_clr = (255,255,255,opacity)
# eye_right_clr = (0,255,0,opacity)

# plt.figure(figsize=(12,12))

# paint_eyes_on_image(IMAGE_PATH, imgs_list, data, eye_left_clr, eye_right_clr)


## Obtain painted eyes images

In [None]:
transparency = 0.0  # Degree of transparency, 0-100%
opacity = int(255 * (1.0-transparency))
eye_left_clr = (255,255,255,opacity)
eye_right_clr = eye_left_clr

random.shuffle(imgs_list)

# Only crop MAX_IMGS_TO_TRAIN from the orig dataset
input_imgs, painted_imgs = save_painted_eye_images(IMAGE_PATH, imgs_list[:MAX_IMGS_TO_TRAIN], data, CROPPED_IMG_PATH, eye_left_clr, eye_right_clr)
print(len(input_imgs), len(painted_imgs))

### Examine saved images

In [None]:
# show_images_on_window(CROPPED_IMG_PATH, painted_imgs)

## Define dataset

In [None]:
# pairs = list(zip(input_imgs, painted_imgs))  # make pairs out of the two lists
# pairs = random.sample(pairs, 3)  # pick 3 random pairs
# batch_images_A, batch_images_B = zip(*pairs)  # separate the pairs

# print(batch_images_A)
# print(batch_images_B)

# for img_A, img_B in pairs[:2]:
#     print(img_A, ' and ', img_B)

In [None]:
class Eyes_dataset():
    def __init__(self, files_A, files_B, img_res=(128, 128)):
        self.files_A = sorted(files_A)
        self.files_B = sorted(files_B)
        self.img_res = img_res

    def load_data(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "test"
        
#         files_A = sorted(glob(os.path.join(self.path_A,'*')))
#         files_B = sorted(glob(os.path.join(self.path_B,'*')))
               
        pairs = list(zip(self.files_A, self.files_B))  # make pairs out of the two lists
        batch_images = random.sample(pairs, batch_size)  # pick random pairs f size=batch_size
#         batch_images_A, batch_images_B = zip(*pairs)  # separate the pairs

        imgs_A = []
        imgs_B = []
        
        for img_path_A, img_path_B in batch_images:
            img_A = self.imread(img_path_A)
            img_B = self.imread(img_path_B)
            
#             h, w, _ = img_A.shape
#             _w = int(w/2)
            
#             img_A, img_B = img_A[:, :_w, :], img_B[:, _w:, :]

            img_A = scipy.misc.imresize(img_A, self.img_res)
            img_B = scipy.misc.imresize(img_B, self.img_res)

#             # If training => do random flip
#             if not is_testing and np.random.random() < 0.5:
#                 img_A = np.fliplr(img_A)
#                 img_B = np.fliplr(img_B)

            imgs_A.append(img_A)
            imgs_B.append(img_B)

        imgs_A = np.array(imgs_A)/127.5 - 1.
        imgs_B = np.array(imgs_B)/127.5 - 1.

        return imgs_A, imgs_B

    def load_batch(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "val"
        #         path = glob('./datasets/%s/%s/*' % (self.dataset_name, data_type))
        
#         files_A = sorted(glob(os.path.join(self.path_A,'*')))
#         files_B = sorted(glob(os.path.join(self.path_B,'*')))
        
        pairs = list(zip(self.files_A, self.files_B))  # make pairs out of the two lists
        self.n_batches = int(len(pairs) / batch_size)

        # for i in range(self.n_batches-1):
        for i in range(self.n_batches):
            batch = pairs[i*batch_size:(i+1)*batch_size]
            imgs_A, imgs_B = [], []
            for img_A_path, img_B_path in batch:
                img_A = self.imread(img_A_path)
                img_B = self.imread(img_B_path)
                
#                 h, w, _ = img_A.shape
#                 half_w = int(w/2)
                
#                 img_A = img_A[:, :half_w, :]
#                 img_B = img_B[:, half_w:, :]

                img_A = scipy.misc.imresize(img_A, self.img_res)
                img_B = scipy.misc.imresize(img_B, self.img_res)

#                 if not is_testing and np.random.random() > 0.5:
#                         img_A = np.fliplr(img_A)
#                         img_B = np.fliplr(img_B)

                imgs_A.append(img_A)
                imgs_B.append(img_B)

            imgs_A = np.array(imgs_A)/127.5 - 1.
            imgs_B = np.array(imgs_B)/127.5 - 1.

            yield imgs_A, imgs_B

    def imread(self, path):
        return scipy.misc.imread(path, mode='RGB').astype(np.float)

In [None]:
paired = list(zip(input_imgs, painted_imgs))
random.shuffle(paired)
# paired = paired[:MAX_IMGS_TO_TRAIN]
input_imgs = [val[0] for val in paired]
painted_imgs = [val[1] for val in paired]

In [None]:
input_imgs_path = sorted([os.path.join(IMAGE_PATH, img) for img in input_imgs])
painted_imgs_path = sorted([os.path.join(CROPPED_IMG_PATH, img) for img in painted_imgs])

In [None]:
# data_loader = Eyes_dataset(input_imgs_path, painted_imgs_path, img_res=(256, 256))
# imgs_A, imgs_B = data_loader.load_data(batch_size=3)

# fig = plt.figure()
# for idx, pair in enumerate(zip(imgs_A, imgs_B)):
#     img_A, img_B = pair
#     paired_img = np.concatenate([img_A, img_B])

#     # Rescale images 0 - 1
#     paired_img = 0.5 * paired_img + 0.5

#     plt.imshow(paired_img)

## Define model

In [None]:
class Pix2Pix():
    def __init__(self, files_A, files_B):
        
        # Input shape
        self.img_rows = 256
        self.img_cols = 256
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        # Configure data loader
        self.data_loader = Eyes_dataset(files_A, files_B,
                                      img_res=(self.img_rows, self.img_cols))

        # Calculate output shape of D (PatchGAN)
        patch = int(self.img_rows / 2**4)
        self.disc_patch = (patch, patch, 1)

        # Number of filters in the first layer of G and D
        self.gf = 64
        self.df = 64
        
        # Define optimizer
        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])

        #-------------------------
        # Construct Computational
        #   Graph of Generator
        #-------------------------

        # Build the generator
        self.generator = self.build_generator()

        # Input images and their conditioning images
        img_A = Input(shape=self.img_shape) # Real
        img_B = Input(shape=self.img_shape) # Conditioning image to generator

        # By conditioning on B generate a fake version of A
        fake_A = self.generator(img_B)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # Discriminators determines validity of translated images / condition pairs
        valid = self.discriminator([fake_A, img_B])

        self.combined = Model(inputs=[img_A, img_B], outputs=[valid, fake_A])
        self.combined.compile(loss=['mse', 'mae'],
                              loss_weights=[1, 100],
                              optimizer=optimizer)

    def build_generator(self):
        """U-Net Generator"""

        def conv2d(layer_input, filters, f_size=4, bn=True):
            """Layers used during downsampling"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d

        def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
            """Layers used during upsampling"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
            if dropout_rate:
                u = Dropout(dropout_rate)(u)
            u = BatchNormalization(momentum=0.8)(u)
            u = Concatenate()([u, skip_input])
            return u

        # Image input
        d0 = Input(shape=self.img_shape)

        # Downsampling
        d1 = conv2d(d0, self.gf, bn=False)
        d2 = conv2d(d1, self.gf*2)
        d3 = conv2d(d2, self.gf*4)
        d4 = conv2d(d3, self.gf*8)
        d5 = conv2d(d4, self.gf*8)
        d6 = conv2d(d5, self.gf*8)
        d7 = conv2d(d6, self.gf*8)

        # Upsampling
        u1 = deconv2d(d7, d6, self.gf*8)
        u2 = deconv2d(u1, d5, self.gf*8)
        u3 = deconv2d(u2, d4, self.gf*8)
        u4 = deconv2d(u3, d3, self.gf*4)
        u5 = deconv2d(u4, d2, self.gf*2)
        u6 = deconv2d(u5, d1, self.gf)

        u7 = UpSampling2D(size=2)(u6)
        output_img = Conv2D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u7)

        return Model(d0, output_img)

    def build_discriminator(self):

        def d_layer(layer_input, filters, f_size=4, bn=True):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d

        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)

        # Concatenate image and conditioning image by channels to produce input
        combined_imgs = Concatenate(axis=-1)([img_A, img_B])

        d1 = d_layer(combined_imgs, self.df, bn=False)
        d2 = d_layer(d1, self.df*2)
        d3 = d_layer(d2, self.df*4)
        d4 = d_layer(d3, self.df*8)

        validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)

        return Model([img_A, img_B], validity)

    def train(self, epochs, batch_size=1, sample_interval=50):

        start_time = datetime.datetime.now()

        # Adversarial loss ground truths
        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)
        
        self.d_losses = []
        self.g_losses = []

        for epoch in range(1,epochs+1):
            epoch_d_losses = []
            epoch_g_losses = []
            
            for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size), start=1):
#                 print('Batch number: ',batch_i)

                # ---------------------
                #  Train Discriminator
                # ---------------------

                # Condition on B and generate a translated version
                fake_A = self.generator.predict(imgs_B)

                # Train the discriminators (original images = real / generated = Fake)
                d_loss_real = self.discriminator.train_on_batch([imgs_A, imgs_B], valid)
                d_loss_fake = self.discriminator.train_on_batch([fake_A, imgs_B], fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

                # -----------------
                #  Train Generator
                # -----------------

                # Train the generators
                g_loss = self.combined.train_on_batch([imgs_A, imgs_B], [valid, imgs_A])
                
#                 print('Loss per batch')
#                 print(d_loss[0])
#                 print(g_loss[0])
#                 epoch_d_losses.append(d_loss[0])
#                 epoch_g_losses.append(g_loss[0])
                
#                 elapsed_time = datetime.datetime.now() - start_time
                
#                 # Plot the progress
#                 print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %f] time: %s" % (epoch, epochs,
#                                                                         batch_i, self.data_loader.n_batches,
#                                                                         d_loss[0], 100*d_loss[1],
#                                                                         g_loss[0],
#                                                                         elapsed_time))

                # If at save interval => save generated image samples
                if batch_i % sample_interval == 0:
                    elapsed_time = datetime.datetime.now() - start_time
                    print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %f] time: %s" % 
                           (epoch, epochs,batch_i, self.data_loader.n_batches, 
                            d_loss[0], 100*d_loss[1], g_loss[0], elapsed_time))
                    self.sample_images(epoch, batch_i, OUT_DIR)
            
#             # Plot the progress
#             elapsed_time = datetime.datetime.now() - start_time
#             print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %f] time: %s" % 
#                    (epoch, epochs,batch_i, self.data_loader.n_batches, 
#                     d_loss[0], 100*d_loss[1], g_loss[0], elapsed_time))

#             self.sample_images(epoch, batch_i, OUT_DIR)
            
#             # Store losses during at the end of the epoch
#             print('Loss per epoch')
#             print(epoch_d_losses)
#             print(epoch_g_losses)
#             self.d_losses.append(np.average(epoch_d_losses, axis=0))
#             self.g_losses.append(np.average(epoch_g_losses, axis=0))
            self.d_losses.append(epoch_d_losses)
            self.g_losses.append(epoch_g_losses)
        
        # ---------------------
        #  End Training
        # ---------------------
        
        # save generator and discriminator models
        gen_filename = 'gen_model_' + str(epoch) + '.h5' 
        self.save_model(self.generator, gen_filename, OUT_DIR)

        combined_filename = 'comb_model_' + str(epoch) + '.h5'  
        self.save_model(self.combined, combined_filename, OUT_DIR)
        
        # Plot losses
        self.plot_losses(self.g_losses, self.d_losses, OUT_DIR)
        
#     def test(self):
#         image_paths = glob(self.data_loader.testing_raw_path+"*")
#         for image_path in image_paths:
#             image = np.array(imageio.imread(image_path))
#             image_normalized = Helpers.normalize(image)
#             generated_batch = self.generator.predict(np.array([image_normalized]))
#             concat = Helpers.unnormalize(np.concatenate([image_normalized, generated_batch[0]], axis=1))
#             cv2.imwrite(BASE_OUTPUT_PATH+os.path.basename(image_path), cv2.cvtColor(np.float32(concat), cv2.COLOR_RGB2BGR))
        
    def save_model(self, model, filename, folder):
        weights_filename = filename.rsplit('.')[0] + '_weights.h5'
        
        filepath = os.path.join(folder, filename)
        weights_path = os.path.join(folder, weights_filename)
        try:
            # If file exists, delete it first
            if os.path.exists(filepath):
                os.remove(filepath)
            if os.path.exists(weights_path):
                os.remove(weights_path)

            model.save(filepath)
            model.save_weights(weights_path)

        except:
            print('Model ',filename,' could not be saved')
            
    def plot_losses(self, gen_loss, disc_loss, folder):
        
        plt.figure(figsize=(10, 10))
        plt.plot(disc_loss, label='Discriminitive loss')
        plt.plot(gen_loss, label='Generative loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        
        timestamp = time.strftime("%d_%m_%y-%H_%M", time.localtime())
        filename = 'pix2pix_loss_' + timestamp + '.png'
        plt.show()
        plt.savefig(os.path.join(folder, filename))
    
    def sample_images(self, epoch, batch_i, out_dir):
        r, c = 3, 3

        imgs_A, imgs_B = self.data_loader.load_data(batch_size=3, is_testing=True)
        fake_A = self.generator.predict(imgs_B)

        gen_imgs = np.concatenate([imgs_B, fake_A, imgs_A])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        titles = ['Condition', 'Generated', 'Original']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[i])
                axs[i,j].axis('off')
                cnt += 1
        
        filename = 'epoch_' + str(epoch) + '_batch_' + str(batch_i) + '.png'
        fig.savefig(os.path.join(out_dir, filename))
        plt.show()

## Train the model

In [None]:
gan = Pix2Pix(input_imgs_path, painted_imgs_path)
gan.train(epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, sample_interval=1)

# Save outputs as a zip
# shutil.make_archive("output", "zip", "./output")