In [None]:
%load_ext autoreload

In [None]:
%autoreload

import sys
sys.path.append('..')

import torch
import torch.utils.data as D

import os
import pprint
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np

import color.models as models
import color.models.model_utils as model_utils
import color.models.predict_name as predict_name_models
import color.data.dataset as color_dataset
import color.training as training
import color.utils.utils as utils

In [None]:
save_dir = '../trained_models/hp/predict_name_seq_hp2_stop_word/rnn_003'
model_class, _ = models.get_model('predict_name_rnn')
assert os.path.isdir(save_dir)

In [None]:
# Training params and losses
training_params = training.load_training_params(save_dir)
pprint.pprint(training_params)

In [None]:
# Plot learning curves
epochs = np.arange(len(training_params['epoch_train_losses'])) + 1
plt.plot(epochs, training_params['epoch_train_losses'], label='Training Loss')
plt.plot(epochs, training_params['epoch_cv_losses'], label='CV Loss')
plt.legend(loc='upper right')
plt.xlabel('Epochs')
plt.ylabel('Losses')
plt.show()

In [None]:
# Re-create dataset
dataset = color_dataset.Dataset.load(save_dir)
print(dataset.params)
print('Dataset Size:', len(dataset))

In [None]:
# Load model params and weights
model_weights, model_params = model_utils.load_model_params(save_dir)
print(model_params)
print(len(model_weights))

In [None]:
# Initialize trained model
model = model_class(**model_params).double()
model.load_state_dict(model_weights)
print(model)
print('Trainable Params:', utils.get_trainable_params(model))

In [None]:
# Convenience wrapper around the predict names function
predict_names = lambda rgb: predict_name_models.predict_names(model, dataset, rgb, num_names=5, max_len=3)

In [None]:
def predict_and_plot(color_rgbs, actual_color_names=None):
    '''
    Predict names for a list of colors
    Plot the actual and predict names alongside the color
    '''
    
    for i in range(len(color_rgbs)):
            
        fig = plt.figure(figsize=(2, 2))
        ax = fig.add_subplot(111)
        size = 100
        plt.xlim([0, size])
        plt.xticks([])
        plt.ylim([0, size])
        plt.yticks([])
        p = patches.Rectangle((0,0), size, size, color=color_rgbs[i])
        ax.add_patch(p)
        plt.show()
        
        if actual_color_names is not None:
            print('Actual Color:', actual_color_names[i])
            
        color_names = predict_names(color_rgbs[i])
        print('Predicted colors:')
        for pred in color_names:
            print('{}\tScore:{}'.format(' '.join(pred.words), pred.similarity))

In [None]:
# Plot some training samples
select = 3
idx = np.arange(len(dataset.train_set))
np.random.shuffle(idx)
train_loader = iter(color_dataset.DataLoader(dataset.train_set, shuffle=True, seq_len_first=True))
colors_selected = [(color_rgb, color_name) for color_rgb, _, color_name in [next(train_loader) for _ in range(select)]]
color_rgb_selected = list(map(lambda x: x[0].view(-1).tolist(), colors_selected))
color_names_selected = list(map(lambda x: x[1][0], colors_selected))
predict_and_plot(color_rgb_selected, color_names_selected)

In [None]:
# Plot some cross-validation samples
select = 3
idx = np.arange(len(dataset.cv_set))
np.random.shuffle(idx)
train_loader = iter(color_dataset.DataLoader(dataset.train_set, shuffle=True, seq_len_first=True))
colors_selected = [(color_rgb, color_name) for color_rgb, _, color_name in [next(train_loader) for _ in range(select)]]
color_rgb_selected = list(map(lambda x: x[0].view(-1).tolist(), colors_selected))
color_names_selected = list(map(lambda x: x[1][0], colors_selected))
predict_and_plot(color_rgb_selected, color_names_selected)