# Colorization
Class project - CS231N - Stanford University

Vincent Billaut  
Matthieu de Rochemonteix  
Marc Thibault  

See our GitHub [repo](https://github.com/vincentbillaut/all-colors-matter) for more details on the implementation.

## Imports

In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline

from matplotlib.pyplot import imread

####################################################
# Setting working directory to enable relative paths
import os
os.chdir("../")
####################################################

from models.coloringmodel import Config
from models.naive_convnet import NaiveConvColoringModel
from utils.dataset import Dataset
from utils.color_utils import RGB_to_YUV, YUV_to_RGB
from utils.data_utils import load_image_jpg_to_YUV
from utils.color_discretizer import ColorDiscretizer
from utils.data_augmentation import DataAugmenter

## Load existing model and visualize outputs per temperature

In [None]:
# Enter output folder name
output_folder = "outputs/20180605_074014-ab2b/"

In [None]:
from models.unet import UNetColoringModel

In [None]:
config = Config("configs/config_unet_suncoast.json")
cd = ColorDiscretizer(max_categories=config.max_categories)
cd.train(config.cd_train_path, 1000)
da = DataAugmenter()
dataset = Dataset(config.train_path, config.val_path, cd, da)
model = UNetColoringModel(config, dataset)

In [None]:
model.load(output_folder)

In [None]:
image_paths = [os.path.join(config.val_path, impath) for impath in os.listdir(config.val_path)]
image_paths_insample = [os.path.join(config.train_path, impath) for impath in os.listdir(config.train_path)]

In [None]:
%matplotlib inline

In [None]:
temperature =  1.

def display_confidence(image_paths, cd):
    n_images = len(image_paths)
    n_methods = 4
    plt.figure(figsize=(16, n_images * 6))
    
    for i, image_path in enumerate(image_paths):
        loss, pred_image_categories, (im_yscale, im_uvscale, msk) = model.pred_color_one_image(image_path)
        mask_shape = [0, 0]
        if msk[:, 0].mean() == 1.:
            mask_shape[0] = msk.shape[1]
        else:
            mask_shape[0] = np.argmin(msk[:, 0])
        if msk[0, :].mean() == 1.:
            mask_shape[1] = msk.shape[0]
        else:
            mask_shape[1] = np.argmin(msk[0, :])
        
        cropped_pred_image_categories = pred_image_categories[0, :mask_shape[0], :mask_shape[1], :]
        cropped_im_yscale = im_yscale[:mask_shape[0], :mask_shape[1]]
        cropped_im_uvscale = im_uvscale[:mask_shape[0], :mask_shape[1], :]

        true_YUV_image = np.concatenate([cropped_im_yscale, cropped_im_uvscale], axis=2)
        true_RGB_image = YUV_to_RGB(true_YUV_image).astype("uint8")
        ax = plt.subplot(n_images, n_methods + 1, 1 + i * (n_methods + 1))
        plt.imshow(true_RGB_image)
        plt.axis('off')
        ax.set_title("ground truth")

        pred_UVimage = cd.UVpixels_from_distribution(cropped_pred_image_categories,
                                                     temperature=temperature)
        predicted_YUV_image = np.concatenate([cropped_im_yscale, pred_UVimage], axis=2)
        predicted_RGB_image = YUV_to_RGB(predicted_YUV_image).astype("uint8")

        ax = plt.subplot(n_images, n_methods + 1, 1 + i * (n_methods + 1)  + 1)
        plt.imshow(predicted_RGB_image)
        plt.axis('off')
        ax.set_title("prediction (temperature = {})".format(temperature))
        
        ax = plt.subplot(n_images, n_methods + 1, 1 + i * (n_methods + 1) +1 + 1)
        confidence = cropped_pred_image_categories.max(axis = 2)
        plt.imshow(confidence, cmap=plt.cm.hot)
        plt.axis('off')
        ax.set_title("max score")
        plt.colorbar()
        
        ax = plt.subplot(n_images, n_methods + 1, 1 + i * (n_methods + 1) +2 + 1)
        score_max = cropped_pred_image_categories-cropped_pred_image_categories.max(axis = 2,keepdims = True)
        score_max[score_max>=0]=-10
        score_max+=cropped_pred_image_categories.max(axis = 2,keepdims = True)
        confidence = cropped_pred_image_categories.max(axis = 2)-score_max.max(axis = 2)
        plt.imshow(confidence, cmap=plt.cm.hot)
        plt.axis('off')
        ax.set_title("max score - second")
        plt.colorbar()
        
        ax = plt.subplot(n_images, n_methods + 1, 1 + i * (n_methods + 1) +3 + 1)

        confidence = cropped_pred_image_categories.max(axis = 2)/score_max.max(axis = 2)
        plt.imshow(confidence, cmap=plt.cm.hot)
        plt.axis('off')
        ax.set_title("max score / second")
        plt.colorbar()
    plt.show()
    

In [None]:
display_confidence(image_paths[70:75], cd)

In [None]:
SAVE_PATH = "chosen_outputs/"

In [None]:
from PIL import Image
for imname in os.listdir(SAVE_PATH):
    if imname.endswith("png"):
        imname_strip = imname.rstrip(".png")
        im = Image.open(os.path.join(SAVE_PATH,imname))
        rgb_im = im.convert('RGB')
        rgb_im.save(os.path.join(SAVE_PATH,imname_strip+".jpg"))

In [None]:

def export_confidence(image_path, cd,save_path = SAVE_PATH):    
    n_images = 1
    n_methods = 3
    imname = image_path.split("/")[-1].split(".")[0]
    print("saving {}".format(imname))
    
    i = 0
    fig = plt.figure(figsize=(20, 6))
    loss, pred_image_categories, (im_yscale, im_uvscale, msk) = model.pred_color_one_image(image_path)
    mask_shape = [0, 0]
    if msk[:, 0].mean() == 1.:
        mask_shape[0] = msk.shape[1]
    else:
        mask_shape[0] = np.argmin(msk[:, 0])
    if msk[0, :].mean() == 1.:
        mask_shape[1] = msk.shape[0]
    else:
        mask_shape[1] = np.argmin(msk[0, :])

    cropped_pred_image_categories = pred_image_categories[0, :mask_shape[0], :mask_shape[1], :]
    cropped_im_yscale = im_yscale[:mask_shape[0], :mask_shape[1]]
    cropped_im_uvscale = im_uvscale[:mask_shape[0], :mask_shape[1], :]

    true_YUV_image = np.concatenate([cropped_im_yscale, cropped_im_uvscale], axis=2)
    true_RGB_image = YUV_to_RGB(true_YUV_image).astype("uint8")

    ax = fig.add_subplot(n_images, n_methods + 1, 1 + i * (n_methods + 1))
    ax.imshow(true_RGB_image)
    ax.axis('off')
    ax.set_title("ground truth")

    pred_UVimage = cd.UVpixels_from_distribution(cropped_pred_image_categories,
                                                 temperature=temperature)
    predicted_YUV_image = np.concatenate([cropped_im_yscale, pred_UVimage], axis=2)
    predicted_RGB_image = YUV_to_RGB(predicted_YUV_image).astype("uint8")

    ax = fig.add_subplot(n_images, n_methods + 1, 1 + i * (n_methods + 1)  + 1)
    ax.imshow(predicted_RGB_image)
    ax.axis('off')
    ax.set_title("prediction (temperature = {})".format(temperature))

    ax = fig.add_subplot(n_images, n_methods + 1, 1 + i * (n_methods + 1) +1 + 1)
    confidence = cropped_pred_image_categories.max(axis = 2)
    z = ax.imshow(confidence, cmap=plt.cm.hot)
    ax.axis('off')
    ax.set_title("max score")
    
    plt.colorbar(z,ax=ax,fraction=0.0330, pad=0.01)

#     ax = fig.add_subplot(n_images, n_methods + 1, 1 + i * (n_methods + 1) +2 + 1)
    score_max = cropped_pred_image_categories-cropped_pred_image_categories.max(axis = 2,keepdims = True)
    score_max[score_max>=0]=-10
    score_max+=cropped_pred_image_categories.max(axis = 2,keepdims = True)
#     confidence = cropped_pred_image_categories.max(axis = 2)-score_max.max(axis = 2)
#     z = ax.imshow(confidence, cmap=plt.cm.hot)
#     ax.axis('off')
#     ax.set_title("max score - second")
#     plt.colorbar(z,ax=ax,fraction=0.046, pad=0.04)

    ax = plt.subplot(n_images, n_methods + 1, 1 + i * (n_methods + 1) +2 + 1)

    confidence = cropped_pred_image_categories.max(axis = 2)/score_max.max(axis = 2)
    z = ax.imshow(confidence, cmap=plt.cm.hot)
    ax.axis('off')
    ax.set_title("max score / second")
    plt.colorbar(z,ax=ax,fraction=0.0330, pad=0.01)
    fig.savefig(fname =  os.path.join(save_path,imname+"_maps.png"), dpi = 600)


In [None]:
export_confidence(image_paths[1402],cd)