In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, IterableDataset
from tqdm import tqdm
import numpy as np

from model import SwipeCurveTransformer, get_m1_model
from tokenizers import CharLevelTokenizerv2, KeyboardTokenizerv1
from dataset import NeuroSwipeDatasetv2
from word_generators import GreedyGenerator

In [3]:
IN_KAGGLE = False

if IN_KAGGLE:
    DATA_ROOT = "/kaggle/input/yandex-cup-playground"
    MODELS_DIR = ""
else:
    DATA_ROOT = "../data/data_separated_grid"
    MODELS_DIR = "../data/trained_models/m1"

In [4]:
def get_grid(grid_name: str, grids_path: str) -> dict:
    with open(grids_path, "r", encoding="utf-8") as f:
        return json.load(f)[grid_name]

In [6]:
MAX_TRAJ_LEN = 299

grid_name_to_grid_path = os.path.join(DATA_ROOT, "gridname_to_grid.json")
grid_name_to_grid = {grid_name: get_grid(grid_name, grid_name_to_grid_path) for grid_name in ("default", "extra")}


kb_tokenizer = KeyboardTokenizerv1()
word_char_tokenizer = CharLevelTokenizerv2(os.path.join(DATA_ROOT, "voc.txt"))
keyboard_selection_set = set(kb_tokenizer.i2t)


val_path = os.path.join(DATA_ROOT, "valid__in_train_format.jsonl")


val_dataset = NeuroSwipeDatasetv2(
    data_path = val_path,
    gridname_to_grid = grid_name_to_grid,
    kb_tokenizer = kb_tokenizer,
    max_traj_len = MAX_TRAJ_LEN,
    word_tokenizer = word_char_tokenizer,
    include_time = False,
    include_velocities = True,
    include_accelerations = True,
    has_target=True,
    has_one_grid_only=False,
    include_grid_name=True,
    keyboard_selection_set=keyboard_selection_set,
    total = 10_000
)

test_path = os.path.join(DATA_ROOT, "test.jsonl")


test_dataset = NeuroSwipeDatasetv2(
    data_path = test_path,
    gridname_to_grid = grid_name_to_grid,
    kb_tokenizer = kb_tokenizer,
    max_traj_len = MAX_TRAJ_LEN,
    word_tokenizer = word_char_tokenizer,
    include_time = False,
    include_velocities = True,
    include_accelerations = True,
    has_target=False,
    has_one_grid_only=False,
    include_grid_name=True,
    keyboard_selection_set=keyboard_selection_set,
    total = 10_000
)

100%|██████████| 10000/10000 [00:01<00:00, 6752.40it/s]
100%|██████████| 10000/10000 [00:01<00:00, 6165.24it/s]


In [19]:
def create_raw_pred_list_greedy(dataset: NeuroSwipeDatasetv2,
                                grid_name_to_greedy_generator,
                                skip_grid_name = None):
    """
    Creates submission file generating words greedily.

    Arguments:
    ----------
    dataset: NeuroSwipeDatasetv2
    grid_name_to_greedy_generator: dict
        Dict mapping grid names to GreedyGenerator objects.
    skip_grid_name: str
    """

    preds = []

    for i, data in tqdm(enumerate(dataset), total=len(dataset)):
        try:
            (xyt, kb_tokens, _, traj_pad_mask, word_mask), target, grid_name = data
            if grid_name == skip_grid_name:
                continue
            pred = grid_name_to_greedy_generator[grid_name](xyt, kb_tokens, traj_pad_mask)
            pred = pred.removeprefix("<sos>") 
            preds.append([pred])
        except KeyboardInterrupt:
            print('Досрочно остановлено пользователем')
            break
    return preds

In [8]:
def remove_duplicates(preds):
    new_preds = []
    met_preds = set()
    for pred in preds:
        if pred in met_preds:
            continue
        met_preds.add(pred)
        new_preds.append(pred)
    return new_preds


def get_metric(preds_list, ref):
    # Works properly if has duplicates or n_line_preds < 4

    MMR = 0
    
    for preds, target in zip(preds_list, ref):
        preds = remove_duplicates(preds)

        weights = [1, 0.1, 0.09, 0.08]

        line_MRR = sum(weights[i]* (pred == target) for i, pred in enumerate(preds))

        MMR += line_MRR
    
    MMR /= len(preds_list)

    return MMR

In [16]:
from typing import Callable, Dict, List


def get_grid_name_to_target_list(dataset: NeuroSwipeDatasetv2):
    GRID_NAMES = ("default", "extra")

    grid_name_to_target_list = {grid_name: [] for grid_name in GRID_NAMES}

    for i, data in tqdm(enumerate(dataset), total=len(dataset)):
        (_, _, _, _, word_mask), target, grid_name = data
        target_len = torch.sum(~word_mask)
        target = word_char_tokenizer.decode(target[:target_len - 1])
        # NeuroSwipeDatasetv2 masks all tokens after <eos>.
        # So the line below is not needed. However, 
        # the current version of NeuroSwipeDatasetv1 is
        # errorous and does not mask the first <pad> token.
        # So the line below is needed for NeuroSwipeDatasetv1.
        target = target.removesuffix('<pad>').removesuffix('<eos>')
        grid_name_to_target_list[grid_name].append(target)
    return grid_name_to_target_list




def evaluate_model_greedy(val_dataset: NeuroSwipeDatasetv2,
                          model: nn.Module,
                          grid_name: str,
                          grid_name_to_target_list: Dict[str, List[str]],
                          word_char_tokenizer: CharLevelTokenizerv2,
                          device: torch.device):
    """
    Evaluates model on validation dataset using greedy generation.
    """
    assert grid_name in ("extra", "default")
    model.eval()
    model.to(device)
    generator = GreedyGenerator(model, word_char_tokenizer, device)
    grid_name_to_greedy_generator = {grid_name:  generator}
    skip_grid_name = "default" if grid_name == "extra" else "extra"
    preds = create_raw_pred_list_greedy(val_dataset,
                                        grid_name_to_greedy_generator,
                                        skip_grid_name=skip_grid_name)
    targets = grid_name_to_target_list[grid_name]
    MMR = get_metric(preds, targets)
    return MMR


def evaluate_weights_greedy(val_dataset: NeuroSwipeDatasetv2,
                            model_getter: Callable,
                            weights_path: str,
                            grid_name: str,
                            grid_name_to_target_list: Dict[str, List[str]],
                            word_char_tokenizer: CharLevelTokenizerv2,
                            device: torch.device):
    
    model = model_getter(device, weights_path)
    MMR = evaluate_model_greedy(val_dataset,
                                model,
                                grid_name,
                                grid_name_to_target_list,
                                word_char_tokenizer,
                                device)
    return MMR

In [24]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [25]:
grid_name = "default"
model_getter = get_m1_model
weights_path = r"..\data\trained_models\m1_1\best_model__2023_11_09__10_36_02__0.14229_default_switch_2_try_2.pt"
grid_name_to_greedy_generator = {grid_name:  GreedyGenerator(model_getter(weights_path = weights_path, device = device), word_char_tokenizer, device)}

In [16]:
grid_name_to_greedy_generator['default']()

<word_generators.GreedyGenerator at 0x18ad4671d50>

In [18]:
model = get_m1_model(device, weights_path)

In [19]:
greedy_generator = GreedyGenerator(model, word_char_tokenizer, device)


In [21]:
greedy_generator = GreedyGenerator(model, word_char_tokenizer, device)


print("{:<20} {:<20}".format("target", "prediction"))
print("-"*31)

n_examples = 40

for i, data in enumerate(val_dataset):

    (xyt, kb_tokens, dec_in_char_seq, traj_pad_mask, word_pad_mask), target, grid_name = data

    pred = greedy_generator(xyt, kb_tokens, traj_pad_mask)

    # strip работвет только потому что в настоящих словах нет этих символов
    pred = pred
    target = word_char_tokenizer.decode(target).strip("<eos><pad>")
    print("{:<20} {:<20}".format(target, pred))

    if i >= n_examples:
        break

target               prediction          
-------------------------------
на                   на                  
все                  все                 
этом                 этом                
добрый               добрый              
девочка              девочка             
сказала              сказала             
скинь                скинь               
геев                 геев                
тобой                тобой               
была                 быстра              
есть                 есть                
да                   да                  
муж                  маж                 
щас                  щас                 
она                  она                 
проблема             проблема            
билайн               билайн              
уже                  уже                 
раньше               раньше              
рам                  нам                 
щас                  щас                 
купил                купил               
ты

In [32]:
mmr = evaluate_weights(
    val_dataset = val_dataset,
    model_getter = get_m1_model,
    grid_name = "default",
    weights_path = weights_path,
    device = device)

100%|██████████| 10000/10000 [14:40<00:00, 11.36it/s] 


In [33]:
print(mmr)

0.8512107051826678


In [36]:
{"m1_v2/best_model__2023_11_09__10_36_02__0.14229_default_switch_2_try_2.pt": 0.8512107051826678}

{'m1_v2/best_model__2023_11_09__10_36_02__0.14229_default_switch_2_try_2.pt': 0.8512107051826678}

# Let's create a submission