In [None]:
# !git clone https://github.com/proshian/yandex-cup-2023-ml-neuroswipe.git
# %cd yandex-cup-2023-ml-neuroswipe

In [None]:
!pip install lightning --quiet
!pip install torchmetrics --quiet

In [None]:
%cd /kaggle/working/yandex-cup-2023-ml-neuroswipe
! git pull
!git checkout embeding_experiments

%cd /kaggle/working/yandex-cup-2023-ml-neuroswipe/src

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# to identify repo state and reproduce the experiment
!git show -s --format=%H

In [None]:
# !rm -r /kaggle/working/yandex-cup-2023-ml-neuroswipe/src/lightning_logs
# !rm -r /kaggle/working/yandex-cup-2023-ml-neuroswipe/src/checkpoints
# !rm -r /kaggle/working/yandex-cup-2023-ml-neuroswipe/src/checkpoint_epoch_end

In [None]:
# !zip -r /kaggle/working/src.zip /kaggle/working/yandex-cup-2023-ml-neuroswipe/src

In [None]:
# !rm /kaggle/working/src.zip 

In [None]:
############# Script arguments emulation #############

GRID_NAME = "default"
TRAIN_BATCH_SIZE = 256
VAL_BATCH_SIZE = 512
IN_KAGGLE = False
RANDOM_SEED = 12
NOISE_RANGE = 0  # set to 0 to avoid augmentation
LOG_DIR = "lightning_logs/"
MODEL_NAME = "v2_weighted_transformer_bigger"  #"weighted_transformer_bigger"  # "transformer_m1_bigger"
TRANSFORM_NAME =  "traj_feats_and_distances"  # "traj_feats_and_nearest_key"
DIST_WEIGHTS_FUNC_NAME =  "weights_function_v1_softmax"  # 'weights_function_v1' 

USE_AUGMENTATIONS_STR = f"uniform_int_noise_{NOISE_RANGE}__" if NOISE_RANGE else ""
EXPERIMENT_NAME = f"{MODEL_NAME}__{GRID_NAME}__{USE_AUGMENTATIONS_STR}from_random_weights__batch__{TRAIN_BATCH_SIZE}/SEED_{RANDOM_SEED}"

DATA_ROOT = "../data/data_separated_grid"

In [None]:
import os
import json
import typing as tp


import torch
# import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import numpy as np
# from torch.utils.tensorboard import SummaryWriter


# from model import SwipeCurveTransformer, get_m1_bigger_model
from ns_tokenizers import CharLevelTokenizerv2, KeyboardTokenizerv1
from ns_tokenizers import ALL_CYRILLIC_LETTERS_ALPHABET_ORD
from dataset import CurveDataset, CollateFn
# from word_generators import GreedyGenerator
from nearest_key_lookup import ExtendedNearestKeyLookup  # NearestKeyLookup
from distances_lookup import DistancesLookup
from transforms import FullTransform, TrajFeats_KbWeights_FullTransform
from transforms import weights_function_v1_softmax, weights_function_v1
from metrics import get_word_level_accuracy, get_word_level_metric

In [None]:
################ Other constants ####################    
GRID_NAME_TO_DS_PATHS = {
    "extra": {
        "train": os.path.join(DATA_ROOT, "train__extra_only_no_errors__2023_11_01__19_49_14.jsonl"),
        "val": os.path.join(DATA_ROOT, "valid__in_train_format__extra_only.jsonl")
    },
    "default": {
        "train": os.path.join(DATA_ROOT, "train__default_only_no_errors__2023_10_31__03_26_16.jsonl"),
        "val": os.path.join(DATA_ROOT, "valid__in_train_format__default_only.jsonl")
    }
}


DS_PATHS =  GRID_NAME_TO_DS_PATHS[GRID_NAME]


DIST_WEIGHTS_FUNCS_DICT = {
    'weights_function_v1_softmax': weights_function_v1_softmax,
    'weights_function_v1': weights_function_v1
}

DIST_WEIGHTS_FUNC = DIST_WEIGHTS_FUNCS_DICT[DIST_WEIGHTS_FUNC_NAME]

In [None]:
# if IN_KAGGLE:
#     DATA_ROOT = "/kaggle/input/neuroswipe-defualt-only-v1"
#     MODELS_DIR = ""

In [None]:
def init_random_seed(value):
    # random.seed(value)
    np.random.seed(value)
    torch.manual_seed(value)
    torch.cuda.manual_seed(value)
    # torch.backends.cudnn.deterministic = True

In [None]:
def get_grid(grid_name: str, grids_path: str) -> dict:
    with open(grids_path, "r", encoding="utf-8") as f:
        return json.load(f)[grid_name]

In [None]:
from typing import List, Dict, Tuple, Optional, Set

def get_gridname_to_out_of_bounds_coords_dict(
        data_paths: List[str], gridname_to_wh: dict,
        totals: tp.Iterable[Optional[int]] = None
        ) -> Dict[str, Set[Tuple[int, int]]]:
    """
    Returns a dictionary with grid names as keys and lists of out of bounds coordinates as values.
    """
    totals = totals or [None] * len(data_paths)
    
    gname_to_out_of_bounds = {gname: set() for gname in gridname_to_wh.keys()}

    for data_path, total in zip(data_paths, totals):
        with open(data_path, "r", encoding="utf-8") as json_file:
            for line in tqdm(json_file, total=total):
                json_data = json.loads(line)
                curve = json_data['curve']
                grid_name = curve['grid_name']
                w, h = gridname_to_wh[grid_name]
                X, Y = curve['x'], curve['y']
                out_of_bounds = set((x, y) for x, y in zip(X, Y) 
                                    if x < 0 or x >= w or y < 0 or y >= h)
                gname_to_out_of_bounds[grid_name].update(out_of_bounds)
    return gname_to_out_of_bounds

In [None]:
from typing import Dict, Set, Tuple


def update_out_of_bounds_with_noise(
    noise_min, noise_max,
    gname_to_out_of_bounds, gridname_to_wh: dict,
    )-> Dict[str, Set[Tuple[int, int]]]:
    
    assert noise_min <= 0
    assert noise_max >= 0
    
    additional_out_of_bounds = {gname: set() for gname in gridname_to_wh.keys()}
    
    for gname in gname_to_out_of_bounds.keys():
        w, h = gridname_to_wh[gname]
        
        for x, y in gname_to_out_of_bounds[gname]:
            for i in range(noise_min, noise_max+1):
                for j in range(noise_min, noise_max+1):
                    if x+i < 0 or x+i >= w or y+j < 0 or y+j >=h: 
                        additional_out_of_bounds[gname].add((x+i, y+j))
        
        for x in range(noise_min, w+noise_max+1):
            for y in range(noise_min, 0):
                additional_out_of_bounds[gname].add((x, y))
        
        for x in range(noise_min, w+noise_max+1):
            for y in range(h+1, h+noise_max+1):
                additional_out_of_bounds[gname].add((x, y))
        
        for x in range(w, w+noise_max+1):
            for y in range(0, h+1):
                additional_out_of_bounds[gname].add((x, y))
        
        for x in range(noise_min, 0):
            for y in range(0, h+1):
                additional_out_of_bounds[gname].add((x, y))
                
        gname_to_out_of_bounds[gname].update(additional_out_of_bounds[gname])
        
    return gname_to_out_of_bounds
        

In [None]:
import numpy as np

class RandIntToTrajTransform:
    def __init__(self, min_ = -3, max_ = 3) -> None:
        self.min = min_
        self.max = max_
        
    def __call__(self, data):
        X, Y, T, grid_name, tgt_word = data
        X = np.array(X, dtype = int) + np.random.randint(self.min, self.max, (len(X),))
        Y = np.array(Y, dtype = int) + np.random.randint(self.min, self.max, (len(Y),))
        return X, Y, T, grid_name, tgt_word
    
class SequentialTransform:
    def __init__(self, transforms) -> None:
        self.transforms = transforms
    
    def __call__(self, data):
        for transform in self.transforms:
            data = transform(data)
        return data

In [None]:
from typing import Callable, Tuple, Optional
from predict import get_grid


def get_transforms(gridname_to_grid_path: str,
                   grid_name: str,
                   transform_name: str,
                   char_tokenizer: KeyboardTokenizerv1,
                   uniform_noise_range: bool,
                   totals: Tuple[Optional[int], Optional[int]] = (None, None)
                   ) -> Tuple[Callable, Callable]:
    """Returns train and validation transforms."""
    
    grid = get_grid(grid_name, gridname_to_grid_path)
    w, h = grid['width'], grid['height']
    gname_to_wh = {grid_name: (w, h)}
    kb_tokenizer = KeyboardTokenizerv1()
    
    if transform_name == "traj_feats_and_nearest_key":
                
        print("Accumulating out-of-bounds coordinates...")
        gname_to_out_of_bounds = get_gridname_to_out_of_bounds_coords_dict(
            DS_PATHS.values(), 
            gridname_to_wh = gname_to_wh,
            totals=totals
        )

        print("augmenting gname_to_out_of_bounds")
        gname_to_out_of_bounds = update_out_of_bounds_with_noise(
            noise_min = -NOISE_RANGE, noise_max=NOISE_RANGE+1,
            gname_to_out_of_bounds = gname_to_out_of_bounds, gridname_to_wh = gname_to_wh,
        )


        print("Creating ExtendedNearestKeyLookups...")
        gridname_to_nkl = {
            grid_name: ExtendedNearestKeyLookup(
                grid, ALL_CYRILLIC_LETTERS_ALPHABET_ORD,
                gname_to_out_of_bounds[grid_name]
            )
        }

        full_transform = FullTransform(
            grid_name_to_nk_lookup=gridname_to_nkl,
            grid_name_to_wh=gname_to_wh,
            kb_tokenizer=kb_tokenizer,
            word_tokenizer=char_tokenizer,
            include_time=False,
            include_velocities=True,
            include_accelerations=True,
            kb_tokens_dtype=torch.int32,
            word_tokens_dtype=torch.int64
        )

    elif transform_name == "traj_feats_and_distances":
        # raise NotImplementedError(f"transform '{transform_name}' is not implemented yet.")
        assert isinstance(kb_tokenizer.i2t, list)
        grid_name_to_dist_lookup = {
            # Extra token is for legacy reasons
            grid_name: DistancesLookup(grid, kb_tokenizer.i2t + ['<extra_token>'])
        }

        
        full_transform = TrajFeats_KbWeights_FullTransform(
            grid_name_to_grid={grid_name: grid},
            grid_name_to_dist_lookup=grid_name_to_dist_lookup,
            word_tokenizer=char_tokenizer,
            include_time=False,
            include_velocities=True,
            include_accelerations=True,
            weights_func=DIST_WEIGHTS_FUNC,
            word_tokens_dtype=torch.int64,
        )

    else:
        raise ValueError(f"Unknown transform name: '{transform_name}'")
    

    val_transform = full_transform

    train_transform = None
    if uniform_noise_range != 0:
        augmentation_transform = RandIntToTrajTransform(-uniform_noise_range, uniform_noise_range + 1)
        train_transform = SequentialTransform([augmentation_transform, full_transform])
    else:
        train_transform = full_transform

    return train_transform, val_transform
                

In [None]:
train_total = 5_237_584
val_total = 9_416

In [None]:
gridname_to_grid_path = os.path.join(DATA_ROOT, "gridname_to_grid.json")
voc_path=os.path.join(DATA_ROOT, "voc.txt")
char_tokenizer = CharLevelTokenizerv2(voc_path)
kb_tokenizer = KeyboardTokenizerv1()


train_transform, val_transform = get_transforms(
    gridname_to_grid_path=gridname_to_grid_path,
    grid_name=GRID_NAME,
    transform_name=TRANSFORM_NAME,
    char_tokenizer=char_tokenizer,
    uniform_noise_range=NOISE_RANGE,
    totals=(train_total, val_total)
)

In [None]:
train_dataset = CurveDataset(
    data_path=DS_PATHS['train'],
    store_gnames=False,
    init_transform=None,
    get_item_transform=train_transform,
    total=train_total  # 349172
)

val_dataset = CurveDataset(
    data_path=DS_PATHS['val'],
    store_gnames=False,
    init_transform=None,
    get_item_transform=val_transform,
    total=val_total  # 349172
)

In [None]:
# import matplotlib.pyplot as plt


# val_ds_idx = 3
# model_in, model_out = val_dataset[val_ds_idx]
# seq_of_key_weights = model_in[1]

# swipe_dot_idx = 13

# weights = seq_of_key_weights[swipe_dot_idx]

# print(weights)
# print(len(weights))



# plt.hist(weights, bins = len(weights))

# plt.show()

In [None]:
init_random_seed(RANDOM_SEED)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def cross_entropy_with_reshape(pred, target, ignore_index=-100, label_smoothing=0.0):
    """
    pred - BatchSize x TargetLen x VocabSize
    target - BatchSize x TargetLen
    """
    pred_flat = pred.view(-1, pred.shape[-1])  # BatchSize*TargetLen x VocabSize
    target_flat = target.reshape(-1)  # BatchSize*TargetLen
    return F.cross_entropy(pred_flat,
                           target_flat,
                           ignore_index=ignore_index,
                           label_smoothing=label_smoothing)

In [None]:
def lr_scheduler(optimizer):
    return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                      patience=20,
                                                      factor=0.5,
                                                      verbose=True)

In [None]:
collate_fn = CollateFn(
    word_pad_idx = char_tokenizer.char_to_idx['<pad>'], batch_first = False)

In [None]:
import multiprocessing

multiprocessing.cpu_count()

In [None]:
######  testing get_word_level_accuracy, get_word_level_metric
from sklearn.metrics import f1_score, accuracy_score
import torch

batch_size = 10
seq_len = 5
y_true__rand = torch.randint(0, 32, (batch_size, seq_len))
pred__rand = torch.randint(0, 32, (batch_size, seq_len))
pred__rand[:3] = y_true__rand[:3]

mask = torch.zeros((batch_size, seq_len), dtype = torch.bool)
mask[:, :-3] = True

print(
    get_word_level_accuracy(
        y_true__rand, pred__rand, pad_token = -1, mask = mask)
)

print(
    get_word_level_metric(accuracy_score, y_true__rand, pred__rand,
                      char_tokenizer, mask = mask)
)

In [None]:
from lightning import LightningModule
from lightning import Trainer
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch import loggers as pl_loggers
import torchmetrics


from model import MODEL_GETTERS_DICT

# ! Make sure:
# * Add metrics

#! Maybe store:
# * batch_size
# * early_stopping_patience

#! Maybe:
# * Checpointing by condition: if model improved on val_loss and val_loss < max_val_loss_to_save


class LitNeuroswipeModel(LightningModule):
    def __init__(self, model_name: str, criterion, 
                 num_classes: int,
                 train_batch_size: int = None,
                 criterion_ignore_index: int = -100, optim_kwargs = None, 
                 optimizer_ctor=None, lr_scheduler_ctor=None, label_smoothing=0.0,
                 ) -> None:
        super().__init__()

        self.optim_kwargs = optim_kwargs or dict(lr=1e-4, weight_decay=0)
        
        self.model_name = model_name
        self.train_batch_size = train_batch_size
        self.label_smoothing = label_smoothing
        self.criterion_ignore_index = criterion_ignore_index

        self.optimizer_ctor = optimizer_ctor
        self.lr_scheduler_ctor = lr_scheduler_ctor

        self.model = MODEL_GETTERS_DICT[model_name]()
        self.criterion = criterion
        
        self.train_token_acc = torchmetrics.classification.Accuracy(
            task="multiclass", num_classes=num_classes, ignore_index=criterion_ignore_index)
        self.val_token_acc = torchmetrics.classification.Accuracy(
            task="multiclass", num_classes=num_classes, ignore_index=criterion_ignore_index)
        self.train_token_f1 = torchmetrics.classification.F1Score(
            task="multiclass", num_classes=num_classes, ignore_index=criterion_ignore_index)
        self.val_token_f1 = torchmetrics.classification.F1Score(
            task="multiclass", num_classes=num_classes, ignore_index=criterion_ignore_index)

    def forward(self, traj_feats, kb_tokens, y, x_pad_mask, y_pad_mask):
        return self.model.forward(traj_feats, kb_tokens, y, x_pad_mask, y_pad_mask)
    
    def configure_optimizers(self):
        optimizer = self.optimizer_ctor(self.parameters(), **self.optim_kwargs)
        
        optimizers_configuration = {'optimizer': optimizer}

        if self.lr_scheduler_ctor:
            lr_scheduler = self.lr_scheduler_ctor(optimizer)
            optimizers_configuration['lr_scheduler'] = lr_scheduler
            optimizers_configuration['monitor'] = 'val_loss'

        return optimizers_configuration


    def training_step(self, batch, batch_idx):
        batch_x, batch_y = batch
        
        batch_size = batch_y.shape[-1]

        # batch_x, batch_y = move_all_to_device(batch_x, batch_y, self.device)

        # * batch_x is a Tuple of (curve_traj_feats, curve_kb_tokens,
        #   decoder_in, curve_pad_mask, dec_seq_pad_mask).
        # * batch_y is decoder_out.
        
        # preds.shape = (chars_seq_len, batch_size, n_classes)
        
        curve_traj_feats, curve_kb_tokens, decoder_in, curve_pad_mask, dec_seq_pad_mask = batch_x

        pred = self.forward(*batch_x)
        
        loss = self.criterion(pred, batch_y, ignore_index=self.criterion_ignore_index,
                              label_smoothing=self.label_smoothing)
        
        
        argmax_pred = torch.argmax(pred, dim=2)
        wl_acccuracy = get_word_level_accuracy(
            argmax_pred.T, batch_y.T, pad_token = self.criterion_ignore_index, mask = dec_seq_pad_mask)
        
        
        flat_y = batch_y.reshape(-1)
        n_classes = pred.shape[-1]
        flat_preds = pred.reshape(-1, n_classes)
        
        self.train_token_acc(flat_preds, flat_y)
        self.log('train_token_level_accuracy', self.train_token_acc, on_step=True, on_epoch=False)
        
        self.train_token_f1(flat_preds, flat_y)
        self.log('train_token_level_f1', self.train_token_f1, on_step=True, on_epoch=False)
        
        
        self.log("train_word_level_accuracy", wl_acccuracy, on_step=True, on_epoch=True, 
                 prog_bar=True, logger=True, batch_size = batch_size)
        
        self.log("train_loss", loss, on_step=True, on_epoch=True, 
                 prog_bar=True, logger=True, batch_size = batch_size)

        return loss

    def validation_step(self, batch, batch_idx):
        batch_x, batch_y = batch
        batch_size = batch_y.shape[-1]
        # batch_x, batch_y = move_all_to_device(batch_x, batch_y, self.device)
        curve_traj_feats, curve_kb_tokens, ecoder_in, curve_pad_mask, dec_seq_pad_mask = batch_x
        pred = self.forward(*batch_x)
        loss = self.criterion(pred, batch_y, ignore_index=self.criterion_ignore_index,
                              label_smoothing=self.label_smoothing)
        argmax_pred = torch.argmax(pred, dim=2)
        wl_acccuracy = get_word_level_accuracy(
            argmax_pred.T, batch_y.T, pad_token = self.criterion_ignore_index, mask = dec_seq_pad_mask)
        
        
        flat_y = batch_y.reshape(-1)
        n_classes = pred.shape[-1]
        flat_preds = pred.reshape(-1, n_classes)
        
        
        self.val_token_acc(flat_preds, flat_y)
        self.log('val_token_level_accuracy', self.train_token_acc, on_step=False, on_epoch=True)
        
        self.val_token_f1(flat_preds, flat_y)
        self.log('val_token_level_f1', self.train_token_f1, on_step=False, on_epoch=True)
        
        
        
        self.log("val_word_level_accuracy", wl_acccuracy, on_step=False, on_epoch=True, 
                 prog_bar=True, logger=True, batch_size = batch_size)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, 
                 logger=True, batch_size = batch_size)
        return loss


tb_logger = pl_loggers.TensorBoardLogger(save_dir=LOG_DIR, name=EXPERIMENT_NAME)

early_stopping_cb = EarlyStopping(
    monitor='val_loss', mode = 'min', patience=35)

model_checkpoint_cb = ModelCheckpoint(
    monitor='val_loss', mode = 'min', save_top_k=10, 
    dirpath='checkpoints/', filename=f'{MODEL_NAME}-{GRID_NAME}--' + '{epoch}-{val_loss:.3f}-{val_word_level_accuracy:.3f}')

# It's more reliable to continue training from epoch-end-checkpoints
model_checkpoint_on_train_epoch_end = ModelCheckpoint(
    save_on_train_epoch_end = True, dirpath='checkpoint_epoch_end/', 
    save_top_k=-1,
    filename=f'{MODEL_NAME}-{GRID_NAME}--' + '{epoch}-{val_loss:.3f}-{val_word_level_accuracy:.3f}')

In [None]:
# ls yandex-cup-2023-ml-neuroswipe/src/checkpoints

In [None]:
dataloader_workers_n = 4


train_loader = DataLoader(
    train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True,
    num_workers=dataloader_workers_n, persistent_workers = True, 
    collate_fn=collate_fn)

val_loader = DataLoader(val_dataset, batch_size=VAL_BATCH_SIZE, shuffle=False,
                        num_workers=dataloader_workers_n, persistent_workers = True, 
                        collate_fn=collate_fn)

In [None]:
from lightning.pytorch.callbacks import Callback

class EmptyCudaCacheCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        torch.cuda.empty_cache()
        
epmty_cuda_cache_cb = EmptyCudaCacheCallback()

In [None]:
label_smoothing = 0.045


pl_model = LitNeuroswipeModel(
    model_name = MODEL_NAME, criterion = cross_entropy_with_reshape, 
    num_classes = 35,  # = len(char_tokenizer.idx_to_char) - len(['<pad>', '<unk>']) = 37 - 2
    train_batch_size = TRAIN_BATCH_SIZE,
    criterion_ignore_index = char_tokenizer.char_to_idx['<pad>'], 
    optim_kwargs = dict(lr=1e-4, weight_decay=0), 
    optimizer_ctor=torch.optim.Adam, lr_scheduler_ctor=lr_scheduler, label_smoothing=0.045,
)

trainer = Trainer(
#     limit_train_batches = 400,  # for validating code before actual training
    log_every_n_steps = 100,
    num_sanity_val_steps=0,
    accelerator = 'gpu',
    # max_epochs=100,
    callbacks=[
        # early_stopping_cb, 
        model_checkpoint_cb, 
        model_checkpoint_on_train_epoch_end, epmty_cuda_cache_cb,
    ],
    logger=tb_logger,
    val_check_interval=3000,
)

trainer.fit(pl_model, train_loader, val_loader,
#             ckpt_path = r'./checkpoint_epoch_end/PASTE-PATH-HERE'
           )