In [2]:
import os
import sys
import warnings

import numpy as np
import pandas as pd
import cv2

import matplotlib.pyplot as plt

from glob import glob

from tqdm import tqdm
import skimage
from skimage.transform import resize
from skimage.color import rgb2gray, gray2rgb, rgb2lab, lab2rgb

from sklearn.model_selection import train_test_split

from keras.applications.inception_resnet_v2 import InceptionResNetV2, preprocess_input
from keras.models import Model, load_model, Sequential
from keras.preprocessing.image import ImageDataGenerator
from keras.layers import Input, Dense, UpSampling2D, RepeatVector, Reshape
from keras.layers.core import Dropout, Lambda
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPooling2D
from keras.layers.merge import concatenate
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from keras import backend as K

import tensorflow as tf

warnings.filterwarnings('ignore', category=UserWarning, module='skimage')
seed = 42
np.random.seed = seed

Setting constant variables values

In [3]:
IMG_WIDTH = 256
IMG_HEIGHT = 256
IMG_CHANNELS = 3
INPUT_SHAPE=(IMG_HEIGHT, IMG_WIDTH, 1)
TRAIN_PATH_GRAY = '../input/image-colorization-dataset/data/train_black/'
TRAIN_PATH_RGB = '../input/image-colorization-dataset/data/train_color/'
TEST_PATH_GRAY = '../input/image-colorization-dataset/data/test_black/'
TEST_PATH_RGB = '../input/image-colorization-dataset/data/test_color/'

Getting train and test paths for inputs and target images

In [4]:
X_train_gray_paths = glob(TRAIN_PATH_GRAY + '/*')
X_train_rgb_paths = glob(TRAIN_PATH_RGB + '/*')
X_test_gray_paths = glob(TEST_PATH_GRAY + '/*')
X_test_rgb_paths = glob(TEST_PATH_RGB + '/*')
for x in [X_train_gray_paths, X_train_rgb_paths, X_test_gray_paths, X_test_rgb_paths]:
    x = sorted([str(img) for img in x])
    print(len(x))

In [4]:
# checking paths to input and target
print(X_train_gray_paths[1000], X_train_rgb_paths[1000], sep='\n')

Creating train and test arrays of images

In [5]:
X_train_gray = np.zeros((len(X_train_gray_paths), IMG_HEIGHT, IMG_WIDTH, 1), dtype=np.uint8)
X_train_rgb = np.zeros((len(X_train_rgb_paths), IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.uint8)
X_test_gray = np.zeros((len(X_test_gray_paths), IMG_HEIGHT, IMG_WIDTH, 1), dtype=np.uint8)
X_test_rgb = np.zeros((len(X_test_rgb_paths), IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.uint8)
for x in [X_train_gray, X_train_rgb, X_test_gray, X_test_rgb]:
    print(x.shape)

In [6]:
img_paths = [X_train_gray_paths, X_train_rgb_paths, X_test_gray_paths, X_test_rgb_paths]
img_arrays = [X_train_gray, X_train_rgb, X_test_gray, X_test_rgb]

j = 1
for paths, img_array in zip(img_paths, img_arrays):
    for i, path in tqdm(enumerate(paths), total=len(paths), leave=False):
        image = cv2.imread(path)
        image = cv2.resize(image, (IMG_HEIGHT, IMG_WIDTH))
        if j % 2 != 0:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            image = image.reshape(IMG_HEIGHT, IMG_WIDTH, 1)
        else:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        img_array[i] = image
    j += 1

In [7]:
plt.figure(figsize=(12, 6))
plt.subplot(121)
plt.title('Input')
plt.imshow(X_train_gray[0].squeeze(), cmap='gray')
plt.subplot(122)
plt.title('Target')
plt.imshow(X_train_rgb[0])
plt.show()

**Train / Valid split**

In [7]:
X_train_rgb, X_valid_rgb = train_test_split(X_train_rgb, test_size=0.1, random_state=seed)
X_train_gray, X_valid_gray = train_test_split(X_train_gray, test_size=0.1, random_state=seed)
print(f'Train RGB: {X_train_rgb.shape[0]}, train GRAY: {X_train_gray.shape[0]}')
print(f'Valid RGB: {X_valid_rgb.shape[0]}, valid GRAY: {X_valid_gray.shape[0]}')

In [9]:
plt.figure(figsize=(12, 6))
plt.subplot(121)
plt.title('Input')
plt.imshow(X_train_gray[0].squeeze(), cmap='gray')
plt.subplot(122)
plt.title('Target')
plt.imshow(X_train_rgb[0])
plt.show()

In [8]:
# convert to float dtype
X_train_gray = X_train_gray.astype(np.float32) / 255.
X_train_rgb = X_train_rgb.astype(np.float32) / 255.
X_valid_rgb = X_valid_rgb.astype(np.float32) / 255.
X_valid_gray = X_valid_gray.astype(np.float32) / 255.
X_test_gray = X_test_gray.astype(np.float32) / 255.
X_test_rgb = X_test_rgb.astype(np.float32) / 255.

# Create the Model

The model is a combination of an autoencoder and resnet classifier. The best an autoencoder by itself is just shade everything in a brownish tone. The model uses an resnet classifier to give the neural network an "idea" of what things should be colored.

In [9]:
inception = InceptionResNetV2(weights=None, include_top=True)
inception.load_weights('../input/inception-resnet-v2-weights/inception_resnet_v2_weights_tf_dim_ordering_tf_kernels.h5')
inception.graph = tf.get_default_graph()

We are trying to generate 2 channels (green–red and blue–yellow filters) in addition to our given gray-scale image (that has 1 channel)

In [10]:
def Colorize():
    embed_input = Input(shape=(1000,))
    
    #Encoder
    encoder_input = Input(shape=(256, 256, 1,))
    encoder_output = Conv2D(128, (3,3), activation='relu', padding='same',strides=1)(encoder_input)
    encoder_output = MaxPooling2D((2, 2), padding='same')(encoder_output)
    encoder_output = Conv2D(128, (4,4), activation='relu', padding='same')(encoder_output)
    encoder_output = Conv2D(128, (3,3), activation='relu', padding='same',strides=1)(encoder_output)
    encoder_output = MaxPooling2D((2, 2), padding='same')(encoder_output)
    encoder_output = Conv2D(256, (4,4), activation='relu', padding='same')(encoder_output)
    encoder_output = Conv2D(256, (3,3), activation='relu', padding='same',strides=1)(encoder_output)
    encoder_output = MaxPooling2D((2, 2), padding='same')(encoder_output)
    encoder_output = Conv2D(256, (4,4), activation='relu', padding='same')(encoder_output)
    encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)
    encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)
    
    #Fusion
    fusion_output = RepeatVector(32 * 32)(embed_input) 
    fusion_output = Reshape(([32, 32, 1000]))(fusion_output)
    fusion_output = concatenate([encoder_output, fusion_output], axis=3) 
    fusion_output = Conv2D(256, (1, 1), activation='relu', padding='same')(fusion_output)
    
    #Decoder
    decoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(fusion_output)
    decoder_output = Conv2D(64, (3,3), activation='relu', padding='same')(decoder_output)
    decoder_output = UpSampling2D((2, 2))(decoder_output)
    decoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(decoder_output)
    decoder_output = UpSampling2D((2, 2))(decoder_output)
    decoder_output = Conv2D(64, (4,4), activation='relu', padding='same')(decoder_output)
    decoder_output = Conv2D(64, (3,3), activation='relu', padding='same')(decoder_output)
    decoder_output = Conv2D(32, (2,2), activation='relu', padding='same')(decoder_output)
    decoder_output = Conv2D(2, (3, 3), activation='tanh', padding='same')(decoder_output)
    decoder_output = UpSampling2D((2, 2))(decoder_output)
    return Model(inputs=[encoder_input, embed_input], outputs=decoder_output)

model = Colorize()
model.compile(optimizer='adam', loss='mean_squared_error') # lr=1e-3
model.summary()

# Prepare datasets

In [11]:
# Image transformer: augmentation
train_datagen = ImageDataGenerator(shear_range=0.2,
                                   zoom_range=0.2,
                                   rotation_range=20,
                                   horizontal_flip=True)

valid_datagen = ImageDataGenerator()

#Create embedding to decide wich part of image to color
def create_inception_embedding(grayscaled_rgb):
    def resize_gray(x):
        return resize(x, (299, 299, 3), mode='constant')
    grayscaled_rgb_resized = np.array([resize_gray(x)for x in grayscaled_rgb])
    grayscaled_rgb_resized = preprocess_input(grayscaled_rgb_resized)
    with inception.graph.as_default():
        embed = inception.predict(grayscaled_rgb_resized)
    return embed

#Generate training data: [grayed color-images, embedding_colored_from_gray], lab_target
# converting rgb images to gray is redundant, but it is more useful with skimage lib
# cv2.COLOR_BGR2LAB
def image_train_generator(X, batch_size = 20):
    for rgbs in train_datagen.flow(X, batch_size=batch_size):
        X_batch = rgb2gray(rgbs) # convert colored into gray
        grayscaled_rgb = gray2rgb(X_batch) # convert converted gray into colored
        lab_batch = rgb2lab(rgbs) # convert colored into Lab format
        X_batch = lab_batch[:,:,:,0] # take grascale channel as X
        X_batch = X_batch.reshape(X_batch.shape+(1,)) # reshape X to fit model input
        Y_batch = lab_batch[:,:,:,1:] / 128 # take 2 channels except gray as target
        yield [X_batch, create_inception_embedding(grayscaled_rgb)], Y_batch

# the same as train_datagen except augmentation and batch_size
def image_valid_generator(X, batch_size=8):
    for rgbs in valid_datagen.flow(X, batch_size=batch_size):
        X_batch = rgb2gray(rgbs)
        grayscaled_rgb = gray2rgb(X_batch)
        lab_batch = rgb2lab(rgbs)
        X_batch = lab_batch[:,:,:,0]
        X_batch = X_batch.reshape(X_batch.shape+(1,))
        Y_batch = lab_batch[:,:,:,1:] / 128
        yield [X_batch, create_inception_embedding(grayscaled_rgb)], Y_batch
        
# later we can use GT RGBS and GRAYS to compare the results in addition to test imgs
# test -> gray + embed = predict, compare with GT

# Checkpoints

In [None]:
# naive custom lr scheduler
def scheduler(epoch, lr):
    if epoch < 2:
        return lr * 0.0001
    elif epoch >= 2 and epoch < 8:
        return lr * 10
    else:
        return lr * 0.65

custom_scheduler = tf.keras.callbacks.LearningRateScheduler(scheduler)

Link to WarmupCosineDecay Scheduler & Callback
[https://www.dlology.com/blog/bag-of-tricks-for-image-classification-with-convolutional-neural-networks-in-keras/](http://)

In [12]:
# Set a learning rate annealer
learning_rate_reduction = ReduceLROnPlateau(monitor='val_loss', 
                                            patience=3, 
                                            verbose=1, 
                                            factor=0.5,
                                            min_lr=0.00001)
filepath = "color.h5"
checkpoint = ModelCheckpoint(filepath,
                             save_best_only=True,
                             monitor='val_loss',
                             mode='min')

early_stop = EarlyStopping(monitor='val_loss', 
                           patience=10,
                           restore_best_weights=True,
                           mode='min',
                           )

model_callbacks = [learning_rate_reduction, checkpoint, early_stop]
# as a way to improve try LearningRateScheduler with custom scheduler function
# e.g. take some warmup step with low lr, then increase it significantly
# and after decay lr monotoniously

# Train the Model

In [14]:
BATCH_SIZE = 20
model.fit_generator(image_train_generator(X_train_rgb, BATCH_SIZE),
            epochs=30,
            validation_data=image_valid_generator(X_valid_rgb, 8),
            validation_steps=X_valid_rgb.shape[0]//8,
            verbose=1,
            steps_per_epoch=X_train_rgb.shape[0]//BATCH_SIZE,
            callbacks=model_callbacks
                   )

In [13]:
model.load_weights(filepath)  # after training

In [None]:
model.save(filepath)
model.save_weights("Color_Weights.h5")

### Evaluate on test images

In [14]:
sample = X_test_rgb[:50]
color_me = gray2rgb(rgb2gray(sample))
color_me_embed = create_inception_embedding(color_me)
color_me = rgb2lab(color_me)[:,:,:,0]
color_me = color_me.reshape(color_me.shape+(1,))

output = model.predict([color_me, color_me_embed])
output = output * 128

decoded_imgs = np.zeros((len(output),256, 256, 3))

for i in tqdm(range(len(output)), total=len(output), leave=False):
    cur = np.zeros((256, 256, 3))
    cur[:,:,0] = color_me[i][:,:,0]
    cur[:,:,1:] = output[i]
    decoded_imgs[i] = lab2rgb(cur)

#### Plot random test image

In [19]:
idx = np.random.randint(50)

plt.figure(figsize=(8, 16))

plt.subplot(311)
plt.title('Grayscale')
plt.imshow(X_test_gray[idx].squeeze(), cmap='gray')
plt.axis('off')
 
plt.subplot(312)
plt.title('Colorized')
plt.imshow(decoded_imgs[idx].reshape(256, 256,3))
plt.axis('off')
    
plt.subplot(313)
plt.title('Original')
plt.imshow(X_test_rgb[idx])
plt.axis('off')
 
plt.tight_layout()
plt.show()