In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import os
import sys
sys.path.append('../src/')
sys.path.append('../problems/')
sys.path.append('../scripts/')
sys.path.append('../scripts/dataset_generation/')
from evaluation import Evaluation, EvaluationDataset
import models
import generation_utils
import tokenizer
import data_utils
import metrics_utils
from utils import get_best_checkpoint
from train_model import get_loaders
from tqdm.auto import tqdm
import itertools
from optimization_utils import test_on_loader

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [4]:
pd.set_option('display.max_colwidth', 999)
pd.set_option('display.max_rows', 9999)

In [5]:
device = torch.device('cuda')

In [6]:
base_path = '../models/evaluation/addition/'

In [7]:
# checkpoint = get_best_checkpoint(base_path)
checkpoint = torch.load(base_path + 'checkpoints/80000_0.5956.pt')

In [8]:
args = checkpoint['args']

In [9]:
for key in ['train', 'test', 'oos']:
    if not args['data'][f'{key}_path'].startswith('streaming_'):
        args['data'][f'{key}_path'] = '../' + args['data'][f'{key}_path']

In [10]:
problem = Evaluation(args)

In [11]:
train_loader, test_loader, oos_loader = get_loaders(problem)

Loading data...


In [12]:
t = problem.get_tokenizer()

In [13]:
args.keys()

dict_keys(['data', 'problem_type', 'model_args', 'optimizer', 'scheduler', 'loader', 'io', 'metrics', 'verbose', 'resume_training', 'overwrite', 'tokenizer'])

In [14]:
model = models.Seq2SeqModel(n_tokens = args['tokenizer']['n_tokens'], 
                          pad_token_id = args['tokenizer']['pad_token_id'],
                          **args['model_args'])
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

Seq2SeqModel(
  (src_embedding): TransformerEmbedding(
    (embedding): Embedding(19, 128)
  )
  (tgt_embedding): TransformerEmbedding(
    (embedding): Embedding(19, 128)
  )
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.05, inplace=False)
  )
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-3): 4 x TransformerEncoderLayer(
          (self_attn): MultiHeadRelativeAttention(
            (w_q): Linear(in_features=128, out_features=128, bias=False)
            (w_k): Linear(in_features=128, out_features=128, bias=False)
            (w_v): Linear(in_features=128, out_features=128, bias=False)
            (out_proj): Linear(in_features=128, out_features=128, bias=False)
            (pe_mod): PositionalEncoding(
              (dropout): Dropout(p=0.05, inplace=False)
            )
            (w_k_pos): Linear(in_features=128, out_features=128, bias=False)
          )
          (linear1): Linear(in_features=128

In [15]:
np.sum([np.prod(p.size()) for p in model.parameters()])

1210515

In [16]:
len(test_loader.dataset)

2048

In [17]:
test_df, test_metrics = problem.compute_metrics(model, device, test_loader, save=False, n_beams=64)

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

In [18]:
test_df

Unnamed: 0,expression,value,model_input,beam_idx,log_prob,output_toks,pred,correct_value,log_prob_decile
0,-3+-22+-65+64+-86,-112,"[[SOS], -, 3, +, -, 2, 2, +, -, 6, 5, +, 6, 4, +, -, 8, 6, [EOS]]",0,-0.948943,"[[SOS], -, 1, 1, 0, [EOS], [PAD]]",-110.0,False,"(-4.206, -0.509]"
1,-3+-22+-65+64+-86,-112,"[[SOS], -, 3, +, -, 2, 2, +, -, 6, 5, +, 6, 4, +, -, 8, 6, [EOS]]",1,-1.362660,"[[SOS], -, 1, 1, 1, [EOS], [PAD]]",-111.0,False,"(-4.206, -0.509]"
2,-3+-22+-65+64+-86,-112,"[[SOS], -, 3, +, -, 2, 2, +, -, 6, 5, +, 6, 4, +, -, 8, 6, [EOS]]",2,-1.725591,"[[SOS], -, 1, 0, 9, [EOS], [PAD]]",-109.0,False,"(-4.206, -0.509]"
3,-3+-22+-65+64+-86,-112,"[[SOS], -, 3, +, -, 2, 2, +, -, 6, 5, +, 6, 4, +, -, 8, 6, [EOS]]",3,-2.361022,"[[SOS], -, 1, 1, 2, [EOS], [PAD]]",-112.0,True,"(-4.206, -0.509]"
4,-3+-22+-65+64+-86,-112,"[[SOS], -, 3, +, -, 2, 2, +, -, 6, 5, +, 6, 4, +, -, 8, 6, [EOS]]",4,-2.911325,"[[SOS], -, 1, 0, 8, [EOS], [PAD]]",-108.0,False,"(-4.206, -0.509]"
...,...,...,...,...,...,...,...,...,...
131067,-56+-27+-55+84+7,-47,"[[SOS], -, 5, 6, +, -, 2, 7, +, -, 5, 5, +, 8, 4, +, 7, [EOS]]",59,-11.166413,"[[SOS], -, 4, 8, [EOS], (]",-48.0,False,"(-11.223, -10.905]"
131068,-56+-27+-55+84+7,-47,"[[SOS], -, 5, 6, +, -, 2, 7, +, -, 5, 5, +, 8, 4, +, 7, [EOS]]",60,-11.166413,"[[SOS], -, 4, 8, [EOS], )]",-48.0,False,"(-11.223, -10.905]"
131069,-56+-27+-55+84+7,-47,"[[SOS], -, 5, 6, +, -, 2, 7, +, -, 5, 5, +, 8, 4, +, 7, [EOS]]",61,-11.166413,"[[SOS], -, 4, 8, [EOS], [EOS]]",-48.0,False,"(-11.223, -10.905]"
131070,-56+-27+-55+84+7,-47,"[[SOS], -, 5, 6, +, -, 2, 7, +, -, 5, 5, +, 8, 4, +, 7, [EOS]]",62,-11.212212,"[[SOS], -, 5, 2, [EOS], [PAD]]",-52.0,False,"(-11.223, -10.905]"


In [19]:
test_metrics

{'correct_value': 0.98828125,
 'beam_accuracy': {'correct_value': {0: 0.23681640625,
   1: 0.1884765625,
   2: 0.13232421875,
   3: 0.10498046875,
   4: 0.078125,
   5: 0.05078125,
   6: 0.0419921875,
   7: 0.0478515625,
   8: 0.0859375,
   9: 0.115234375,
   10: 0.1435546875,
   11: 0.16357421875,
   12: 0.16552734375,
   13: 0.16943359375,
   14: 0.17822265625,
   15: 0.177734375,
   16: 0.18017578125,
   17: 0.18310546875,
   18: 0.18408203125,
   19: 0.1875,
   20: 0.18896484375,
   21: 0.1904296875,
   22: 0.1904296875,
   23: 0.19091796875,
   24: 0.181640625,
   25: 0.16845703125,
   26: 0.16015625,
   27: 0.1572265625,
   28: 0.16455078125,
   29: 0.1630859375,
   30: 0.1630859375,
   31: 0.16552734375,
   32: 0.166015625,
   33: 0.16552734375,
   34: 0.16455078125,
   35: 0.16650390625,
   36: 0.16552734375,
   37: 0.1650390625,
   38: 0.1650390625,
   39: 0.16455078125,
   40: 0.16552734375,
   41: 0.16357421875,
   42: 0.142578125,
   43: 0.12646484375,
   44: 0.1171875,
   

In [None]:
sys.exit()