In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
Test_dir_path = ""
Train_dir_path = ""
save_weight_path =""
load_weight_path =""
save_model_path = ""
save_outpput_path =""

# Imports

In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from skimage.io import imsave
from skimage.color import rgb2lab, lab2rgb
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, RepeatVector,Reshape
from tensorflow.keras.layers import  Dense, Flatten, Input, Concatenate
from tensorflow.keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img

# Import Data

In [None]:
def loadImagesToArray(dir_path, num_of_img=-1, search_inside =False):
  """
  dir_path : path of directory from which images will be imported
  num_of_imgs (Integer): number of images to be imported from the directory if not 
              given than all images will be imported 
  search_inside (boolean, default : False) : If true all images inside that directory
              along with the images in subdirectory will be added to output array
  """
  images = []
  count = -1
  if search_inside==False:
      for filename in os.listdir(dir_path):
          count+=1
          if(count==num_of_img):
              break
          images.append(img_to_array(load_img(dir_path+os.sep+filename)))
  if search_inside==True:
      for root,dirs,files in os.walk(dir_path):
        for filename in files:
            count+=1
            if(count==num_of_img):
                break
            images.append(img_to_array(load_img(root+os.sep+filename)))
  return np.array(images,dtype=float)/255.0

def DataGenerator():
    DataGen = ImageDataGenerator(        
        shear_range=0.2,
        zoom_range=0.2,
        rotation_range=20,
        horizontal_flip=True)
    return DataGen

def RGB2GRAY(img,add_channel_dim=False):
  conv_matrix = np.array([0.212671 ,0.715160,0.072169])
  gray_img = img @ conv_matrix
  if add_channel_dim==True:
    return gray_img.reshape(np.array([*list(gray_img.shape),1]))
  else:
    return gray_img

def RGB2ab(img,use_skimage=True):
  """
  Refrences
  * https://en.wikipedia.org/wiki/Lab_color_space
  * https://github.com/scikit-image/scikit-image/blob/main/skimage/color/colorconv.py#L990-L1050
  """
  if use_skimage==False:
    def finv(cie):
      cond = cie > 0.008856
      cie[cond] = np.cbrt(cie[cond])
      cie[~cond] = 7.787 * cie[~cond] + 16. / 116.
      return cie     

    conv_matrix =np.array( [[0.412453, 0.357580, 0.180423],
            [0.212671, 0.715160, 0.072169],
            [0.019334, 0.119193, 0.950227]])
    CIE = np.matmul(img,conv_matrix.T)
    CIE[0] = CIE[0]/0.95047
    CIE[2] = CIE[2]/1.08883
    CIE = finv(CIE)
    x, y, z = CIE[..., 0], CIE[..., 1], CIE[..., 2]
    a =  (500*(x-y)+127)/255.0
    b =  (200*(y-z)+127)/255.0
    return np.concatenate([x[..., np.newaxis] for x in [a, b]], axis=-1)
  else:
    Lab = rgb2lab(img)
    a = (Lab[...,1]+127)/255.0
    b = (Lab[...,2]+127)/255.0
    return np.concatenate([x[..., np.newaxis] for x in [a, b]], axis=-1)

def Lab2RGB(gray,ab):
  ab = ab*255.0 -127
  gray = gray*100
  Lab =np.concatenate([x[..., np.newaxis] for x in [gray[...,0], ab[...,0],ab[...,1]]], axis=-1)
  return lab2rgb(Lab)

def compare_results(img_gt,img_in,img_out,save_results=False,save_as=""):
  fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
  ax1.imshow(img_gt)
  ax1.set_title('Ground Truth')
  ax2.imshow(img_in,cmap='gray')
  ax2.set_title('Input')
  ax3.imshow(img_out)
  ax3.set_title('Output')
  axes = [ax1,ax2,ax3]
  for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
  plt.show()
  if save_results==True:
    path = save_as+'.svg'
    fig.savefig(path,dpi=300)

def BatchGenerator(data,imgDataGen,batch_size=64):
  for batch in imgDataGen.flow(data, batch_size=batch_size):
    yield RGB2GRAY(batch,True), RGB2ab(batch)


# Model

In [None]:
HEIGHT = 256
WIDTH  = 256
ks = (3,3) 
actt ='sigmoid'
learning_rate=0.001


def build_model(ks=(3,3),act='sigmoid',learning_rate=1e-2):
  input_lvl = Input(shape = (HEIGHT,WIDTH,1))
  
  # Initial Shared Network of Low - Level Features
  low_lvl = Conv2D(64 ,kernel_size=ks,strides=(2,2),activation=act,padding='SAME')(input_lvl)
  low_lvl = layers.BatchNormalization()(low_lvl)
  low_lvl = Conv2D(128,kernel_size=ks,strides=(1,1),activation=act,padding='SAME')(low_lvl) 
  low_lvl = layers.BatchNormalization()(low_lvl)
  low_lvl = Conv2D(128,kernel_size=ks,strides=(2,2),activation=act,padding='SAME')(low_lvl) 
  low_lvl = layers.BatchNormalization()(low_lvl)
  low_lvl = Conv2D(256,kernel_size=ks,strides=(1,1),activation=act,padding='SAME')(low_lvl) 
  low_lvl = layers.BatchNormalization()(low_lvl)
  low_lvl = Conv2D(256,kernel_size=ks,strides=(2,2),activation=act,padding='SAME')(low_lvl)
  low_lvl = layers.BatchNormalization()(low_lvl)
  low_lvl = Conv2D(512,kernel_size=ks,strides=(1,1),activation=act,padding='SAME')(low_lvl)
  low_lvl = layers.BatchNormalization()(low_lvl)

  #Path one for  Mid-Level Features Network
  mid_lvl = Conv2D(512,kernel_size=ks,strides=(1,1),activation=act,padding='SAME')(low_lvl)
  mid_lvl = layers.BatchNormalization()(mid_lvl)
  mid_lvl = Conv2D(256,kernel_size=ks,strides=(1,1),activation=act,padding='SAME')(mid_lvl)
  mid_lvl = layers.BatchNormalization()(mid_lvl)

  #Path two for Global Features Network
  global_lvl = Conv2D(512,kernel_size=ks,strides=(2,2),activation=act,padding='SAME')(low_lvl)
  global_lvl = layers.BatchNormalization()(global_lvl)
  global_lvl = Conv2D(512,kernel_size=ks,strides=(1,1),activation=act,padding='SAME')(global_lvl)
  global_lvl = layers.BatchNormalization()(global_lvl)
  global_lvl = Conv2D(512,kernel_size=ks,strides=(2,2),activation=act,padding='SAME')(global_lvl)
  global_lvl = layers.BatchNormalization()(global_lvl)
  global_lvl = Conv2D(512,kernel_size=ks,strides=(1,1),activation=act,padding='SAME')(global_lvl)
  global_lvl = layers.BatchNormalization()(global_lvl)
  global_lvl = Flatten()(global_lvl) 
  global_lvl = Dense(1024,activation=act)(global_lvl)
  global_lvl = Dense(512 ,activation=act)(global_lvl)
  global_lvl = Dense(256 ,activation=act)(global_lvl)
  

  # Fusing the output of above two paths
  fusion_lvl = RepeatVector(mid_lvl.shape[1] * mid_lvl.shape[1])(global_lvl) 
  fusion_lvl = Reshape(([mid_lvl.shape[1],mid_lvl.shape[1]  , 256]))(fusion_lvl)
  fusion_lvl = Concatenate( axis=3)([mid_lvl, fusion_lvl]) 
  fusion_lvl = Conv2D(256, kernel_size=ks,strides =(1, 1), activation=act,padding='SAME')(fusion_lvl)

  #Colorization Network
  color_lvl = Conv2DTranspose(128,kernel_size = ks,strides = (1,1),padding='SAME',activation=act)(fusion_lvl)
  color_lvl = layers.BatchNormalization()(color_lvl)
  color_lvl = Conv2DTranspose(64,kernel_size = ks,strides = (2,2),padding='SAME',activation=act)(color_lvl)
  color_lvl = layers.BatchNormalization()(color_lvl)
  color_lvl = Conv2DTranspose(64,kernel_size = ks,strides = (1,1),padding='SAME',activation=act)(color_lvl)
  color_lvl = layers.BatchNormalization()(color_lvl)
  color_lvl = Conv2DTranspose(32,kernel_size = ks,strides = (2,2),padding='SAME',activation=act)(color_lvl)
  color_lvl = layers.BatchNormalization()(color_lvl)
  # color_lvl = Conv2D(32,kernel_size = ks,strides = (1,1),padding='VALID',activation=act)(color_lvl)
  # color_lvl = layers.BatchNormalization()(color_lvl)
  # #Output Layer
  output_lvl = Conv2DTranspose(2,kernel_size=ks,strides=(2,2),padding='SAME',activation='sigmoid')(color_lvl)


  #Model Parameters
  model = Model(inputs = input_lvl, outputs = output_lvl)
  optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
  model.compile(loss = tf.keras.losses.MeanSquaredError(), optimizer = optimizer,metrics = ['accuracy',tf.keras.metrics.CosineSimilarity(
      name="cosine_similarity", dtype=None, axis=-1
      )])

  return model

mymod = build_model(act = actt ,learning_rate = 0.0001)

In [None]:
mymod.load_weights(load_weight_path)

In [None]:
mymod.summary()

In [None]:
data = loadImagesToArray(train_dir_path,500,True)
datagen = DataGenerator()

In [None]:
history = mymod.fit(BatchGenerator(data,datagen,50),steps_per_epoch = 100,epochs=10)
mymod.load_weights(save_weight_path)
mymod.save(save_model_path)


In [None]:
import pandas as pd

log = pd.DataFrame(history.history)
log[log.columns[0]].plot(figsize=(8,5))
plt.savefig('loss.svg')
plt.title('Loss')
plt.show()

In [None]:
color_me = loadImagesToArray(test_dir_path,40,False)
gray = RGB2GRAY(color_me,True)
gray2 = RGB2GRAY(color_me)
output = mymod.predict(gray)
for i in range(40):
  pred = Lab2RGB(gray[i],output[i])
  path = save_output_path+str(i)
  compare_results(color_me[i],gray2[i],pred.reshape(color_me[i].shape),save_results=True,save_as=path)