In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import os
import sys
sys.path.append('../src/')
import models
import generation_utils
import tokenizer
import data_utils

In [2]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [3]:
pd.set_option('display.max_colwidth', 999)
pd.set_option('display.max_rows', 9999)

In [4]:
device = torch.device('cuda')

In [5]:
checkpoint_path = '../models/base_24_rerun/checkpoints/138000_0.0754.pt'

In [6]:
checkpoint = torch.load(checkpoint_path)

In [7]:
args = checkpoint['args']

In [8]:
args['data']['data_loc']

'./data/2^16.json'

In [9]:
data_utils.gfm = data_utils.GlobalFactorMapping(data_path = '.' + args['data']['data_loc'] if args['data']['data_loc'].endswith('.json') else \
                                          args['data']['data_loc'] + '2^%d.json'%args['data']['max_pow'])

In [10]:
t = tokenizer.Tokenizer(base = args['data']['base'])

In [11]:
args.keys()

dict_keys(['data', 'model_args', 'optimizer', 'scheduler', 'loader', 'io', 'metrics', 'verbose', 'tokenizer'])

In [12]:
model = models.Factorizer(n_tokens = args['tokenizer']['n_tokens'], 
                          pad_token_id = args['tokenizer']['pad_token_id'],
                          **args['model_args'])
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

Factorizer(
  (embedding): TransformerEmbedding(
    (embedding): Embedding(28, 128)
    (pe): PositionalEncoding(
      (dropout): Dropout(p=0.05, inplace=False)
    )
  )
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): _LinearWithBias(in_features=128, out_features=128, bias=True)
          )
          (linear1): Linear(in_features=128, out_features=512, bias=True)
          (dropout): Dropout(p=0.05, inplace=False)
          (linear2): Linear(in_features=512, out_features=128, bias=True)
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.05, inplace=False)
          (dropout2): Dropout(p=0.05, inplace=False)
        )
        (1): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            

In [13]:
import importlib
importlib.reload(generation_utils)

<module 'generation_utils' from '../src\\generation_utils.py'>

In [16]:
example_row = generation_utils.factor(1522605027922533360535618378132637429718068114961380, args['data']['base'], model, t, device, args['model_args']['max_decode_size'], n_beams = 1)

In [17]:
example_row

Unnamed: 0,target_num,target_is_prime,input_string,pred_list,pred_str,beam_idx,log_prob,target_str,target_factor_list,n_target_factors,pred_factor_list,n_pred_factors,product,correct_product,correct_factorization,num_prime_factors_pred,percent_prime_factors_pred,pred_same_as_target
0,1522605027922533360535618378132637429718068114961380,,"[1, 7, 6, 5, 11, 20, 13, 3, 19, 22, 12, 21, 18, 11, 9, 21, 17, 16, 8, 21, 1, 21, 21, 4, 3, 14, 19, 12, 11, 21, 10, 7, 2, 22, 22, 21, 6, 20]","[27, 1, 17, 24, 23, 19, 26]",> 1 17 x 23 19 .,0,-1.488945,,,0,"[41, 571]",2,23411,False,False,2,1.0,False


In [39]:
example_row['correct_product'].any()

True

In [40]:
example_row['correct_factorization'].any()

True

In [16]:
input = t.encode(data_utils.form_input(data_utils.dec2base(example_row['target_num'].iloc[0], args['data']['base'])))
tgt = t.encode(example_row['pred_str'].iloc[0].replace('_','').strip().split(' '))

In [None]:
input, tgt

In [None]:
input = torch.tensor(input).unsqueeze(0).to(device)
tgt = torch.tensor(tgt).unsqueeze(0).to(device)

In [None]:
with torch.no_grad():
    memory, memory_key_padding_mask = model.encode(input)

In [None]:
res, mem_attn, self_attn = model.decode(tgt, memory.repeat(1, tgt.size(0), 1), memory_key_padding_mask.repeat(tgt.size(0), 1), return_enc_dec_attn=True)

In [None]:
# res: Batch size x # tokens in tgt x # possible tokens
# mem_attn: # layers x # heads x # tokens in tgt # num tokens in memory
# self_attn: # layers x # heads x # tokens in tgt x # tokens in tgt
res.size(), mem_attn.size(), self_attn.size()

In [None]:
mem_label = t.decode(input[0].data.cpu().numpy().tolist(), decode_special=True).split(' ')
tgt_label_attended_to = t.decode(tgt[0].data.cpu().numpy().tolist(), decode_special=True).split(' ')
tgt_label_attended_for = t.decode(tgt[0].data.cpu().numpy().tolist(), decode_special=True).split(' ')

In [None]:
mem_label

In [None]:
tgt_label_attended_to

In [None]:
tgt_label_attended_for

In [None]:
mem_attn[0][0].size()

In [None]:
import matplotlib as mpl
def show_attn(fig, ax, matrix, self_or_mem, title):
    ax.set_title(title)
    ax.set_yticks(np.arange(len(tgt_label_attended_for)))
    ax.set_yticklabels(labels=tgt_label_attended_for, fontsize=16)
    ax.set_ylabel('Predicting this token')
    ax.set_xlabel('Attending to this token')
    
    if self_or_mem=='self':
        other_label = tgt_label_attended_to
    elif self_or_mem=='mem':
        other_label = mem_label
    else:
        raise ValueError('bad self or mem, got %s'%self_or_mem)
    ax.set_xticks(np.arange(len(other_label)))
    ax.set_xticklabels(labels=other_label, fontsize=16)
    
    im = ax.imshow(matrix, cmap='Blues')
    fig.colorbar(im, ax=ax)

    
    if self_or_mem=='self':
        fig.set_size_inches(7,7)
    else:
        fig.set_size_inches(4,7)

In [None]:
fig, ax = plt.subplots(mem_attn.size(0), mem_attn.size(1))

for i in range(mem_attn.size(0)):
    for j in range(mem_attn.size(1)):
        title = '%s Layer: %d Head: %d'%('Mem', i,j)
        show_attn(fig, ax[i,j], mem_attn[i][j].data.cpu().numpy(), 'mem', title)
fig.set_size_inches(36,36)
fig.tight_layout()

In [None]:
fig, ax = plt.subplots(self_attn.size(0), self_attn.size(1))

for i in range(self_attn.size(0)):
    for j in range(self_attn.size(1)):
        title = '%s AttentionLayer: %d Head: %d'%('Self', i,j)
        show_attn(fig, ax[i,j], np.clip(self_attn[i][j].data.cpu().numpy(), a_min=0, a_max=.6), 'self', title)
fig.set_size_inches(36,36)
fig.tight_layout()