In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# !git clone https://github.com/proshian/yandex-cup-2023-ml-neuroswipe.git
# %cd yandex-cup-2023-ml-neuroswipe
# ! git checkout remake_after_finals 

In [7]:
# !pip install dvc --quiet
# !pip install dvc_gdrive --quiet

In [None]:
# %cd /kaggle/working/yandex-cup-2023-ml-neuroswipe/src

In [10]:
import os
import json
import typing as tp
import traceback
from datetime import datetime
import copy

import torch
# import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import numpy as np

from model import SwipeCurveTransformer, get_m1_bigger_model
from tokenizers import CharLevelTokenizerv2, KeyboardTokenizerv1
from tokenizers import ALL_CYRILLIC_LETTERS_ALPHABET_ORD
from dataset import NeuroSwipeDatasetv3, collate_fn
from word_generators import GreedyGenerator
from utils import prepare_batch


In [11]:
############# Script arguments emulation #############

GRID_NAME = "extra"
BATCH_SIZE = 320
IN_KAGGLE = False
RANDOM_SEED = 12

# keyboard_selection_set is a set of labels of keys that will be
# considered when finding the nearest keyboard label.  In this
# case only cyrillic letters are considered. This means that
# all action keys like "shift", "backspace", "enter", etc. are excluded.
DS_KWARGS = dict(
    include_time = False,
    include_velocities = True,
    include_accelerations = True,
    has_target=True,
    has_one_grid_only=True,
    include_grid_name=False,
    keyboard_selection_set=set(ALL_CYRILLIC_LETTERS_ALPHABET_ORD)
)

DATA_ROOT = "../data/data_separated_grid"
MODELS_DIR = "../data/trained_models/m1"

In [12]:
################ Other constants ####################
GRID_NAME_TO_DS_PATHS = {
    "extra": {
        "train": os.path.join(DATA_ROOT, "train__extra_only_no_errors__2023_11_01__19_49_14.jsonl"),
        "val": os.path.join(DATA_ROOT, "valid__in_train_format__extra_only.jsonl")
    },
    "default": {
        "train": os.path.join(DATA_ROOT, "train__default_only_no_errors__2023_10_31__03_26_16.jsonl"),
        "val": os.path.join(DATA_ROOT, "valid__in_train_format__default_only.jsonl")
    }
}

In [13]:
# if IN_KAGGLE:
#     DATA_ROOT = "/kaggle/input/neuroswipe-defualt-only-v1"
#     MODELS_DIR = ""

In [14]:
def init_random_seed(value=42):
    # random.seed(value)
    np.random.seed(value)
    torch.manual_seed(value)
    torch.cuda.manual_seed(value)
    # torch.backends.cudnn.deterministic = True

In [15]:
def get_grid(grid_name: str, grids_path: str) -> dict:
    with open(grids_path, "r", encoding="utf-8") as f:
        return json.load(f)[grid_name]

In [16]:
def get_datasets(grid_name: str, grid_name_to_grid_path: str,
                 train_data_path: str, val_data_path: str,
                 ds_kwargs: dict, kb_tokenizer: KeyboardTokenizerv1,
                 word_char_tokenizer: CharLevelTokenizerv2
                 ) -> tuple[NeuroSwipeDatasetv3, NeuroSwipeDatasetv3]:
    
    gridname_to_grid  = {grid_name: get_grid(grid_name, grid_name_to_grid_path)}

    train_ds = NeuroSwipeDatasetv3(
        data_path=train_data_path,
        gridname_to_grid = gridname_to_grid,
        kb_tokenizer=kb_tokenizer,
        word_tokenizer =word_char_tokenizer,
        total = 349_172,
        **ds_kwargs
    )

    val_ds = NeuroSwipeDatasetv3(
        data_path=val_data_path,
        gridname_to_grid =gridname_to_grid,
        kb_tokenizer=kb_tokenizer,
        word_tokenizer =word_char_tokenizer,
        total = 10_000,
        **ds_kwargs
    )

    return train_ds, val_ds

In [17]:
init_random_seed(RANDOM_SEED)

In [18]:
# Pickling the dataset would be great to not waste
# around 20 minutes creating train_dataset.

kb_tokenizer = KeyboardTokenizerv1()
voc_path=os.path.join(DATA_ROOT, "voc.txt")
word_char_tokenizer = CharLevelTokenizerv2(voc_path)

train_dataset, val_dataset = get_datasets(
    grid_name=GRID_NAME,
    grid_name_to_grid_path=os.path.join(DATA_ROOT, "gridname_to_grid.json"),
    train_data_path = GRID_NAME_TO_DS_PATHS[GRID_NAME]['train'],
    val_data_path = GRID_NAME_TO_DS_PATHS[GRID_NAME]['val'],
    ds_kwargs=DS_KWARGS,
    kb_tokenizer=kb_tokenizer,
    word_char_tokenizer=word_char_tokenizer,
)

100%|██████████| 349172/349172 [00:43<00:00, 7947.14it/s]
  6%|▌         | 584/10000 [00:00<00:01, 8698.26it/s]


In [19]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [20]:
transformer = get_m1_bigger_model(device)



In [21]:
def cross_entropy_with_reshape(pred, target, ignore_index=-100, label_smoothing=0.0):
    """
    pred - BatchSize x TargetLen x VocabSize
    target - BatchSize x TargetLen
    """
    pred_flat = pred.view(-1, pred.shape[-1])  # BatchSize*TargetLen x VocabSize
    target_flat = target.reshape(-1)  # BatchSize*TargetLen
    return F.cross_entropy(pred_flat,
                           target_flat,
                           ignore_index=ignore_index,
                           label_smoothing=label_smoothing)

In [22]:
def lr_scheduler(optimizer):
    return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                      patience=20,
                                                      factor=0.5,
                                                      verbose=True)

In [23]:
def move_all_to_device(x, device):
    if torch.is_tensor(x):
        return x.to(device)
    elif not isinstance(x, (list, tuple)):
        raise ValueError(f'Unexpected data type {type(x)}')
    new_x = []
    for el in x:
        if not torch.is_tensor(el):
            raise ValueError(f'Unexpected data type {type(el)}')
        new_x.append(el.to(device))
    return new_x

In [24]:
# Протестируем корректность collate_fn (вызывается неявно в DataLoader)

batch_size = 6


PAD_CHAR_TOKEN = word_char_tokenizer.char_to_idx["<pad>"]


train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False,
                                        num_workers=0, collate_fn=collate_fn)


dataset_els = [train_dataset[i] for i in range(batch_size)]
unproc_batch_x, unproc_batch_y = zip(*dataset_els)

batch_x, batch_y = next(iter(train_dataloader))


############### Проверка корректности batch_y ###################
max_out_seq_len = max([len(y) for y in unproc_batch_y])

assert batch_y.shape == (max_out_seq_len, batch_size)


for i in range(batch_size):
    assert (batch_y[:len(unproc_batch_y[i]), i] == unproc_batch_y[i]).all()
    assert (batch_y[len(unproc_batch_y[i]):, i] == PAD_CHAR_TOKEN).all()

print("batch_y is correct")



############### Проверка корректности batch_x ###################
unproc_batch_traj_feats, unproc_batch_kb_tokens, unproc_batch_dec_in_char_seq, \
    unproc_batch_word_pad_mask = zip(*unproc_batch_x)

(traj_feats, kb_tokens, dec_in_char_seq, traj_pad_mask, word_pad_mask) = batch_x


# каждая сущность, полученная выше из unpoc_batch_x - это tuple длины batch_size.
# Например, unproc_batch_traj_feats[i] = train_dataset[i][0][0]

N_TRAJ_FEATS = 6
max_curve_len = max([el.shape[0] for el in unproc_batch_traj_feats]) 

assert max_curve_len == max([el.shape[0] for el in unproc_batch_kb_tokens])

assert traj_feats.shape == (max_curve_len, batch_size, N_TRAJ_FEATS)
assert kb_tokens.shape == (max_curve_len, batch_size)
assert dec_in_char_seq.shape == (max_out_seq_len, batch_size)
assert traj_pad_mask.shape == (batch_size, max_curve_len)
assert word_pad_mask.shape == (batch_size, max_out_seq_len)


for i in range(batch_size):
    assert (traj_feats[:len(unproc_batch_traj_feats[i]), i] == unproc_batch_traj_feats[i]).all()
    assert (kb_tokens[:len(unproc_batch_kb_tokens[i]), i] == unproc_batch_kb_tokens[i]).all()

    assert (dec_in_char_seq[:len(unproc_batch_dec_in_char_seq[i]), i] == unproc_batch_dec_in_char_seq[i]).all()
    assert (dec_in_char_seq[len(unproc_batch_dec_in_char_seq[i]):, i] == PAD_CHAR_TOKEN).all()

    assert (traj_pad_mask[i, :len(unproc_batch_traj_feats[i])] == False).all()
    assert (traj_pad_mask[i, len(unproc_batch_traj_feats[i]):] == True).all()
    
    assert (word_pad_mask[i, :len(unproc_batch_dec_in_char_seq[i])] == unproc_batch_word_pad_mask[i]).all()

print("batch_x is correct")

batch_y is correct
batch_x is correct


In [26]:
from typing import List

def predict_greedy_raw(dataset,
                       greedy_word_generator: GreedyGenerator,
                      ) -> List[List[str]]:
    """
    Creates predictions using greedy generation.

    Supposed to be used with a dataset of a single grid
    
    Arguments:
    ----------
    dataset: NeuroSwipeDatasetv2
    grid_name_to_greedy_generator: dict
        Dict mapping grid names to GreedyGenerator objects.
    """
    preds = [None] * len(dataset)

    for data in tqdm(enumerate(dataset), total=len(dataset)):
#         i, ((xyt, kb_tokens, _, traj_pad_mask, _), _, grid_name) = data
        i, ((xyt, kb_tokens, _, _), _) = data

        pred = greedy_word_generator.generate_word_only(xyt, kb_tokens)
        pred = pred.removeprefix("<sos>")
#         preds[i] = [pred]
        preds[i] = pred

    return preds

In [31]:
###################### протестируем predict_greedy_raw ######################

MODEL_TO_TEST_GREEDY_GEN__PATH = "/kaggle/input/m1-bigger-v2-0-13413-extra-l2-0-ls0-switch-1-pt/" \
    "m1_bigger_v2__2023_11_12__02_27_14__0.13413_extra_l2_0_ls0_switch_1.pt"


def get_targets(dataset: NeuroSwipeDatasetv3) -> tp.List[str]:
    targets = []
    for (_, _, _, word_pad_mask), target_tokens in dataset:
        target_len = int(torch.sum(~word_pad_mask)) - 1
        target = word_char_tokenizer.decode(target_tokens[:target_len])
        targets.append(target)
    return targets


def test_greedy_generator(val_dataset, model_getter, word_char_tokenizer, device) -> float:
    val_targets = get_targets(val_dataset)

    model = model_getter(device, MODEL_TO_TEST_GREEDY_GEN__PATH)


    greedy_generator = GreedyGenerator(model, word_char_tokenizer, device)

    greedy_predictions = predict_greedy_raw(val_dataset, greedy_generator)

    accuracy = sum(greedy_prediction == val_target 
                   for greedy_prediction, val_target 
                   in zip(greedy_predictions, val_targets)) / len(val_targets)
    
    return accuracy



test_greedy_generator(val_dataset, get_m1_bigger_model, word_char_tokenizer, device)

  0%|          | 0/584 [00:00<?, ?it/s]

0.8561643835616438

In [35]:
def train_eval_loop(model, train_dataset, val_dataset, criterion,
                    lr=1e-4, epoch_n=10, batch_size=32,
                    collate_fn = None,
                    device=None, early_stopping_patience=10, l2_reg_alpha=0,
                    max_batches_per_epoch_train=10000,
                    max_batches_per_epoch_val=1000,
                    optimizer_ctor=None,
                    lr_scheduler_ctor=None,
                    shuffle_train=True,
                    label_smoothing = 0.0,
                    dataloader_workers_n=0,
                    criterion_ignore_index = -100,
                    model_name_postfix = "",
                    model_save_root = "."
                    ):
    """
    Цикл для обучения модели. После каждой эпохи качество модели оценивается по отложенной выборке.
    :param model: torch.nn.Module - обучаемая модель
    :param train_dataset: torch.utils.data.Dataset - данные для обучения
    :param val_dataset: torch.utils.data.Dataset - данные для оценки качества
    :param criterion: функция потерь для настройки модели
    :param lr: скорость обучения
    :param epoch_n: максимальное количество эпох
    :param batch_size: количество примеров, обрабатываемых моделью за одну итерацию
    :param device: cuda/cpu - устройство, на котором выполнять вычисления
    :param early_stopping_patience: наибольшее количество эпох, в течение которых допускается
        отсутствие улучшения модели, чтобы обучение продолжалось.
    :param l2_reg_alpha: коэффициент L2-регуляризации
    :param max_batches_per_epoch_train: максимальное количество итераций на одну эпоху обучения
    :param max_batches_per_epoch_val: максимальное количество итераций на одну эпоху валидации
    :return: кортеж из двух элементов:
        - среднее значение функции потерь на валидации на лучшей эпохе
        - лучшая модель
    """
    if device is None:
        device =  torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model.to(device)

    if optimizer_ctor is None:
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_reg_alpha)
    else:
        optimizer = optimizer_ctor(model.parameters(), lr=lr)

    if lr_scheduler_ctor is not None:
        lr_scheduler = lr_scheduler_ctor(optimizer)
    else:
        lr_scheduler = None

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle_train,
                                        num_workers=dataloader_workers_n, collate_fn=collate_fn)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                                      num_workers=dataloader_workers_n, collate_fn=collate_fn)

    best_val_loss = float('inf')
    best_epoch_i = 0

    best_model_path = "m1_bigger_v2.pt"
    best_model = copy.deepcopy(model)

    if os.path.exists(best_model_path):
        best_model.load_state_dict(torch.load(best_model_path))
        print(f"Загружено состояние модели {best_model_path}")

    for epoch_i in tqdm(range(epoch_n), position = 0):
        try:
            model.train()
            mean_train_loss = 0
            train_batches_n = 0
            for batch_i, (batch_x, batch_y) in tqdm(enumerate(train_dataloader), total = min(max_batches_per_epoch_train, len(train_dataset) // batch_size), position=1, leave = False):
                if batch_i > max_batches_per_epoch_train:
                    break
                    

                batch_x, batch_y = [move_all_to_device(el, device) for el in (batch_x, batch_y)]

                pred = model(*batch_x)
                loss = criterion(pred, batch_y, ignore_index = criterion_ignore_index, label_smoothing=label_smoothing)

                model.zero_grad()
                loss.backward()

                optimizer.step()

                mean_train_loss += float(loss)
                train_batches_n += 1

            mean_train_loss /= train_batches_n
            
            print('Среднее значение функции потерь на обучении', mean_train_loss)



            model.eval()
            mean_val_loss = 0
            val_batches_n = 0

            with torch.no_grad():
                for batch_i, (batch_x, batch_y) in enumerate(val_dataloader):
                    if batch_i > max_batches_per_epoch_val:
                        break

                    batch_x, batch_y = [move_all_to_device(el, device) for el in (batch_x, batch_y)]

                    pred = model(*batch_x)
                    loss = criterion(pred, batch_y, 
                                     ignore_index = criterion_ignore_index, 
                                     label_smoothing=label_smoothing)

                    mean_val_loss += float(loss)
                    val_batches_n += 1

            mean_val_loss /= val_batches_n
            print('Среднее значение функции потерь на валидации', mean_val_loss)

            if mean_val_loss < best_val_loss:
                best_epoch_i = epoch_i
                best_val_loss = mean_val_loss
                best_model = copy.deepcopy(model)
                torch.save(model.state_dict(), os.path.join(model_save_root, best_model_path))
                
                cur_time = "{:%Y_%m_%d__%H_%M_%S}".format(datetime.now())
                
                greedy_word_generator = GreedyGenerator(model, word_char_tokenizer, device)
                greedy_predictions = predict_greedy_raw(val_dataset, greedy_word_generator)
                greedy_accuracy = sum(greedy_prediction == val_target for greedy_prediction, val_target in zip(greedy_predictions, val_targets)) / len(val_targets)
                
                torch.save(model.state_dict(), os.path.join(model_save_root, f"m1_bigger_v2__{cur_time}__{mean_val_loss:.5f}__greed_acc_{greedy_accuracy:.5f}__{model_name_postfix}.pt"))
                print(f"Greedy accuracy = {greedy_accuracy}")
                print('Новая лучшая модель!')
            elif epoch_i - best_epoch_i > early_stopping_patience:
                print('Модель не улучшилась за последние {} эпох, прекращаем обучение'.format(
                    early_stopping_patience))
                break

            if lr_scheduler is not None:
                lr_scheduler.step(mean_val_loss)

            print()
        except KeyboardInterrupt:
            print('Досрочно остановлено пользователем')
            break
        except Exception as ex:
            print('Ошибка при обучении: {}\n{}'.format(ex, traceback.format_exc()))
            break

    return best_val_loss, best_model


In [None]:
l2_reg_alpha =  0 #5e-5
label_smoothing=  0 #0.045

best_val_loss, best_model = train_eval_loop(
    transformer, train_dataset, val_dataset, cross_entropy_with_reshape,
    lr=1e-4, epoch_n=10000, batch_size=BATCH_SIZE, collate_fn = collate_fn,
    device=device, early_stopping_patience=10, l2_reg_alpha=l2_reg_alpha,
    max_batches_per_epoch_train=2000,
    max_batches_per_epoch_val=1000,
    optimizer_ctor=None,
    lr_scheduler_ctor=lr_scheduler,
    shuffle_train=True,
    dataloader_workers_n=0,
    criterion_ignore_index = word_char_tokenizer.char_to_idx['<pad>'],
    model_name_postfix = f'{GRID_NAME}_l2_{l2_reg_alpha}_ls{label_smoothing}_switch_2',
    model_save_root = "../..",
    label_smoothing=label_smoothing
)