In [1]:
# This section is to move to the directory on Google Drive
import os
# os.chdir('drive/MyDrive/ComVis_20211')

In [2]:
import numpy as np # linear algebra
import os
import cv2
from keras import backend as K
from keras.layers import Conv2D,UpSampling2D,Input
from keras.layers.merge import concatenate
from keras.models import Model
from keras.preprocessing.image import img_to_array, load_img
import tensorflow as tf
from tensorflow.keras.applications.inception_resnet_v2 import preprocess_input
from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb
from skimage.transform import resize
import math

tf.random.set_seed(123)
session_conf = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)
tf.compat.v1.keras.backend.set_session(sess)
tf.random.set_seed(2)
np.random.seed(1)

In [3]:
HEIGHT=256
WIDTH=256

trainPath = 'Train'
validPath = 'Valid'
testPath = 'Test'

In [4]:
# backbone from pretrained model
def create_resnet_embedding(grayscaled_rgb):
    grayscaled_rgb_resized = []
    for i in grayscaled_rgb:
        i = resize(i, (224, 224, 3), mode='constant')
        grayscaled_rgb_resized.append(i)
    grayscaled_rgb_resized = np.array(grayscaled_rgb_resized)
    grayscaled_rgb_resized = preprocess_input(grayscaled_rgb_resized)
    embed = resnet.predict(grayscaled_rgb_resized)
    return embed

In [5]:
# extract color
def color_extraction(lab_batch, HEIGHT, WIDTH):
    color_a = []
    color_b = []
    for img in lab_batch:
        a = cv2.calcHist([img], [1], mask=None, histSize=[256], ranges=[-128, 128])
        b = cv2.calcHist([img], [2], mask=None, histSize=[256], ranges=[-128, 128])
        a = a[:, 0]/(HEIGHT*WIDTH)
        b = b[:, 0]/(HEIGHT*WIDTH)
        color_a.append(a)
        color_b.append(b)
    color_a = np.array(color_a)
    color_b = np.array(color_b)
    
    return color_a, color_b

In [6]:
class DataSequence(tf.keras.utils.Sequence):

  def __init__(self, imagePath, batch_size):
      self.imagePath = imagePath
      self.img_list = os.listdir(imagePath)
      self.batch_size = batch_size
      

  def __len__(self):
      return math.ceil(len(self.img_list) / self.batch_size)

  def __getitem__(self, idx):
      X = []
      for image in self.img_list[( idx*self.batch_size ) : ( (idx+1) * self.batch_size )]:
          img = img_to_array(load_img(os.path.join(self.imagePath, image)))
          img = resize(img, (HEIGHT,WIDTH,3))
          X.append(img)
      X = np.array(X, dtype=np.float32)
      Xtrain = 1.0/255*X

      grayscaled_rgb = gray2rgb(rgb2gray(Xtrain))
      embed = create_resnet_embedding(grayscaled_rgb)
      lab_batch = rgb2lab(Xtrain)
      color_a, color_b = color_extraction(lab_batch, HEIGHT, WIDTH)
      X_batch = lab_batch[:,:,:,0]
      X_batch = X_batch.reshape(X_batch.shape+(1,))
      Y_batch = lab_batch[:,:,:,1:] / 128

      return ([X_batch, embed, color_a, color_b], Y_batch)

In [7]:
K.clear_session()
#Load weights
resnet = tf.keras.applications.resnet50.ResNet50(include_top=True, weights='imagenet', classes=1000)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels.h5


In [8]:
CHECKPOINT = 'models/encoder_color_lab'
if not os.path.exists(CHECKPOINT):
    embed_input = Input(shape=(1000,))
    color_a_input = Input(shape=(256,)) #channel a color
    color_b_input = Input(shape=(256,)) #channel b color

    #Encoder
    encoder_input = Input(shape=(256, 256, 1,))
    encoder_output = Conv2D(64, (3,3), activation='relu', padding='same', strides=2)(encoder_input)
    encoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(encoder_output)
    encoder_output = Conv2D(128, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
    encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)
    encoder_output = Conv2D(256, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
    encoder_output = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output)
    encoder_output = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output)
    encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)

    #Fusion
    fusion_a = tf.keras.layers.RepeatVector(32 * 32)(color_a_input)
    fusion_a = tf.keras.layers.Reshape(([32, 32, 256]))(fusion_a)
    fusion_b = tf.keras.layers.RepeatVector(32 * 32)(color_b_input)
    fusion_b = tf.keras.layers.Reshape(([32, 32, 256]))(fusion_b)
    fusion_output = tf.keras.layers.RepeatVector(32 * 32)(embed_input)
    fusion_output = tf.keras.layers.Reshape(([32, 32, 1000]))(fusion_output)
    fusion_output = concatenate([encoder_output, fusion_output, fusion_a, fusion_b], axis=3) 
    fusion_output = Conv2D(256, (1, 1), activation='relu', padding='same')(fusion_output) 

    #Decoder
    decoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(fusion_output)
    decoder_output = UpSampling2D((2, 2))(decoder_output)
    decoder_output = Conv2D(64, (3,3), activation='relu', padding='same')(decoder_output)
    decoder_output = UpSampling2D((2, 2))(decoder_output)
    decoder_output = Conv2D(32, (3,3), activation='relu', padding='same')(decoder_output)
    decoder_output = Conv2D(16, (3,3), activation='relu', padding='same')(decoder_output)
    decoder_output = Conv2D(2, (3, 3), activation='tanh', padding='same')(decoder_output)
    decoder_output = UpSampling2D((2, 2))(decoder_output)

    model = Model(inputs=[encoder_input, embed_input, color_a_input, color_b_input], outputs=decoder_output)
else:
    model = tf.keras.models.load_model(CHECKPOINT)

In [9]:
LEARNING_RATE = 0.001
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
                            loss='mean_absolute_error')
BATCH_SIZE = 32
EPOCHS = 50

early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, verbose=1, 
                           mode='auto', restore_best_weights=True)

model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
                                filepath=CHECKPOINT,
                                save_weights_only=False,
                                monitor='val_loss',
                                mode='max',
                                save_best_only=False)

model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_5 (InputLayer)           [(None, 256, 256, 1  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 128, 128, 64  640         ['input_5[0][0]']                
                                )                                                                 
                                                                                                  
 conv2d_1 (Conv2D)              (None, 128, 128, 12  73856       ['conv2d[0][0]']                 
                                8)                                                            

In [10]:
model.fit(DataSequence(trainPath,BATCH_SIZE),
                        batch_size=BATCH_SIZE,
                        epochs=EPOCHS,
                        validation_data=DataSequence(validPath,BATCH_SIZE),
                        shuffle=True,
                        callbacks=[early_stop, model_checkpoint])

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
INFO:tensorflow:Assets written to: model_resnet_color/assets
Epoch 00028: early stopping


<keras.callbacks.History at 0x7f66bafb3d10>