In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
# !git clone https://github.com/proshian/yandex-cup-2023-ml-neuroswipe.git
# %cd yandex-cup-2023-ml-neuroswipe
# ! git checkout datasetv4

In [None]:
# !pip install dvc --quiet
# !pip install dvc_gdrive --quiet

In [None]:
# %cd /kaggle/working/yandex-cup-2023-ml-neuroswipe
# ! git pull
# ! git checkout datasetv4

In [None]:
%cd /kaggle/working/yandex-cup-2023-ml-neuroswipe/src

In [2]:
############# Script arguments emulation #############

GRID_NAME = "default"
BATCH_SIZE = 320
IN_KAGGLE = False
RANDOM_SEED = 12

DATA_ROOT = "../data/data_separated_grid"
MODELS_DIR = "../data/trained_models/m1"

In [3]:
import os
import json
import typing as tp
import traceback
from datetime import datetime
import copy

import torch
# import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from torch.utils.tensorboard import SummaryWriter


from model import SwipeCurveTransformer, get_m1_bigger_model
from tokenizers import CharLevelTokenizerv2, KeyboardTokenizerv1
from tokenizers import ALL_CYRILLIC_LETTERS_ALPHABET_ORD
from dataset import CurveDataset, CollateFn
from word_generators import GreedyGenerator
from nearest_key_lookup import NearestKeyLookup, ExtendedNearestKeyLookup
from transforms import TransformerInputOutputGetter

In [4]:
################ Other constants ####################
GRID_NAME_TO_DS_PATHS = {
    "extra": {
        "train": os.path.join(DATA_ROOT, "train__extra_only_no_errors__2023_11_01__19_49_14.jsonl"),
        "val": os.path.join(DATA_ROOT, "valid__in_train_format__extra_only.jsonl")
    },
    "default": {
        "train": os.path.join(DATA_ROOT, "train__default_only_no_errors__2023_10_31__03_26_16.jsonl"),
        "val": os.path.join(DATA_ROOT, "valid__in_train_format__default_only.jsonl")
    }
}

In [5]:
# if IN_KAGGLE:
#     DATA_ROOT = "/kaggle/input/neuroswipe-defualt-only-v1"
#     MODELS_DIR = ""

In [5]:
def init_random_seed(value):
    # random.seed(value)
    np.random.seed(value)
    torch.manual_seed(value)
    torch.cuda.manual_seed(value)
    # torch.backends.cudnn.deterministic = True

In [6]:
def get_grid(grid_name: str, grids_path: str) -> dict:
    with open(grids_path, "r", encoding="utf-8") as f:
        return json.load(f)[grid_name]

In [7]:
from typing import List, Dict, Tuple, Optional, Set

def get_gridname_to_out_of_bounds_coords_dict(
        data_paths: List[str], gridname_to_wh: dict,
        total: Optional[int] = None
        ) -> Dict[str, Set[Tuple[int, int]]]:
    """
    Returns a dictionary with grid names as keys and lists of out of bounds coordinates as values.
    """
    gname_to_out_of_bounds = {gname: set() for gname in gridname_to_wh.keys()}

    for data_path in data_paths:
        with open(data_path, "r", encoding="utf-8") as json_file:
            for line in tqdm(json_file, total=total):
                json_data = json.loads(line)
                curve = json_data['curve']
                grid_name = curve['grid_name']
                w, h = gridname_to_wh[grid_name]
                X, Y = curve['x'], curve['y']
                out_of_bounds = set((x, y) for x, y in zip(X, Y) 
                                    if x < 0 or x >= w or y < 0 or y >= h)
                gname_to_out_of_bounds[grid_name].update(out_of_bounds)
    return gname_to_out_of_bounds

In [8]:
def get_datasets(grid_name: str, grid_name_to_grid_path: str,
                 train_data_path: str, val_data_path: str,
                 nearest_key_candidates: tp.Set[str],
                 kb_tokenizer: KeyboardTokenizerv1,
                 word_char_tokenizer: CharLevelTokenizerv2
                 ) -> tuple[CurveDataset, CurveDataset]:
    
    gridname_to_grid  = {grid_name: get_grid(grid_name, grid_name_to_grid_path)}

    gname_to_wh = {
        gname: (grid['width'], grid['height']) 
        for gname, grid in gridname_to_grid.items()
    }
    
    print("Accumulating out-of-bounds coordinates...")
    gname_to_out_of_bounds = get_gridname_to_out_of_bounds_coords_dict(
        [train_data_path, val_data_path], gname_to_wh, total=6_000_000
    )
    
    print("Creating ExtendedNearestKeyLookups...")
    gridname_to_nkl = {
        gname: ExtendedNearestKeyLookup(grid, nearest_key_candidates, gname_to_out_of_bounds[gname])
        for gname, grid in gridname_to_grid.items()
    }
    
    
    transformer_in_out_getter = TransformerInputOutputGetter(
        grid_name_to_nk_lookup=gridname_to_nkl,
        grid_name_to_wh=gname_to_wh,
        kb_tokenizer=kb_tokenizer,
        word_tokenizer=word_char_tokenizer,
        include_time=False,
        include_velocities=True,
        include_accelerations=True
    )
    
    print("Creating datasets...")
    train_ds = CurveDataset(
        data_path=train_data_path,
        transform = transformer_in_out_getter,
        total = 5_237_584,  # 349_172 for extra
    )

    val_ds = CurveDataset(
        data_path=val_data_path,
        transform = transformer_in_out_getter,
        total = 9_416,
    )
    
    return train_ds, val_ds

In [9]:
init_random_seed(RANDOM_SEED)

In [11]:
# Pickling the dataset would be great to not waste
# around 20 minutes creating train_dataset.

kb_tokenizer = KeyboardTokenizerv1()
voc_path=os.path.join(DATA_ROOT, "voc.txt")
word_char_tokenizer = CharLevelTokenizerv2(voc_path)

train_dataset, val_dataset = get_datasets(
    grid_name=GRID_NAME,
    grid_name_to_grid_path=os.path.join(DATA_ROOT, "gridname_to_grid.json"),
    train_data_path = GRID_NAME_TO_DS_PATHS[GRID_NAME]['train'],
    val_data_path = GRID_NAME_TO_DS_PATHS[GRID_NAME]['val'],
    nearest_key_candidates = ALL_CYRILLIC_LETTERS_ALPHABET_ORD,
    kb_tokenizer=kb_tokenizer,
    word_char_tokenizer=word_char_tokenizer,
)

Accumulating out-of-bounds coordinates...


 87%|████████▋ | 5237584/6000000 [04:18<00:37, 20274.45it/s]
  0%|          | 9416/6000000 [00:00<04:45, 20953.72it/s]


Creating ExtendedNearestKeyLookups...
Creating datasets...


100%|██████████| 5237584/5237584 [05:14<00:00, 16649.41it/s]
100%|██████████| 9416/9416 [00:00<00:00, 17750.98it/s]


In [20]:
train_dataset.transform.get_encoder_feats.grid_name_to_nk_lookup['default'].extended_coord_to_kb_label[624, -24]

'г'

In [49]:
n_iters = 100000

In [50]:
for i in range(n_iters):
    train_dataset.transform.get_encoder_feats.grid_name_to_nk_lookup['default'](624, -24)

In [51]:
for i in range(n_iters):
    train_dataset.transform.get_encoder_feats.grid_name_to_nk_lookup['default'](1, 1)

In [52]:
for i in range(n_iters):
    train_dataset.transform.get_encoder_feats.grid_name_to_nk_lookup['default'](-1111, -1111)

In [57]:
for i in tqdm(range(len(train_dataset))):
    train_dataset[i]

  0%|          | 16512/5237584 [01:27<7:42:23, 188.19it/s] 


KeyboardInterrupt: 

In [59]:
from dataset import NeuroSwipeDatasetv3

In [61]:
def get_datasets_old(grid_name: str, grid_name_to_grid_path: str,
                 train_data_path: str, val_data_path: str,
                 ds_kwargs: dict, kb_tokenizer: KeyboardTokenizerv1,
                 word_char_tokenizer: CharLevelTokenizerv2
                 ) -> tuple[NeuroSwipeDatasetv3, NeuroSwipeDatasetv3]:
    
    gridname_to_grid  = {grid_name: get_grid(grid_name, grid_name_to_grid_path)}

    train_ds = NeuroSwipeDatasetv3(
        data_path=train_data_path,
        gridname_to_grid = gridname_to_grid,
        kb_tokenizer=kb_tokenizer,
        word_tokenizer =word_char_tokenizer,
        total = 5_237_584,  # 349_172 for extra
        **ds_kwargs
    )

    val_ds = NeuroSwipeDatasetv3(
        data_path=val_data_path,
        gridname_to_grid =gridname_to_grid,
        kb_tokenizer=kb_tokenizer,
        word_tokenizer =word_char_tokenizer,
        total = 9_416,
        **ds_kwargs
    )

    return train_ds, val_ds

In [63]:
DS_KWARGS = dict(
    include_time = False,
    include_velocities = True,
    include_accelerations = True,
    has_target=True,
    has_one_grid_only=True,
    include_grid_name=False,
    keyboard_selection_set=set(ALL_CYRILLIC_LETTERS_ALPHABET_ORD)
)


train_dataset, val_dataset = get_datasets_old(
    grid_name=GRID_NAME,
    grid_name_to_grid_path=os.path.join(DATA_ROOT, "gridname_to_grid.json"),
    train_data_path = GRID_NAME_TO_DS_PATHS[GRID_NAME]['train'],
    val_data_path = GRID_NAME_TO_DS_PATHS[GRID_NAME]['val'],
    ds_kwargs=DS_KWARGS,
    kb_tokenizer=kb_tokenizer,
    word_char_tokenizer=word_char_tokenizer,
)

100%|██████████| 5237584/5237584 [13:35<00:00, 6425.51it/s] 
100%|██████████| 9416/9416 [00:02<00:00, 4328.65it/s]


In [65]:
for i in tqdm(range(len(train_dataset))):
    train_dataset[i]

  0%|          | 1984/5237584 [00:02<2:05:15, 696.63it/s] 


KeyboardInterrupt: 

In [10]:
import json
from typing import Optional, List, Tuple, Dict, Set, Callable
import array

import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence

from tokenizers import CharLevelTokenizerv2, KeyboardTokenizerv1

class CurveDatasetv2(Dataset):
    """
    Dataset class for NeuroSwipe jsonl dataset
    
    curve_dataset_obj[i] is a tuple (X, Y, T, grid_name, tgt_word)
    If there is no 'word' property in .json file, `tgt_word` is None.

    Extracting features (for example nearest keyboard  key label) 
    is be done via transforms.  Transfroms are be applied in __getitem__
    but if they may be split into two args in the future: 
    init_transforms and get_item_transforms.
    """

    def __init__(self,
                 data_path: str,
                 transform: Optional[Callable] = None,
                 total: Optional[int] = None):
        """
        Arguments:
        ----------
        data_path: str
            Path to the NeuroSwipe dataset in JSON format.
            A custom version of the dataset is used: "grid" property
            is replaced with "grid_name". The grid itself is stored in
            a separate gridname_to_grid dictionary.
            Dataset is a list of JSON lines. Each line is a dictionary
            with the following properties:
            - word (str): word that was typed. 
                Is abscent in test and val datasets.
            - curve (dict): dictionary that contains the following properties:
                - x (List[int]): x coordinates of the swipe trajectory.
                - y (List[int]): y coordinates of the swipe trajectory.
                - t (List[int]): time (in ms) from the beginning of the swipe.
                - grid_name (str): name of the keyboard grid.
        transform: Optional[Callable]
            A function that takes raw data (X, Y, T, grid_name, tgt_word)
            and returns a tuple (model_input, target).
        total: Optional[int]
            Number of dataset elements. Is used only for progress bar.
        """
        self.data_list = self._get_data(data_path, total = total)
        self.transform = transform

    def _get_data(self,
                  data_path: str,
                  total: Optional[int] = None):
        data_list = []
        with open(data_path, "r", encoding="utf-8") as json_file:
            for line in tqdm(json_file, total = total):
                data_list.append(
                    self.transform(
                        self._get_data_from_json_line(line)
                    )
                )
        return data_list

    def _get_data_from_json_line(self,
                                 line
                                 ) -> Tuple[list, list, list, str]:
        """
        Parses a JSON line and returns a dictionary with data.
        """
        data = json.loads(line)

        X = array.array('h', data['curve']['x'])
        Y = array.array('h', data['curve']['y'])
        T = array.array('h', data['curve']['t'])

        grid_name = data['curve']['grid_name']   

        tgt_word = data['word'] if 'word' in data else None

        return X, Y, T, grid_name, tgt_word

    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        return self.data_list[idx] 

In [11]:
def get_datasetsv2(grid_name: str, grid_name_to_grid_path: str,
                 train_data_path: str, val_data_path: str,
                 nearest_key_candidates: tp.Set[str],
                 kb_tokenizer: KeyboardTokenizerv1,
                 word_char_tokenizer: CharLevelTokenizerv2
                 ) -> tuple[CurveDataset, CurveDataset]:
    
    gridname_to_grid  = {grid_name: get_grid(grid_name, grid_name_to_grid_path)}

    gname_to_wh = {
        gname: (grid['width'], grid['height']) 
        for gname, grid in gridname_to_grid.items()
    }
    
    print("Accumulating out-of-bounds coordinates...")
    gname_to_out_of_bounds = get_gridname_to_out_of_bounds_coords_dict(
        [train_data_path, val_data_path], gname_to_wh, total=6_000_000
    )
    
    print("Creating ExtendedNearestKeyLookups...")
    gridname_to_nkl = {
        gname: ExtendedNearestKeyLookup(grid, nearest_key_candidates, gname_to_out_of_bounds[gname])
        for gname, grid in gridname_to_grid.items()
    }
    
    
    transformer_in_out_getter = TransformerInputOutputGetter(
        grid_name_to_nk_lookup=gridname_to_nkl,
        grid_name_to_wh=gname_to_wh,
        kb_tokenizer=kb_tokenizer,
        word_tokenizer=word_char_tokenizer,
        include_time=False,
        include_velocities=True,
        include_accelerations=True
    )
    
    print("Creating datasets...")
    train_ds = CurveDatasetv2(
        data_path=train_data_path,
        transform = transformer_in_out_getter,
        total = 5_237_584,  # 349_172 for extra
    )

    val_ds = CurveDatasetv2(
        data_path=val_data_path,
        transform = transformer_in_out_getter,
        total = 9_416,
    )
    
    return train_ds, val_ds

In [12]:
# Pickling the dataset would be great to not waste
# around 20 minutes creating train_dataset.

kb_tokenizer = KeyboardTokenizerv1()
voc_path=os.path.join(DATA_ROOT, "voc.txt")
word_char_tokenizer = CharLevelTokenizerv2(voc_path)

train_dataset, val_dataset = get_datasets(
    grid_name=GRID_NAME,
    grid_name_to_grid_path=os.path.join(DATA_ROOT, "gridname_to_grid.json"),
    train_data_path = GRID_NAME_TO_DS_PATHS[GRID_NAME]['train'],
    val_data_path = GRID_NAME_TO_DS_PATHS[GRID_NAME]['val'],
    nearest_key_candidates = ALL_CYRILLIC_LETTERS_ALPHABET_ORD,
    kb_tokenizer=kb_tokenizer,
    word_char_tokenizer=word_char_tokenizer,
)

Accumulating out-of-bounds coordinates...


 87%|████████▋ | 5237584/6000000 [04:40<00:40, 18701.45it/s]
  0%|          | 9416/6000000 [00:00<06:25, 15539.80it/s]


Creating ExtendedNearestKeyLookups...
Creating datasets...


100%|██████████| 5237584/5237584 [04:44<00:00, 18403.42it/s]
100%|██████████| 9416/9416 [00:00<00:00, 22207.71it/s]


In [13]:
for i in tqdm(range(len(train_dataset))):
    train_dataset[i]

  1%|          | 38451/5237584 [02:31<5:40:36, 254.40it/s] 


KeyboardInterrupt: 

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
transformer = get_m1_bigger_model(device)

In [None]:
def cross_entropy_with_reshape(pred, target, ignore_index=-100, label_smoothing=0.0):
    """
    pred - BatchSize x TargetLen x VocabSize
    target - BatchSize x TargetLen
    """
    pred_flat = pred.view(-1, pred.shape[-1])  # BatchSize*TargetLen x VocabSize
    target_flat = target.reshape(-1)  # BatchSize*TargetLen
    return F.cross_entropy(pred_flat,
                           target_flat,
                           ignore_index=ignore_index,
                           label_smoothing=label_smoothing)

In [None]:
def lr_scheduler(optimizer):
    return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                      patience=20,
                                                      factor=0.5,
                                                      verbose=True)

In [None]:
def move_all_to_device(x, device):
    if torch.is_tensor(x):
        return x.to(device)
    elif not isinstance(x, (list, tuple)):
        raise ValueError(f'Unexpected data type {type(x)}')
    new_x = []
    for el in x:
        if not torch.is_tensor(el):
            raise ValueError(f'Unexpected data type {type(el)}')
        new_x.append(el.to(device))
    return new_x

In [None]:
collate_fn = CollateFn(
    word_pad_idx = word_char_tokenizer.char_to_idx['<pad>'], batch_first = False)

In [None]:
# Протестируем корректность collate_fn (вызывается неявно в DataLoader)

batch_size = 6


PAD_CHAR_TOKEN = word_char_tokenizer.char_to_idx["<pad>"]


train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False,
                              num_workers=0, collate_fn=collate_fn)


dataset_els = [train_dataset[i] for i in range(batch_size)]
unproc_batch_x, unproc_batch_y = zip(*dataset_els)

batch_x, batch_y = next(iter(train_dataloader))


############### Проверка корректности batch_y ###################
max_out_seq_len = max([len(y) for y in unproc_batch_y])

assert batch_y.shape == (max_out_seq_len, batch_size)


for i in range(batch_size):
    assert (batch_y[:len(unproc_batch_y[i]), i] == unproc_batch_y[i]).all()
    assert (batch_y[len(unproc_batch_y[i]):, i] == PAD_CHAR_TOKEN).all()

print("batch_y is correct")



############### Проверка корректности batch_x ###################
unproc_batch_traj_feats, unproc_batch_kb_tokens, unproc_batch_dec_in_char_seq = zip(*unproc_batch_x)

(traj_feats, kb_tokens, dec_in_char_seq, traj_pad_mask, word_pad_mask) = batch_x


# каждая сущность, полученная выше из unpoc_batch_x - это tuple длины batch_size.
# Например, unproc_batch_traj_feats[i] = train_dataset[i][0][0]

N_TRAJ_FEATS = 6
max_curve_len = max([el.shape[0] for el in unproc_batch_traj_feats]) 

assert max_curve_len == max([el.shape[0] for el in unproc_batch_kb_tokens])

assert traj_feats.shape == (max_curve_len, batch_size, N_TRAJ_FEATS)
assert kb_tokens.shape == (max_curve_len, batch_size)
assert dec_in_char_seq.shape == (max_out_seq_len, batch_size)
assert traj_pad_mask.shape == (batch_size, max_curve_len)
assert word_pad_mask.shape == (batch_size, max_out_seq_len)


for i in range(batch_size):
    assert (traj_feats[:len(unproc_batch_traj_feats[i]), i] == unproc_batch_traj_feats[i]).all()
    assert (kb_tokens[:len(unproc_batch_kb_tokens[i]), i] == unproc_batch_kb_tokens[i]).all()

    assert (dec_in_char_seq[:len(unproc_batch_dec_in_char_seq[i]), i] == unproc_batch_dec_in_char_seq[i]).all()
    assert (dec_in_char_seq[len(unproc_batch_dec_in_char_seq[i]):, i] == PAD_CHAR_TOKEN).all()

    assert (traj_pad_mask[i, :len(unproc_batch_traj_feats[i])] == False).all()
    assert (traj_pad_mask[i, len(unproc_batch_traj_feats[i]):] == True).all()
    
    assert (word_pad_mask[i, :len(unproc_batch_dec_in_char_seq[i])] == False).all()
    assert (word_pad_mask[i, len(unproc_batch_dec_in_char_seq[i]):] == True).all()

print("batch_x is correct")

In [None]:
from typing import List

# def predict_greedy_raw(dataset,
#                        greedy_word_generator: GreedyGenerator,
#                       ) -> List[List[str]]:
#     """
#     Creates predictions using greedy generation.

#     Supposed to be used with a dataset of a single grid
    
#     Arguments:
#     ----------
#     dataset: NeuroSwipeDatasetv2
#     grid_name_to_greedy_generator: dict
#         Dict mapping grid names to GreedyGenerator objects.
#     """
#     preds = [None] * len(dataset)

#     for data in tqdm(enumerate(dataset), total=len(dataset)):
# #         i, ((xyt, kb_tokens, _, traj_pad_mask, _), _, grid_name) = data
#         i, ((xyt, kb_tokens, _, _), _) = data

#         pred = greedy_word_generator.generate_word_only(xyt, kb_tokens)
#         pred = pred.removeprefix("<sos>")
# #         preds[i] = [pred]
#         preds[i] = pred

#     return preds




def predict_greedy_raw(dataset,
                       greedy_word_generator: GreedyGenerator,
                      ) -> List[List[str]]:
    """
    Creates predictions using greedy generation.

    Supposed to be used with a dataset of a single grid
    
    Arguments:
    ----------
    dataset: NeuroSwipeDatasetv2
    grid_name_to_greedy_generator: dict
        Dict mapping grid names to GreedyGenerator objects.
    """
    preds = [None] * len(dataset)

    for data in tqdm(enumerate(dataset), total=len(dataset)):
#         i, ((xyt, kb_tokens, _, traj_pad_mask, _), _, grid_name) = data
        i, ((xyt, kb_tokens, _), _) = data

        pred = greedy_word_generator.generate_word_only(xyt, kb_tokens)
        pred = pred.removeprefix("<sos>")
#         preds[i] = [pred]
        preds[i] = pred

    return preds




# def get_targets(dataset: NeuroSwipeDatasetv3) -> tp.List[str]:
#     targets = []
#     for (_, _, _, word_pad_mask), target_tokens in dataset:
#         target_len = int(torch.sum(~word_pad_mask)) - 1
#         target = word_char_tokenizer.decode(target_tokens[:target_len])
#         targets.append(target)
#     return targets

def get_targets(dataset: CurveDataset) -> tp.List[str]:
    targets = []
    for (_, _, _), target_tokens in dataset:
        target_len = int(torch.sum(~word_pad_mask)) - 1
        target = word_char_tokenizer.decode(target_tokens[:target_len])
        targets.append(target)
    return targets


def get_accuracy(preds, targets) -> float:
    return sum(pred == target for pred, target 
               in zip(preds, targets)) / len(targets)


def get_greedy_generator_accuracy(val_dataset, model, 
                                  word_char_tokenizer, device) -> float:
    val_targets = get_targets(val_dataset)
    greedy_generator = GreedyGenerator(model, word_char_tokenizer, device)
    greedy_preds = predict_greedy_raw(val_dataset, greedy_generator)
    return get_accuracy(greedy_preds, val_targets)

In [None]:
###################### протестируем predict_greedy_raw ######################


MODEL_TO_TEST_GREEDY_GEN__PATH = None

# MODEL_TO_TEST_GREEDY_GEN__PATH = "/kaggle/input/m1-bigger-v2-0-13413-extra-l2-0-ls0-switch-1-pt/" \
#     "m1_bigger_v2__2023_11_12__02_27_14__0.13413_extra_l2_0_ls0_switch_1.pt"
    

def test_greedy_generator(val_dataset, model_getter, model_weights, word_char_tokenizer, device) -> float:
    
    model = model_getter(device, model_weights)

    return get_greedy_generator_accuracy(val_dataset, model, word_char_tokenizer, device)



test_greedy_generator(val_dataset, get_m1_bigger_model, MODEL_TO_TEST_GREEDY_GEN__PATH, word_char_tokenizer, device)

In [None]:
def train_eval_loop(model, train_dataset, val_dataset, criterion,
                    tb, epoch_start, lr=1e-4, epoch_n=10, batch_size=32,
                    collate_fn = None,
                    device=None, early_stopping_patience=20, l2_reg_alpha=0,
                    max_batches_per_epoch_train=10000,
                    max_batches_per_epoch_val=1000,
                    optimizer_ctor=None,
                    lr_scheduler_ctor=None,
                    shuffle_train=True,
                    label_smoothing = 0.0,
                    dataloader_workers_n=0,
                    criterion_ignore_index = -100,
                    model_name_postfix = "",
                    model_save_root = ".",
                    ):
    """
    Цикл для обучения модели. После каждой эпохи качество модели оценивается по отложенной выборке.
    :param model: torch.nn.Module - обучаемая модель
    :param train_dataset: torch.utils.data.Dataset - данные для обучения
    :param val_dataset: torch.utils.data.Dataset - данные для оценки качества
    :param criterion: функция потерь для настройки модели
    :param lr: скорость обучения
    :param epoch_n: максимальное количество эпох
    :param batch_size: количество примеров, обрабатываемых моделью за одну итерацию
    :param device: cuda/cpu - устройство, на котором выполнять вычисления
    :param early_stopping_patience: наибольшее количество эпох, в течение которых допускается
        отсутствие улучшения модели, чтобы обучение продолжалось.
    :param l2_reg_alpha: коэффициент L2-регуляризации
    :param max_batches_per_epoch_train: максимальное количество итераций на одну эпоху обучения
    :param max_batches_per_epoch_val: максимальное количество итераций на одну эпоху валидации
    :return: кортеж из двух элементов:
        - среднее значение функции потерь на валидации на лучшей эпохе
        - лучшая модель
    """
    if device is None:
        device =  torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model.to(device)

    if optimizer_ctor is None:
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_reg_alpha)
    else:
        optimizer = optimizer_ctor(model.parameters(), lr=lr)

    if lr_scheduler_ctor is not None:
        lr_scheduler = lr_scheduler_ctor(optimizer)
    else:
        lr_scheduler = None

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle_train,
                                        num_workers=dataloader_workers_n, collate_fn=collate_fn)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                                      num_workers=dataloader_workers_n, collate_fn=collate_fn)

    best_val_loss = float('inf')
    best_epoch_i = 0

    best_model_path = "m1_bigger_v2.pt"
    
    n_train_examples_in_epoch = (batch_size * max_batches_per_epoch_train 
                                 if max_batches_per_epoch_train < len(train_dataset) // batch_size
                                 else len(train_dataset))

    if os.path.exists(best_model_path):
        best_model.load_state_dict(torch.load(best_model_path))
        print(f"Загружено состояние модели {best_model_path}")

    for epoch_i in tqdm(range(epoch_start, epoch_start + epoch_n), position = 0):
        try:
            model.train()
            mean_train_loss = 0
            train_batches_n = 0
            for batch_i, (batch_x, batch_y) in tqdm(enumerate(train_dataloader), total = min(max_batches_per_epoch_train, len(train_dataset) // batch_size), position=1, leave = False):
                if batch_i > max_batches_per_epoch_train:
                    break
                    

                batch_x, batch_y = [move_all_to_device(el, device) for el in (batch_x, batch_y)]

                pred = model(*batch_x)
                loss = criterion(pred, batch_y, ignore_index = criterion_ignore_index, 
                                 label_smoothing=label_smoothing)

                model.zero_grad()
                loss.backward()

                optimizer.step()

                mean_train_loss += float(loss)
                train_batches_n += 1

            mean_train_loss /= train_batches_n
            
            print('Среднее значение функции потерь на обучении', mean_train_loss)

            tb.add_scalar('mean_loss/train', mean_train_loss, epoch_i * n_train_examples_in_epoch)



            model.eval()
            mean_val_loss = 0
            val_batches_n = 0

            with torch.no_grad():
                for batch_i, (batch_x, batch_y) in enumerate(val_dataloader):
                    if batch_i > max_batches_per_epoch_val:
                        break

                    batch_x, batch_y = [move_all_to_device(el, device) for el in (batch_x, batch_y)]

                    pred = model(*batch_x)
                    loss = criterion(pred, batch_y, 
                                     ignore_index = criterion_ignore_index, 
                                     label_smoothing=label_smoothing)

                    mean_val_loss += float(loss)
                    val_batches_n += 1

            mean_val_loss /= val_batches_n
            print('Среднее значение функции потерь на валидации', mean_val_loss)
            tb.add_scalar('mean_loss/val', mean_val_loss, epoch_i * n_train_examples_in_epoch)

            if mean_val_loss < best_val_loss:
                best_epoch_i = epoch_i
                best_val_loss = mean_val_loss
                best_model = copy.deepcopy(model)
                torch.save(model.state_dict(), os.path.join(model_save_root, best_model_path))
                
                cur_time = "{:%Y_%m_%d__%H_%M_%S}".format(datetime.now())
                
               
                greedy_accuracy = get_greedy_generator_accuracy(val_dataset, model, word_char_tokenizer, device)
                tb.add_scalar('greedy_accuracy/val', greedy_accuracy, epoch_i * n_train_examples_in_epoch)
                
                torch.save(model.state_dict(), os.path.join(model_save_root, f"m1_bigger_v2__{cur_time}__{mean_val_loss:.5f}__greed_acc_{greedy_accuracy:.5f}__{model_name_postfix}__epoch_i_{epoch_i}.pt"))
                print(f"Greedy accuracy = {greedy_accuracy}")
                print('Новая лучшая модель!')
            elif epoch_i - best_epoch_i > early_stopping_patience:
                print('Модель не улучшилась за последние {} эпох, прекращаем обучение'.format(
                    early_stopping_patience))
                break

            if lr_scheduler is not None:
                lr_scheduler.step(mean_val_loss)

            print()
        except KeyboardInterrupt:
            print('Досрочно остановлено пользователем')
            break
        except Exception as ex:
            print('Ошибка при обучении: {}\n{}'.format(ex, traceback.format_exc()))
            break

In [None]:
EXPERIMENT_NAME = f"m1_bigger_model__{GRID_NAME}__from_random_weights__batch__{BATCH_SIZE}/SEED_{RANDOM_SEED}__run1"
TENSORBOARD_LOG_PATH = f"/kaggle/working/tensorboard_log/{EXPERIMENT_NAME}"

tb = SummaryWriter(TENSORBOARD_LOG_PATH)


In [None]:
l2_reg_alpha =  0 #5e-5
label_smoothing=  0 #0.045
epoch_start = 0

best_val_loss, best_model = train_eval_loop(
    transformer, train_dataset, val_dataset, cross_entropy_with_reshape, tb, epoch_start,
    lr=1e-4, epoch_n=10000, batch_size=BATCH_SIZE, collate_fn = collate_fn,
    device=device, early_stopping_patience=20, l2_reg_alpha=l2_reg_alpha,
    max_batches_per_epoch_train=2000,
    max_batches_per_epoch_val=1000,
    optimizer_ctor=None,
    lr_scheduler_ctor=lr_scheduler,
    shuffle_train=True,
    dataloader_workers_n=0,
    criterion_ignore_index = word_char_tokenizer.char_to_idx['<pad>'],
    model_name_postfix = f'{GRID_NAME}_l2_{l2_reg_alpha}_ls{label_smoothing}',
    model_save_root = "../..",
    label_smoothing=label_smoothing,
)

In [None]:
# Эпоха должна длиться 16 минут