In [None]:
%load_ext autoreload
%autoreload 2

import sys
import glob
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from skimage import io, transform, color
from colorize import network, util

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True
print('Using device:', device)

In [None]:
with open('data/train.txt', 'r') as f:
    train_images = f.read().splitlines()[:3]

with open('data/test.txt', 'r') as f:
    test_images = f.read().splitlines()

# Show some random images
for path in np.random.choice(train_images, 3, replace=False):
    L, ab = util.imread(path)
    
    # Show encoded/decoded
    Y = transform.resize(ab, (56, 56))
    Z = util.soft_encode(Y)
    Y_decoded = util.decode(Z)
    Y_decoded = transform.resize(Y_decoded, util.input_size)
    
    util.side_by_side((L, ab, 'Original'), (L, Y_decoded, 'Encoded/decoded'))

In [None]:
# Train model
net = network.Network()
net(torch.zeros((1, 1, *util.input_size)), summary=True)
util.train(net, train_images, device, epochs=15)

In [None]:
# Load trained model
net = network.Network()
net.load_state_dict(torch.load('models/model_15_full.pth', map_location=device))

In [None]:
# Colorize the validation images
to_color = np.random.choice(test_images, 10, replace=False)
util.colorize_images(net, to_color, device)