# Colorization
Class project - CS231N - Stanford University

Vincent Billaut  
Matthieu de Rochemonteix  
Marc Thibault  

See our GitHub [repo](https://github.com/vincentbillaut/all-colors-matter) for more details on the implementation.

## Imports

In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
from matplotlib.pyplot import imread, imsave
import numpy as np

%matplotlib inline

from matplotlib.pyplot import imread
from matplotlib import animation
from IPython.display import display, HTML
from tqdm import tqdm_notebook

####################################################
# Setting working directory to enable relative paths
import os
os.chdir("../")
import pickle
####################################################

from models.coloringmodel import Config
from models.naive_convnet import NaiveConvColoringModel
from models.unet import UNetColoringModel
from utils.dataset import Dataset
from utils.color_utils import RGB_to_YUV, YUV_to_RGB
from utils.data_utils import load_image_jpg_to_YUV, dump_YUV_image_to_jpg
from utils.color_discretizer import ColorDiscretizer
from utils.data_augmentation import DataAugmenter
from utils.video_utils import smoothen_frame_list

## Load existing model

In [None]:
# Enter output folder name
output_folder = "outputs/20180604_223338-2cd6/"

In [None]:
config = Config("configs/config_unet_suncoast2.json")

cd = ColorDiscretizer(max_categories=config.max_categories)
cd_loaded = pickle.load(open("notebooks/cd.pkl", "rb"))
for k in cd_loaded.__dict__:
    cd.__dict__[k] = cd_loaded.__dict__[k]

In [None]:
config.val_path = "data/long_video_frames_/"

In [None]:
da = DataAugmenter()
dataset = Dataset(config.train_path, config.val_path, cd, da)
model = UNetColoringModel(config, dataset)

In [None]:
model.load(output_folder)

In [None]:
image_paths = [os.path.join(config.val_path, impath) for impath in os.listdir(config.val_path)]
image_paths = sorted(image_paths)

### Utility functions

In [None]:
temperatures = [.1, 1., 3.]
cold_temperatures = [.05, .1, .38]
    
def output_prediction(image_paths, cd):
    n_images = len(image_paths)
    prediction_list = []
    yscale_list = []
    uvtruth_list = []
    for i, image_path in enumerate(tqdm_notebook(image_paths)):
        loss, pred_image_categories, (im_yscale, im_uvscale, msk) = model.pred_color_one_image(image_path)
        mask_shape = [0, 0]
        if msk[:, 0].mean() == 1.:
            mask_shape[0] = msk.shape[1]
        else:
            mask_shape[0] = np.argmin(msk[:, 0])
        if msk[0, :].mean() == 1.:
            mask_shape[1] = msk.shape[0]
        else:
            mask_shape[1] = np.argmin(msk[0, :])
        
        cropped_pred_image_categories = pred_image_categories[0, :mask_shape[0], :mask_shape[1], :]
        cropped_im_yscale = im_yscale[:mask_shape[0], :mask_shape[1]]
        cropped_im_uvscale = im_uvscale[:mask_shape[0], :mask_shape[1], :]
        
        prediction_list.append(cropped_pred_image_categories)
        yscale_list.append(cropped_im_yscale)
        uvtruth_list.append(cropped_im_uvscale)

    return prediction_list, yscale_list, uvtruth_list

def display_pred_array(prediction_list, yscale_list, uvtruth_list, cd):
    n_images = len(prediction_list)
    n_temperatures = len(temperatures)
    plt.figure(figsize=(16, n_images * 5))
    
    for i in range(n_images):
        pred_frame = prediction_list[i]
        yscale_frame = yscale_list[i]
        uvtruth_frame = uvtruth_list[i]
        

        true_YUV_image = np.concatenate([yscale_frame, uvtruth_frame], axis=2)
        true_RGB_image = YUV_to_RGB(true_YUV_image).astype("uint8")
        
        ax = plt.subplot(n_images, n_temperatures + 1, 1 + i * (n_temperatures + 1))
        plt.imshow(true_RGB_image)
        plt.axis('off')
        ax.set_title("ground truth")
        for j, temperature in enumerate(temperatures):
            pred_UVimage = cd.UVpixels_from_distribution(pred_frame,
                                                         temperature=temperature)
            predicted_YUV_image = np.concatenate([yscale_frame, pred_UVimage], axis=2)
            predicted_RGB_image = YUV_to_RGB(predicted_YUV_image).astype("uint8")

            ax = plt.subplot(n_images, n_temperatures + 1, 1 + i * (n_temperatures + 1) + j + 1)
            plt.imshow(predicted_RGB_image)
            plt.axis('off')
            ax.set_title("prediction (temperature = {})".format(temperature))
    
    plt.show()

def display_pred_array_UV(prediction_list, yscale_list, uvtruth_list, cd):
    n_images = len(prediction_list)
    n_temperatures = len(temperatures) + 1
    plt.figure(figsize=(16, n_images * 5))
    
    for i in range(n_images):
        pred_frame = prediction_list[i]
        yscale_frame0 = yscale_list[i]
        yscale_frame = np.ones(yscale_frame0.shape) / 2.
        uvtruth_frame = uvtruth_list[i]
        
        true_YUV_image = np.concatenate([yscale_frame0, uvtruth_frame], axis=2)
        true_RGB_image = YUV_to_RGB(true_YUV_image).astype("uint8")
        
        ax = plt.subplot(n_images, n_temperatures + 1, 1 + i * (n_temperatures + 1))
        plt.imshow(true_RGB_image)
        plt.axis('off')
        ax.set_title("ground truth")

        
        true_YUV_image = np.concatenate([yscale_frame, uvtruth_frame], axis=2)
        true_RGB_image = YUV_to_RGB(true_YUV_image).astype("uint8")
        
        ax = plt.subplot(n_images, n_temperatures + 1, 1 + i * (n_temperatures + 1) + 1)
        plt.imshow(true_RGB_image)
        plt.axis('off')
        ax.set_title("ground truth UV")
        for j, temperature in enumerate(cold_temperatures):
            pred_UVimage = cd.UVpixels_from_distribution(pred_frame,
                                                         temperature=temperature)
            predicted_YUV_image = np.concatenate([yscale_frame, pred_UVimage], axis=2)
            predicted_RGB_image = YUV_to_RGB(predicted_YUV_image).astype("uint8")

            ax = plt.subplot(n_images, n_temperatures + 1, 1 + i * (n_temperatures + 1) + j + 2)
            plt.imshow(predicted_RGB_image)
            plt.axis('off')
            ax.set_title("prediction (temperature = {})".format(temperature))
    
    plt.show()
    

In [None]:
def double_image_size(in_image):
    in_shape = in_image.shape
    zoomed_image = np.zeros((in_shape[0] * 2, in_shape[1] * 2, in_shape[2]), dtype=np.uint8)
    zoomed_image[::2, ::2, :] = in_image
    zoomed_image[1::2, ::2, :] = in_image
    zoomed_image[::2, 1::2, :] = in_image
    zoomed_image[1::2, 1::2, :] = in_image
    return zoomed_image

In [None]:
def plot_movies_mp4(image_array):
    fig, ax = plt.subplots(len(image_array), figsize=(7.5, 12))
    im = {}
    for i in range(len(image_array)):
        im[i] = ax[i].imshow(image_array[i][0])
        ax[i].axis('off')

    def animate(j):
        for i in range(len(image_array)):
            im[i].set_array(image_array[i][j])
        return (im[k] for k in range(len(image_array)))

    anim = animation.FuncAnimation(fig, animate, frames=len(image_array[0]))
    display(HTML(anim.to_html5_video()))
    
def compare_pred_videos(prediction_list, prediction_list_smoothened, yscale_list, uvtruth_list, cd, 
                      temperature=1.):
    n_images = len(prediction_list)
    n_methods = 2

    true_rgb_images = []
    true_grey_images = []
    pred_rgb_images = []
    pred_smooth_rgb_images = []
    for i in range(n_images):
        pred_frame = prediction_list[i]
        pred_frame_smooth = prediction_list_smoothened[i]
        yscale_frame = yscale_list[i]
        uvtruth_frame = uvtruth_list[i]

        true_YUV_image = np.concatenate([yscale_frame, uvtruth_frame], axis=2)
        true_RGB_image = YUV_to_RGB(true_YUV_image).astype("uint8")
        true_rgb_images.append(double_image_size(true_RGB_image))
        
        true_greyscale_image = np.concatenate([yscale_frame]*3, axis=2)
        true_greyscale_image = (true_greyscale_image * 255.).astype("uint8")
        true_grey_images.append(double_image_size(true_greyscale_image))
        
        pred_UVimage = cd.UVpixels_from_distribution(pred_frame, temperature=temperature)
        predicted_YUV_image = np.concatenate([yscale_frame, pred_UVimage], axis=2)
        predicted_RGB_image = YUV_to_RGB(predicted_YUV_image).astype("uint8")
        pred_rgb_images.append(double_image_size(predicted_RGB_image))
        
        pred_UVimage = cd.UVpixels_from_distribution(pred_frame_smooth, temperature=temperature)
        predicted_YUV_image = np.concatenate([yscale_frame, pred_UVimage], axis=2)
        predicted_RGB_image = YUV_to_RGB(predicted_YUV_image).astype("uint8")
        pred_smooth_rgb_images.append(double_image_size(predicted_RGB_image))
    
    plot_movies_mp4([true_rgb_images, true_grey_images, pred_rgb_images, pred_smooth_rgb_images])
    
def dump_pred_videos(prediction_list, prediction_list_smoothened, yscale_list, uvtruth_list, cd, 
                      temperature=1.):
    n_images = len(prediction_list)
    n_methods = 2

    for i in tqdm_notebook(range(n_images)):
        pred_frame = prediction_list[i]
        pred_frame_smooth = prediction_list_smoothened[i]
        yscale_frame = yscale_list[i]
        uvtruth_frame = uvtruth_list[i]

        true_YUV_image = np.concatenate([yscale_frame, uvtruth_frame], axis=2)
        dump_YUV_image_to_jpg(true_YUV_image, "outputs/video/true_frame{}.png".format(i))
        
        true_greyscale_image = np.concatenate([yscale_frame]*3, axis=2)
        true_greyscale_image = (true_greyscale_image * 255.).astype("uint8")
        imsave(fname="outputs/video/greyscale_frame{}.png".format(i), arr=true_greyscale_image,format = 'png')
        
        pred_UVimage = cd.UVpixels_from_distribution(pred_frame, temperature=temperature)
        predicted_YUV_image = np.concatenate([yscale_frame, pred_UVimage], axis=2)
        dump_YUV_image_to_jpg(predicted_YUV_image, "outputs/video/predicted_frame{}.png".format(i))
        
        pred_UVimage = cd.UVpixels_from_distribution(pred_frame_smooth, temperature=temperature)
        predicted_YUV_image = np.concatenate([yscale_frame, pred_UVimage], axis=2)
        dump_YUV_image_to_jpg(predicted_YUV_image, "outputs/video/predicted_smooth_frame{}.png".format(i))
    

In [None]:
prediction_list, yscale_list, uvtruth_list = output_prediction(image_paths[::10], cd)

In [None]:
display_pred_array_UV(prediction_list[15::50], yscale_list[15::50], uvtruth_list[15::50], cd)

### Smoothing the predictions

In [None]:
w = np.exp(.2 * np.arange(10))
w = w / sum(w)
filter_size = len(w)
plt.plot(range(-filter_size, 0), w)

In [None]:
prediction_list_smoothened = smoothen_frame_list(prediction_list, conv_weights=w)

In [None]:
#dump_pred_videos(prediction_list[filter_size-1:], prediction_list_smoothened, 
#                   yscale_list[filter_size-1:], uvtruth_list[filter_size-1:], cd)

In [None]:
compare_pred_videos(prediction_list[filter_size-1::5], prediction_list_smoothened[::5], 
                   yscale_list[filter_size-1::5], uvtruth_list[filter_size-1::5], cd)

## Histograms of predicted colors

In [None]:
from collections import Counter

In [None]:
cpred = Counter()
for predicted_frame in prediction_list:    
    cpred.update(np.argmax(predicted_frame, axis=2).ravel())

In [None]:
ctruth = Counter()
for uvtruth_frame in uvtruth_list:
    ctruth.update(cd.categorize(uvtruth_frame).ravel())

In [None]:
def plot_counter(c, title=""):
    values = [c.get(k, 0) for k in range(33)]
    values = [v / sum(values) for v in values]
    indexes = np.arange(33)
    width = .35

    plt.bar(indexes, values, width, color='g')
    plt.xticks(indexes, indexes)
    plt.ylabel("frequency", fontsize=12)
    plt.title(title, fontsize=17)

    
def plot_counters(c1, c2, title=""):
    values1 = [c1.get(k, 0) for k in range(33)]
    values1 = [v / sum(values1) for v in values1]
    values2 = [c2.get(k, 0) for k in range(33)]
    values2 = [v / sum(values2) for v in values2]
    indexes = np.arange(33)
    width = .35

    rects1 = plt.bar(indexes, values1, width)
    rects2 = plt.bar(indexes + width, values2, width)
    
    plt.legend((rects1[0], rects2[0]), ('Truth', 'Prediction'), fontsize=17)
    plt.xticks(indexes + width / 2, indexes)
    plt.ylabel("frequency", fontsize=12)
    plt.title(title, fontsize=17)


In [None]:
plt.figure(figsize=(12, 9))
plt.subplot(211)
plot_counters(ctruth, cpred, title="Color bin frequencies from sample images")
plt.subplot(212)
plot_counter(cd.category_frequency, title="Color bin frequencies of the Color Discretizer")

Merge both histograms for poster.

In [None]:
def plot_counters(c1, c2, cback, title=""):
    values1 = [c1.get(k, 0) for k in range(33)]
    values1 = [v / sum(values1) for v in values1]
    values2 = [c2.get(k, 0) for k in range(33)]
    values2 = [v / sum(values2) for v in values2]
    indexes = np.arange(33)


    
    values = [cback.get(k, 0) for k in range(33)]
    values = [v / sum(values) for v in values]
    width = .9
    back = plt.bar(indexes, values, width, color='g', alpha = .3)

    width = .35
    rects1 = plt.bar(indexes, values1, width)
    rects2 = plt.bar(indexes + width, values2, width)


    plt.legend((rects1[0], rects2[0], back[0]), ('Truth', 'Prediction', 'Dataset'), fontsize=17)
    plt.xticks(indexes + width / 2, indexes)
    plt.ylabel("frequency", fontsize=12)
    
    
    plt.title(title, fontsize=17)


In [None]:
plt.figure(figsize=(12, 5))
plot_counters(ctruth, cpred, cd.category_frequency, title="Color bin frequencies from sample images")