In [113]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [114]:
import os
import json

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, IterableDataset
from tqdm import tqdm

from model import SwipeCurveTransformer, get_m1_model
from tokenizers import CharLevelTokenizerv2, KeyboardTokenizerv1
from dataset import NeuroSwipeDatasetv2
from word_generators import GreedyGenerator

In [115]:
IN_KAGGLE = False

if IN_KAGGLE:
    DATA_ROOT = "/kaggle/input/yandex-cup-playground"
    MODELS_DIR = ""
else:
    DATA_ROOT = "../data/data_separated_grid"
    MODELS_DIR = "../data/trained_models/m1"

In [116]:
import pickle

word_tokenizer = CharLevelTokenizerv2(os.path.join(DATA_ROOT, "voc.txt"))

if not IN_KAGGLE:
    word_tokenizer_save_path = os.path.join(DATA_ROOT, "word_tokenizer.pkl")

    with open(word_tokenizer_save_path, 'wb') as f:
        pickle.dump(word_tokenizer, f, protocol=pickle.HIGHEST_PROTOCOL)

    with open(word_tokenizer_save_path, 'rb') as f:
        word_tokenizer = pickle.load(f)

In [117]:
print(word_tokenizer.idx_to_char)

{0: '-', 1: 'а', 2: 'б', 3: 'в', 4: 'г', 5: 'д', 6: 'е', 7: 'ж', 8: 'з', 9: 'и', 10: 'й', 11: 'к', 12: 'л', 13: 'м', 14: 'н', 15: 'о', 16: 'п', 17: 'р', 18: 'с', 19: 'т', 20: 'у', 21: 'ф', 22: 'х', 23: 'ц', 24: 'ч', 25: 'ш', 26: 'щ', 27: 'ъ', 28: 'ы', 29: 'ь', 30: 'э', 31: 'ю', 32: 'я', 33: '<eos>', 34: '<unk>', 35: '<pad>', 36: '<sos>'}


In [118]:
def get_grid(grid_name: str, grids_path: str) -> dict:
    with open(grids_path, "r", encoding="utf-8") as f:
        return json.load(f)[grid_name]

In [120]:
sample_data = os.path.join(DATA_ROOT, "valid__in_train_format__default_only.jsonl")
grid_path =  os.path.join(DATA_ROOT, "gridname_to_grid.json")
grid_name = "default"

grid = get_grid(grid_name, grid_path)
kb_tokenizer = KeyboardTokenizerv1()
word_tokenizer = CharLevelTokenizerv2(os.path.join(DATA_ROOT, "voc.txt"))


dataset = NeuroSwipeDatasetv2(
    data_path = sample_data,
    gridname_to_grid= {grid_name: grid},
    kb_tokenizer = kb_tokenizer,
    max_traj_len = 299,
    word_tokenizer = word_tokenizer,
    include_time = False,
    include_velocities = True,
    include_accelerations = True,
    keyboard_selection_set=set(KeyboardTokenizerv1.i2t),
    total = 9_416
)

100%|██████████| 9416/9416 [00:01<00:00, 5319.79it/s]


In [89]:
# Create dataset. Look at several examples. Espesially accelerations and velocities

In [90]:
x, target = dataset[5]
(xyt, kb_tokens, dec_in_char_seq, traj_pad_mask, word_pad_mask) = x

In [91]:
print([el.shape for el in x])

[torch.Size([299, 6]), torch.Size([299]), torch.Size([35]), torch.Size([299]), torch.Size([35])]


In [92]:
for embed in xyt:
    print(list(map(float, embed)))

[0.32592591643333435, 0.5877061486244202, 0.0, 0.0, 0.0, 0.0]
[0.32592591643333435, 0.5877061486244202, -0.02777777798473835, -0.02777777798473835, -0.002314814832061529, -0.05092592537403107]
[0.32499998807907104, 0.5862069129943848, -0.0833333358168602, -1.8333333730697632, 0.0005787037080153823, -0.10706018656492233]
[0.3222222328186035, 0.4557721018791199, 0.0, -5.166666507720947, 0.002314814832061529, -0.09890571981668472]
[0.32499998807907104, 0.30734631419181824, 0.0, -5.39393949508667, -0.007792208343744278, 0.03708513453602791]
[0.3222222328186035, 0.18890555202960968, -0.2571428716182709, -3.942857027053833, -0.009523809887468815, 0.07878787815570831]
[0.3166666626930237, 0.10044977813959122, -0.3333333432674408, -2.6363637447357178, 0.0025974027812480927, 0.07792206853628159]
[0.31203705072402954, 0.05847076326608658, -0.17142857611179352, -1.3714286088943481, 0.007002801634371281, 0.05263560265302658]
[0.31111112236976624, 0.028485756367444992, -0.0882352963089943, -0.79411

In [68]:
word_tokenizer.decode(dec_in_char_seq)

'<sos>скинь<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>'

In [69]:
word_tokenizer.decode(target)

'скинь<eos><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>'

In [128]:
!python dataset_test.py

selection set and nearest labels set are different!
{'г', 'х', 'я', 'ъ', 'е', 'ы', 'п', 'с', 'ю', 'з', 'т', 'р', 'б', 'ч', 'о', 'м', 'ж', 'ш', 'л', 'ц', 'ь', 'щ', 'и', 'у', 'а', 'э', 'д', 'н', 'й', 'в', 'к', 'ф'}
{'г', 'х', 'я', 'ъ', 'е', 'п', 'ы', 'с', '<pad>', 'ю', '<unk>', 'з', 'т', 'р', 'б', 'ч', 'ж', 'м', 'о', 'л', 'ш', 'ц', 'ь', 'щ', 'и', 'у', '-', 'а', 'э', 'д', 'н', 'й', 'в', 'к', 'ф', 'ë'}
keyboard_selection_set success



  0%|          | 0/10000 [00:00<?, ?it/s]
  4%|▍         | 389/10000 [00:00<00:02, 3852.92it/s]
 11%|█▏        | 1144/10000 [00:00<00:01, 6019.07it/s]
 19%|█▉        | 1944/10000 [00:00<00:01, 6920.43it/s]
 27%|██▋       | 2727/10000 [00:00<00:01, 7250.53it/s]
 35%|███▍      | 3472/10000 [00:00<00:00, 7295.97it/s]
 42%|████▏     | 4240/10000 [00:00<00:00, 7401.92it/s]
 50%|████▉     | 4981/10000 [00:00<00:00, 7262.79it/s]
 57%|█████▋    | 5708/10000 [00:00<00:00, 7087.61it/s]
 65%|██████▍   | 6489/10000 [00:00<00:00, 7307.03it/s]
 72%|███████▏  | 7222/10000 [00:01<00:00, 7291.53it/s]
 80%|████████  | 8009/10000 [00:01<00:00, 7465.24it/s]
 88%|████████▊ | 8757/10000 [00:01<00:00, 7448.07it/s]
 95%|█████████▌| 9515/10000 [00:01<00:00, 7487.60it/s]
100%|██████████| 10000/10000 [00:01<00:00, 7209.80it/s]
