In [None]:
%load_ext autoreload
%autoreload 2

import glob
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from skimage import io, transform, color
from colorize import network, util, dataset

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

print('Device:', device)

In [None]:
images = glob.glob("data/**/*.jpg")[:1000]

# Show some random images
for path in np.random.choice(images, 3, replace=False):
    fig = plt.figure()
    
    # Show original
    L, ab = util.imread(path)
    fig.add_subplot(1, 2, 1)
    util.imshow(L, ab)
    
    # Show encoded/decoded
    Y = transform.resize(ab, (56, 56))
    Z = util.soft_encode(Y)
    Y_decoded = util.decode(Z)
    Y_decoded = transform.resize(Y_decoded, (224, 224))
    fig.add_subplot(1, 2, 2)
    util.imshow(L, Y_decoded)

In [None]:
params = {'batch_size': 32,
          'shuffle': True,
          'num_workers': 6}
epochs = 5

data = dataset.Dataset(images)
dataloader = torch.utils.data.DataLoader(data, **params)

net = network.Network()
net.to(device)
util.w = util.w.to(device)

optimizer = torch.optim.Adam(net.parameters())

for epoch in range(epochs):
    running_loss = 0.0
    for batch, (X, Z) in enumerate(dataloader, start=1):
        X, Z = X.to(device), Z.to(device)
        optimizer.zero_grad()
        
        Z_hat = net(X)
        loss = util.multinomial_cross_entropy_loss(Z_hat, Z)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()

    print(f'Epoch: {epoch + 1}/{epochs}, Loss: {running_loss / batch}')
    torch.save(net.state_dict(), f'models/model_{epoch}.pth')

In [None]:
# Colorize a batch
for X, Z in dataloader:
    X = X.to(device)
    Z_hat = net(X)
    
    X = X.cpu().data.numpy()
    Z_hat = Z_hat.cpu().data.numpy()
    
    for i, _ in enumerate(X):
        X_img = util.reshape(X[i:i+1], 3)[0] * 50 + 50
        Z_hat_img = util.reshape(Z_hat[i:i+1], 3)[0]
        Y_img = util.decode(Z_hat_img)
        Y_img = transform.resize(Y_img, (224, 224))
        plt.figure()
        util.imshow(X_img, Y_img)
        plt.show()
    break