## Imports

In [None]:
import numpy as np
import sys
import os
import tensorflow as tf
import matplotlib.pyplot as plt
from IPython import display
from skimage import io, color
from skimage.transform import rescale, resize, downscale_local_mean
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import TensorflowUtils as utils
%matplotlib inline

## Space for global  variables

In [None]:
# width and height of input images 
input_size = 128

# directory for input images
input_directory = 'Data/OldImages/'

# directory for checkpoints (save / restore models)
# download the l2 loss error model if necessary 
l2_checkpoints      = "Models/CheckpointsL2/"
checkpoints_l2_link = "https://www.dropbox.com/s/0nv5qm7f3h06avv/CheckpointsL2.zip?dl=1"
utils.maybe_download_and_extract("Models", checkpoints_l2_link, is_zipfile=True)

# download the huber loss error model if necessary 
huber_checkpoints      = "Models/CheckpointsHuber/"
checkpoints_huber_link = "https://www.dropbox.com/s/r1k1dol4wzwdaqe/CheckpointsHuber.zip?dl=1"
utils.maybe_download_and_extract("Models", checkpoints_huber_link, is_zipfile=True)

# download the pairwise mean squared error model if necessary 
pmse_checkpoints     = "Models/CheckpointsPairwiseMSE"
checkpoints_mse_link = "https://www.dropbox.com/s/xux39xsmgyo0y1r/CheckpointsPairwiseMSE.zip?dl=1"
utils.maybe_download_and_extract("Models", checkpoints_mse_link, is_zipfile=True)

# set checkpoints variable to the proper checkpoint
checkpoints = pmse_checkpoints

# where to find and store the pre-trained VGG model 
model_dir = "Models/VGGModel/"
model_url = 'http://www.vlfeat.org/matconvnet/models/beta16/imagenet-vgg-verydeep-19.mat'

## Get the data

In [None]:
# get all of the names of the images in that directory and shuffle the names
input_images = np.asarray([x for x in os.listdir(input_directory) if x.endswith(".jpg")])

## Define the AutoEncoder network architecture

In [None]:
# Function to fill-in-the-blanks for the VGG pre-trained network 
def vgg_net(weights, image):
    layers = (
        # 'conv1_1', 'relu1_1',
        'conv1_2', 'relu1_2', 'pool1',
        'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
        'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
        'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
        'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
        'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
        'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3',
        'relu5_3', 'conv5_4', 'relu5_4'
    )

    net = {}
    current = image
    for i, name in enumerate(layers):
        kind = name[:4]
        if kind == 'conv':
            kernels, bias = weights[i + 2][0][0][0][0]
            kernels = utils.get_variable(np.transpose(kernels, (1, 0, 2, 3)), name=name + "_w")
            bias = utils.get_variable(bias.reshape(-1), name=name + "_b")
            current = utils.conv2d_basic(current, kernels, bias)
        elif kind == 'relu':
            current = tf.nn.relu(current, name=name)
        elif kind == 'pool':
            current = utils.avg_pool_2x2(current)
        net[name] = current

    return net

In [None]:
# Function that builds the rest of the net
def generator(images, train_phase):
    
    # Ge the model data and set up
    print("setting up vgg initialized conv layers ...")
    model_data = utils.get_model_data(model_dir, model_url)
    weights = np.squeeze(model_data['layers'])

    # Build the remaining "decoder" that will colorize the image
    with tf.variable_scope("generator") as scope:
        
        # First Layer: 3x3 2dConv with bias follower by RELU
        #              Need this layer because the input is only 1 channel
        W0 = utils.weight_variable([3, 3, 1, 64], name="W0")
        b0 = utils.bias_variable([64], name="b0")
        conv0 = utils.conv2d_basic(images, W0, b0)
        hrelu0 = tf.nn.relu(conv0, name="relu")

        # Add in the VGG network 
        image_net = vgg_net(weights, hrelu0)
        vgg_final_layer = image_net["relu5_3"]
        pool5 = utils.max_pool_2x2(vgg_final_layer)
        
        # Decoder Level 1: begin to upscale the image and decrease the number of filters 
        #                  Use conv2d_transpose_strided() with 4x4 filter
        deconv_shape1 = image_net["pool4"].get_shape()
        W_t1 = utils.weight_variable([4, 4, deconv_shape1[3].value, pool5.get_shape()[3].value], name="W_t1")
        b_t1 = utils.bias_variable([deconv_shape1[3].value], name="b_t1")
        conv_t1 = utils.conv2d_transpose_strided(pool5, W_t1, b_t1, output_shape=tf.shape(image_net["pool4"]))
        fuse_1 = tf.add(conv_t1, image_net["pool4"], name="fuse_1")

        # Decoder Level 2: continue to upscale the image and decrease the number of filters 
        deconv_shape2 = image_net["pool3"].get_shape()
        print(deconv_shape2)
        W_t2 = utils.weight_variable([4, 4, deconv_shape2[3].value, deconv_shape1[3].value], name="W_t2")
        b_t2 = utils.bias_variable([deconv_shape2[3].value], name="b_t2")
        conv_t2 = utils.conv2d_transpose_strided(fuse_1, W_t2, b_t2, output_shape=tf.shape(image_net["pool3"]))
        fuse_2 = tf.add(conv_t2, image_net["pool3"], name="fuse_2")
        
        # Decoder Level 3: continue to upscale the image and decrease the number of filters 
        shape = tf.shape(images)
        deconv_shape3 = tf.stack([shape[0], shape[1], shape[2], 2])
        W_t3 = utils.weight_variable([16, 16, 2, deconv_shape2[3].value], name="W_t3")
        b_t3 = utils.bias_variable([2], name="b_t3")
        pred = utils.conv2d_transpose_strided(fuse_2, W_t3, b_t3, output_shape=deconv_shape3, stride=8)

    # return the concatenation of the input with the output to make it the full image
    return tf.concat(axis=3, values=[images, pred], name="pred_image")

## Set up the network for training

In [None]:
print("Setting up network...")

# Create placeholders for the input images and the output images 
train_phase = tf.placeholder(tf.bool, name="train_phase")
images = tf.placeholder(tf.float32, shape=[None, None, None, 1], name='L_image')
lab_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name="LAB_image")

# set pred_images to the output of the network 
pred_image = generator(images, train_phase)

# define the loss function that we are minimizing as the L2-loss between the images 
gen_loss_pmse  = tf.reduce_mean(tf.losses.mean_pairwise_squared_error(lab_images, pred_image))

## Function that takes the LAB layers and outputs the RGB image

In [None]:
# function takes the L, A, B, channels --> concatenates them, and converts them to RGB
def labChannelsToRGB(l, a, b): 
    l[l > 99] = 99
    new_lab = np.stack((l, a, b), axis=2)
    new_lab = new_lab.astype('float64');
    return color.lab2rgb(new_lab)

## Function to test an image and output the three relevant images

In [None]:
# function takes two images in the LAB-color-scheme and converts them to RGB before displaying them side-by-side
def showNetPredictions(l_image, output_l2, output_huber, output_pmse, color_images): 
    num_tests = l_image.shape[0]
    output_l2    = np.asarray(output_l2)[0,:,:,:,:]
    output_huber = np.asarray(output_huber)[0,:,:,:,:]
    output_pmse  = np.asarray(output_pmse)[0,:,:,:,:]

    for i in range(num_tests): 
        fig=plt.figure(figsize=(18, 16), dpi= 80, facecolor='w', edgecolor='k')
        
        # Plot the Black-and-White Image 
        plt.subplot(1,5,1)
        plt.title("Black and White Image")
        plt.imshow(l_image[i,:,:,0], cmap='gray')
        
        # Plot the Reconstructed / Predicted Image for L2 Loss
        plt.subplot(1,5,2)
        plt.title("L2 Loss Predicted Image")
        plt.imshow(color.lab2rgb(output_l2[i,:,:,:].astype('float64')))
        
        # Plot the Reconstructed / Predicted Image for Huber Loss
        plt.subplot(1,5,3)
        plt.title("Huber Loss Predicted Image")
        plt.imshow(color.lab2rgb(output_huber[i,:,:,:].astype('float64')))
        
        # Plot the Reconstructed / Predicted Image for Huber Loss
        plt.subplot(1,5,4)
        plt.title("Pairwise MSE Loss Predicted Image")
        plt.imshow(color.lab2rgb(output_pmse[i,:,:,:].astype('float64')))
        
        # Plot the Original / Ground-Truth Image 
        plt.subplot(1,5,5)
        plt.title("Ground Truth Image")
        plt.imshow(color.lab2rgb(color_images[i,:,:,:].astype('float64')))

## Test the trained network

In [None]:
# first we will get the saver 
saver = tf.train.Saver()
should_resize = False
input_size = 224

# start a session
with tf.Session() as sess:
    # restore the pairwise MSE model
    ckpt = tf.train.get_checkpoint_state(checkpoints)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
    else: 
        assert(False)  
        
    # iterate through the examples: 
    for i in range(len(input_images)): 
        im =  io.imread(input_directory+input_images[i])
        if im.shape[2] != 3: 
            print("INCORRECT IMAGE CHANNELS")
            assert(0)
        
        # if we are resizing then do it
        if should_resize:
            im = resize(im, (input_size, input_size))
        
        # convert the image to LAB and extract the L channel
        lab_img = color.rgb2lab(im)
        lab_img = np.expand_dims(lab_img, axis=0)
        lab_l = lab_img[:,:,:,0]
        lab_l = np.expand_dims(lab_l, axis=3)
        
        # get the predicted image
        feed_dict = {images: lab_l, lab_images: lab_img, train_phase: False}
        output_pmse = sess.run([pred_image], feed_dict=feed_dict)
        output_pmse = np.asarray(output_pmse)
        output_pmse = output_pmse[0,0,:,:,:]
        print(output_pmse.shape)
        
        # plot the l-channel image
        fig=plt.figure(figsize=(14, 16), dpi= 80, facecolor='w', edgecolor='k')
        plt.subplot(1,2,1)
        plt.title("Black and White Image")
        plt.imshow(lab_l[0,:,:,0], cmap='gray')
        
        # plot the reconstructed / predicted image
        plt.subplot(1,2,2)
        plt.title("Pairwise MSE Loss Predicted Image")
        plt.imshow(color.lab2rgb(output_pmse.astype('float64')))
    