In [0]:
import numpy as np
import tensorflow as tf
%matplotlib inline 
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.layers import InputLayer, Conv2D, UpSampling2D,BatchNormalization,Dense
from tensorflow.keras.models import Sequential
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.keras.callbacks import ReduceLROnPlateau
from skimage.color import rgb2lab, lab2rgb
import os
import sys
import cv2

%load_ext tensorboard

In [0]:
# Hyper-parameters
EPOCHS = 25
BATCH_SIZE = 16
IMAGE_SIZE = 32

In [0]:
(train_images, _), (test_images, _) = cifar10.load_data()
datagen = ImageDataGenerator(shear_range=0.2, zoom_range=0.2,
                             rotation_range=20, horizontal_flip=True)


def generate_dataset(images, debug=False):
    X = []
    Y = []

    for i in images:
        lab_image_array = rgb2lab(i / 255)
        x = lab_image_array[:, :, 0]
        y = lab_image_array[:, :, 1:]
        y /= 128  # normalize
        
        if debug:
            fig = plt.figure()
            fig.add_subplot(1, 2, 1)
            plt.imshow(i / 255)

            fig.add_subplot(1, 2, 2)
            plt.imshow(lab2rgb(np.dstack((x, y * 128))))
            plt.show()

        X.append(x.reshape(IMAGE_SIZE, IMAGE_SIZE, 1))
        Y.append(y)

    X = np.array(X)
    Y = np.array(Y)

    return X, Y


# X_train, Y_train = generate_dataset(train_images)
X_test, Y_test = generate_dataset(test_images)

In [0]:
model = Sequential()
model.add(InputLayer(input_shape=(32, 32, 1)))
model.add(Conv2D(8, (3, 3), activation='relu', padding='same', strides=2))
model.add(Conv2D(8, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(16, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(16, (3, 3), activation='relu', padding='same', strides=2))
model.add(Conv2D(32, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(32, (3, 3), activation='relu', padding='same', strides=2))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(32, (3, 3), activation='relu', padding='same'))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(16, (3, 3), activation='relu', padding='same'))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(2, (3, 3), activation='tanh', padding='same'))
model.compile(optimizer='rmsprop',loss='mse')

tensorboard = TensorBoard(log_dir='logs/run')
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
                              patience=2, min_lr=0.00001)

In [0]:
def train_generator(batch_size):
    for batch in datagen.flow(train_images, batch_size=batch_size, shuffle=False):
        X_batch, Y_batch = generate_dataset(batch)
        yield (X_batch, Y_batch)

In [0]:
checkpoint_path = "/content/drive/My Drive/Colab Notebooks/colorize/model-checkpoints/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = ModelCheckpoint(filepath=checkpoint_path,
                              save_weights_only=True, verbose=0)

In [0]:
model.fit(train_generator(BATCH_SIZE), callbacks=[tensorboard, cp_callback,reduce_lr],
          epochs=EPOCHS, steps_per_epoch=len(train_images) // BATCH_SIZE,
          validation_data=(X_test, Y_test))

In [0]:
latest = tf.train.latest_checkpoint(checkpoint_dir)
model.load_weights(latest)

In [0]:
%tensorboard --logdir logs/run/

In [0]:
interested_ids = [23, 24, 30, 36, 40, 45, 48, 54, 56, 60, 65, 72, 73, 89, 130,
                  133, 138, 171, 179, 207, 219, 83, 97, 192, 81, 123, 246]

Y_hat = model.predict(X_test[:250])
total_count = len(Y_hat)


In [0]:
for idx, (x, y, y_hat) in enumerate(zip(X_test[:250], Y_test[:250], Y_hat)):

    if idx not in interested_ids:
        continue

    # Original RGB image
    orig_lab = np.dstack((x, y * 128))
    orig_rgb = lab2rgb(orig_lab)

    # Grayscale version of the original image
    grayscale_lab = np.dstack((x, np.zeros((IMAGE_SIZE, IMAGE_SIZE, 2))))
    grayscale_rgb = lab2rgb(grayscale_lab)

    # Colorized image
    predicted_lab = np.dstack((x, y_hat * 128))
    predicted_rgb = lab2rgb(predicted_lab)
    plt.figure(figsize=(32, 32), dpi=1)
    plt.margins(x=0, y=0)
    plt.axis('off')
    plt.imshow(grayscale_rgb)
    plt.savefig(os.path.join("/content/drive/My Drive/Colab Notebooks/colorize/", 'results', '{}-bw.png'.format(idx)), dpi=1)

    plt.axis('off')
    plt.imshow(orig_rgb)
    plt.savefig(os.path.join("/content/drive/My Drive/Colab Notebooks/colorize/", 'results', '{}-gt.png'.format(idx)), dpi=1)

    plt.axis('off')
    plt.imshow(predicted_rgb)
    plt.savefig(os.path.join("/content/drive/My Drive/Colab Notebooks/colorize/", 'results', '{}-tanhcnn.png'.format(idx)), dpi=1)

    sys.stdout.flush()
    sys.stdout.write('\r{} / {}'.format(idx + 1, total_count))
