In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

from torch.utils.data import DataLoader

from full_vocab_estimation import estimate_probs_of_words
from model import MODEL_GETTERS_DICT
from dataset import CurveDataset, CollateFn
from feature_extractors import get_transforms
from feature_extractors import weights_function_v1, weights_function_v1_softmax
from ns_tokenizers import CharLevelTokenizerv2, KeyboardTokenizerv1

In [None]:
######### Comand line arguments emulation #########

MODEL_NAME = "v2_weighted_transformer_bigger"
GRID_NAME = "default"
CPT_PATH = r'/home/proshian/Downloads/weighted_transformer_bigger-sigmoid--epoch=26.pt'  # r'/kaggle/input/weighted-transformer-bigger-sigmoid-epoch26-pt/weighted_transformer_bigger-sigmoid--epoch26.pt'
TRANSFORM_NAME =  "traj_feats_and_distances"  # "traj_feats_and_nearest_key"
DIST_WEIGHTS_FUNC_NAME =  "weights_function_v1_softmax"  # 'weights_function_v1' 
DATA_ROOT = "../data/data_separated_grid"
DEVICE  = 'cpu'  # 'cuda'
BATCH_SIZE = 2

In [None]:
######### Other constants #########

DIST_WEIGHTS_FUNCS_DICT = {
    'weights_function_v1_softmax': weights_function_v1_softmax,
    'weights_function_v1': weights_function_v1
}

GRID_NAME_TO_DS_PATHS = {
    "extra": {
        "train": os.path.join(DATA_ROOT, "train__extra_only_no_errors__2023_11_01__19_49_14.jsonl"),
        "val": os.path.join(DATA_ROOT, "valid__in_train_format__extra_only.jsonl")
    },
    "default": {
        "train": os.path.join(DATA_ROOT, "train__default_only_no_errors__2023_10_31__03_26_16.jsonl"),
        "val": os.path.join(DATA_ROOT, "valid__in_train_format__default_only.jsonl")
    }
}


DS_PATHS =  GRID_NAME_TO_DS_PATHS[GRID_NAME]

In [None]:
DIST_WEIGHTS_FUNC = DIST_WEIGHTS_FUNCS_DICT[DIST_WEIGHTS_FUNC_NAME]

In [None]:
gridname_to_grid_path = os.path.join(DATA_ROOT, "gridname_to_grid.json")
voc_path=os.path.join(DATA_ROOT, "voc.txt")
char_tokenizer = CharLevelTokenizerv2(voc_path)
kb_tokenizer = KeyboardTokenizerv1()

In [None]:
_, val_transform = get_transforms(
    gridname_to_grid_path=gridname_to_grid_path,
    grid_name=GRID_NAME,
    transform_name=TRANSFORM_NAME,
    char_tokenizer=char_tokenizer,
    dist_weights_func=DIST_WEIGHTS_FUNC,
)

In [None]:
val_dataset = CurveDataset(
    data_path=DS_PATHS['val'],
    store_gnames=False,
    init_transform=None,
    get_item_transform=val_transform,
    # total=val_total  # 9416
)

In [None]:
dataloader_workers_n = 4

collate_fn = CollateFn(
    word_pad_idx = char_tokenizer.char_to_idx['<pad>'], batch_first = False)

val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=dataloader_workers_n, persistent_workers = True, 
                        collate_fn=collate_fn)

In [None]:
model = MODEL_GETTERS_DICT[MODEL_NAME](DEVICE, CPT_PATH).eval()

In [None]:
with open(voc_path, 'r') as f:
    voc = f.read().splitlines()

In [None]:
word_lsts = [voc[:3] for _ in val_dataset]

word_lsts_loader = DataLoader(word_lsts, batch_size=BATCH_SIZE, shuffle=False,
                              num_workers=dataloader_workers_n, persistent_workers = True)

In [None]:
probs = estimate_probs_of_words(model, val_loader, 
                                word_lsts_loader, 
                                char_tokenizer, batch_first=False, device=DEVICE)

In [None]:
probs.shape

In [None]:
import torch
torch.exp(probs)