In [1]:
%load_ext autoreload

In [2]:
%autoreload

import sys
sys.path.append('..')

import torch
import torch.utils.data as D

import os
import pprint
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np

import color.models.predict_name as pred_name_models
import color.data.dataset as color_dataset
import color.training as training
import color.utils.utils as utils

In [3]:
save_dir = '../trained_models/hp/predict_name_rnn_r3_grid/rnn_base'
assert os.path.isdir(save_dir)

In [4]:
# Training params and losses
training_params = training.load_training_params(save_dir)
pprint.pprint(training_params)

{'curr_epoch': 1,
 'do_cv': True,
 'draw_plots': True,
 'epoch_cv_losses': [0.16623204999265417,
                     0.16565741479924295,
                     0.1629130101836888,
                     0.16191619265395985,
                     0.16062822595106818,
                     0.16160090623703677,
                     0.15907331483553996,
                     0.15942843614426333,
                     0.15890084325739767,
                     0.15876268707545457],
 'epoch_durations': [41.76962757110596,
                     40.77054166793823,
                     40.919513463974,
                     41.18310570716858,
                     42.378578662872314,
                     42.06674671173096,
                     41.40788984298706,
                     41.12403988838196,
                     41.59677815437317,
                     41.250587701797485],
 'epoch_train_losses': [0.1828087883063587,
                        0.16563212333120814,
                        0.163517298

In [5]:
# Dataset params and partitions
dataset_params, (color_names_train, color_names_cv, color_names_test) = color_dataset.load_dataset_params(save_dir)
pprint.pprint(dataset_params)
print(len(color_names_train), len(color_names_cv), len(color_names_test))

{'add_stop_word': True,
 'batch_size': 1,
 'cv_split': 0.1,
 'dataset': 'big',
 'emb_len': 300,
 'max_words': None,
 'normalize_rgb': True,
 'num_workers': 0,
 'pad_len': None,
 'test_split': 0,
 'use_cuda': True,
 'var_seq_len': True}
16273 1808 0


In [7]:
# Model params and weights
model_weights, model_params = pred_name_models.load_model_params(save_dir)
print(model_params)
print(len(model_weights))

{'name': 'rnn_base', 'emb_dim': 300, 'color_dim': 3, 'lr': 0.001, 'momentum': 0.91, 'weight_decay': 1e-05, 'lr_decay': (1, 0.9), 'hidden_dim': 300, 'num_layers': 3, 'dropout': 0, 'nonlinearity': 'relu'}
16


In [8]:
# Load model and pre-trained weights
# model = pred_name_models.NamePredictorLSTM(**model_params)
model = pred_name_models.NamePredictorRNN(**model_params)
model.load_state_dict(model_weights)
print(model)
print('Trainable Params:', utils.get_trainable_params(model))

NamePredictorRNN(
  (rgb2emb): Linear(in_features=3, out_features=300, bias=True)
  (rnn): RNN(300, 300, num_layers=3)
  (hidden2emb): Linear(in_features=300, out_features=300, bias=True)
)
Trainable Params: 633300


In [9]:
# Re-create partitions
dataset = color_dataset.Dataset(**dataset_params)
color_name_dict = {color_name: color_rgb for color_rgb, _, color_name in dataset}
print(len(color_name_dict))
train_set = np.array([(name, torch.Tensor.tolist(color_name_dict[name])) for name in color_names_train])
cv_set = np.array([(name, torch.Tensor.tolist(color_name_dict[name])) for name in color_names_cv])
print(train_set[0],len(train_set), len(cv_set))

Loading colors dataset
Loading embeddings
Splitting dataset
18080
['candyman' list([0.99609375, 0.6171875, 0.4609375])] 16273 1808


In [10]:
# Runs a new color name throught the model to predict its color
def predict_name(model, dataset, rgb, max_len=3, stop_word=False):
    rgb = torch.FloatTensor(rgb).view(1, 3)
    rgb = rgb / 256
    
    embs = torch.FloatTensor(dataset.embeddings)
    embs_mag = torch.sqrt(torch.sum(embs*embs, dim=1)).reshape(-1)
    
    name = []
    sims = []
    
    with torch.no_grad():
        name_generator = model.gen_name(rgb)
        next_emb = None
        for i in range(max_len):
            emb_pred = name_generator.send(next_emb).view(-1,1)
            emb_pred_mag = torch.sqrt(torch.sum(emb_pred*emb_pred))
            emb_dot = torch.mm(embs, emb_pred).view(-1)
            embs_sim = emb_dot / (embs_mag * emb_pred_mag)
        
            nearest_idx = int(torch.argmax(embs_sim))
            word, sim = dataset.vocab[nearest_idx], float(embs_sim[nearest_idx])
            if stop_word and word == 'STOP_WORD':
                print('Stop Word', i+1)
                break
            name.append(word)
            sims.append(sim)
            next_emb = dataset.embeddings[nearest_idx]
            next_emb = torch.FloatTensor(next_emb).view(1,1,-1)
        
    return name, sims

In [25]:
predict_name(model, dataset, (255, 255, 255), stop_word=True)

Stop Word 3


(['pink', 'pink'], [0.6771054863929749, 0.700982928276062])