In [None]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from skimage import io, transform, color
from colorize import network, util, dataset
from IPython.display import HTML

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True
print('Using device:', device)

In [None]:
# Load the video's image paths
ranges = np.load('resources/ranges.npy')
video_paths = [util.video_paths('data/dog', start, end) for _, start, end in ranges]

In [None]:
# Train model
cnn = network.CNN()
cnn.load_state_dict(torch.load('models/cnn_30_full.pth', map_location=device))
rnn = network.RNN(hidden_size=128)

util.train_rnn(rnn, cnn, video_paths, device, epochs=1, seq_len=16)

In [None]:
# Colorize a video
gt, colorized = util.colorize_video(rnn, cnn, video_paths[0], device)

In [None]:
# Show original video
HTML(util.animate(gt, fps=30).to_html5_video())

In [None]:
# Show colorized video
HTML(util.animate(colorized, fps=30).to_html5_video())