# Setting up environment (Skip when running locally)
# Only for kaggle / colab

In [None]:
WORKING_DIRECTORY =  "/kaggle/working" # "/content" 
GIT_BRANCH = "refactoring"

In [None]:
# !git clone https://github.com/proshian/neuroswipe.git
# %cd {WORKING_DIRECTORY}/neuroswipe
# !git pull
# !git checkout {GIT_BRANCH}

In [None]:
# !!!!!!!!!! Pull data from dvc in this cell !!!!!!!!!!

In [None]:
# !ls {WORKING_DIRECTORY}/neuroswipe/src/checkpoint_epoch_end

In [None]:
!pip install lightning --quiet
!pip install torchmetrics --quiet

In [None]:
%cd {WORKING_DIRECTORY}/neuroswipe
!git pull
!git checkout {GIT_BRANCH}

%cd {WORKING_DIRECTORY}/neuroswipe/src

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# To identify repo state and reproduce the experiment
!git show -s --format=%H

In [None]:
# !rm -r /kaggle/working/neuroswipe/src/lightning_logs
# !rm -r /kaggle/working/neuroswipe/src/checkpoints
# !rm -r /kaggle/working/neuroswipe/src/checkpoint_epoch_end

In [None]:
# !zip -r /kaggle/working/src.zip /kaggle/working/neuroswipe/src

In [None]:
# !rm /kaggle/working/src.zip

# Script arguments emulation

In [None]:
############# Script arguments emulation #############

CKPT_FNAME = None
GRID_NAME = "default"
TRAIN_BATCH_SIZE = 256
VAL_BATCH_SIZE = 512
IN_KAGGLE = False
RANDOM_SEED = 12
NOISE_RANGE = 0  # set to 0 to avoid augmentation
LOG_DIR = "lightning_logs/"
MODEL_NAME = "v3_nearest_and_traj_transformer_bigger"   # "v2_weighted_transformer_bigger"  #"weighted_transformer_bigger"  # "transformer_m1_bigger"
TRANSFORM_NAME =  "traj_feats_and_nearest_key"  # "traj_feats_and_distances"  # "nearest_key_only"
DIST_WEIGHTS_FUNC_NAME =  "weights_function_v1"  # "weights_function_sigmoid_normalized_v1"  # "weights_function_v1_softmax"

USE_COORDS = True
USE_TIME = False
USE_VELOCITY = True
USE_ACCELERATION = True


TRAJ_FEATS_STR = f"{'time' if USE_TIME else ''}{'_acceleration' if USE_ACCELERATION else ''}{'_velocity' if USE_VELOCITY else ''}"


USE_AUGMENTATIONS_STR = f"uniform_int_noise_{NOISE_RANGE}__" if NOISE_RANGE else ""
EXPERIMENT_NAME = f"{MODEL_NAME}__{GRID_NAME}__{TRAJ_FEATS_STR}__{USE_AUGMENTATIONS_STR}from_random_weights__batch__{TRAIN_BATCH_SIZE}/SEED_{RANDOM_SEED}"

DATA_ROOT = "../data/data_separated_grid"


DATALOADER_NUM_WORKERS = 4

In [8]:
# sanity check for amount of workers
import multiprocessing
multiprocessing.cpu_count()

4

# Imports

In [None]:
import os

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import numpy as np


from ns_tokenizers import CharLevelTokenizerv2, KeyboardTokenizerv1
from dataset import CurveDataset, CollateFnV2
from feature_extractors import weights_function_v1_softmax, weights_function_v1, weights_function_sigmoid_normalized_v1
from feature_extractors import get_transforms
from metrics import get_word_level_accuracy, get_word_level_metric

# Other constancts (that need imports above)

In [None]:
################ Other constants ####################
GRID_NAME_TO_DS_PATHS = {
    "extra": {
        "train": os.path.join(DATA_ROOT, "train__extra_only_no_errors__2023_11_01__19_49_14.jsonl"),
        "val": os.path.join(DATA_ROOT, "valid__in_train_format__extra_only.jsonl")
    },
    "default": {
        "train": os.path.join(DATA_ROOT, "train__default_only_no_errors__2023_10_31__03_26_16.jsonl"),
        "val": os.path.join(DATA_ROOT, "valid__in_train_format__default_only.jsonl")
    }
}


DS_PATHS = GRID_NAME_TO_DS_PATHS[GRID_NAME]


DIST_WEIGHTS_FUNCS_DICT = {
    'weights_function_v1_softmax': weights_function_v1_softmax,
    'weights_function_v1': weights_function_v1,
    'weights_function_sigmoid_normalized_v1': weights_function_sigmoid_normalized_v1,
}

DIST_WEIGHTS_FUNC = DIST_WEIGHTS_FUNCS_DICT[DIST_WEIGHTS_FUNC_NAME]

N_COORD_FEATS = 2 * (USE_COORDS + USE_VELOCITY + USE_ACCELERATION) + USE_TIME

In [None]:
# if IN_KAGGLE:
#     DATA_ROOT = "/kaggle/input/neuroswipe-defualt-only-v1"
#     MODELS_DIR = ""

# Fn defintions

In [None]:
def init_random_seed(value):
    # random.seed(value)
    np.random.seed(value)
    torch.manual_seed(value)
    torch.cuda.manual_seed(value)
    # torch.backends.cudnn.deterministic = True

In [None]:
def cross_entropy_with_reshape(pred, target, ignore_index=-100, label_smoothing=0.0):
    """
    pred - BatchSize x TargetLen x VocabSize
    target - BatchSize x TargetLen
    """
    pred_flat = pred.view(-1, pred.shape[-1])  # BatchSize*TargetLen x VocabSize
    target_flat = target.reshape(-1)  # BatchSize*TargetLen
    return F.cross_entropy(pred_flat,
                           target_flat,
                           ignore_index=ignore_index,
                           label_smoothing=label_smoothing)

In [None]:
def get_lr_scheduler(optimizer):
    return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                      patience=20,
                                                      factor=0.5,
                                                      verbose=True)

# Dataset creation

In [None]:
train_total = 5_237_584
val_total = 9_416

In [None]:
gridname_to_grid_path = os.path.join(DATA_ROOT, "gridname_to_grid.json")
voc_path=os.path.join(DATA_ROOT, "voc.txt")
char_tokenizer = CharLevelTokenizerv2(voc_path)
kb_tokenizer = KeyboardTokenizerv1()


train_transform, val_transform = get_transforms(
    gridname_to_grid_path=gridname_to_grid_path,
    grid_names=[GRID_NAME],
    transform_name=TRANSFORM_NAME,
    char_tokenizer=char_tokenizer,
    uniform_noise_range=NOISE_RANGE,
    include_time=USE_TIME,
    include_velocities=USE_VELOCITY,
    include_accelerations=USE_ACCELERATION,
    dist_weights_func=DIST_WEIGHTS_FUNC,
    ds_paths_list=DS_PATHS.values(),
    totals=(train_total, val_total)
)

In [None]:
train_dataset = CurveDataset(
    data_path=DS_PATHS['train'],
    store_gnames=False,
    init_transform=None,
    get_item_transform=train_transform,
    total=train_total  # 349172
)

val_dataset = CurveDataset(
    data_path=DS_PATHS['val'],
    store_gnames=False,
    init_transform=None,
    get_item_transform=val_transform,
    total=val_total  # 349172
)

# WORD_PAD_IDX definition


In [6]:
WORD_PAD_IDX = char_tokenizer.char_to_idx['<pad>']

# Random seed

In [None]:
init_random_seed(RANDOM_SEED)

# Model components

In [None]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda')

In [None]:
collate_fn = CollateFnV2(
    word_pad_idx = WORD_PAD_IDX, batch_first = False)

# Lightning Module; Trainig 

In [None]:
from lightning import LightningModule
from lightning import Trainer
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch import loggers as pl_loggers
import torchmetrics


from model import MODEL_GETTERS_DICT

# ! Make sure:
# * Add metrics

#! Maybe store:
# * batch_size
# * early_stopping_patience

#! Maybe:
# * Checpointing by condition: if model improved on val_loss and val_loss < max_val_loss_to_save


class LitNeuroswipeModel(LightningModule):
    def __init__(self, model_name: str, n_coord_feats: int, criterion,
                 num_classes: int,
                 train_batch_size: int = None,  # to be able to know batch size from checkpoint
                 criterion_ignore_index: int = -100, optim_kwargs = None,
                 optimizer_ctor=None, lr_scheduler_ctor=None, label_smoothing=0.0,
                 ) -> None:
        super().__init__()

        self.save_hyperparameters(ignore = ["criterion", 'lr_scheduler_ctor', 'optimizer_ctor'])

        self.optim_kwargs = optim_kwargs or dict(lr=1e-4, weight_decay=0)

        self.model_name = model_name
        self.train_batch_size = train_batch_size
        self.label_smoothing = label_smoothing
        self.criterion_ignore_index = criterion_ignore_index

        self.optimizer_ctor = optimizer_ctor
        self.lr_scheduler_ctor = lr_scheduler_ctor

        self.model = MODEL_GETTERS_DICT[model_name](n_coord_feats=n_coord_feats)
        self.criterion = criterion

        self.train_token_acc = torchmetrics.classification.Accuracy(
            task="multiclass", num_classes=num_classes, ignore_index=criterion_ignore_index)
        self.val_token_acc = torchmetrics.classification.Accuracy(
            task="multiclass", num_classes=num_classes, ignore_index=criterion_ignore_index)
        self.train_token_f1 = torchmetrics.classification.F1Score(
            task="multiclass", num_classes=num_classes, ignore_index=criterion_ignore_index)
        self.val_token_f1 = torchmetrics.classification.F1Score(
            task="multiclass", num_classes=num_classes, ignore_index=criterion_ignore_index)

    def forward(self, encoder_in, y, encoder_in_pad_mask, y_pad_mask):
        return self.model.forward(encoder_in, y, encoder_in_pad_mask, y_pad_mask)

    def configure_optimizers(self):
        optimizer = self.optimizer_ctor(self.parameters(), **self.optim_kwargs)

        optimizers_configuration = {'optimizer': optimizer}

        if self.lr_scheduler_ctor:
            lr_scheduler = self.lr_scheduler_ctor(optimizer)
            optimizers_configuration['lr_scheduler'] = lr_scheduler
            optimizers_configuration['monitor'] = 'val_loss'

        return optimizers_configuration


    def training_step(self, batch, batch_idx):
        batch_x, batch_y = batch

        batch_size = batch_y.shape[-1]

        # batch_x, batch_y = move_all_to_device(batch_x, batch_y, self.device)

        # * batch_x is a Tuple of (curve_traj_feats, curve_kb_tokens,
        #   decoder_in, swipe_pad_mask, dec_seq_pad_mask).
        # * batch_y is decoder_out.

        # preds.shape = (chars_seq_len, batch_size, n_classes)

        encoder_in, decoder_in, swipe_pad_mask, dec_seq_pad_mask = batch_x

        pred = self.forward(*batch_x)

        loss = self.criterion(pred, batch_y, ignore_index=self.criterion_ignore_index,
                              label_smoothing=self.label_smoothing)


        argmax_pred = torch.argmax(pred, dim=2)
        wl_acccuracy = get_word_level_accuracy(
            argmax_pred.T, batch_y.T, pad_token = self.criterion_ignore_index, mask = dec_seq_pad_mask)


        flat_y = batch_y.reshape(-1)
        n_classes = pred.shape[-1]
        flat_preds = pred.reshape(-1, n_classes)

        self.train_token_acc(flat_preds, flat_y)
        self.log('train_token_level_accuracy', self.train_token_acc, on_step=True, on_epoch=False)

        self.train_token_f1(flat_preds, flat_y)
        self.log('train_token_level_f1', self.train_token_f1, on_step=True, on_epoch=False)


        self.log("train_word_level_accuracy", wl_acccuracy, on_step=True, on_epoch=True,
                 prog_bar=True, logger=True, batch_size = batch_size)

        self.log("train_loss", loss, on_step=True, on_epoch=True,
                 prog_bar=True, logger=True, batch_size = batch_size)

        return loss

    def validation_step(self, batch, batch_idx):
        batch_x, batch_y = batch
        batch_size = batch_y.shape[-1]
        # batch_x, batch_y = move_all_to_device(batch_x, batch_y, self.device)
        encoder_in, decoder_in, swipe_pad_mask, dec_seq_pad_mask = batch_x
        pred = self.forward(*batch_x)
        loss = self.criterion(pred, batch_y, ignore_index=self.criterion_ignore_index,
                              label_smoothing=self.label_smoothing)
        argmax_pred = torch.argmax(pred, dim=2)
        wl_acccuracy = get_word_level_accuracy(
            argmax_pred.T, batch_y.T, pad_token = self.criterion_ignore_index, mask = dec_seq_pad_mask)


        flat_y = batch_y.reshape(-1)
        n_classes = pred.shape[-1]
        flat_preds = pred.reshape(-1, n_classes)


        self.val_token_acc(flat_preds, flat_y)
        self.log('val_token_level_accuracy', self.train_token_acc, on_step=False, on_epoch=True)

        self.val_token_f1(flat_preds, flat_y)
        self.log('val_token_level_f1', self.train_token_f1, on_step=False, on_epoch=True)



        self.log("val_word_level_accuracy", wl_acccuracy, on_step=False, on_epoch=True,
                 prog_bar=True, logger=True, batch_size = batch_size)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True,
                 logger=True, batch_size = batch_size)
        return loss


tb_logger = pl_loggers.TensorBoardLogger(save_dir=LOG_DIR, name=EXPERIMENT_NAME)

early_stopping_cb = EarlyStopping(
    monitor='val_loss', mode = 'min', patience=35)

model_checkpoint_cb = ModelCheckpoint(
    monitor='val_loss', mode = 'min', save_top_k=10,
    dirpath='checkpoints/', filename=f'{MODEL_NAME}-{GRID_NAME}--' + '{epoch}-{val_loss:.3f}-{val_word_level_accuracy:.3f}')

# It's more reliable to continue training from epoch-end-checkpoints
model_checkpoint_on_train_epoch_end = ModelCheckpoint(
    save_on_train_epoch_end = True, dirpath='checkpoint_epoch_end/',
    save_top_k=-1,
    filename=f'{MODEL_NAME}-{GRID_NAME}--' + '{epoch}-{val_loss:.3f}-{val_word_level_accuracy:.3f}')

In [None]:
# ls neuroswipe/src/checkpoints

In [None]:
train_loader = DataLoader(
    train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True,
    num_workers=DATALOADER_NUM_WORKERS, persistent_workers = True,
    collate_fn=collate_fn)

val_loader = DataLoader(val_dataset, batch_size=VAL_BATCH_SIZE, shuffle=False,
                        num_workers=DATALOADER_NUM_WORKERS, persistent_workers = True,
                        collate_fn=collate_fn)

In [None]:
from lightning.pytorch.callbacks import Callback

class EmptyCudaCacheCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        torch.cuda.empty_cache()

epmty_cuda_cache_cb = EmptyCudaCacheCallback()

In [None]:
label_smoothing = 0.045

ckpt_path = None if CKPT_FNAME is None else fr'./checkpoint_epoch_end/{CKPT_FNAME}'

pl_model = LitNeuroswipeModel(
    model_name = MODEL_NAME, criterion = cross_entropy_with_reshape, n_coord_feats=N_COORD_FEATS,
    num_classes = 35,  # = len(char_tokenizer.idx_to_char) - len(['<pad>', '<unk>']) = 37 - 2
    train_batch_size = TRAIN_BATCH_SIZE,
    criterion_ignore_index = WORD_PAD_IDX,
    optim_kwargs = dict(lr=1e-4, weight_decay=0),
    optimizer_ctor=torch.optim.Adam, lr_scheduler_ctor=get_lr_scheduler, label_smoothing=label_smoothing,
)

trainer = Trainer(
#     limit_train_batches = 400,  # for validating code before actual training
    log_every_n_steps = 100,
    num_sanity_val_steps=0,
    accelerator = 'gpu',
    # max_epochs=100,
    callbacks=[
        # early_stopping_cb,
        model_checkpoint_cb,
        model_checkpoint_on_train_epoch_end, epmty_cuda_cache_cb,
    ],
    logger=tb_logger,
    val_check_interval=3000,
)

trainer.fit(pl_model, train_loader, val_loader,
            ckpt_path = ckpt_path
           )