In [1]:
import pickle
import random
import glob
import os

import torch
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

from image_colorization import *
from utils import *
from optim import *
from fusion_model import *

%matplotlib inline

In [2]:
use_cuda = False

In [3]:
class PlacesDataset(data.Dataset):
    def __init__(self, root, n=None, n_classes=None):
        super(PlacesDataset, self).__init__()
        self.categories = [os.path.basename(fname) for fname in glob.glob(os.path.join(root, '*'))]
        if n_classes is not None:
            self.categories = random.sample(self.categories, n_classes)
        
        self.n_classes = n_classes
        self.images_per_category = pickle.load(open('images_per_category', 'rb'))
        self.images_per_category = {k: self.images_per_category[k] for k in self.categories}
        if n is not None:
            for c in self.images_per_category:
                self.images_per_category[c] = random.sample(self.images_per_category[c], n)
        
        self.name_id_map = {cat: i for i, cat in enumerate(self.categories)}
        self.id_name_map = {v: k for k, v in self.name_id_map.items()}
        self.size = sum([len(v) for v in self.images_per_category.values()])
        
    def category_from_path(self, path):
        folder = os.path.basename(os.path.split(path)[0])
        return self.name_id_map[folder]
        
    def load_process_image(self, image_path):
        image = Image.open(image_path)
        image = image.resize((224, 224))
        image_lab = color.rgb2lab(np.array(image))
        image_lab = image_lab.transpose(2, 0, 1)
        input, label = image_lab[0, :, :] - 50, image_lab[1:, :, :]
        label = (label + 128) / 256
        return torch.FloatTensor(input).unsqueeze(0), torch.FloatTensor(label)
    
    def __getitem__(self, index):
        if index >= len(self):
            raise IndexError('Index %s out of range for size %s' % (index, len(self)))
        for c, images in self.images_per_category.items():
            if index >= len(images):
                index -= len(images)
            else:
                break
        image_path = self.images_per_category[c][index]
        category = self.category_from_path(image_path)
        input, ab_label = self.load_process_image(image_path)
        return input, ab_label, category
    
    def __len__(self):
        return self.size

In [4]:
def train_model(model, alpha, optimizer, loader, num_epochs=10, show_every=20):
    losses = []
    criterion_ab, criterion_class = nn.MSELoss(), nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        print('Epoch %s' % epoch)
        print('=' * 10)
        
        running_loss = []
        for i, data in enumerate(iter(loader)):
            input, label_ab, label_cat = data
            input, label_ab, label_cat = Variable(input), Variable(label_ab), Variable(label_cat)
            output_ab, output_cat = model(input, input)
            
            optimizer.zero_grad()
            loss_ab = criterion_ab(output_ab, label_ab)
            loss_class = criterion_class(output_cat, label_cat)
            loss = loss_ab + alpha * loss_class
            
            loss.backward()
            optimizer.step()
            
            running_loss.append(loss.data[0])
            if show_every is not None and i % show_every == 0:
                print('Iter %s: %s' % (i, np.mean(running_loss)))
        print('Average loss: %s' % (np.mean(running_loss)))
    
    return model

In [49]:
def predict(model, dset):
    input, label, category = dset[random.choice(np.arange(len(dset)))]
    L = input.numpy() + 50
    input = Variable(input.unsqueeze(0))
    out_ab, out_cat = model(input, input)
    out_ab, out_cat = out_ab.squeeze(0), out_cat.squeeze(0)
    _, out_cat = torch.max(out_cat, 0)
    print(dset.id_name_map[out_cat.data[0]])

    actual_ab = label.numpy() * 256 - 128
    actual = np.concatenate((L, actual_ab), axis=0).transpose(1, 2, 0)
    actual = color.lab2rgb(actual.astype(np.float64))
    
    out_ab = out_ab.data.numpy().clip(0, 1) * 256 - 128
    pred = np.concatenate((L, out_ab), axis=0).transpose(1, 2, 0)
    pred = color.lab2rgb(pred.astype(np.float64))
    
    fig, (ax1, ax2) = plt.subplots(1, 2)
    ax1.imshow(actual)
    ax2.imshow(pred)

In [60]:
dset = PlacesDataset('data/data/vision/torralba/deeplearning/images256', n=100, n_classes=5)
loader = data.DataLoader(dset, shuffle=True, batch_size=32)

In [61]:
model = FusionColorizer(dset.n_classes)
optimizer = optim.Adadelta(model.parameters())
alpha = 1/300

In [None]:
model = train_model(model, alpha, optimizer, loader, num_epochs=100, show_every=None)

Epoch 0
Average loss: 0.0167534859502
Epoch 1
Average loss: 0.010117443162
Epoch 2
Average loss: 0.00934561429312
Epoch 3
Average loss: 0.00994461291702
Epoch 4
Average loss: 0.0132441690657
Epoch 5


In [None]:
predict(model, dset)