In [1]:
import random
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
from torch.autograd import Variable

import torchvision.transforms as transforms
import torchvision.models as models

import glob
from PIL import Image
from skimage import color
import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import zoom

from colormath.color_objects import AdobeRGBColor, sRGBColor, LabColor
from colormath.color_conversions import convert_color

from utils import *
from image_colorization import *

%matplotlib inline

In [2]:
use_cuda = False
classification = False

In [3]:
image_list = get_images(n=100)

In [4]:
ab2cat = pickle.load(open('cache/ab2cat_10.pkl', 'rb'))
cat2ab = pickle.load(open('cache/cat2ab_10.pkl', 'rb'))
nearest_neighbors = pickle.load(open('cache/nearest_neighbors.pkl', 'rb'))
class_weights = np.loadtxt('cache/class_weights.txt')

n_spaces = len(ab2cat)

In [5]:
preprocessor = lambda label: get_smoothed_label(label, nearest_neighbors, ab2cat)
dset = ColorizationDataset(image_list, preprocessor)
loader = data.DataLoader(dset, shuffle=True, batch_size=4)

In [None]:
def train_model_class(model, optimizer, num_epochs=10, show_every=20):
    for epoch in range(num_epochs):
        print('Epoch %s' % epoch)
        print('=' * 10)
        
        running_loss = []
        for i, data in enumerate(iter(loader)):
            input, labels = data
            input, labels = Variable(input), Variable(labels)
            output = model(input)
            
            optimizer.zero_grad()
            logits = -F.log_softmax(output)
            loss_per_pixel = logits * labels
            loss_per_pixel = torch.sum(loss_per_pixel, 1)
            
            _, true_labels = torch.max(labels.data, 1)
            true_labels = true_labels.type(torch.LongTensor)
            
            b, w, h = true_labels.size()
            
            px_weights = torch.index_select(weights, 0, true_labels.view(b, -1).view(-1))
            px_weights = px_weights.view(b, -1).view(b, w, h)
                        
            loss_per_pixel = loss_per_pixel * Variable(px_weights)    
            loss = torch.sum(loss_per_pixel)
            
            loss.backward()
            optimizer.step()
            
            running_loss.append(loss.data[0])
            if show_every is not None and i % show_every == 0:
                print('Iter %s: %s' % (i, np.mean(running_loss)))
        print('Average loss: %s' % (np.mean(running_loss)))
    
    return model

In [None]:
def train_model_reg(model, optimizer, num_epochs=10, show_every=20):
    for epoch in range(num_epochs):
        print('Epoch %s' % epoch)
        print('=' * 10)
        
        running_loss = []
        for i, data in enumerate(iter(loader)):
            input, labels = data
            input, labels = Variable(input), Variable(labels)
            output = model(input)
            
            optimizer.zero_grad()
            loss = F.mse_loss(output, labels)
            
            loss.backward()
            optimizer.step()
            
            running_loss.append(loss.data[0])
            if show_every is not None and i % show_every == 0:
                print('Iter %s: %s' % (i, np.mean(running_loss)))
        print('Average loss: %s' % (np.mean(running_loss)))
    
    return model

In [None]:
model = ImageColorizer(n_classes=n_spaces)
model.apply(weights_init)
optimizer = optim.RMSprop(model.parameters())

weights = torch.FloatTensor(class_weights)

In [None]:
model = train_model(model, optimizer, num_epochs=10, show_every=None)

In [None]:
def predict_class(model):
    input, label = dset[random.choice(np.arange(len(dset)))]
    L = input.numpy() + 50
    input = Variable(input.unsqueeze(0))
    out = model(input).squeeze(0)
    _, out = torch.max(out, 0)
    out = out.data.numpy()
    plt.hist(out.reshape(-1))
    out_actual = np.zeros((2,) + out.shape)
    for i in range(out_actual.shape[1]):
        for j in range(out_actual.shape[2]):
            a, b = cat2ab[out[i, j]]
            out_actual[:, i, j] = [a + 5, b + 5]
    out_actual = zoom(out_actual, (1, 4, 4))
    pred = np.concatenate((L, out_actual), axis=0).transpose(1, 2, 0).clip(-128, 128)
    pred = color.lab2rgb(pred.astype(np.float64))
    
    plt.figure()
    plt.title('Grayscale')
    plt.imshow(L.squeeze(0), cmap='gray')
    
    plt.figure()
    plt.title('Predicted')
    plt.imshow(pred)


In [None]:
def predict_reg(model):
    input, label = dset[random.choice(np.arange(len(dset)))]
    L = input.numpy() + 50
    input = Variable(input.unsqueeze(0))
    out = model(input).squeeze(0)

    out_actual = out.data.numpy()
    out_actual = zoom(out_actual, (1, 4, 4))
    pred = np.concatenate((L, out_actual), axis=0).transpose(1, 2, 0).clip(-128, 128)
    pred = color.lab2rgb(pred.astype(np.float64))
    
    plt.figure()
    plt.title('Grayscale')
    plt.imshow(L.squeeze(0), cmap='gray')
    
    plt.figure()
    plt.title('Predicted')
    plt.imshow(pred)


In [None]:
predict(model)