In [2]:
%load_ext autoreload
%autoreload 2

In [None]:
# !git clone https://github.com/proshian/yandex-cup-2023-ml-neuroswipe.git
# %cd yandex-cup-2023-ml-neuroswipe
# ! git checkout datasetv4

In [None]:
# ! pip install gdown
# ! python ./src/downloaders/download_weights.py

In [None]:
# !pip install dvc --quiet
# !pip install dvc_gdrive --quiet

In [None]:
# %cd /kaggle/working/yandex-cup-2023-ml-neuroswipe
# ! git pull
# ! git checkout datasetv4

In [None]:
%cd /kaggle/working/yandex-cup-2023-ml-neuroswipe/src

In [2]:
############# Script arguments emulation #############

GRID_NAME = "default"
BATCH_SIZE = 320
IN_KAGGLE = False
RANDOM_SEED = 12

DATA_ROOT = "../data/data_separated_grid"
MODELS_DIR = "../data/trained_models/m1"

In [3]:
import os
import json
import typing as tp
import traceback
from datetime import datetime
import copy

import torch
# import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import numpy as np
from torch.utils.tensorboard import SummaryWriter


from model import SwipeCurveTransformer, get_m1_bigger_model
from tokenizers import CharLevelTokenizerv2, KeyboardTokenizerv1
from tokenizers import ALL_CYRILLIC_LETTERS_ALPHABET_ORD
from dataset import CurveDataset, CollateFn
from word_generators import GreedyGenerator
from nearest_key_lookup import ExtendedNearestKeyLookup
from transforms import InitTransform, GetItemTransform

In [4]:
################ Other constants ####################
GRID_NAME_TO_DS_PATHS = {
    "extra": {
        "train": os.path.join(DATA_ROOT, "train__extra_only_no_errors__2023_11_01__19_49_14.jsonl"),
        "val": os.path.join(DATA_ROOT, "valid__in_train_format__extra_only.jsonl")
    },
    "default": {
        "train": os.path.join(DATA_ROOT, "train__default_only_no_errors__2023_10_31__03_26_16.jsonl"),
        "val": os.path.join(DATA_ROOT, "valid__in_train_format__default_only.jsonl")
    }
}

In [5]:
# if IN_KAGGLE:
#     DATA_ROOT = "/kaggle/input/neuroswipe-defualt-only-v1"
#     MODELS_DIR = ""

In [6]:
def init_random_seed(value):
    # random.seed(value)
    np.random.seed(value)
    torch.manual_seed(value)
    torch.cuda.manual_seed(value)
    # torch.backends.cudnn.deterministic = True

In [7]:
def get_grid(grid_name: str, grids_path: str) -> dict:
    with open(grids_path, "r", encoding="utf-8") as f:
        return json.load(f)[grid_name]

In [8]:
from typing import List, Dict, Tuple, Optional, Set

def get_gridname_to_out_of_bounds_coords_dict(
        data_paths: List[str], gridname_to_wh: dict,
        total: Optional[int] = None
        ) -> Dict[str, Set[Tuple[int, int]]]:
    """
    Returns a dictionary with grid names as keys and lists of out of bounds coordinates as values.
    """
    gname_to_out_of_bounds = {gname: set() for gname in gridname_to_wh.keys()}

    for data_path in data_paths:
        with open(data_path, "r", encoding="utf-8") as json_file:
            for line in tqdm(json_file, total=total):
                json_data = json.loads(line)
                curve = json_data['curve']
                grid_name = curve['grid_name']
                w, h = gridname_to_wh[grid_name]
                X, Y = curve['x'], curve['y']
                out_of_bounds = set((x, y) for x, y in zip(X, Y) 
                                    if x < 0 or x >= w or y < 0 or y >= h)
                gname_to_out_of_bounds[grid_name].update(out_of_bounds)
    return gname_to_out_of_bounds

In [9]:
def get_datasets(grid_name: str, grid_name_to_grid_path: str,
                 train_data_path: str, val_data_path: str,
                 nearest_key_candidates: tp.Set[str],
                 kb_tokenizer: KeyboardTokenizerv1,
                 word_char_tokenizer: CharLevelTokenizerv2
                 ) -> tuple[CurveDataset, CurveDataset]:
    
    gridname_to_grid  = {grid_name: get_grid(grid_name, grid_name_to_grid_path)}

    gname_to_wh = {
        gname: (grid['width'], grid['height']) 
        for gname, grid in gridname_to_grid.items()
    }
    
    print("Accumulating out-of-bounds coordinates...")
    gname_to_out_of_bounds = get_gridname_to_out_of_bounds_coords_dict(
        [train_data_path, val_data_path], gname_to_wh, total=6_000_000
    )
    
    print("Creating ExtendedNearestKeyLookups...")
    gridname_to_nkl = {
        gname: ExtendedNearestKeyLookup(grid, nearest_key_candidates, gname_to_out_of_bounds[gname])
        for gname, grid in gridname_to_grid.items()
    }
    
    
    init_transform = InitTransform(
        grid_name_to_nk_lookup=gridname_to_nkl,
        kb_tokenizer=kb_tokenizer,
    )

    get_item_transform = GetItemTransform(
        grid_name_to_wh=gname_to_wh,
        word_tokenizer=word_char_tokenizer,
        include_time=False,
        include_velocities=True,
        include_accelerations=True,
    )

    


    val_ds = CurveDataset(
        data_path=val_data_path,
        store_gnames = False,
        init_transform=init_transform,
        get_item_transform=get_item_transform,
        total = 9_416,
    )
    
    return val_ds

In [10]:
init_random_seed(RANDOM_SEED)

In [11]:
# Pickling the dataset would be great to not waste
# around 20 minutes creating train_dataset.

kb_tokenizer = KeyboardTokenizerv1()
voc_path=os.path.join(DATA_ROOT, "voc.txt")
word_char_tokenizer = CharLevelTokenizerv2(voc_path)

val_dataset = get_datasets(
    grid_name=GRID_NAME,
    grid_name_to_grid_path=os.path.join(DATA_ROOT, "gridname_to_grid.json"),
    train_data_path = GRID_NAME_TO_DS_PATHS[GRID_NAME]['train'],
    val_data_path = GRID_NAME_TO_DS_PATHS[GRID_NAME]['val'],
    nearest_key_candidates = ALL_CYRILLIC_LETTERS_ALPHABET_ORD,
    kb_tokenizer=kb_tokenizer,
    word_char_tokenizer=word_char_tokenizer,
)

Accumulating out-of-bounds coordinates...


  0%|          | 0/6000000 [00:00<?, ?it/s]

  0%|          | 0/6000000 [00:00<?, ?it/s]

Creating ExtendedNearestKeyLookups...


100%|██████████| 9416/9416 [00:01<00:00, 5503.85it/s]


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
transformer = get_m1_bigger_model(device)

In [None]:
def cross_entropy_with_reshape(pred, target, ignore_index=-100, label_smoothing=0.0):
    """
    pred - BatchSize x TargetLen x VocabSize
    target - BatchSize x TargetLen
    """
    pred_flat = pred.view(-1, pred.shape[-1])  # BatchSize*TargetLen x VocabSize
    target_flat = target.reshape(-1)  # BatchSize*TargetLen
    return F.cross_entropy(pred_flat,
                           target_flat,
                           ignore_index=ignore_index,
                           label_smoothing=label_smoothing)

In [None]:
def lr_scheduler(optimizer):
    return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                      patience=20,
                                                      factor=0.5,
                                                      verbose=True)

In [None]:
def move_all_to_device(x, device):
    if torch.is_tensor(x):
        return x.to(device)
    elif not isinstance(x, (list, tuple)):
        raise ValueError(f'Unexpected data type {type(x)}')
    new_x = []
    for el in x:
        if not torch.is_tensor(el):
            raise ValueError(f'Unexpected data type {type(el)}')
        new_x.append(el.to(device))
    return new_x

In [None]:
collate_fn = CollateFn(
    word_pad_idx = word_char_tokenizer.char_to_idx['<pad>'], batch_first = False)

In [None]:
# Протестируем корректность collate_fn (вызывается неявно в DataLoader)

batch_size = 6


PAD_CHAR_TOKEN = word_char_tokenizer.char_to_idx["<pad>"]


train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False,
                              num_workers=0, collate_fn=collate_fn)


dataset_els = [train_dataset[i] for i in range(batch_size)]
unproc_batch_x, unproc_batch_y = zip(*dataset_els)

batch_x, batch_y = next(iter(train_dataloader))


############### Проверка корректности batch_y ###################
max_out_seq_len = max([len(y) for y in unproc_batch_y])

assert batch_y.shape == (max_out_seq_len, batch_size)


for i in range(batch_size):
    assert (batch_y[:len(unproc_batch_y[i]), i] == unproc_batch_y[i]).all()
    assert (batch_y[len(unproc_batch_y[i]):, i] == PAD_CHAR_TOKEN).all()

print("batch_y is correct")



############### Проверка корректности batch_x ###################
unproc_batch_traj_feats, unproc_batch_kb_tokens, unproc_batch_dec_in_char_seq = zip(*unproc_batch_x)

(traj_feats, kb_tokens, dec_in_char_seq, traj_pad_mask, word_pad_mask) = batch_x


# каждая сущность, полученная выше из unpoc_batch_x - это tuple длины batch_size.
# Например, unproc_batch_traj_feats[i] = train_dataset[i][0][0]

N_TRAJ_FEATS = 6
max_curve_len = max([el.shape[0] for el in unproc_batch_traj_feats]) 

assert max_curve_len == max([el.shape[0] for el in unproc_batch_kb_tokens])

assert traj_feats.shape == (max_curve_len, batch_size, N_TRAJ_FEATS)
assert kb_tokens.shape == (max_curve_len, batch_size)
assert dec_in_char_seq.shape == (max_out_seq_len, batch_size)
assert traj_pad_mask.shape == (batch_size, max_curve_len)
assert word_pad_mask.shape == (batch_size, max_out_seq_len)


for i in range(batch_size):
    assert (traj_feats[:len(unproc_batch_traj_feats[i]), i] == unproc_batch_traj_feats[i]).all()
    assert (kb_tokens[:len(unproc_batch_kb_tokens[i]), i] == unproc_batch_kb_tokens[i]).all()

    assert (dec_in_char_seq[:len(unproc_batch_dec_in_char_seq[i]), i] == unproc_batch_dec_in_char_seq[i]).all()
    assert (dec_in_char_seq[len(unproc_batch_dec_in_char_seq[i]):, i] == PAD_CHAR_TOKEN).all()

    assert (traj_pad_mask[i, :len(unproc_batch_traj_feats[i])] == False).all()
    assert (traj_pad_mask[i, len(unproc_batch_traj_feats[i]):] == True).all()
    
    assert (word_pad_mask[i, :len(unproc_batch_dec_in_char_seq[i])] == False).all()
    assert (word_pad_mask[i, len(unproc_batch_dec_in_char_seq[i]):] == True).all()

print("batch_x is correct")

In [None]:
from typing import List


def predict_greedy_raw(dataset,
                       greedy_word_generator: GreedyGenerator,
                      ) -> List[List[str]]:
    """
    Creates predictions using greedy generation.

    Supposed to be used with a dataset of a single grid
    
    Arguments:
    ----------
    dataset: NeuroSwipeDatasetv2
    grid_name_to_greedy_generator: dict
        Dict mapping grid names to GreedyGenerator objects.
    """
    preds = [None] * len(dataset)

    for data in tqdm(enumerate(dataset), total=len(dataset)):
        i, ((xyt, kb_tokens, _), _) = data

        pred = greedy_word_generator.generate_word_only(xyt, kb_tokens)
        pred = pred.removeprefix("<sos>")
        preds[i] = pred

    return preds


def get_targets(dataset: CurveDataset) -> tp.List[str]:
    targets = []
    for _, target_tokens in dataset:
        # Last token is <eos>.
        target_str = word_char_tokenizer.decode(target_tokens[:-1])
        targets.append(target_str)
    return targets


def get_accuracy(preds, targets) -> float:
    return sum(pred == target for pred, target 
               in zip(preds, targets)) / len(targets)


def get_greedy_generator_accuracy(val_dataset, model, 
                                  word_char_tokenizer, device) -> float:
    val_targets = get_targets(val_dataset)
    greedy_generator = GreedyGenerator(model, word_char_tokenizer, device)
    greedy_preds = predict_greedy_raw(val_dataset, greedy_generator)
    return get_accuracy(greedy_preds, val_targets)

In [None]:
###################### протестируем predict_greedy_raw ######################


MODEL_TO_TEST_GREEDY_GEN__PATH = "../data/trained_models_for_final_submit/m1_bigger/" \
    "m1_bigger_v2__2023_11_11__14_29_37__0.13679_default_l2_0_ls0_switch_0.pt"

# Leads to super slow inference.  I think it's due to 
# high price of operations on small-amplitude floats.
# MODEL_TO_TEST_GREEDY_GEN__PATH = None


def test_greedy_generator(val_dataset, model_getter, model_weights, word_char_tokenizer, device) -> float:
    
    model = model_getter(device, model_weights)

    return get_greedy_generator_accuracy(val_dataset, model, word_char_tokenizer, device)



test_greedy_generator(val_dataset, get_m1_bigger_model, MODEL_TO_TEST_GREEDY_GEN__PATH, word_char_tokenizer, device)

In [None]:
def train_eval_loop(model, train_dataset, val_dataset, criterion,
                    tb, epoch_start, lr=1e-4, epoch_n=10, batch_size=32,
                    collate_fn = None,
                    device=None, early_stopping_patience=20, l2_reg_alpha=0,
                    max_batches_per_epoch_train=10000,
                    max_batches_per_epoch_val=1000,
                    optimizer_ctor=None,
                    lr_scheduler_ctor=None,
                    shuffle_train=True,
                    label_smoothing = 0.0,
                    dataloader_workers_n=0,
                    criterion_ignore_index = -100,
                    model_name_postfix = "",
                    model_save_root = ".",
                    ):
    """
    Цикл для обучения модели. После каждой эпохи качество модели оценивается по отложенной выборке.
    :param model: torch.nn.Module - обучаемая модель
    :param train_dataset: torch.utils.data.Dataset - данные для обучения
    :param val_dataset: torch.utils.data.Dataset - данные для оценки качества
    :param criterion: функция потерь для настройки модели
    :param lr: скорость обучения
    :param epoch_n: максимальное количество эпох
    :param batch_size: количество примеров, обрабатываемых моделью за одну итерацию
    :param device: cuda/cpu - устройство, на котором выполнять вычисления
    :param early_stopping_patience: наибольшее количество эпох, в течение которых допускается
        отсутствие улучшения модели, чтобы обучение продолжалось.
    :param l2_reg_alpha: коэффициент L2-регуляризации
    :param max_batches_per_epoch_train: максимальное количество итераций на одну эпоху обучения
    :param max_batches_per_epoch_val: максимальное количество итераций на одну эпоху валидации
    :return: кортеж из двух элементов:
        - среднее значение функции потерь на валидации на лучшей эпохе
        - лучшая модель
    """
    if device is None:
        device =  torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model.to(device)

    if optimizer_ctor is None:
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_reg_alpha)
    else:
        optimizer = optimizer_ctor(model.parameters(), lr=lr)

    if lr_scheduler_ctor is not None:
        lr_scheduler = lr_scheduler_ctor(optimizer)
    else:
        lr_scheduler = None

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle_train,
                                        num_workers=dataloader_workers_n, collate_fn=collate_fn)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                                      num_workers=dataloader_workers_n, collate_fn=collate_fn)

    best_val_loss = float('inf')
    best_epoch_i = 0

    best_model_path = "m1_bigger_v2.pt"
    
    n_train_examples_in_epoch = (batch_size * max_batches_per_epoch_train 
                                 if max_batches_per_epoch_train < len(train_dataset) // batch_size
                                 else len(train_dataset))

    if os.path.exists(best_model_path):
        best_model.load_state_dict(torch.load(best_model_path))
        print(f"Загружено состояние модели {best_model_path}")

    for epoch_i in tqdm(range(epoch_start, epoch_start + epoch_n), position = 0):
        try:
            model.train()
            mean_train_loss = 0
            train_batches_n = 0
            for batch_i, (batch_x, batch_y) in tqdm(enumerate(train_dataloader), total = min(max_batches_per_epoch_train, len(train_dataset) // batch_size), position=1, leave = False):
                if batch_i > max_batches_per_epoch_train:
                    break
                    

                batch_x, batch_y = [move_all_to_device(el, device) for el in (batch_x, batch_y)]

                pred = model(*batch_x)
                loss = criterion(pred, batch_y, ignore_index = criterion_ignore_index, 
                                 label_smoothing=label_smoothing)

                model.zero_grad()
                loss.backward()

                optimizer.step()

                mean_train_loss += float(loss)
                train_batches_n += 1

            mean_train_loss /= train_batches_n
            
            print('Среднее значение функции потерь на обучении', mean_train_loss)

            tb.add_scalar('mean_loss/train', mean_train_loss, epoch_i * n_train_examples_in_epoch)



            model.eval()
            mean_val_loss = 0
            val_batches_n = 0

            with torch.no_grad():
                for batch_i, (batch_x, batch_y) in enumerate(val_dataloader):
                    if batch_i > max_batches_per_epoch_val:
                        break

                    batch_x, batch_y = [move_all_to_device(el, device) for el in (batch_x, batch_y)]

                    pred = model(*batch_x)
                    loss = criterion(pred, batch_y, 
                                     ignore_index = criterion_ignore_index, 
                                     label_smoothing=label_smoothing)

                    mean_val_loss += float(loss)
                    val_batches_n += 1

            mean_val_loss /= val_batches_n
            print('Среднее значение функции потерь на валидации', mean_val_loss)
            tb.add_scalar('mean_loss/val', mean_val_loss, epoch_i * n_train_examples_in_epoch)

            if mean_val_loss < best_val_loss:
                best_epoch_i = epoch_i
                best_val_loss = mean_val_loss
                best_model = copy.deepcopy(model)
                torch.save(model.state_dict(), os.path.join(model_save_root, best_model_path))
                
                cur_time = "{:%Y_%m_%d__%H_%M_%S}".format(datetime.now())
                
               
                greedy_accuracy = get_greedy_generator_accuracy(val_dataset, model, word_char_tokenizer, device)
                tb.add_scalar('greedy_accuracy/val', greedy_accuracy, epoch_i * n_train_examples_in_epoch)
                
                torch.save(model.state_dict(), os.path.join(model_save_root, f"m1_bigger_v2__{cur_time}__{mean_val_loss:.5f}__greed_acc_{greedy_accuracy:.5f}__{model_name_postfix}__epoch_i_{epoch_i}.pt"))
                print(f"Greedy accuracy = {greedy_accuracy}")
                print('Новая лучшая модель!')
            elif epoch_i - best_epoch_i > early_stopping_patience:
                print('Модель не улучшилась за последние {} эпох, прекращаем обучение'.format(
                    early_stopping_patience))
                break

            if lr_scheduler is not None:
                lr_scheduler.step(mean_val_loss)

            print()
        except KeyboardInterrupt:
            print('Досрочно остановлено пользователем')
            break
        except Exception as ex:
            print('Ошибка при обучении: {}\n{}'.format(ex, traceback.format_exc()))
            break

In [None]:
EXPERIMENT_NAME = f"m1_bigger_model__{GRID_NAME}__from_random_weights__batch__{BATCH_SIZE}/SEED_{RANDOM_SEED}__run1"
TENSORBOARD_LOG_PATH = f"/kaggle/working/tensorboard_log/{EXPERIMENT_NAME}"

tb = SummaryWriter(TENSORBOARD_LOG_PATH)


In [None]:
l2_reg_alpha =  0 #5e-5
label_smoothing=  0 #0.045
epoch_start = 0

best_val_loss, best_model = train_eval_loop(
    transformer, train_dataset, val_dataset, cross_entropy_with_reshape, tb, epoch_start,
    lr=1e-4, epoch_n=10000, batch_size=BATCH_SIZE, collate_fn = collate_fn,
    device=device, early_stopping_patience=20, l2_reg_alpha=l2_reg_alpha,
    max_batches_per_epoch_train=2000,
    max_batches_per_epoch_val=1000,
    optimizer_ctor=None,
    lr_scheduler_ctor=lr_scheduler,
    shuffle_train=True,
    dataloader_workers_n=0,
    criterion_ignore_index = word_char_tokenizer.char_to_idx['<pad>'],
    model_name_postfix = f'{GRID_NAME}_l2_{l2_reg_alpha}_ls{label_smoothing}',
    model_save_root = "../..",
    label_smoothing=label_smoothing,
)

In [None]:
# Эпоха должна длиться 16 минут

In [12]:
import json
from typing import Optional, List, Tuple, Dict, Set
import array

import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence

from tokenizers import CharLevelTokenizerv2, KeyboardTokenizerv1


class NeuroSwipeDatasetv3(Dataset):
    """
    Dataset class for NeuroSwipe dataset.

    The dataset uses all data from a given json in the same order as the json.
    There are separate json files for every grid.

    Given a NeuroSwipeDatasetv3 object nsd, nsd[i] is a tuple:
    ((trajectory_features, k_key_tokens, decoder_in_char_seq), decoder_out_char_seq)
    
    ! Warning: refactoring planned so that the output is a dictionary.

    WARNING:
    The class is in the process of refactoring. The padding will be done
    in DataLoaders collate_fn instead of __getitem__ method.
    Currently traj_pad_mask was removed. Word mask will be removed later.
    So nsd[i] = (
        (trajectory_features, k_key_tokens, decoder_in_char_seq, word_pad_mask), decoder_out_char_seq)

    ! It seems reasonable for the dataset to always return grid_name as a
    dict property. We just won't use it in collate function.
    """

    def __init__(self,
                 data_path: str,
                 gridname_to_grid: dict,
                 kb_tokenizer: KeyboardTokenizerv1,
                 word_tokenizer: CharLevelTokenizerv2,  # should contain max word len
                 include_time: bool = False,
                 include_velocities: bool = True,
                 include_accelerations: bool = True,
                 include_grid_name: bool = False,
                 has_target: bool = True,
                 has_one_grid_only: bool = True,
                 keyboard_selection_set: Optional[Set[str]] = None,
                 total: Optional[int] = None):
        """
        Arguments:
        ----------
        data_path: str
            Path to the NeuroSwipe dataset in JSON format.
            A custom version of the dataset is used: "grid" property
            is replaced with "grid_name". The grid itself is stored in
            a separate gridname_to_grid dictionary.
            Dataset is a list of JSON lines. Each line is a dictionary
            with the following properties:
            - word (str): word that was typed. May be absent if has_target is False.
            - curve (dict): dictionary that contains the following properties:
                - x (List[int]): x coordinates of the swipe trajectory.
                - y (List[int]): y coordinates of the swipe trajectory.
                - t (List[int]): time in milliseconds from the beginning of the swipe.
                - grid_name (str): name of the keyboard grid.

        gridname_to_grid: dict
            Dictionary that maps grid_name to grid.
            Grid is a dictionary that contains the following properties:
                - width (int): width of the keyboard in pixels.
                - height (int): height of the keyboard in pixels.
                - keys (List[dict]): list of keys. Each key is a dictionary
                    that contains the following properties:
                    - label (str): label of the key. May be absent if the key
                        is a special key (e.g. backspace).
                    - action (str): action of the key. May be absent if the key
                        is a character key (e.g. 'a', 'б', 'в').
                    - hitbox (dict): dictionary that contains the following properties:
                        - x (int): x coordinate of the top left corner of the key.
                        - y (int): y coordinate of the top left corner of the key.
                        - w (int): width of the key.
                        - h (int): height of the key.
            
        
        keyboard_selection_set: Optional[Set[str]]
            Set of keyboard key labels allowed. When looking
            for a key with the nearest to to trajectory point
            center coordinates we only consider keys with labels
            from this set.
            If None, all keys are allowed.
            Isn't used explicitly: only in is_allowed_label method.

        
        total: Optional[int]
            Number of dataset elements. Is used only for progress bar.

        """
        if include_accelerations and not include_velocities:
            raise ValueError("Accelerations are supposed \
                             to be an addition to velocities. Add velocities.")
        
        if has_one_grid_only and len(gridname_to_grid) != 1:
            raise ValueError(f"has_one_grid_only is True \
                             but len(gridname_to_grid) != 1")

        self.include_velocities = include_velocities
        self.include_accelerations = include_accelerations
        self.include_time = include_time
        self.has_target = has_target
        self.include_grid_name = include_grid_name
        self._keyboard_selection_set = keyboard_selection_set

        self.word_tokenizer = word_tokenizer

        self._grid_name_to_grid = gridname_to_grid

        self._nearest_kb_label_dict = (
            self._create_nearest_kb_label_dict(gridname_to_grid))

        self.data_list = []
        self._set_data(data_path, gridname_to_grid,
                       kb_tokenizer, self.data_list, total = total)


    def is_allowed_label(self, label: str) -> bool:
        if self._keyboard_selection_set is None:
            return True
        return label in self._keyboard_selection_set


    def get_nearest_kb_label(self, x, y, grid_name, gridname_to_grid):
        """
        Given coords on a keyboard (x, y) and its grid_name returns the nearest keyboard key

        By default it uses an array assosiated with grid_name
        that stores the nearest key label for every possible coord pair.

        If coords are outside of the keyboard boarders finds
        the nearest key by iterating over all keys.
        """        
        grid = gridname_to_grid[grid_name]
        if x < 0 or x >= grid['width'] or y < 0 or y >= grid['height']:
            return self._get_kb_label_without_map(x, y, grid)
        else:
            return self._nearest_kb_label_dict[grid_name][x, y]
    

    def _get_key_center(self, hitbox: Dict[str, int]) -> Tuple[int, int]:
        x = hitbox['x'] + hitbox['w'] / 2
        y = hitbox['y'] + hitbox['h'] / 2
        return x, y
    
    def _get_kb_label(self, key: dict) -> str:
        if 'label' in key:
            return key['label']
        if 'action' in key:
            return key['action']
        raise ValueError("Key has no label or action property")


    def _get_kb_label_without_map(self, x, y, grid: dict) -> str:
        """
        Returns label of the nearest key on the keyboard without using a map.
         
        Iterates over all keys and calculates the
        distance to (x, y) to find the nearest one.
        """
        nearest_kb_label = None
        min_dist = float("inf")

        for key in grid['keys']:
            label = self._get_kb_label(key)
            
            if not self.is_allowed_label(label):
                continue

            key_x, key_y = self._get_key_center(key['hitbox'])
            dist = (x - key_x)**2 + (y - key_y)**2
            if dist < min_dist:
                min_dist = dist
                nearest_kb_label = label 
        return nearest_kb_label


    def _create_nearest_kb_label_dict(self, gridname_to_grid: dict
                                   ) -> Dict[str, np.array]:
        """
        Creates a dict that maps grid_name to a map (np.array)
        from coordinates [x, y] to nearest key label.
        """
        nearest_kb_label_dict = {}
        for grid_name, grid in gridname_to_grid.items():
            nearest_kb_label_dict[grid_name] = self._get_coord_to_kb_label(grid)
        return nearest_kb_label_dict
    

    def _get_coord_to_kb_label(self, grid: dict) -> np.array: # dtype = object
        coord_to_kb_label = np.zeros(
            (grid['width'], grid['height']), dtype=object)  # 1080 x 640 in our case
        coord_to_kb_label.fill('')

        for key in grid['keys']:
            label = self._get_kb_label(key)

            if not self.is_allowed_label(label):
                continue

            x_left = key['hitbox']['x']
            x_right = x_left + key['hitbox']['w']
            y_top = key['hitbox']['y']
            y_bottom = y_top + key['hitbox']['h']

            coord_to_kb_label[x_left:x_right, y_top:y_bottom] = label

        for x in range(grid['width']):
            for y in range(grid['height']):
                if coord_to_kb_label[x, y] != '':
                    continue
                coord_to_kb_label[x, y] = self._get_kb_label_without_map(x, y, grid)

        return coord_to_kb_label
            

    def _set_data(self,
                  data_path: str,
                  gridname_to_grid: dict,
                  kb_tokenizer,
                  data_list: list,
                  total: Optional[int] = None):
        with open(data_path, "r", encoding="utf-8") as json_file:
            for line in tqdm(json_file, total = total):
                data_list.append(self._get_data_from_json_line(line, gridname_to_grid, kb_tokenizer))


    def _get_dx_dt(self,
                   X: torch.tensor,
                   T: torch.tensor) -> List[float]:
        """
        Calculates dx/dt for a list of x coordinates and a list of t coordinates.

        Arguments:
        ----------
        X : torch.tensor
            x (position) coordinates.
        T : torch.tensor
            T[i] = time from the beginning of the swipe corresponding to X[i].
        len : int
            Length of the swipe trajectory. Indexes greater than len are ignored.
        """
        dx_dt = torch.zeros_like(X)
        # dx_dt[1:-1] = (X[2:] - X[:-2]) / (T[2:] - T[:-2])
        dx_dt[1:len(X)-1] = (X[2:len(X)] - X[:len(X)-2]) / (T[2:len(X)] - T[:len(X)-2])

        # Example:
        # x0 x1 x2 x3
        # t0 t1 t2 t3
        # dx_dt[0] = 0
        # dx_dt[1] = (x2 - x0) / (t2 - t0)
        # dx_dt[2] = (x3 - x1) / (t3 - t1)
        # dx_dt[3] = 0


        # if True in torch.isnan(dx_dt):
        #     print(dx_dt)
        #     raise ValueError("dx_dt contains NaNs")

        return dx_dt
    

    def _get_data_from_json_line(self,
                                 line,
                                 gridname_to_grid,
                                 kb_tokenizer) -> Tuple[list, list, list, str]:
        """
        Parses a JSON line and returns a dictionary with data.
        """
        data = json.loads(line)

        X = array.array('h', data['curve']['x'])
        Y = array.array('h', data['curve']['y'])
        T = array.array('h', data['curve']['t'])

        grid_name = data['curve']['grid_name']   

        kb_labels = [self.get_nearest_kb_label(x, y, grid_name, gridname_to_grid) for x, y in zip(X, Y)]
        kb_tokens = [kb_tokenizer.get_token(label) for label in kb_labels]
        kb_tokens = array.array('h', kb_tokens)

        if not self.has_target:
            return X, Y, T, kb_tokens, grid_name
        else:
            word: str = data['word']
            return X, Y, T, kb_tokens, word, grid_name


    def __len__(self):
        return len(self.data_list)
    

    def __getitem__(self, idx):
        if self.has_target:
            X_list, Y_list, T_list, kb_tokens, word, grid_name = self.data_list[idx]
        else:
            X_list, Y_list, T_list, kb_tokens, grid_name = self.data_list[idx]

        X = torch.tensor(X_list, dtype=torch.float32)
        Y = torch.tensor(Y_list, dtype=torch.float32)
        T = torch.tensor(T_list, dtype=torch.float32)

        xyt = torch.cat(
            (
                X.reshape(-1, 1),
                Y.reshape(-1, 1),
            ),
            axis = 1
        )

        if self.include_time:
            xyt = torch.cat(
                (
                    xyt,
                    T.reshape(-1, 1)
                ),
                axis = 1
            )

        if self.include_velocities:
            dx_dt = self._get_dx_dt(X, T)
            dy_dt = self._get_dx_dt(Y, T)
            xyt = torch.cat(
                [
                    xyt,
                    dx_dt.reshape(-1, 1),
                    dy_dt.reshape(-1, 1)
                ],
                axis = 1
            )

        if self.include_accelerations:
            d2x_dt2 = self._get_dx_dt(dx_dt, T)
            d2y_dt2 = self._get_dx_dt(dy_dt, T)
            xyt = torch.cat(
                [
                    xyt,
                    d2x_dt2.reshape(-1, 1),
                    d2y_dt2.reshape(-1, 1)
                ],
                axis = 1
            )
        
        
        grid = self._grid_name_to_grid[grid_name]
        xyt[:len(X_list), 0] = xyt[:len(X_list), 0] / grid['width'] 
        xyt[:len(Y_list), 1] = xyt[:len(X_list), 1] / grid['height']
        # Switch to this:
        # xyt[:, 0] = xyt[:, 0] / grid['width'] 
        # xyt[:, 1] = xyt[:, 1] / grid['height']

        kb_tokens = torch.tensor(kb_tokens, dtype=torch.int64)

        decoder_out_char_seq = None
        decoder_in_char_seq = None
        word_mask = None

        if self.has_target:
            # <sos>, token1, token2, ... token_n, <eos>
            token_seq: List[int] = self.word_tokenizer.encode(word)
            token_seq = torch.tensor(token_seq, dtype = torch.int64)

            # model inputs and outputs are one token smaller than max_word,
            # Model inputs: <sos>, token1, ... token_n, <pad_0>, <pad_1>, ... <pad_k>
            # Model outputs:       token1, ... token_n, <EOS!>,  <pad_1>, ... <pad_k>
            decoder_seq_len = self.word_tokenizer.max_word_len - 1

            
            word_mask = torch.ones(decoder_seq_len, dtype=torch.bool)
           
            # <sos> and full word are not masked;
            # <eos> and all <pad> are masked.
            word_mask[:len(word) + 1] = False 
            
            # <sos>, token1, ... token_n
            decoder_in_char_seq = torch.full(
                (decoder_seq_len,),
                self.word_tokenizer.char_to_idx['<pad>'],
                dtype=torch.int64)
            decoder_in_char_seq[:len(word) + 1] = token_seq[:-1]

            # token1, ... token_n, <eos>
            decoder_out_char_seq = torch.full(
                (decoder_seq_len,),
                self.word_tokenizer.char_to_idx['<pad>'],
                dtype=torch.int64)
            decoder_out_char_seq[:len(word) + 1] = token_seq[1:]
        
        if self.include_grid_name:
            return (xyt, kb_tokens, decoder_in_char_seq, word_mask), decoder_out_char_seq, grid_name
        
        return (xyt, kb_tokens, decoder_in_char_seq, word_mask), decoder_out_char_seq



class NeuroSwipeGridSubset(Dataset):
    def __init__(self, dataset: Dataset, grid_name: str):
        self.dataset = dataset
        self.grid_name = grid_name
        self.grid_name_idxs = self._get_grid_name_idxs()
        
            
    def _get_grid_name_idxs(self):
        grid_name_idxs: list[int] = []
        for i, (x, y, grid_name) in enumerate(self.dataset):
            if grid_name == self.grid_name:
                grid_name_idxs.append(i)
        return grid_name_idxs

    
    def __len__(self):
        return len(self.grid_name_idxs)
    
    def __getitem__(self, idx):
        return self.dataset[self.grid_name_idxs[idx]]
    


def collate_fn(batch: list):
    """
    batch - list of tuples:
    ((traj_feats, kb_tokens, dec_in_char_seq, word_pad_mask), dec_out_char_seq)
    """
    x, dec_out_char_seq = zip(*batch)
    (traj_feats_no_pad, kb_tokens_no_pad,
     dec_in_char_seq, word_pad_mask) = zip(*x)

    # traj_feats[i].shape = (curve_len, n_coord_feats)
    traj_feats = pad_sequence(traj_feats_no_pad, batch_first=False)  # (curves_len, batch_size, n_coord_feats)
    # kb_tokens[i].shape = (curve_len,) 
    kb_tokens = pad_sequence(kb_tokens_no_pad, batch_first=False)  # (curves_len, batch_size)
    
    dec_in_char_seq = torch.stack(dec_in_char_seq).transpose_(0, 1)  # (chars_seq_len - 1, batch_size)
    dec_out_char_seq = torch.stack(dec_out_char_seq).transpose_(0, 1)  # (chars_seq_len - 1, batch_size)
    word_pad_mask = torch.stack(word_pad_mask)
    

    max_curve_len = traj_feats.shape[0]

    traj_lens = torch.tensor([len(x) for x in traj_feats_no_pad])

    # Берем матрицу c len(traj_lens) строками вида [0, 1, ... , max_curve_len - 1].
    # Каждый элемент i-ой строки сравниваем с длиной i-ой траектории.  Получится
    # матрица, где True только на позициях, больших, чем длина соответствующей траектории.
    # (batch_size, max_curve_len)    
    traj_pad_mask = torch.arange(max_curve_len).expand(len(traj_lens), max_curve_len) >= traj_lens.unsqueeze(1)

    return (traj_feats, kb_tokens, dec_in_char_seq, traj_pad_mask, word_pad_mask), dec_out_char_seq

In [14]:
def get_grid(grid_name: str, grids_path: str) -> dict:
    with open(grids_path, "r", encoding="utf-8") as f:
        return json.load(f)[grid_name]

gridname_to_grid  = {GRID_NAME: get_grid(GRID_NAME, os.path.join(DATA_ROOT, "gridname_to_grid.json"))}

DS_KWARGS = dict(
    include_time = False,
    include_velocities = True,
    include_accelerations = True,
    has_target=True,
    has_one_grid_only=True,
    include_grid_name=False,
    keyboard_selection_set=set(ALL_CYRILLIC_LETTERS_ALPHABET_ORD)
)


val_ds_old = NeuroSwipeDatasetv3(
        data_path=GRID_NAME_TO_DS_PATHS[GRID_NAME]['val'],
        gridname_to_grid = gridname_to_grid,
        kb_tokenizer=kb_tokenizer,
        word_tokenizer = word_char_tokenizer,
        total = 9_416,
        **DS_KWARGS
    )

100%|██████████| 9416/9416 [00:01<00:00, 6017.32it/s]


In [26]:
for el1, el2 in tqdm(zip(val_dataset, val_ds_old)):
    (traj_feats, kb_tokens, dec_in), dec_out = el1
    (tf_o, kb_o, di_o, _), do_o = el2

    assert torch.equal(traj_feats, tf_o)
    assert torch.equal(kb_tokens , kb_o)
    assert torch.equal(dec_in, di_o[:sum(di_o!=35)])
    assert torch.equal(dec_out, do_o[:sum(do_o!=35)])

9416it [00:42, 221.74it/s]


In [16]:
dec_in

tensor([36, 14,  1])

tensor([36, 14,  1])