In [2]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import os
import sys
sys.path.append('../src/')
sys.path.append('../problems/')
sys.path.append('../scripts/')
sys.path.append('../scripts/dataset_generation/')
import models
import generation_utils
import tokenizer
import data_utils
from train_model import get_loaders
from utils import get_best_checkpoint
from factorization import Factorization

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [5]:
pd.set_option('display.max_colwidth', 999)
pd.set_option('display.max_rows', 9999)

In [6]:
device = torch.device('cuda')

In [7]:
base_path = '../models/factorization/2^22/scaled/'

In [8]:
checkpoint = get_best_checkpoint(base_path)

Loading model at ../models/factorization/2^22/scaled/checkpoints/389150_0.0727.pt


In [9]:
args = checkpoint['args']

In [10]:
args

{'data': {'train_path': 'data/train_data_2^22.npy',
  'test_path': 'data/test_data_2^22.npy',
  'oos_path': 'data/oos_data_2^22.npy',
  'base': 30},
 'problem_type': 'factorization',
 'model_args': {'embed_dim': 256,
  'num_encoder_layers': 10,
  'num_decoder_layers': 10,
  'dim_feedforward': 1024,
  'dropout': 0.1,
  'shared_embeddings': False,
  'scale_embeddings': False,
  'scale_embeddings_at_init': False,
  'max_decode_size': 64,
  'norm_first': False,
  'learn_positional_encoding': True,
  'repeat_positional_encoding': False,
  'positional_encoding_query_key_only': False,
  'positional_encoding_type': 'relative-transfxl',
  'extra_positional_encoding_relative_decoder_mha': True,
  'attn_weight_xavier_init_constant': 0.5,
  'embedding_initialization': 'xavier'},
 'optimizer': {'type': 'AdamW',
  'opt_args': {'lr': 0.001, 'weight_decay': 0.1},
  'max_grad_norm': 1,
  'gradient_accumulation_steps': 8},
 'scheduler': {'type': 'get_linear_schedule_with_warmup',
  'nb_epochs': 200,
  '

In [11]:
for key in ['train', 'test', 'oos']:
    args['data'][f'{key}_path'] = '../' + args['data'][f'{key}_path'].replace('2^22', '2^13')

In [12]:
problem = Factorization(args)

In [13]:
tokenizer = problem.get_tokenizer()

In [14]:
len(tokenizer)

34

In [15]:
tokenizer.token_mapper

{'0': 0,
 '1': 1,
 '2': 2,
 '3': 3,
 '4': 4,
 '5': 5,
 '6': 6,
 '7': 7,
 '8': 8,
 '9': 9,
 '10': 10,
 '11': 11,
 '12': 12,
 '13': 13,
 '14': 14,
 '15': 15,
 '16': 16,
 '17': 17,
 '18': 18,
 '19': 19,
 '20': 20,
 '21': 21,
 '22': 22,
 '23': 23,
 '24': 24,
 '25': 25,
 '26': 26,
 '27': 27,
 '28': 28,
 '29': 29,
 '*': 30,
 '[PAD]': 31,
 '[EOS]': 32,
 '[SOS]': 33}

In [19]:
train_loader, test_loader, oos_loader = get_loaders(problem)

Loading data...


In [20]:
args.keys()

dict_keys(['data', 'problem_type', 'model_args', 'optimizer', 'scheduler', 'loader', 'io', 'metrics', 'verbose', 'wandb', 'resume_training', 'tokenizer'])

In [21]:
model = models.Seq2SeqModel(n_tokens = args['tokenizer']['n_tokens'], 
                          pad_token_id = args['tokenizer']['pad_token_id'],
                          **args['model_args'])
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

Seq2SeqModel(
  (src_embedding): TransformerEmbedding(
    (embedding): Embedding(34, 256)
  )
  (tgt_embedding): TransformerEmbedding(
    (embedding): Embedding(34, 256)
  )
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-9): 10 x TransformerEncoderLayer(
          (self_attn): MultiHeadRelativeAttention(
            (w_q): Linear(in_features=256, out_features=256, bias=False)
            (w_k): Linear(in_features=256, out_features=256, bias=False)
            (w_v): Linear(in_features=256, out_features=256, bias=False)
            (out_proj): Linear(in_features=256, out_features=256, bias=False)
            (pe_mod): PositionalEncoding(
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (w_k_pos): Linear(in_features=256, out_features=256, bias=False)
          )
          (linear1): Linear(in_features=256,

In [22]:
test_loader.dataset[0]

{'input': ['[SOS]', 8, '[EOS]'],
 'label': ['[SOS]', 2, '*', 2, '*', 2, '[EOS]']}

In [23]:
%%time
test_preds, test_metrics = problem.compute_metrics(model, device, test_loader, save=False, n_beams=20)
# oos_preds, oos_metrics = problem.compute_metrics(model, device, oos_loader, save=False, n_beams=5)

  0%|          | 0/856 [00:00<?, ?it/s]

  information['percent_prime_factors_pred'] = information['num_prime_factors_pred'] / information['num_pred_factors']


  0%|          | 0/4 [00:00<?, ?it/s]

Wall time: 2min 29s


In [24]:
test_preds

Unnamed: 0,input_num,model_input,beam_idx,log_prob,pred_str,pred_factor_list,product,num_pred_factors,num_prime_factors_pred,percent_prime_factors_pred,correct_product,correct_factorization,pred_same_as_input,input_is_prime,target_factor_str,target_factor_list,num_target_factors,min_target_prime_factor_if_composite,log_prob_decile
0,8,"[[SOS], 8, [EOS]]",0,-0.001551,[SOS] 2 * 2 * 2 [EOS] [PAD] [PAD] [PAD],"[2, 2, 2]",8.0,3,3.0,1.000000,True,True,False,False,[SOS] 2 * 2 * 2 [EOS],"[2, 2, 2]",3,2,"(-3.285403, 0.0]"
1,8,"[[SOS], 8, [EOS]]",1,-6.526535,[SOS] 2 * 2 * 2 * 1 1 [EOS],"[2, 2, 2, 31]",248.0,4,4.0,1.000000,False,False,False,False,[SOS] 2 * 2 * 2 [EOS],"[2, 2, 2]",3,2,"(-6.983982, -4.74091]"
2,8,"[[SOS], 8, [EOS]]",2,-9.211891,[SOS] 2 * 2 * 2 [EOS] [PAD] [PAD] 0,"[2, 2, 60]",240.0,3,2.0,0.666667,False,False,False,False,[SOS] 2 * 2 * 2 [EOS],"[2, 2, 2]",3,2,"(-13.885378999999999, -9.210369]"
3,8,"[[SOS], 8, [EOS]]",3,-9.211891,[SOS] 2 * 2 * 2 [EOS] [PAD] [PAD] 1,"[2, 2, 61]",244.0,3,3.0,1.000000,False,False,False,False,[SOS] 2 * 2 * 2 [EOS],"[2, 2, 2]",3,2,"(-13.885378999999999, -9.210369]"
4,8,"[[SOS], 8, [EOS]]",4,-9.211891,[SOS] 2 * 2 * 2 [EOS] [PAD] [PAD] 2,"[2, 2, 62]",248.0,3,2.0,0.666667,False,False,False,False,[SOS] 2 * 2 * 2 [EOS],"[2, 2, 2]",3,2,"(-13.885378999999999, -9.210369]"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
17115,8184,"[[SOS], 9, 2, 24, [EOS]]",15,-9.210367,[SOS] 2 * 2 * 2 * 3 * 11 * 1 1 [EOS] 14,"[2, 2, 2, 3, 11, 944]",249216.0,6,5.0,0.833333,False,False,False,False,[SOS] 2 * 2 * 2 * 3 * 11 * 1 1 [EOS],"[2, 2, 2, 3, 11, 31]",6,2,"(-9.210369, -9.210348]"
17116,8184,"[[SOS], 9, 2, 24, [EOS]]",16,-9.210367,[SOS] 2 * 2 * 2 * 3 * 11 * 1 1 [EOS] 15,"[2, 2, 2, 3, 11, 945]",249480.0,6,5.0,0.833333,False,False,False,False,[SOS] 2 * 2 * 2 * 3 * 11 * 1 1 [EOS],"[2, 2, 2, 3, 11, 31]",6,2,"(-9.210369, -9.210348]"
17117,8184,"[[SOS], 9, 2, 24, [EOS]]",17,-9.210367,[SOS] 2 * 2 * 2 * 3 * 11 * 1 1 [EOS] 16,"[2, 2, 2, 3, 11, 946]",249744.0,6,5.0,0.833333,False,False,False,False,[SOS] 2 * 2 * 2 * 3 * 11 * 1 1 [EOS],"[2, 2, 2, 3, 11, 31]",6,2,"(-9.210369, -9.210348]"
17118,8184,"[[SOS], 9, 2, 24, [EOS]]",18,-9.210367,[SOS] 2 * 2 * 2 * 3 * 11 * 1 1 [EOS] 17,"[2, 2, 2, 3, 11, 947]",250008.0,6,6.0,1.000000,False,False,False,False,[SOS] 2 * 2 * 2 * 3 * 11 * 1 1 [EOS],"[2, 2, 2, 3, 11, 31]",6,2,"(-9.210369, -9.210348]"


In [25]:
test_preds.groupby('input_num')['correct_factorization'].any().mean()

1.0

In [26]:
import pprint

In [27]:
test_metrics.keys()

dict_keys(['correct', 'beam_accuracy', 'by_prob', 'by_num_target_factors', 'by_input_num', 'pred_same_as_input_beam_0', 'by_min_factor_composite_only', 'loss', 'meta'])

In [28]:
pd.DataFrame(test_metrics['beam_accuracy'])

Unnamed: 0,correct_product,correct_factorization
0,0.997664,0.867991
1,0.057243,0.046729
2,0.029206,0.029206
3,0.016355,0.014019
4,0.014019,0.010514
5,0.01285,0.011682
6,0.016355,0.014019
7,0.004673,0.004673
8,0.003505,0.003505
9,0.004673,0.004673


In [29]:
pd.DataFrame(test_metrics['pred_same_as_input_beam_0'])

Unnamed: 0,0,1
input_is_prime,False,True
pred_same_as_input,False,False
group_size,751,105
correct_product_mean,0.997337,1.0
correct_factorization_mean,0.849534,1.0
