In [None]:
NOISE_RANGE = 0  # set to 0 to avoid augmentation

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%cd /kaggle/working/yandex-cup-2023-ml-neuroswipe
!git checkout transformer-conformer-lightning

%cd /kaggle/working/yandex-cup-2023-ml-neuroswipe/src

In [None]:
# !rm -r /kaggle/working/yandex-cup-2023-ml-neuroswipe/src/lightning_logs
# !rm -r /kaggle/working/yandex-cup-2023-ml-neuroswipe/src/checkpoints
# !rm -r /kaggle/working/yandex-cup-2023-ml-neuroswipe/src/checkpoint_epoch_end

In [None]:
# !zip -r /kaggle/working/src_uniform_int_noise_{NOISE_RANGE}.zip /kaggle/working/yandex-cup-2023-ml-neuroswipe/src

In [None]:
# ls lightning_logs/transformer_m1_bigger__default__from_random_weights__batch__256/SEED_121/version_4

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# !git clone https://github.com/proshian/yandex-cup-2023-ml-neuroswipe.git
# %cd yandex-cup-2023-ml-neuroswipe

In [None]:
# !pip install dvc --quiet
# !pip install dvc_gdrive --quiet

In [None]:
# ! pip install gdown
# ! python ./src/downloaders/download_weights.py

In [None]:
%cd /kaggle/working/yandex-cup-2023-ml-neuroswipe
! git pull
! git checkout transformer-conformer-lightning

In [None]:
%cd /kaggle/working/yandex-cup-2023-ml-neuroswipe/src

In [None]:
############# Script arguments emulation #############

GRID_NAME = "default"
TRAIN_BATCH_SIZE = 64
IN_KAGGLE = False
RANDOM_SEED = 12

DATA_ROOT = "../data/data_separated_grid"
MODELS_DIR = "../data/trained_models/m1"

In [None]:
import os
import json
import typing as tp
import traceback
from datetime import datetime
import copy

import torch
# import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import numpy as np
from torch.utils.tensorboard import SummaryWriter


from model import SwipeCurveTransformer, get_m1_bigger_model
from ns_tokenizers import CharLevelTokenizerv2, KeyboardTokenizerv1
from ns_tokenizers import ALL_CYRILLIC_LETTERS_ALPHABET_ORD
from dataset import CurveDataset, CollateFn
from word_generators import GreedyGenerator
from nearest_key_lookup import ExtendedNearestKeyLookup
from transforms import KbTokens_InitTransform, KbTokens_GetItemTransform

In [None]:
################ Other constants ####################    
GRID_NAME_TO_DS_PATHS = {
    "extra": {
        "train": os.path.join(DATA_ROOT, "train__extra_only_no_errors__2023_11_01__19_49_14.jsonl"),
        "val": os.path.join(DATA_ROOT, "valid__in_train_format__extra_only.jsonl")
    },
    "default": {
        "train": os.path.join(DATA_ROOT, "train__default_only_no_errors__2023_10_31__03_26_16.jsonl"),
        "val": os.path.join(DATA_ROOT, "valid__in_train_format__default_only.jsonl")
    }
}


DS_PATHS =  GRID_NAME_TO_DS_PATHS[GRID_NAME]

In [None]:
# if IN_KAGGLE:
#     DATA_ROOT = "/kaggle/input/neuroswipe-defualt-only-v1"
#     MODELS_DIR = ""

In [None]:
def init_random_seed(value):
    # random.seed(value)
    np.random.seed(value)
    torch.manual_seed(value)
    torch.cuda.manual_seed(value)
    # torch.backends.cudnn.deterministic = True

In [None]:
def get_grid(grid_name: str, grids_path: str) -> dict:
    with open(grids_path, "r", encoding="utf-8") as f:
        return json.load(f)[grid_name]

In [None]:
from typing import List, Dict, Tuple, Optional, Set

def get_gridname_to_out_of_bounds_coords_dict(
        data_paths: List[str], gridname_to_wh: dict,
        totals: tp.Iterable[Optional[int]] = None
        ) -> Dict[str, Set[Tuple[int, int]]]:
    """
    Returns a dictionary with grid names as keys and lists of out of bounds coordinates as values.
    """
    totals = totals or [None] * len(data_paths)
    
    gname_to_out_of_bounds = {gname: set() for gname in gridname_to_wh.keys()}

    for data_path, total in zip(data_paths, totals):
        with open(data_path, "r", encoding="utf-8") as json_file:
            for line in tqdm(json_file, total=total):
                json_data = json.loads(line)
                curve = json_data['curve']
                grid_name = curve['grid_name']
                w, h = gridname_to_wh[grid_name]
                X, Y = curve['x'], curve['y']
                out_of_bounds = set((x, y) for x, y in zip(X, Y) 
                                    if x < 0 or x >= w or y < 0 or y >= h)
                gname_to_out_of_bounds[grid_name].update(out_of_bounds)
    return gname_to_out_of_bounds

In [None]:
from typing import Dict, Set, Tuple


def update_out_of_bounds_with_noise(
    noise_min, noise_max,
    gname_to_out_of_bounds, gridname_to_wh: dict,
    )-> Dict[str, Set[Tuple[int, int]]]:
    
    assert noise_min <= 0
    assert noise_max >= 0
    
    additional_out_of_bounds = {gname: set() for gname in gridname_to_wh.keys()}
    
    for gname in gname_to_out_of_bounds.keys():
        w, h = gridname_to_wh[gname]
        
        for x, y in gname_to_out_of_bounds[gname]:
            for i in range(noise_min, noise_max+1):
                for j in range(noise_min, noise_max+1):
                    if x+i < 0 or x+i >= w or y+j < 0 or y+j >=h: 
                        additional_out_of_bounds[gname].add((x+i, y+j))
        
        for x in range(noise_min, w+noise_max+1):
            for y in range(noise_min, 0):
                additional_out_of_bounds[gname].add((x, y))
        
        for x in range(noise_min, w+noise_max+1):
            for y in range(h+1, h+noise_max+1):
                additional_out_of_bounds[gname].add((x, y))
        
        for x in range(w, w+noise_max+1):
            for y in range(0, h+1):
                additional_out_of_bounds[gname].add((x, y))
        
        for x in range(noise_min, 0):
            for y in range(0, h+1):
                additional_out_of_bounds[gname].add((x, y))
                
        gname_to_out_of_bounds[gname].update(additional_out_of_bounds[gname])
        
    return gname_to_out_of_bounds
        

In [None]:
import numpy as np

class RandIntToTrajTransform:
    def __init__(self, min_ = -3, max_ = 3) -> None:
        self.min = min_
        self.max = max_
        
    def __call__(self, data):
        X, Y, T, grid_name, tgt_word = data
        X = np.array(X, dtype = int) + np.random.randint(self.min, self.max, (len(X),))
        Y = np.array(Y, dtype = int) + np.random.randint(self.min, self.max, (len(Y),))
        return X, Y, T, grid_name, tgt_word
    
class SequentialTransform:
    def __init__(self, transforms) -> None:
        self.transforms = transforms
    
    def __call__(self, data):
        for transform in self.transforms:
            data = transform(data)
        return data

In [None]:
from dataset import CurveDataset  # CurveDatasetWithMultiProcInit
from nearest_key_lookup import NearestKeyLookup, ExtendedNearestKeyLookup
from ns_tokenizers import KeyboardTokenizerv1, CharLevelTokenizerv2
from ns_tokenizers import ALL_CYRILLIC_LETTERS_ALPHABET_ORD
from predict import get_grid_name_to_grid
from transforms import FullTransform
from dataset import _get_data_from_json_line




gridname_to_grid_path = os.path.join(DATA_ROOT, "gridname_to_grid.json")
voc_path=os.path.join(DATA_ROOT, "voc.txt")
char_tokenizer = CharLevelTokenizerv2(voc_path)





# has one grid only
gridname_to_grid = get_grid_name_to_grid(gridname_to_grid_path)


totals = [None, 6_000_000]

gname_to_wh = {
    gname: (grid['width'], grid['height']) 
    for gname, grid in gridname_to_grid.items()
}


print("Accumulating out-of-bounds coordinates...")
gname_to_out_of_bounds = get_gridname_to_out_of_bounds_coords_dict(
    DS_PATHS.values(), gname_to_wh, totals=totals
)

print("augmenting gname_to_out_of_bounds")
gname_to_out_of_bounds = update_out_of_bounds_with_noise(
    noise_min = -NOISE_RANGE, noise_max=NOISE_RANGE+1,
    gname_to_out_of_bounds = gname_to_out_of_bounds, gridname_to_wh = gname_to_wh,
)



print("Creating ExtendedNearestKeyLookups...")
gridname_to_nkl = {
    gname: ExtendedNearestKeyLookup(grid, ALL_CYRILLIC_LETTERS_ALPHABET_ORD, gname_to_out_of_bounds[gname])
    for gname, grid in gridname_to_grid.items()
}





kb_tokenizer = KeyboardTokenizerv1()


full_transform = FullTransform(
    grid_name_to_nk_lookup=gridname_to_nkl,
    grid_name_to_wh=gname_to_wh,
    kb_tokenizer=kb_tokenizer,
    word_tokenizer=char_tokenizer,
    include_time=False,
    include_velocities=True,
    include_accelerations=True,
    kb_tokens_dtype=torch.int32,
    word_tokens_dtype=torch.int64
)



train_transform = None
if NOISE_RANGE != 0:
    augmentation_transform = RandIntToTrajTransform(-NOISE_RANGE, NOISE_RANGE + 1)
    train_transform = SequentialTransform([augmentation_transform, full_transform])
else:
    train_transform = full_transform
    
val_transform = full_transform


print("Calling CurveDatasetWithMultiProcInit.__init__ with full_transform...")


# jsonl_to_bins_on_disk(args.jsonl_path, 10_000, full_transform, args.output_path, total=5_237_584)



train_dataset = CurveDataset(
    data_path=DS_PATHS['train'],
    store_gnames=False,
    init_transform=None,
    get_item_transform=train_transform,
    total=5_237_584 # 349172
)

val_dataset = CurveDataset(
    data_path=DS_PATHS['val'],
    store_gnames=False,
    init_transform=None,
    get_item_transform=val_transform,
    total=10_000 # 349172
)

In [None]:
def get_datasets(grid_name: str, grid_name_to_grid_path: str,
                 data_paths: tp.Iterable[str], totals: tp.Iterable[tp.Optional[int]],
                 nearest_key_candidates: tp.Set[str],
                 kb_tokenizer: KeyboardTokenizerv1,
                 word_char_tokenizer: CharLevelTokenizerv2
                 ) -> List[CurveDataset]:
    
    gridname_to_grid  = {grid_name: get_grid(grid_name, grid_name_to_grid_path)}

    gname_to_wh = {
        gname: (grid['width'], grid['height']) 
        for gname, grid in gridname_to_grid.items()
    }
    
    print("Accumulating out-of-bounds coordinates...")
    gname_to_out_of_bounds = get_gridname_to_out_of_bounds_coords_dict(
        data_paths, gname_to_wh, totals=totals
    )
    
    print("Creating ExtendedNearestKeyLookups...")
    gridname_to_nkl = {
        gname: ExtendedNearestKeyLookup(grid, nearest_key_candidates, gname_to_out_of_bounds[gname])
        for gname, grid in gridname_to_grid.items()
    }
    
    
    init_transform = KbTokens_InitTransform(
        grid_name_to_nk_lookup=gridname_to_nkl,
        kb_tokenizer=kb_tokenizer,
    )

    get_item_transform = KbTokens_GetItemTransform(
        grid_name_to_wh=gname_to_wh,
        word_tokenizer=word_char_tokenizer,
        include_time=False,
        include_velocities=True,
        include_accelerations=True,
    )
    
    print("Creating datasets...")
    datasets = []
    for d_path, total in zip(data_paths, totals):
        ds = CurveDataset(
            data_path=d_path,
            store_gnames = False,
            init_transform=init_transform,
            get_item_transform=get_item_transform,
            total = total,
        )
        datasets.append(ds)
    
    return datasets

In [None]:
init_random_seed(RANDOM_SEED)

In [None]:
# Pickling the dataset would be great to not waste
# around 20 minutes creating train_dataset.

kb_tokenizer = KeyboardTokenizerv1()
voc_path=os.path.join(DATA_ROOT, "voc.txt")
word_char_tokenizer = CharLevelTokenizerv2(voc_path)

# data_paths = [
#     GRID_NAME_TO_DS_PATHS[GRID_NAME]['train'],
#     GRID_NAME_TO_DS_PATHS[GRID_NAME]['val']
# ]
# totals = [6_000_000, None]

# data_paths = [
#     GRID_NAME_TO_DS_PATHS[GRID_NAME]['val']
# ]
# totals = [None]


# # train_dataset, val_dataset = get_datasets(
# (val_dataset, ) = get_datasets(
#     grid_name=GRID_NAME,
#     grid_name_to_grid_path=os.path.join(DATA_ROOT, "gridname_to_grid.json"),
#     data_paths = data_paths, totals = totals,
#     nearest_key_candidates = ALL_CYRILLIC_LETTERS_ALPHABET_ORD,
#     kb_tokenizer=kb_tokenizer,
#     word_char_tokenizer=word_char_tokenizer,
# )

In [None]:
# %%time

# import pickle

# try:
#     with open(GRID_NAME_TO_FULLY_TRANSFORMED_DS_PATHS[GRID_NAME]['train'], 'rb') as f:
#         data_list = pickle.load(f)
# except e:
#     print(e)

# # train_dataset = CurveDataset.from_data_list(data_list)

In [None]:
# data_list

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
transformer = get_m1_bigger_model(device)

In [None]:
def cross_entropy_with_reshape(pred, target, ignore_index=-100, label_smoothing=0.0):
    """
    pred - BatchSize x TargetLen x VocabSize
    target - BatchSize x TargetLen
    """
    pred_flat = pred.view(-1, pred.shape[-1])  # BatchSize*TargetLen x VocabSize
    target_flat = target.reshape(-1)  # BatchSize*TargetLen
    return F.cross_entropy(pred_flat,
                           target_flat,
                           ignore_index=ignore_index,
                           label_smoothing=label_smoothing)

In [None]:
def lr_scheduler(optimizer):
    return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                      patience=20,
                                                      factor=0.5,
                                                      verbose=True)

In [None]:
# def move_all_to_device(x, device):
#     if torch.is_tensor(x):
#         return x.to(device)
#     elif not isinstance(x, (list, tuple)):
#         raise ValueError(f'Unexpected data type {type(x)}')
#     new_x = []
#     for el in x:
#         if not torch.is_tensor(el):
#             raise ValueError(f'Unexpected data type {type(el)}')
#         new_x.append(el.to(device))
#     return new_x

In [None]:
collate_fn = CollateFn(
    word_pad_idx = word_char_tokenizer.char_to_idx['<pad>'], batch_first = False)

In [None]:
# # Протестируем корректность collate_fn (вызывается неявно в DataLoader)

# batch_size = 6


# PAD_CHAR_TOKEN = word_char_tokenizer.char_to_idx["<pad>"]


# train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False,
#                               num_workers=0, collate_fn=collate_fn)


# dataset_els = [train_dataset[i] for i in range(batch_size)]
# unproc_batch_x, unproc_batch_y = zip(*dataset_els)

# batch_x, batch_y = next(iter(train_dataloader))


# ############### Проверка корректности batch_y ###################
# max_out_seq_len = max([len(y) for y in unproc_batch_y])

# assert batch_y.shape == (max_out_seq_len, batch_size)


# for i in range(batch_size):
#     assert (batch_y[:len(unproc_batch_y[i]), i] == unproc_batch_y[i]).all()
#     assert (batch_y[len(unproc_batch_y[i]):, i] == PAD_CHAR_TOKEN).all()

# print("batch_y is correct")



# ############### Проверка корректности batch_x ###################
# unproc_batch_traj_feats, unproc_batch_kb_tokens, unproc_batch_dec_in_char_seq = zip(*unproc_batch_x)

# (traj_feats, kb_tokens, dec_in_char_seq, traj_pad_mask, word_pad_mask) = batch_x


# # каждая сущность, полученная выше из unpoc_batch_x - это tuple длины batch_size.
# # Например, unproc_batch_traj_feats[i] = train_dataset[i][0][0]

# N_TRAJ_FEATS = 6
# max_curve_len = max([el.shape[0] for el in unproc_batch_traj_feats]) 

# assert max_curve_len == max([el.shape[0] for el in unproc_batch_kb_tokens])

# assert traj_feats.shape == (max_curve_len, batch_size, N_TRAJ_FEATS)
# assert kb_tokens.shape == (max_curve_len, batch_size)
# assert dec_in_char_seq.shape == (max_out_seq_len, batch_size)
# assert traj_pad_mask.shape == (batch_size, max_curve_len)
# assert word_pad_mask.shape == (batch_size, max_out_seq_len)


# for i in range(batch_size):
#     assert (traj_feats[:len(unproc_batch_traj_feats[i]), i] == unproc_batch_traj_feats[i]).all()
#     assert (kb_tokens[:len(unproc_batch_kb_tokens[i]), i] == unproc_batch_kb_tokens[i]).all()

#     assert (dec_in_char_seq[:len(unproc_batch_dec_in_char_seq[i]), i] == unproc_batch_dec_in_char_seq[i]).all()
#     assert (dec_in_char_seq[len(unproc_batch_dec_in_char_seq[i]):, i] == PAD_CHAR_TOKEN).all()

#     assert (traj_pad_mask[i, :len(unproc_batch_traj_feats[i])] == False).all()
#     assert (traj_pad_mask[i, len(unproc_batch_traj_feats[i]):] == True).all()
    
#     assert (word_pad_mask[i, :len(unproc_batch_dec_in_char_seq[i])] == False).all()
#     assert (word_pad_mask[i, len(unproc_batch_dec_in_char_seq[i]):] == True).all()

# print("batch_x is correct")

In [None]:
from typing import List


def predict_greedy_raw(dataset,
                       greedy_word_generator: GreedyGenerator,
                       max_n_steps = 19, # длина самого длинного слова в валидационной выборке
                      ) -> List[List[str]]:
    """
    Creates predictions using greedy generation.

    Supposed to be used with a dataset of a single grid
    
    Arguments:
    ----------
    dataset: NeuroSwipeDatasetv2
    grid_name_to_greedy_generator: dict
        Dict mapping grid names to GreedyGenerator objects.
    """
    preds = [None] * len(dataset)

    for data in tqdm(enumerate(dataset), total=len(dataset)):
        i, ((xyt, kb_tokens, _), _) = data

        pred = greedy_word_generator.generate_word_only(xyt, kb_tokens, max_n_steps)
        pred = pred.removeprefix("<sos>")
        preds[i] = pred

    return preds


def get_targets(dataset: CurveDataset) -> tp.List[str]:
    targets = []
    for _, target_tokens in dataset:
        # Last token is <eos>.
        target_str = word_char_tokenizer.decode(target_tokens[:-1])
        targets.append(target_str)
    return targets


def get_accuracy(preds, targets) -> float:
    return sum(pred == target for pred, target 
               in zip(preds, targets)) / len(targets)


def get_greedy_generator_accuracy(val_dataset, model, 
                                  word_char_tokenizer, device) -> float:
#     ! Лучше не гененрировать слово целиком, а продолжать побуквенно. 
#     Если буква не совпала сразу обрывать и говорить, 
#     что предсказание для этой кривой не совпало, а не гененировать все слово впустую
    val_targets = get_targets(val_dataset)
    greedy_generator = GreedyGenerator(model, word_char_tokenizer, device)
    greedy_preds = predict_greedy_raw(val_dataset, greedy_generator)
    return get_accuracy(greedy_preds, val_targets)

In [None]:
# ###################### протестируем predict_greedy_raw ######################


# # Главное теситровать не на случайных веах, потому что тогда будут генеироваться не короткие слова, а слова длиной max_seq_len


# MODEL_TO_TEST_GREEDY_GEN__PATH = "../data/trained_models_for_final_submit/m1_bigger/" \
#     "m1_bigger_v2__2023_11_11__14_29_37__0.13679_default_l2_0_ls0_switch_0.pt"

# # Leads to super slow inference.  I think it's due to 
# # high price of operations on small-amplitude floats.
# # MODEL_TO_TEST_GREEDY_GEN__PATH = None


# def test_greedy_generator(val_dataset, model_getter, model_weights, word_char_tokenizer, device) -> float:
    
#     model = model_getter(device, model_weights)

#     return get_greedy_generator_accuracy(val_dataset, model, word_char_tokenizer, device)



# test_greedy_generator(val_dataset, get_m1_bigger_model, MODEL_TO_TEST_GREEDY_GEN__PATH, word_char_tokenizer, device)

In [None]:

# greedy_accuracy = get_greedy_generator_accuracy(val_dataset, model, word_char_tokenizer, device)
# tb.add_scalar('greedy_accuracy/val', greedy_accuracy, epoch_i * n_train_examples_in_epoch)


In [None]:
MODEL_NAME = "transformer_bb_model"
USE_AUGMENTATIONS_STR = f"uniform_int_noise_{NOISE_RANGE}__" if NOISE_RANGE else ""
EXPERIMENT_NAME = f"{MODEL_NAME}__{GRID_NAME}__{USE_AUGMENTATIONS_STR}from_random_weights__batch__{TRAIN_BATCH_SIZE}/SEED_{RANDOM_SEED}1"
# TENSORBOARD_LOG_PATH = f"/kaggle/working/tensorboard_log/{EXPERIMENT_NAME}"

# tb = SummaryWriter(TENSORBOARD_LOG_PATH)


In [None]:
# from tqdm import tqdm

In [None]:
import multiprocessing

multiprocessing.cpu_count()

In [None]:
LOG_DIR = "lightning_logs/"

In [None]:
!pip install lightning

In [None]:
def get_word_level_accuracy(y_true_batch: torch.Tensor, 
                            pred_batch: torch.Tensor, 
                            pad_token: int, 
                            mask: torch.Tensor) -> float:
    # By default y_true.shape = pred.shape = (chars_seq_len, batch_size)
    # So we have to transpose here or before calling

    y_true_batch = y_true_batch.masked_fill(mask, pad_token)
    pred_batch = pred_batch.masked_fill(mask, pad_token)
    equality_results = torch.all(torch.eq(y_true_batch, pred_batch), dim = 1)
        
    return float(equality_results.sum() / len(equality_results))


decode_batch = lambda seq_batch, tokenizer: [tokenizer.decode(seq) for seq in seq_batch]


def get_word_level_metric(metric_fn,
                          y_true_batch: torch.Tensor, 
                          pred_batch: torch.Tensor, 
                          tokenizer,
                          mask: torch.Tensor) -> float:
    
    y_true_batch.masked_fill_(mask, tokenizer.char_to_idx['<pad>'])
    pred_batch.masked_fill_(mask, tokenizer.char_to_idx['<pad>'])
        
    y_true_batch = decode_batch(y_true_batch, word_char_tokenizer)
    pred_batch = decode_batch(pred_batch, word_char_tokenizer)
    
    return metric_fn(y_true_batch, pred_batch)

In [None]:
######  testing get_word_level_accuracy, get_word_level_metric
from sklearn.metrics import f1_score, accuracy_score
import torch

batch_size = 10
seq_len = 5
y_true__rand = torch.randint(0, 32, (batch_size, seq_len))
pred__rand = torch.randint(0, 32, (batch_size, seq_len))
pred__rand[:3] = y_true__rand[:3]

mask = torch.zeros((batch_size, seq_len), dtype = torch.bool)
mask[:, :-3] = True

print(
    get_word_level_accuracy(
        y_true__rand, pred__rand, pad_token = -1, mask = mask)
)

print(
    get_word_level_metric(accuracy_score, y_true__rand, pred__rand,
                      word_char_tokenizer, mask = mask)
)

In [None]:
!pip install torchmetrics

In [None]:
from lightning import LightningModule
from lightning import Trainer
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch import loggers as pl_loggers
import torchmetrics


from model import MODEL_GETTERS_DICT

# ! Make sure:
# * Add metrics

#! Maybe store:
# * batch_size
# * early_stopping_patience

#! Maybe:
# * Checpointing by condition: if model improved on val_loss and val_loss < max_val_loss_to_save


class LitNeuroswipeModel(LightningModule):
    def __init__(self, model_name: str, criterion, 
                 num_classes: int,
                 train_batch_size: int = None,
                 criterion_ignore_index: int = -100, optim_kwargs = None, 
                 optimizer_ctor=None, lr_scheduler_ctor=None, label_smoothing=0.0,
                 ) -> None:
        super().__init__()

        self.optim_kwargs = optim_kwargs or dict(lr=1e-4, weight_decay=0)
        
        self.model_name = model_name
        self.train_batch_size = train_batch_size
        self.label_smoothing = label_smoothing
        self.criterion_ignore_index = criterion_ignore_index

        self.optimizer_ctor = optimizer_ctor
        self.lr_scheduler_ctor = lr_scheduler_ctor

        self.model = MODEL_GETTERS_DICT[model_name]()
        self.criterion = criterion
        
        self.train_token_acc = torchmetrics.classification.Accuracy(
            task="multiclass", num_classes=num_classes, ignore_index=criterion_ignore_index)
        self.val_token_acc = torchmetrics.classification.Accuracy(
            task="multiclass", num_classes=num_classes, ignore_index=criterion_ignore_index)
        self.train_token_f1 = torchmetrics.classification.F1Score(
            task="multiclass", num_classes=num_classes, ignore_index=criterion_ignore_index)
        self.val_token_f1 = torchmetrics.classification.F1Score(
            task="multiclass", num_classes=num_classes, ignore_index=criterion_ignore_index)

    def forward(self, x, kb_tokens, y, x_pad_mask, y_pad_mask):
        x_encoded = self.model.encode(x, kb_tokens, x_pad_mask)
        return self.model.decode(x_encoded, y, x_pad_mask, y_pad_mask)
    
    def configure_optimizers(self):
        optimizer = self.optimizer_ctor(self.parameters(), **self.optim_kwargs)
        
        optimizers_configuration = {'optimizer': optimizer}

        if self.lr_scheduler_ctor:
            lr_scheduler = self.lr_scheduler_ctor(optimizer)
            optimizers_configuration['lr_scheduler'] = lr_scheduler
            optimizers_configuration['monitor'] = 'val_loss'

        return optimizers_configuration


    def training_step(self, batch, batch_idx):
        batch_x, batch_y = batch
        
        batch_size = batch_y.shape[-1]

        # batch_x, batch_y = move_all_to_device(batch_x, batch_y, self.device)

        # * batch_x is a Tuple of (curve_traj_feats, curve_kb_tokens,
        #   decoder_in, curve_pad_mask, dec_seq_pad_mask).
        # * batch_y is decoder_out.
        
        # preds.shape = (chars_seq_len, batch_size, n_classes)
        
        curve_traj_feats, curve_kb_tokens, ecoder_in, curve_pad_mask, dec_seq_pad_mask = batch_x

        pred = self.forward(*batch_x)

        loss = self.criterion(pred, batch_y, ignore_index=self.criterion_ignore_index,
                              label_smoothing=self.label_smoothing)
        
        
        argmax_pred = torch.argmax(pred, dim=2)
        wl_acccuracy = get_word_level_accuracy(
            argmax_pred.T, batch_y.T, pad_token = self.criterion_ignore_index, mask = dec_seq_pad_mask)
        
        
        flat_y = batch_y.reshape(-1)
        n_classes = pred.shape[-1]
        flat_preds = pred.reshape(-1, n_classes)
        
        self.train_token_acc(flat_preds, flat_y)
        self.log('train_token_level_accuracy', self.train_token_acc, on_step=True, on_epoch=False)
        
        self.train_token_f1(flat_preds, flat_y)
        self.log('train_token_level_f1', self.train_token_f1, on_step=True, on_epoch=False)
        
        
        self.log("train_word_level_accuracy", wl_acccuracy, on_step=True, on_epoch=True, 
                 prog_bar=True, logger=True, batch_size = batch_size)
        
        self.log("train_loss", loss, on_step=True, on_epoch=True, 
                 prog_bar=True, logger=True, batch_size = batch_size)

        return loss

    def validation_step(self, batch, batch_idx):
        batch_x, batch_y = batch
        batch_size = batch_y.shape[-1]
        # batch_x, batch_y = move_all_to_device(batch_x, batch_y, self.device)
        curve_traj_feats, curve_kb_tokens, ecoder_in, curve_pad_mask, dec_seq_pad_mask = batch_x
        pred = self.forward(*batch_x)
        loss = self.criterion(pred, batch_y, ignore_index=self.criterion_ignore_index,
                              label_smoothing=self.label_smoothing)
        argmax_pred = torch.argmax(pred, dim=2)
        wl_acccuracy = get_word_level_accuracy(
            argmax_pred.T, batch_y.T, pad_token = self.criterion_ignore_index, mask = dec_seq_pad_mask)
        
        
        flat_y = batch_y.reshape(-1)
        n_classes = pred.shape[-1]
        flat_preds = pred.reshape(-1, n_classes)
        
        
        self.val_token_acc(flat_preds, flat_y)
        self.log('val_token_level_accuracy', self.train_token_acc, on_step=False, on_epoch=True)
        
        self.val_token_f1(flat_preds, flat_y)
        self.log('val_token_level_f1', self.train_token_f1, on_step=False, on_epoch=True)
        
        
        
        self.log("val_word_level_accuracy", wl_acccuracy, on_step=False, on_epoch=True, 
                 prog_bar=True, logger=True, batch_size = batch_size)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, 
                 logger=True, batch_size = batch_size)
        return loss


tb_logger = pl_loggers.TensorBoardLogger(save_dir=LOG_DIR, name=EXPERIMENT_NAME)

early_stopping_cb = EarlyStopping(
    monitor='val_loss', mode = 'min', patience=25)

model_checkpoint_cb = ModelCheckpoint(
    monitor='val_loss', mode = 'min', save_top_k=10, 
    dirpath='checkpoints/', filename=f'{MODEL_NAME}-{GRID_NAME}--' + '{epoch}-{val_loss:.2f}-{val_word_level_accuracy:.2f}')

# It's more reliable to continue training from epoch-end-checkpoints
model_checkpoint_on_train_epoch_end = ModelCheckpoint(
    save_on_train_epoch_end = True, dirpath='checkpoint_epoch_end/', 
    save_top_k=-1,
    filename=f'{MODEL_NAME}-{GRID_NAME}--' + '{epoch}-{val_loss:.2f}-{val_word_level_accuracy:.2f}')

In [None]:
# ls yandex-cup-2023-ml-neuroswipe/src/checkpoints

In [None]:
dataloader_workers_n = 4

VAL_BATCH_SIZE = 64

train_loader = DataLoader(
    train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True,
    num_workers=dataloader_workers_n, persistent_workers = True, 
    collate_fn=collate_fn)

val_loader = DataLoader(val_dataset, batch_size=VAL_BATCH_SIZE, shuffle=False,
                        num_workers=dataloader_workers_n, persistent_workers = True, 
                        collate_fn=collate_fn)

In [None]:
from lightning.pytorch.callbacks import Callback

class EmptyCudaCacheCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        torch.cuda.empty_cache()
        
epmty_cuda_cache_cb = EmptyCudaCacheCallback()

In [None]:
# ! Word-level accuracy of greedy search results and in-training-predictions are equal.
# ! Thus GreedyAccuracyCallback doesn't make sence and should be deleted.

# ! However if we would count char-level metrics,
# ! greedy search results and in-training-predictions would be different

class GreedyAccuracyCallback(Callback):
    def __init__(self, each_n_steps: int, val_dataset, word_char_tokenizer, logger):
        self.each_n_steps = each_n_steps
        self.val_dataset = val_dataset
        self.word_char_tokenizer = word_char_tokenizer
        self.logger = logger

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_index):
        device = next(pl_module.parameters()).device
        
        if (pl_module.global_step + 1) % self.each_n_steps == 0:
            greedy_accuracy = get_greedy_generator_accuracy(
                val_dataset, pl_module.model, word_char_tokenizer, device)
            self.logger.log_metrics({"greedy_val_accuracy": greedy_accuracy}, step = pl_module.global_step)

In [None]:
greedy_acc_callback = GreedyAccuracyCallback(
    each_n_steps = 9000, val_dataset=val_dataset, 
    word_char_tokenizer=word_char_tokenizer, logger = tb_logger)

In [None]:
""

In [None]:
label_smoothing = 0.045


pl_model = LitNeuroswipeModel(
    model_name = MODEL_NAME, criterion = cross_entropy_with_reshape, 
    num_classes = 35,  # = len(word_char_tokenizer.idx_to_char) - len(['<pad>', '<unk>']) = 37 - 2
    train_batch_size = TRAIN_BATCH_SIZE,
    criterion_ignore_index = word_char_tokenizer.char_to_idx['<pad>'], 
    optim_kwargs = dict(lr=1e-4, weight_decay=0), 
    optimizer_ctor=torch.optim.Adam, lr_scheduler_ctor=lr_scheduler, label_smoothing=0.045,
)

trainer = Trainer(
#     limit_train_batches = 400,  # for validating code before actual training
    log_every_n_steps = 100,
    num_sanity_val_steps=1,
    accelerator = 'gpu',
    max_epochs=1,
    callbacks=[
        early_stopping_cb, model_checkpoint_cb, 
        model_checkpoint_on_train_epoch_end, epmty_cuda_cache_cb,
    ],
    logger=tb_logger,
    val_check_interval=3000,
)

trainer.fit(pl_model, train_loader, val_loader,
            ckpt_path = r'./checkpoints/transformer_bb_model-default--epoch=0-val_loss=0.49-val_word_level_accuracy=0.81.ckpt')

In [None]:
!ls checkpoints

In [None]:
!rm 

In [None]:
!zip -r '/kaggle/working/checkpoints.zip' './checkpoints' 

In [None]:
!rm /kaggle/working/lightning_logs.zip

In [None]:
ls lightning_logs/transformer_bb_model__default__from_random_weights__batch__64