In [76]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [77]:
import os
import json

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, IterableDataset
from tqdm import tqdm

from model import SwipeCurveTransformer, get_m1_model
from tokenizers import CharLevelTokenizerv1, KeyboardTokenizerv1
from dataset import NeuroSwipeDatasetv1
from word_generators import GreedyGenerator

In [78]:
IN_KAGGLE = False

if IN_KAGGLE:
    DATA_ROOT = "/kaggle/input/yandex-cup-playground"
    MODELS_DIR = ""
else:
    DATA_ROOT = "../data/data_separated_grid"
    MODELS_DIR = "../data/trained_models/m1"

In [79]:
def truncate_padding(seq, mask):
    max_curve_len = int(torch.max(torch.sum(~mask, dim = 1)))
    seq = seq[:, :max_curve_len]
    mask = mask[:, :max_curve_len]
    return seq, mask

In [80]:
grid_name_to_grid_path = os.path.join(DATA_ROOT, "gridname_to_grid.json")
with open(grid_name_to_grid_path, "r", encoding="utf-8") as f:
    grid_name_to_grid = json.load(f)

val_path = os.path.join(DATA_ROOT, f"valid__in_train_format.jsonl")

kb_tokenizer = KeyboardTokenizerv1()
word_tokenizer = CharLevelTokenizerv1(os.path.join(DATA_ROOT, "voc.txt"))


val_dataset = NeuroSwipeDatasetv1(
    data_path = val_path,
    gridname_to_grid = grid_name_to_grid,
    kb_tokenizer = kb_tokenizer,
    max_traj_len = 299,
    word_tokenizer = word_tokenizer,
    include_time = False,
    include_velocities = True,
    include_accelerations = True,
    has_target=True,
    has_one_grid_only=False,
    include_grid_name=True,
    total = 10_000
)



100%|██████████| 10000/10000 [00:03<00:00, 3185.73it/s]


In [81]:
val_loader = DataLoader(val_dataset, 5, shuffle=False)

In [82]:
batch = next(iter(val_loader))

In [83]:
(xyt, kb_tokens, dec_in_char_seq, traj_pad_mask, word_pad_mask), target, g_name = batch

In [75]:
traj_pad_mask[0]

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True, 

In [45]:
# torch.set_printoptions(threshold=10_000)

# print(kb_tokens)
# kb_tokens, traj_pad_mask = truncate_padding(kb_tokens, traj_pad_mask)
# print(kb_tokens)

In [46]:
model = get_m1_model(os.path.join(MODELS_DIR,
                                  "best_model__2023_11_04__18_31_37__0.02530_default_switch_2.pt"),
                     device='cpu')

In [48]:
output_before_truncation = model(xyt.transpose(0,1),
                                 kb_tokens.transpose(0,1),
                                 dec_in_char_seq.transpose(0,1),
                                 traj_pad_mask,
                                 word_pad_mask)



In [50]:
xyt, traj_pad_mask = truncate_padding(xyt, traj_pad_mask)
kb_tokens, traj_pad_mask = truncate_padding(kb_tokens, traj_pad_mask)
dec_in_char_seq, word_pad_mask = truncate_padding(kb_tokens, traj_pad_mask)

In [51]:
output_after_truncation = model(xyt.transpose(0,1),
                                 kb_tokens.transpose(0,1),
                                 dec_in_char_seq.transpose(0,1),
                                 traj_pad_mask,
                                 word_pad_mask)



In [66]:
torch.allclose(output_before_truncation, output_after_truncation, atol = 1e-6)

False

In [55]:
print(output_before_truncation)

tensor([[[-2.1640e+01,  3.9447e+00, -6.8112e-01, -3.2235e+00,  5.1712e+00,
           9.2946e-01,  2.3947e+00, -3.5701e+00,  1.9089e+00,  2.6484e-01,
          -5.5972e+00,  1.6402e+00,  3.9524e-02,  9.9606e-02,  1.3544e+01,
           1.5968e+00, -1.6720e-01, -1.7022e+00, -2.1444e+00, -8.1203e-01,
          -1.6630e+00, -6.4178e+00, -1.3355e+00, -5.3939e+00, -6.4922e+00,
           1.3956e+00,  9.5815e-02, -7.5364e+00, -4.4702e+00, -5.0071e+00,
          -4.6992e+00, -4.3204e+00, -1.8045e+00,  6.9030e-01, -9.2705e+00],
         [-1.8800e+01,  2.7956e+00, -2.8284e+00,  1.5690e+01, -6.3868e+00,
          -2.7300e+00,  2.6580e+00, -3.9543e+00, -1.5982e+00,  4.4504e-01,
          -4.7149e-01,  5.5964e-01, -9.7434e-01,  1.7409e-01, -1.8291e+00,
          -2.0795e-01, -2.3277e-01,  3.8752e-02,  3.9486e+00, -1.4775e+00,
           2.3515e+00,  2.8121e+00, -6.1756e+00,  1.8705e+00,  3.4870e+00,
          -3.1141e+00, -6.7650e+00, -5.4600e+00,  6.0021e-01, -6.1514e+00,
          -5.1412e+00, -

In [56]:
print(output_after_truncation)

tensor([[[-2.1640e+01,  3.9447e+00, -6.8112e-01, -3.2235e+00,  5.1712e+00,
           9.2946e-01,  2.3947e+00, -3.5701e+00,  1.9089e+00,  2.6484e-01,
          -5.5972e+00,  1.6402e+00,  3.9524e-02,  9.9606e-02,  1.3544e+01,
           1.5968e+00, -1.6720e-01, -1.7022e+00, -2.1444e+00, -8.1203e-01,
          -1.6630e+00, -6.4178e+00, -1.3355e+00, -5.3939e+00, -6.4922e+00,
           1.3956e+00,  9.5815e-02, -7.5364e+00, -4.4702e+00, -5.0071e+00,
          -4.6992e+00, -4.3204e+00, -1.8045e+00,  6.9030e-01, -9.2705e+00],
         [-1.8800e+01,  2.7956e+00, -2.8284e+00,  1.5690e+01, -6.3868e+00,
          -2.7300e+00,  2.6580e+00, -3.9543e+00, -1.5982e+00,  4.4504e-01,
          -4.7149e-01,  5.5964e-01, -9.7434e-01,  1.7409e-01, -1.8291e+00,
          -2.0795e-01, -2.3277e-01,  3.8752e-02,  3.9486e+00, -1.4775e+00,
           2.3515e+00,  2.8121e+00, -6.1756e+00,  1.8705e+00,  3.4870e+00,
          -3.1141e+00, -6.7650e+00, -5.4600e+00,  6.0021e-01, -6.1514e+00,
          -5.1412e+00, -

In [88]:
a = [1, 2,4]
b = torch.tensor([1, 4, 5, 7, 8, 9], dtype=torch.long)

In [86]:
b[:len(a)] = a

TypeError: can't assign a list to a torch.LongTensor