In [None]:
import tensorflow as tf
import numpy as np
import os
import imageio
from IPython.display import Image, display
import matplotlib.pyplot as plt
import time

In [None]:
dataset_path = '../input/landscape-pictures'

In [None]:
dataset_len = len(os.listdir(dataset_path))

In [None]:
dataset_len

# **Dataset Preparation**

In [None]:
batch_size = 32

dataset = tf.keras.preprocessing.image_dataset_from_directory(
    '../input/', label_mode=None, image_size=(256, 256), batch_size=batch_size, shuffle=True, seed=16
)
AUTOTUNE = tf.data.AUTOTUNE

dataset = dataset.cache().prefetch(buffer_size=AUTOTUNE)

In [None]:
# batch_size = 128

# datagen = tf.keras.preprocessing.image.ImageDataGenerator()
# dataset = datagen.flow_from_directory(
#     '../input/flickr-image-dataset/flickr30k_images/flickr30k_images', class_mode=None, target_size=(256, 256), batch_size=batch_size
# )

# step_len = dataset_len // 128

# **Color Space Conversion**

In [None]:
def rgb2ycbcr(img):
    img_r = img[:,:,0]
    img_g = img[:,:,1]
    img_b = img[:,:,2]

    y = 0.299*img_r + 0.587*img_g + 0.114*img_b
    cr = (img_r - y)*0.713 + 128
    cb = (img_b - y)*0.564 + 128

    ycbcr = np.stack([y, cb, cr], axis=2).clip(min=0, max=255)
    return ycbcr.astype('float32')

def ycbcr2rgb(img):
    img_y = img[:,:,0]
    img_cb = img[:,:,1]
    img_cr = img[:,:,2]

    r = img_y + 1.403*(img_cr-128)
    g = img_y - 0.714*(img_cr-128) - 0.344*(img_cb-128)
    b = img_y + 1.773*(img_cb-128)

    rgb = np.stack([r, g, b], axis=2).clip(min=0, max=255)
    return rgb.astype('float32')

# **Deep Learning Model**

In [None]:
from tensorflow.keras import Sequential
from tensorflow.keras import Model

from tensorflow.keras.applications import InceptionResNetV2
from tensorflow.keras.applications import Xception

from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import Add
from tensorflow.keras.layers import RepeatVector
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import ReLU
from tensorflow.keras.layers import UpSampling2D
from tensorflow.keras.layers import MaxPool2D

## Low-Level Features Network

In [None]:
class low_level(tf.keras.Model):
    def __init__(self, kernel_size=(3,3)):
        super(low_level, self).__init__()
        
        self.filters = [64, 128, 128, 256, 256, 512]
        
        self.conv1 = Conv2D(filters=self.filters[0], kernel_size=kernel_size, strides=(2,2), padding='same')
        self.conv2 = Conv2D(filters=self.filters[1], kernel_size=kernel_size, strides=(1,1), padding='same')
        self.conv3 = Conv2D(filters=self.filters[2], kernel_size=kernel_size, strides=(2,2), padding='same')
        self.conv4 = Conv2D(filters=self.filters[3], kernel_size=kernel_size, strides=(1,1), padding='same')
        self.conv5 = Conv2D(filters=self.filters[4], kernel_size=kernel_size, strides=(2,2), padding='same')
        self.conv6 = Conv2D(filters=self.filters[5], kernel_size=kernel_size, strides=(1,1), padding='same')
    
    def call(self, input_img):
        out_s1 = ReLU()(self.conv1(input_img))
        out = ReLU()(self.conv2(out_s1))
        out_s2 = ReLU()(self.conv3(out))
        out = ReLU()(self.conv4(out_s2))
        out_s3 = ReLU()(self.conv5(out))
        out = ReLU()(self.conv6(out_s3))
        return out, out_s1, out_s2, out_s3

## Mid-Level Features Network

In [None]:
class mid_level(tf.keras.Model):
    def __init__(self, kernel_size=(3,3)):
        super(mid_level, self).__init__()
        
        self.filters = [512, 256]
        
        self.conv1 = Conv2D(filters=self.filters[0], kernel_size=kernel_size, strides=(1,1), padding='same')
        self.conv2 = Conv2D(filters=self.filters[1], kernel_size=kernel_size, strides=(1,1), padding='same')
    
    def call(self, input_img):
        out = ReLU()(self.conv1(input_img))
        out = ReLU()(self.conv2(out))
        return out

## Global-Level Features Network

In [None]:
class global_level(tf.keras.Model):
    def __init__(self, kernel_size=(3,3)):
        super(global_level, self).__init__()
        
        # DEFAULT
#         self.filters = [512, 512, 512, 512, 1024, 512, 256]
        
#         self.conv1 = Conv2D(filters=self.filters[0], kernel_size=kernel_size, strides=(2,2), padding='same')
#         self.conv2 = Conv2D(filters=self.filters[1], kernel_size=kernel_size, strides=(1,1), padding='same')
#         self.conv3 = Conv2D(filters=self.filters[2], kernel_size=kernel_size, strides=(2,2), padding='same')
#         self.conv4 = Conv2D(filters=self.filters[3], kernel_size=kernel_size, strides=(1,1), padding='same')
#         self.fc1 = Dense(units=self.filters[4])
#         self.fc2 = Dense(units=self.filters[5])
#         self.fc3 = Dense(units=self.filters[6])

        # VGG MANUAL
#         self.filters = [64, 64, 128, 128, 256, 256, 256, 512, 512, 512, 1024, 512, 256]
    
#         self.conv1 = Conv2D(filters=self.filters[0], kernel_size=kernel_size, padding='same')
#         self.conv2 = Conv2D(filters=self.filters[1], kernel_size=kernel_size, padding='same')
#         self.mp1 = MaxPool2D(pool_size=(2,2),strides=(2,2))
#         self.conv3 = Conv2D(filters=self.filters[2], kernel_size=kernel_size, padding='same')
#         self.conv4 = Conv2D(filters=self.filters[3], kernel_size=kernel_size, padding='same')
#         self.mp2 = MaxPool2D(pool_size=(2,2),strides=(2,2))
#         self.conv5 = Conv2D(filters=self.filters[4], kernel_size=kernel_size, padding='same')
#         self.conv6 = Conv2D(filters=self.filters[5], kernel_size=kernel_size, padding='same')
#         self.conv7 = Conv2D(filters=self.filters[6], kernel_size=kernel_size, padding='same')
#         self.mp3 = MaxPool2D(pool_size=(2,2),strides=(2,2))
#         self.conv8 = Conv2D(filters=self.filters[7], kernel_size=kernel_size, padding='same')
#         self.conv9 = Conv2D(filters=self.filters[8], kernel_size=kernel_size, padding='same')
#         self.conv10 = Conv2D(filters=self.filters[9], kernel_size=kernel_size, padding='same')
#         self.mp4 = MaxPool2D(pool_size=(2,2),strides=(2,2))
        
#         self.fc1 = Dense(units=self.filters[10])
#         self.fc2 = Dense(units=self.filters[11])
#         self.fc3 = Dense(units=self.filters[12])
        
        # NASNETLARGE
        self.filters = [1024, 512, 256]
        self.nasnet = Xception(
            input_shape=(256,256,3),
            include_top=False,
            weights="imagenet",
        )
        self.nasnet.trainable = True
        self.fc1 = Dense(units=self.filters[0])
        self.fc2 = Dense(units=self.filters[1])
        self.fc3 = Dense(units=self.filters[2])
    
    def call(self, input_img):
        # DEFAULT
#         out = ReLU()(self.conv1(input_img))
#         out = ReLU()(self.conv2(out))
#         out = ReLU()(self.conv3(out))
#         out = ReLU()(self.conv4(out))
#         out = Flatten()(out)
#         out = ReLU()(self.fc1(out))
#         out = ReLU()(self.fc2(out))
#         out = ReLU()(self.fc3(out))

        # VGG MANUAL
#         out = ReLU()(self.conv1(input_img))
#         out = ReLU()(self.conv2(out))
#         out = self.mp1(out)
#         out = ReLU()(self.conv3(out))
#         out = ReLU()(self.conv4(out))
#         out = self.mp1(out)
#         out = ReLU()(self.conv5(out))
#         out = ReLU()(self.conv6(out))
#         out = ReLU()(self.conv7(out))
#         out = self.mp1(out)
#         out = ReLU()(self.conv8(out))
#         out = ReLU()(self.conv9(out))
#         out = ReLU()(self.conv10(out))
#         out = self.mp1(out)
#         out = Flatten()(out)
#         out = ReLU()(self.fc1(out))
#         out = ReLU()(self.fc2(out))
#         out = ReLU()(self.fc3(out))

        # NASNET LARGE
        out = tf.repeat(input_img, repeats=[3], axis=3)
        out = self.nasnet(out)
        out = Flatten()(out)
        out = ReLU()(self.fc1(out))
        out = ReLU()(self.fc2(out))
        out = ReLU()(self.fc3(out))
        
        return out

## Fusion Layer

In [None]:
class Fusion(tf.keras.layers.Layer):
    def __init__(self):
        super(Fusion, self).__init__()
    
    def call(self, inputs):
        input_img = inputs[0]
        global_feature = inputs[1]
        # print('global_feature:', global_feature.shape)
        repeat = tf.expand_dims(global_feature, axis=1)
        repeat = tf.repeat(repeat, repeats=[input_img.shape[1]], axis=1)
        repeat = tf.expand_dims(repeat, axis=2)
        repeat = tf.repeat(repeat, repeats=[input_img.shape[1]], axis=2)
        concat = tf.concat([input_img, repeat], axis=3)
        return concat

## Colorization Network

In [None]:
class colorization(tf.keras.Model):
    def __init__(self, kernel_size=(3,3), skip=True):
        super(colorization, self).__init__()
        
        self.filters = [256, 128, 64, 64, 32, 2]
        self.skip = skip
        
        self.fusion = Fusion()
        self.conv1 = Conv2D(filters=self.filters[0], kernel_size=kernel_size, strides=(1,1), padding='same')
        self.upsm1 = UpSampling2D(size=(2,2))
        self.conv2 = Conv2D(filters=self.filters[1], kernel_size=kernel_size, strides=(1,1), padding='same')
        self.conv3 = Conv2D(filters=self.filters[2], kernel_size=kernel_size, strides=(1,1), padding='same')
        self.conv4 = Conv2D(filters=self.filters[3], kernel_size=kernel_size, strides=(1,1), padding='same')
        self.upsm2 = UpSampling2D(size=(2,2))
        self.conv5 = Conv2D(filters=self.filters[4], kernel_size=kernel_size, strides=(1,1), padding='same')
        self.conv6 = Conv2D(filters=self.filters[5], kernel_size=kernel_size, strides=(1,1), padding='same', activation='sigmoid')
    
    def call(self, inputs):
        input_img = inputs[0]
        global_feature = inputs[1]
        skip_1 = inputs[2]
        skip_2 = inputs[3]
        skip_3 = inputs[4]
        
        out = self.fusion([input_img, global_feature])
        out = ReLU()(self.conv1(out))
        if self.skip:
            out = Add()([out, skip_3])
        
        out = ReLU()(self.upsm1(out))
        out = ReLU()(self.conv2(out))
        
        if self.skip:
            out = Add()([out, skip_2])
        out = ReLU()(self.conv3(out))
        out = ReLU()(self.conv4(out))
        
        out = ReLU()(self.upsm2(out))
        # if self.skip:
        #     out = Add()([out, skip_1])
        out = ReLU()(self.conv5(out))
        out = self.conv6(out)
        return out

## Final Model

In [None]:
class Colnet(tf.keras.Model):
    def __init__(self, kernel_size=(3,3), skip=True):
        super(Colnet, self).__init__()
        
        self.low_level = low_level(kernel_size=kernel_size)
        self.mid_level = mid_level(kernel_size=kernel_size)
        self.global_level = global_level(kernel_size=kernel_size)
#         self.global_level = InceptionResNetV2(
#                                 include_top=False,
#                                 weights="imagenet",
#                                 input_shape=(256,256,3),
#                             )
        self.global_level.trainable = False
        self.colorization = colorization(kernel_size=kernel_size, skip=skip)
    
    def call(self, input_img):
        low, skip_1, skip_2, skip_3 = self.low_level(input_img)
        mid = self.mid_level(low)
        glo = self.global_level(input_img)
        # glo = tf.image.grayscale_to_rgb(input_img)
        # glo = self.global_level(glo)
        col = self.colorization([mid, glo, skip_1, skip_2, skip_3])
        out = UpSampling2D(size=(2,2))(col)
        return out

# Functions

In [None]:
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import mean_squared_error
import numpy as np
from math import log

def convert_ycbcr2rgb(y, cbcr):
    y = y * 255.0
    cbcr = cbcr * 255.0
    img = np.concatenate([y, cbcr], axis=2)
    img = ycbcr2rgb(img)
    return img

def show_img(step, y_img, cbcr_pred ,cbcr_img, skip):
    # img = tf.concat([y_img, cbcr_img], axis=3)
    # img = tf.map_fn(fn=lambda img: ycbcr2rgb(img*255.0), elems=img)
    
    imgs = []
    for i, (y, cbcr_p, cbcr_o) in enumerate(zip(y_img, cbcr_pred, cbcr_img)):
        #img_orig = convert_ycbcr2rgb(y, cbcr_o)
        img_pred = convert_ycbcr2rgb(y, cbcr_p)
        #img_merg = np.concatenate([img_orig, img_pred], axis=1)
        imgs.append(img_pred)
        if i == 15:
            break
    imgs = np.array(imgs).reshape(4, 4, 256, 256, 3).swapaxes(1, 2).reshape(4*256, 4*256, 3)
    # y_img = y_img[0] * 255.0
    # cbcr_img = cbcr_img[0] * 255.0

    # img = np.concatenate([y_img, cbcr_img], axis=2)
    # img = ycbcr2rgb(img)

    imgs = tf.keras.preprocessing.image.array_to_img(imgs)
    if skip:
        name = "generated_img_skip" + str(step) + ".png"
    else:
        name = "generated_img" + str(step) + ".png"
    imgs.save(os.path.join('', name))
    display(Image(name))

def avgpsnr(psnrs):
    return log((np.full(len(psnrs), 10) ** (np.array(psnrs) / 10)).sum() / len(psnrs), 10) * 10

def img_metrics(epoch, y_img, cbcr_pred ,cbcr_img):
    imgs_metrics = []
    total_mse = 0
    total_ssim = 0
    total_psnr = []
    
    for i, (y, cbcr_p, cbcr_o) in enumerate(zip(y_img, cbcr_pred, cbcr_img)):
        img_orig = convert_ycbcr2rgb(y, cbcr_o)
        img_pred = convert_ycbcr2rgb(y, cbcr_p)
        
        curr_mse = mean_squared_error(img_orig, img_pred)
        curr_ssim = ssim(img_orig, img_pred, data_range=img_pred.max() - img_pred.min(), multichannel=True)
        curr_psnr = psnr(img_orig, img_pred, data_range=img_pred.max() - img_pred.min())
        
        total_mse += curr_mse
        total_ssim += curr_ssim
        total_psnr.append(curr_psnr)
        
        metric = {
            'mse': curr_mse,
            'ssim': curr_ssim,
            'psnr': curr_psnr
        }
        
        imgs_metrics.append(img_pred)
    
    avg_mse = total_mse / len(y_img)
    avg_ssim = total_ssim / len(y_img)
    avg_psnr = avgpsnr(total_psnr)
    
    return imgs_metrics, {'mse': avg_mse, 'ssim': avg_ssim, 'psnr': avg_psnr}

# Model Training

In [None]:
def train_step(real_image, model):
    ycbcr = tf.map_fn(fn=lambda img: rgb2ycbcr(img) / 255.0, elems=real_image)
    y_img = tf.cast(tf.expand_dims(ycbcr[:,:,:,0], axis=3), tf.float32)
    cbcr_img = tf.cast(ycbcr[:,:,:,1:], tf.float32)
    
    with tf.GradientTape() as tape:
        logits = model(y_img)
        loss = tf.keras.losses.MSE(cbcr_img, logits)
    grads = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    return loss, y_img, logits, cbcr_img

def train(epochs, dataset, model, skip):
    metrics = []
    metrics_detailed = []
    for epoch in range(epochs):
        start_epoch = time.time()
        for step, img_batch in enumerate(dataset):
            start_step = time.time()
            loss, y_img, cbcr_pred, cbcr_img = train_step(img_batch, model)
            # print('step: %d - took: %2.f seconds' % (step, time.time()-start_step))
            # if step >= step_len:
            #     break
        
        if epoch % 2 == 0:
            metrics_detailed_single, metrics_single = img_metrics(epoch, y_img, cbcr_pred, cbcr_img)
            print('epoch: %d - took: %2.f seconds | mse: %.4f | ssim: %.4f | psnr: %.4f' % (epoch, time.time()-start_epoch, metrics_single['mse'], metrics_single['ssim'], metrics_single['psnr']))
            show_img(epoch, y_img, cbcr_pred, cbcr_img, skip)
            metrics_detailed.append(metrics_detailed_single)
            metrics.append(metrics_single)
        
        if skip:
            model.save('./model-skip', save_format='tf')
        else:
            model.save('./model', save_format='tf')
            
    return metrics, metrics_detailed

In [None]:
colnet = Colnet(skip=False)
colnet_skip = Colnet(skip=True)
optimizer = tf.keras.optimizers.Adadelta()
epochs = 150

In [None]:
metrics_skip, metrics_detailed_skip = train(epochs, dataset, colnet_skip, True)

In [None]:
# metrics, metrics_detailed = train(epochs, dataset, colnet, False)

In [None]:
import pandas as pd

out_path = './'

# df_metrics = pd.DataFrame(metrics)
df_metrics_skip = pd.DataFrame(metrics_skip)

# df_metrics_detailed = pd.DataFrame(metrics_detailed)
df_metrics_detailed_skip = pd.DataFrame(metrics_detailed_skip)

# df_metrics.to_csv(out_path + 'metrics.csv', index=False)
df_metrics_skip.to_csv(out_path + 'metrics_skip.csv', index=False)
# df_metrics_detailed.to_csv(out_path + 'metrics_detailed.csv', index=False)
df_metrics_detailed_skip.to_csv(out_path + 'metrics_detailed_skip.csv', index=False)

In [None]:
# def test(model)