In [None]:
# This section is to move to the directory on Google Drive
import os
# os.chdir('drive/MyDrive/ComVis_20211')

In [None]:
import numpy as np # linear algebra
import os
import cv2
from keras import backend as K
from keras.layers import Conv2D,UpSampling2D,Input
from keras.layers.merge import concatenate
from keras.models import Model
from keras.preprocessing.image import img_to_array, load_img
import tensorflow as tf
from tensorflow.keras.applications.inception_resnet_v2 import preprocess_input
from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb
from skimage.transform import resize
import math

tf.random.set_seed(123)
session_conf = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)
tf.compat.v1.keras.backend.set_session(sess)
tf.random.set_seed(2)
np.random.seed(1)

In [None]:
HEIGHT=256
WIDTH=256

trainPath = 'Train'
validPath = 'Valid'
testPath = 'Test'

In [None]:
# backbone from pretrained model
def create_resnet_embedding(grayscaled_rgb):
    grayscaled_rgb_resized = []
    for i in grayscaled_rgb:
        i = resize(i, (224, 224, 3), mode='constant')
        grayscaled_rgb_resized.append(i)
    grayscaled_rgb_resized = np.array(grayscaled_rgb_resized)
    grayscaled_rgb_resized = preprocess_input(grayscaled_rgb_resized)
    embed = resnet.predict(grayscaled_rgb_resized)
    return embed

In [None]:
# extract color
def color_extraction(rgb_batch, HEIGHT, WIDTH):
    color_r = []
    color_g = []
    color_b = []
    for img in rgb_batch:
        r = cv2.calcHist([img], [0], mask=None, histSize=[256], ranges=[0,1])
        g = cv2.calcHist([img], [1], mask=None, histSize=[256], ranges=[0,1])
        b = cv2.calcHist([img], [2], mask=None, histSize=[256], ranges=[0,1])
        r = r[:, 0]/(HEIGHT*WIDTH)
        g = g[:, 0]/(HEIGHT*WIDTH)
        b = b[:, 0]/(HEIGHT*WIDTH)
        color_r.append(r)
        color_g.append(g)
        color_b.append(b)
    color_r = np.array(color_r)
    color_g = np.array(color_g)
    color_b = np.array(color_b)
    
    return color_r, color_g, color_b

In [None]:
class DataSequence(tf.keras.utils.Sequence):

  def __init__(self, imagePath, batch_size):
      self.imagePath = imagePath
      self.img_list = os.listdir(imagePath)
      self.batch_size = batch_size
      

  def __len__(self):
      return math.ceil(len(self.img_list) / self.batch_size)

  def __getitem__(self, idx):
      X = []
      for image in self.img_list[( idx*self.batch_size ) : ( (idx+1) * self.batch_size )]:
          img = img_to_array(load_img(os.path.join(self.imagePath, image)))
          img = resize(img, (HEIGHT,WIDTH,3))
          X.append(img)
      X = np.array(X, dtype=np.float32)
      Xtrain = 1.0/255*X

      grayscaled_rgb = gray2rgb(rgb2gray(Xtrain))
      embed = create_resnet_embedding(grayscaled_rgb)
      color_r, color_g, color_b = color_extraction(Xtrain, HEIGHT, WIDTH)
      X_batch = grayscaled_rgb[:,:,:,0]
      X_batch = X_batch.reshape(X_batch.shape+(1,))
      Y_batch = Xtrain

      return ([X_batch, embed, color_r, color_g, color_b], Y_batch)
    #   return ([X_batch, embed], Y_batch)

In [None]:
K.clear_session()
#Load weights
resnet = tf.keras.applications.resnet50.ResNet50(include_top=True, weights='imagenet', classes=1000)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels.h5


In [None]:
CHECKPOINT = 'models/unet_color_rgb'
if not os.path.exists(CHECKPOINT):
    embed_input = Input(shape=(1000,))
    color_r_input = Input(shape=(256,)) #channel r color
    color_g_input = Input(shape=(256,)) #channel g color
    color_b_input = Input(shape=(256,)) #channel b color

    #Encoder
    encoder_input = Input(shape=(256, 256, 1,))
    encoder_256 = Conv2D(64, (3,3), activation='relu', padding='same')(encoder_input)
    encoder_128 = Conv2D(64, (3,3), activation='relu', padding='same', strides=2)(encoder_256)
    encoder_128 = Conv2D(128, (3,3), activation='relu', padding='same')(encoder_128)
    encoder_64 = Conv2D(128, (3,3), activation='relu', padding='same', strides=2)(encoder_128)
    encoder_64 = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_64)
    encoder_32 = Conv2D(256, (3,3), activation='relu', padding='same', strides=2)(encoder_64)
    encoder_32 = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_32)
    encoder_32 = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_32)
    encoder_32 = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_32)

    #Fusion
    fusion_r = tf.keras.layers.RepeatVector(32 * 32)(color_r_input)
    fusion_r = tf.keras.layers.Reshape(([32, 32, 256]))(fusion_r)
    fusion_g = tf.keras.layers.RepeatVector(32 * 32)(color_g_input)
    fusion_g = tf.keras.layers.Reshape(([32, 32, 256]))(fusion_g)
    fusion_b = tf.keras.layers.RepeatVector(32 * 32)(color_b_input)
    fusion_b = tf.keras.layers.Reshape(([32, 32, 256]))(fusion_b)
    fusion_output = tf.keras.layers.RepeatVector(32 * 32)(embed_input)
    fusion_output = tf.keras.layers.Reshape(([32, 32, 1000]))(fusion_output)
    fusion_output = concatenate([fusion_output, fusion_r, fusion_g, fusion_b], axis=3) 
    fusion_output = Conv2D(256, (1, 1), activation='relu', padding='same')(fusion_output) 

    #Decoder
    decoder_32 = concatenate([encoder_32, fusion_output], axis=3)
    decoder_32 = Conv2D(256, (3,3), activation='relu', padding='same')(decoder_32)
    # decoder_32 = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_32)
    decoder_64 = UpSampling2D((2, 2))(decoder_32)
    decoder_64 = concatenate([encoder_64, decoder_64], axis=3)
    decoder_64 = Conv2D(256, (3,3), activation='relu', padding='same')(decoder_64)
    decoder_64 = Conv2D(128, (3,3), activation='relu', padding='same')(decoder_64)
    decoder_128 = UpSampling2D((2, 2))(decoder_64)
    decoder_128 = concatenate([encoder_128, decoder_128], axis=3)
    decoder_128 = Conv2D(128, (3,3), activation='relu', padding='same')(decoder_128)
    decoder_128 = Conv2D(64, (3,3), activation='relu', padding='same')(decoder_128)
    decoder_256 = UpSampling2D((2, 2))(decoder_128)
    decoder_256 = concatenate([encoder_256, decoder_256], axis=3)
    decoder_256 = Conv2D(64, (3,3), activation='relu', padding='same')(decoder_256)
    decoder_output = Conv2D(3, (3,3), activation='tanh', padding='same')(decoder_256)
    

    model = Model(inputs=[encoder_input, embed_input, color_r_input, color_g_input, color_b_input], outputs=decoder_output)
    # model = Model(inputs=[encoder_input, embed_input], outputs=decoder_output)
else:
    model = tf.keras.models.load_model(CHECKPOINT)

In [None]:
LEARNING_RATE = 0.0005
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
                            loss='mean_absolute_error')
BATCH_SIZE = 32
EPOCHS = 50

early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, verbose=1, 
                           mode='auto', restore_best_weights=True)

model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
                                filepath=CHECKPOINT,
                                save_weights_only=False,
                                monitor='val_loss',
                                mode='max',
                                save_best_only=False)

model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_6 (InputLayer)           [(None, 256, 256, 1  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 256, 256, 64  640         ['input_6[0][0]']                
                                )                                                                 
                                                                                                  
 conv2d_1 (Conv2D)              (None, 128, 128, 64  36928       ['conv2d[0][0]']                 
                                )                                                             

In [None]:
model.fit(DataSequence(trainPath,BATCH_SIZE),
                        batch_size=BATCH_SIZE,
                        epochs=EPOCHS,
                        validation_data=DataSequence(validPath,BATCH_SIZE),
                        shuffle=True,
                        callbacks=[early_stop, model_checkpoint])

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50