In [35]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import os
import sys
sys.path.append('../src/')
sys.path.append('../problems/')
sys.path.append('../scripts/')
from evaluation import Evaluation, EvaluationDataset
import models
import generation_utils
import tokenizer
import data_utils
import metrics_utils
from utils import get_best_checkpoint
from train_model import get_loaders
from tqdm.auto import tqdm
import itertools
from optimization_utils import test_on_loader

In [36]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [37]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [38]:
pd.set_option('display.max_colwidth', 999)
pd.set_option('display.max_rows', 9999)

In [39]:
device = torch.device('cuda')

In [44]:
base_path = '../models/evaluation/v3/'

In [45]:
os.listdir(base_path)

['checkpoints', 'config.yaml', 'loss_hist.csv']

In [46]:
checkpoint = get_best_checkpoint(base_path)
# checkpoint = torch.load(checkpoint)

Loading model at ../models/evaluation/v3/checkpoints/900000_0.3911.pt


In [47]:
args = checkpoint['args']

In [48]:
for key in ['train', 'test', 'oos']:
    args['data'][f'{key}_path'] = '../' + args['data'][f'{key}_path']

In [49]:
problem = Evaluation(args)

In [50]:
train_loader, test_loader, oos_loader = get_loaders(problem)

Loading data...


In [51]:
t = problem.get_tokenizer()

In [52]:
args.keys()

dict_keys(['data', 'problem_type', 'model_args', 'optimizer', 'scheduler', 'loader', 'io', 'metrics', 'verbose', 'resume_training', 'overwrite', 'tokenizer'])

In [53]:
model = models.Seq2SeqModel(n_tokens = args['tokenizer']['n_tokens'], 
                          pad_token_id = args['tokenizer']['pad_token_id'],
                          **args['model_args'])
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

Seq2SeqModel(
  (src_embedding): TransformerEmbedding(
    (embedding): Embedding(39, 128)
  )
  (tgt_embedding): TransformerEmbedding(
    (embedding): Embedding(39, 128)
  )
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.05, inplace=False)
  )
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-3): 4 x TransformerEncoderLayer(
          (self_attn): MultiHeadRelativeAttention(
            (w_q): Linear(in_features=128, out_features=128, bias=False)
            (w_k): Linear(in_features=128, out_features=128, bias=False)
            (w_v): Linear(in_features=128, out_features=128, bias=False)
            (out_proj): Linear(in_features=128, out_features=128, bias=False)
            (pe_mod): PositionalEncoding(
              (dropout): Dropout(p=0.05, inplace=False)
            )
            (w_k_pos): Linear(in_features=128, out_features=128, bias=False)
          )
          (linear1): Linear(in_features=128

In [54]:
np.sum([np.prod(p.size()) for p in model.parameters()])

1215655

In [55]:
len(test_loader.dataset)

10000

In [65]:
test_df, test_metrics = problem.compute_metrics(model, device, test_loader, save=False, n_beams=64)

TODO FIX ME!!! FIGURE OUT WHAT ARGS TO CALL WITH!!!


  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

In [66]:
test_df

Unnamed: 0,expression,value,model_input,beam_idx,log_prob,output_toks,pred,correct_value,log_prob_decile
0,60*(248—236),720,"[[SOS], 2, 0, *, (, 8, 8, —, 7, 26, ), [EOS]]",0,-0.004322,"[[SOS], 24, 0, [EOS], [PAD]]",720.0,True,"(-4.215, 0.0]"
1,60*(248—236),720,"[[SOS], 2, 0, *, (, 8, 8, —, 7, 26, ), [EOS]]",1,-5.467641,"[[SOS], 28, 0, [EOS], [PAD]]",840.0,False,"(-5.645, -4.215]"
2,60*(248—236),720,"[[SOS], 2, 0, *, (, 8, 8, —, 7, 26, ), [EOS]]",2,-9.214662,"[[SOS], 24, 0, [EOS], 0]",720.0,True,"(-9.964, -9.213]"
3,60*(248—236),720,"[[SOS], 2, 0, *, (, 8, 8, —, 7, 26, ), [EOS]]",3,-9.214662,"[[SOS], 24, 0, [EOS], 1]",720.0,True,"(-9.964, -9.213]"
4,60*(248—236),720,"[[SOS], 2, 0, *, (, 8, 8, —, 7, 26, ), [EOS]]",4,-9.214662,"[[SOS], 24, 0, [EOS], 2]",720.0,True,"(-9.964, -9.213]"
...,...,...,...,...,...,...,...,...,...
639995,111+167—222,56,"[[SOS], 3, 21, +, 5, 17, —, 7, 12, [EOS]]",59,-10.267776,"[[SOS], 1, 20, [EOS], 20]",50.0,False,"(-12.854, -9.964]"
639996,111+167—222,56,"[[SOS], 3, 21, +, 5, 17, —, 7, 12, [EOS]]",60,-10.267776,"[[SOS], 1, 20, [EOS], 21]",50.0,False,"(-12.854, -9.964]"
639997,111+167—222,56,"[[SOS], 3, 21, +, 5, 17, —, 7, 12, [EOS]]",61,-10.267776,"[[SOS], 1, 20, [EOS], 22]",50.0,False,"(-12.854, -9.964]"
639998,111+167—222,56,"[[SOS], 3, 21, +, 5, 17, —, 7, 12, [EOS]]",62,-10.267776,"[[SOS], 1, 20, [EOS], 23]",50.0,False,"(-12.854, -9.964]"


In [67]:
test_metrics

{'correct_value': 0.9213663227486727,
 'beam_accuracy': {'correct_value': {0: 0.5734,
   1: 0.2913,
   2: 0.3833,
   3: 0.4294,
   4: 0.4401,
   5: 0.4496,
   6: 0.4512,
   7: 0.4544,
   8: 0.4562,
   9: 0.4592,
   10: 0.463,
   11: 0.4607,
   12: 0.4623,
   13: 0.4615,
   14: 0.4618,
   15: 0.4615,
   16: 0.4612,
   17: 0.4622,
   18: 0.4621,
   19: 0.4623,
   20: 0.4615,
   21: 0.4624,
   22: 0.4605,
   23: 0.4619,
   24: 0.4612,
   25: 0.4613,
   26: 0.4615,
   27: 0.4618,
   28: 0.4607,
   29: 0.4611,
   30: 0.46,
   31: 0.4611,
   32: 0.4618,
   33: 0.4608,
   34: 0.4614,
   35: 0.4612,
   36: 0.4614,
   37: 0.4609,
   38: 0.2341,
   39: 0.1154,
   40: 0.0651,
   41: 0.0512,
   42: 0.0412,
   43: 0.0381,
   44: 0.0359,
   45: 0.0329,
   46: 0.0335,
   47: 0.0309,
   48: 0.0319,
   49: 0.0311,
   50: 0.0306,
   51: 0.0303,
   52: 0.0306,
   53: 0.03,
   54: 0.0297,
   55: 0.0293,
   56: 0.0289,
   57: 0.0291,
   58: 0.0299,
   59: 0.0297,
   60: 0.0293,
   61: 0.0298,
   62: 0.03,


In [None]:
sys.exit()

In [None]:
sys.exit()

In [None]:
pairs = np.array(list(itertools.permutations(list(range(400)), 2)))
sample_pct = .01
pairs = pairs[np.random.rand(pairs.shape[0]) < sample_pct]
print(pairs.shape[0])
n_beams = 5

random_addition_df = problem.compute_metrics(model, device, problem.get_dataset(pairs), save=False, n_beams=5, max_samples=-1)

### Explore the model a little

In [None]:
embeddings = model.src_embedding.embedding.weight.data.cpu().numpy()

In [None]:
tokens = [''.join(t.decode([i], decode_special=True)) for i in range(len(t))]
special_tokens = set(problem.special_tokens)
tokens = np.array([tok if tok in special_tokens else data_utils.base2dec([int(tok)], args['data']['base']) for tok in tokens])

In [None]:
embeddings.shape

In [None]:
tokens.shape

In [None]:
tokens

In [None]:
from sklearn.manifold import TSNE

In [None]:
tsne = TSNE()
embeddings_for_plot = tsne.fit_transform(embeddings)

In [None]:
plt.scatter(embeddings_for_plot[:,0], embeddings_for_plot[:,1])
ax = plt.gca()
for tok, (x,y) in zip(tokens, embeddings_for_plot):
    fontsize = 12 if not tok in ['.', '_'] else 24
    ax.annotate(tok, (x+.3,y), fontsize=fontsize)
plt.show()

## Cosine Similarity
* Some embeddings have relatively simlar cosine similarities

In [None]:
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
cs_sim_mat = cosine_similarity(embeddings)
cs_sims = np.triu(cs_sim_mat, 1).ravel()
cs_sims = cs_sims[~np.isclose(cs_sims, 0)]

In [None]:
plt.title('cosine simiarlty of embeddings')
pd.Series(cs_sims).hist()
plt.show()

In [None]:
fig = plt.gcf()
fig.set_size_inches(20, 20)
ax = plt.gca()
plt.colorbar(ax.matshow(np.clip(cs_sim_mat, a_min=-1, a_max=cs_sim_mat[cs_sim_mat<.99].max()*1.1)), ax=ax)
ax.set_xticks(np.arange(cs_sim_mat.shape[0]))
ax.set_yticks(np.arange(cs_sim_mat.shape[0]))
ax.set_xticklabels(tokens[:cs_sim_mat.shape[0]])
ax.set_yticklabels(tokens[:cs_sim_mat.shape[0]])
plt.show()

## See what the attention looks at

In [None]:
example_row = problem.form_prediction_df(model, device, problem.get_dataset([[1,1]]), args['model_args']['max_decode_size'], n_beams=1, temperature=1.)

In [None]:
example_row

In [None]:
input = t.encode(problem.form_input(example_row['n1'].iloc[0], example_row['n2'].iloc[0], args['data']['base']))
# tgt = t.encode(data_utils.dec2base(example_row['pred_num'].iloc[0].replace('_', '').strip().split(' ')))
tgt = t.encode(problem.form_label(example_row['pred_num'].iloc[0], 0, args['data']['base']))

In [None]:
input = torch.tensor(input).unsqueeze(0).to(device)
tgt = torch.tensor(tgt).unsqueeze(0).to(device)

In [None]:
with torch.no_grad():
    (memory, encoder_attn_weights), memory_key_padding_mask = model.encode(input, need_weights=True)

In [None]:
memory.size(), memory_key_padding_mask.size(), encoder_attn_weights.size()

In [None]:
res, mem_attn, self_attn = model.decode(tgt, memory.repeat(1, tgt.size(0), 1), memory_key_padding_mask.repeat(tgt.size(0), 1), return_enc_dec_attn=True)

In [None]:
res.size(), mem_attn.size(), self_attn.size()

In [None]:
mem_label = t.decode(input[0].data.cpu().numpy().tolist(), decode_special=True).split(' ')
tgt_label_attended_to = t.decode(tgt[0].data.cpu().numpy().tolist(), decode_special=True).split(' ')
tgt_label_attended_for = t.decode(tgt[0].data.cpu().numpy().tolist(), decode_special=True).split(' ')

In [None]:
mem_label

In [None]:
tgt_label_attended_to

In [None]:
tgt_label_attended_for

In [None]:
import matplotlib as mpl
def show_attn(fig, ax, matrix, attn_type, title):
    ax.set_title(title)
    
    ax.set_ylabel('Predicting the next token')
    ax.set_xlabel('Attending to this token')
    
    if attn_type=='encoder_self':
        ax.set_yticks(np.arange(len(mem_label)))
        ax.set_yticklabels(labels=mem_label, fontsize=16)
        
        ax.set_xticks(np.arange(len(mem_label)))
        ax.set_xticklabels(labels=mem_label, fontsize=16)
    elif attn_type=='decoder_self':
        ax.set_yticks(np.arange(len(tgt_label_attended_for)))
        ax.set_yticklabels(labels=tgt_label_attended_for, fontsize=16)
        
        ax.set_xticks(np.arange(len(tgt_label_attended_to)))
        ax.set_xticklabels(labels=tgt_label_attended_to, fontsize=16)
    elif attn_type=='mem':
        ax.set_yticks(np.arange(len(tgt_label_attended_for)))
        ax.set_yticklabels(labels=tgt_label_attended_for, fontsize=16)
        
        ax.set_xticks(np.arange(len(mem_label)))
        ax.set_xticklabels(labels=mem_label, fontsize=16)
    else:
        raise ValueError(f'attn type {attn_type} not understood')

    
    
    im = ax.imshow(matrix, cmap='Blues')
    fig.colorbar(im, ax=ax)

    
#     if 'self' in attn_type:
#         fig.set_size_inches(7,7)
#     else:
#         fig.set_size_inches(4,7)

In [None]:
mem_attn.size()

In [None]:
fig, ax = plt.subplots(encoder_attn_weights.size(0), encoder_attn_weights.size(1))

for i in range(encoder_attn_weights.size(0)):
    for j in range(encoder_attn_weights.size(1)):
        title = '%s Layer: %d Head: %d'%('Encoder Self Attention', i,j)
        show_attn(fig, ax[i,j], encoder_attn_weights[i][j].data.cpu().numpy(), 'encoder_self', title)
fig.set_size_inches(36,36)
fig.tight_layout()

In [None]:
fig, ax = plt.subplots(mem_attn.size(0), mem_attn.size(1))

for i in range(mem_attn.size(0)):
    for j in range(mem_attn.size(1)):
        title = '%s Layer: %d Head: %d'%('Mem', i,j)
        show_attn(fig, ax[i,j], mem_attn[i][j].data.cpu().numpy(), 'mem', title)
fig.set_size_inches(36,36)
fig.tight_layout()

In [None]:
fig, ax = plt.subplots(self_attn.size(0), self_attn.size(1))

for i in range(self_attn.size(0)):
    for j in range(self_attn.size(1)):
        title = '%s AttentionLayer: %d Head: %d'%('Self', i,j)
        show_attn(fig, ax[i,j], np.clip(self_attn[i][j].data.cpu().numpy(), a_min=0, a_max=.6), 'decoder_self', title)
fig.set_size_inches(36,36)
fig.tight_layout()