In [2]:
%load_ext autoreload
%autoreload 2

In [56]:
import os
import json
import copy
from multiprocessing import cpu_count

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, IterableDataset
from tqdm import tqdm
import numpy as np

from model import SwipeCurveTransformer, get_m1_model
from tokenizers import CharLevelTokenizerv2, KeyboardTokenizerv1
from dataset import NeuroSwipeDatasetv2
from word_generators import GreedyGenerator
from word_generation_v2 import predict_greedy_raw_multiproc, predict_greedy_raw

In [4]:
IN_KAGGLE = False

if IN_KAGGLE:
    DATA_ROOT = "/kaggle/input/yandex-cup-playground"
    MODELS_DIR = ""
else:
    DATA_ROOT = "../data/data_separated_grid"
    MODELS_ROOT = "../data/trained_models"

In [5]:
def get_grid(grid_name: str, grids_path: str) -> dict:
    with open(grids_path, "r", encoding="utf-8") as f:
        return json.load(f)[grid_name]

In [6]:
MAX_TRAJ_LEN = 299

grid_name_to_grid_path = os.path.join(DATA_ROOT, "gridname_to_grid.json")
grid_name_to_grid = {grid_name: get_grid(grid_name, grid_name_to_grid_path) for grid_name in ("default", "extra")}


kb_tokenizer = KeyboardTokenizerv1()
word_char_tokenizer = CharLevelTokenizerv2(os.path.join(DATA_ROOT, "voc.txt"))
keyboard_selection_set = set(kb_tokenizer.i2t)


val_path = os.path.join(DATA_ROOT, "valid__in_train_format.jsonl")


val_dataset = NeuroSwipeDatasetv2(
    data_path = val_path,
    gridname_to_grid = grid_name_to_grid,
    kb_tokenizer = kb_tokenizer,
    max_traj_len = MAX_TRAJ_LEN,
    word_tokenizer = word_char_tokenizer,
    include_time = False,
    include_velocities = True,
    include_accelerations = True,
    has_target=True,
    has_one_grid_only=False,
    include_grid_name=True,
    keyboard_selection_set=keyboard_selection_set,
    total = 10_000
)

test_path = os.path.join(DATA_ROOT, "test.jsonl")


test_dataset = NeuroSwipeDatasetv2(
    data_path = test_path,
    gridname_to_grid = grid_name_to_grid,
    kb_tokenizer = kb_tokenizer,
    max_traj_len = MAX_TRAJ_LEN,
    word_tokenizer = word_char_tokenizer,
    include_time = False,
    include_velocities = True,
    include_accelerations = True,
    has_target=False,
    has_one_grid_only=False,
    include_grid_name=True,
    keyboard_selection_set=keyboard_selection_set,
    total = 10_000
)

100%|██████████| 10000/10000 [00:00<00:00, 10075.54it/s]
100%|██████████| 10000/10000 [00:01<00:00, 9818.58it/s]


In [7]:
from torch.utils.data import Dataset

class NeuroSwipeGridSubset(Dataset):
    def __init__(self, dataset: Dataset, grid_name: str):
        self.dataset = dataset
        self.grid_name = grid_name
        self.grid_name_idxs = self._get_grid_name_idxs()
        
            
    def _get_grid_name_idxs(self):
        grid_name_idxs: list[int] = []
        for i, ((_, _, _, _, _), _, grid_name) in enumerate(self.dataset):
            if grid_name == self.grid_name:
                grid_name_idxs.append(i)
        return grid_name_idxs

    
    def __len__(self):
        return len(self.grid_name_idxs)
    
    def __getitem__(self, idx):
        return self.dataset[self.grid_name_idxs[idx]]

In [8]:
val_default_dataset = NeuroSwipeGridSubset(val_dataset, "default")
val_extra_dataset = NeuroSwipeGridSubset(val_dataset, "extra")

test_default_dataset = NeuroSwipeGridSubset(test_dataset, "default")
test_extra_dataset = NeuroSwipeGridSubset(test_dataset, "extra")

In [9]:
def remove_duplicates(preds):
    new_preds = []
    met_preds = set()
    for pred in preds:
        if pred in met_preds:
            continue
        met_preds.add(pred)
        new_preds.append(pred)
    return new_preds


def get_metric(preds_list, ref):
    # Works properly if has duplicates or n_line_preds < 4

    MMR = 0
    
    for preds, target in zip(preds_list, ref):
        preds = remove_duplicates(preds)

        weights = [1, 0.1, 0.09, 0.08]

        line_MRR = sum(weights[i]* (pred == target) for i, pred in enumerate(preds))

        MMR += line_MRR
    
    MMR /= len(preds_list)

    return MMR

In [26]:
from typing import Callable, Dict, List


def get_targets(dataset: NeuroSwipeDatasetv2) -> List[str]:
    targets = []
    for (_, _, _, _, word_pad_mask), target_tokens, _ in dataset:
        target_len = int(torch.sum(~word_pad_mask)) - 1
        target = word_char_tokenizer.decode(target_tokens[:target_len])
        targets.append(target)
    return targets

def evaluate_model_greedy(val_dataset: NeuroSwipeDatasetv2,
                          model: nn.Module,
                          grid_name: str,
                          targets: List[str],
                          word_char_tokenizer: CharLevelTokenizerv2,
                          device: torch.device):
    """
    Evaluates model on validation dataset using greedy generation.
    """
    assert grid_name in ("extra", "default")
    model.eval()
    model.to(device)
    generator = GreedyGenerator(model, word_char_tokenizer, device)
    grid_name_to_greedy_generator = {grid_name:  generator}
    preds = predict_greedy_raw(val_dataset,
                                grid_name_to_greedy_generator)
    MMR = get_metric(preds, targets)
    return MMR, preds


def evaluate_weights_greedy(val_dataset: NeuroSwipeDatasetv2,
                            model_getter: Callable,
                            weights_path: str,
                            grid_name: str,
                            targets: List[str],
                            word_char_tokenizer: CharLevelTokenizerv2,
                            device: torch.device):
    
    model = model_getter(device, weights_path)
    MMR, preds = evaluate_model_greedy(val_dataset,
                                       model,
                                       grid_name,
                                       targets,
                                       word_char_tokenizer,
                                       device)
    return MMR, preds


In [11]:
# def get_i_to_grid_name(dataset: NeuroSwipeDatasetv2):
#     i_to_grid_name = []
#     for i, data in tqdm(enumerate(dataset), total=len(dataset)):
#         (_, _, _, _, _), _, grid_name = data
#         i_to_grid_name.append(grid_name)
#     return i_to_grid_name


# def combine_preds(i_to_grid_name, default_preds, extra_preds):
#     preds = []
#     default_i = 0
#     extra_i = 0
#     for i, grid_name in enumerate(i_to_grid_name):
#         if grid_name == "default":
#             preds.append(default_preds[default_i])
#             default_i += 1
#         elif grid_name == "extra":
#             preds.append(extra_preds[extra_i])
#             extra_i += 1
#         else:
#             raise ValueError(f"Unknown grid_name: {grid_name}")
        
#     return preds
        

In [39]:
def merge_preds(default_preds,
                extra_preds,
                default_idxs,
                extra_idxs):
    preds = [None] * (len(default_preds) + len(extra_preds))

    for i, val in zip(default_idxs, default_preds):
        preds[i] = val
    for i, val in zip(extra_idxs, extra_preds):
        preds[i] = val

    return preds


In [13]:
device = torch.device('cpu')
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [65]:
grid_name = "default"
model_getter = get_m1_model
weights_path = os.path.join(MODELS_ROOT, "m1/best_model__2023_11_04__18_31_37__0.02530_default_switch_2.pt")
model = model_getter(device, weights_path)
grid_name_to_greedy_generator = {grid_name: GreedyGenerator(model, word_char_tokenizer, device)}

In [66]:
greedy_generator = GreedyGenerator(model, word_char_tokenizer, device)


print("{:<20} {:<20}".format("target", "prediction"))
print("-"*31)

n_examples = 40

for i, data in enumerate(val_default_dataset):

    (xyt, kb_tokens, dec_in_char_seq, traj_pad_mask, word_pad_mask), target, grid_name = data

    pred = greedy_generator(xyt, kb_tokens, traj_pad_mask)

    # strip работвет только потому что в настоящих словах нет этих символов
    pred = pred
    target_len = int(torch.sum(~word_pad_mask)) - 1
    target = word_char_tokenizer.decode(target[:target_len])
    print("{:<20} {:<20}".format(target, pred))

    if i >= n_examples:
        break

target               prediction          
-------------------------------
на                   на                  
все                  фай                 
добрый               было                
девочка              будут               
сказала              сейчаск             
скинь                фий                 
геев                 груз                
тобой                бой                 
была                 был                 
да                   да                  
муж                  муй                 
щас                  хотя                
она                  она                 
проблема             пубей               
билайн               бы                  
уже                  же                  
раньше               буду                
рам                  ты                  
щас                  ты                  
купил                куйду               
ты                   ты                  
зовут                хэту                
ко

In [40]:
val_default_targets = get_targets(val_default_dataset)
val_extra_targets = get_targets(val_extra_dataset)

In [49]:
mmr, preds = evaluate_model_greedy(val_default_dataset,
                                    model,
                                    grid_name,
                                    val_default_targets,
                                    word_char_tokenizer,
                                    device)

  1%|▏         | 140/9416 [00:06<06:50, 22.62it/s]


KeyboardInterrupt: 

In [None]:
print(mmr)

In [43]:
print(preds[200:250])

[['скинь'], ['мазора'], ['то'], ['анюта'], ['звони'], ['лесник'], ['минут'], ['забрала'], ['на'], ['обуд'], ['завтра'], ['такими'], ['давай'], ['посади'], ['бон'], ['даже'], ['перчатка'], ['работа'], ['никого'], ['отресли'], ['не'], ['раз'], ['блин'], ['пока'], ['ну'], ['тогда'], ['башка'], ['был'], ['продал'], ['хочу'], ['хорошая'], ['кофе'], ['быть'], ['ты'], ['стиревем'], ['мойкой'], ['мы'], ['но'], ['мо'], ['нету'], ['ну'], ['так'], ['ты'], ['закрой'], ['сейчас'], ['пойми'], ['что'], ['поровну'], ['это'], ['не']]


In [36]:
{"m1_v2/best_model__2023_11_09__10_36_02__0.14229_default_switch_2_try_2.pt": 0.8512107051826678,
 "m1_bigger/m1_bigger_v2__2023_11_10__13_38_32__0.50552_default_l2_5e-05_ls0.045_switch_0.pt": 0.810429056924384,
 "m1_bigger/m1_bigger_v2__2023_11_10__16_36_38__0.49848_default_l2_5e-05_ls0.045_switch_0.pt": 0.818500424808836}

{'m1_v2/best_model__2023_11_09__10_36_02__0.14229_default_switch_2_try_2.pt': 0.8512107051826678}

In [59]:
print(cpu_count())

8


In [67]:
predictions = predict_greedy_raw_multiproc(val_default_dataset,
                                           grid_name_to_greedy_generator,
                                           num_workers=4)

  0%|          | 0/9416 [00:00<?, ?it/s]

In [53]:
predictions = predict_greedy_raw(val_default_dataset,
                                grid_name_to_greedy_generator)

  1%|          | 56/9416 [00:02<06:41, 23.30it/s]


KeyboardInterrupt: 

# Evaluate models separately and as a pair

In [27]:
val_default_targets = get_targets(val_default_dataset)
val_extra_targets = get_targets(val_extra_dataset)

In [17]:
grid_name = "default"
model_getter = get_m1_model
weights_path = os.path.join(MODELS_ROOT, "m1_v2/m1_v2__2023_11_09__10_36_02__0.14229_default_switch_0.pt")
model = model_getter(device, weights_path)
grid_name_to_greedy_generator = {grid_name: GreedyGenerator(model, word_char_tokenizer, device)}

In [18]:
default_predictions = predict_greedy_raw_multiproc(val_default_dataset,
                                                    grid_name_to_greedy_generator,
                                                    num_workers=4)

100%|██████████| 9416/9416 [04:47<00:00, 32.76it/s]


In [28]:
default_MMR =  get_metric(default_predictions, val_default_targets)
default_MMR

0.8512107051826678

In [29]:
grid_name = "extra"
model_getter = get_m1_model
weights_path = os.path.join(MODELS_ROOT, "m1_v2/m1_v2__2023_11_09__17_47_40__0.14301_extra_l2_1e-05_switch_0.pt")
model = model_getter(device, weights_path)
grid_name_to_greedy_generator = {grid_name: GreedyGenerator(model, word_char_tokenizer, device)}

In [31]:
extra_predictions = predict_greedy_raw_multiproc(val_extra_dataset,
                                                 grid_name_to_greedy_generator,
                                                 num_workers=4)

100%|██████████| 584/584 [00:18<00:00, 31.60it/s]


In [33]:
extra_MMR = get_metric(extra_predictions, val_extra_targets)
extra_MMR

0.851027397260274

In [41]:
all_preds = merge_preds(default_predictions, extra_predictions, val_default_dataset.grid_name_idxs, val_extra_dataset.grid_name_idxs)

In [42]:
all_targets = None
with open(os.path.join(DATA_ROOT, "valid.ref"), 'r', encoding='utf-8') as f:
    all_targets = f.read().splitlines() 

In [46]:
full_MMR = get_metric(all_preds, all_targets)
full_MMR

0.8512

# Let's create a greedy submission

In [47]:
from typing import Set, List, Tuple

In [48]:
def separate_invalid_preds_beam(preds: List[List[Tuple[float, str]]],
                                vocab_set: Set[str]) -> Tuple[List[List[str]], Dict[int, List[str]]]:
    """
    Arguments:
    ----------
    preds: List[List[Tuple[float, str]]]
        preds[i] stores raw output of a word generator corresponding
        to the i-th curve. The raw output is a list of tuples, where
        each tuple is (-log_probability, word). The list is sorted
        by -log_probability in ascending order.
    vocab_set: Set[str]
        A set of all possible words.

    Returns:
    --------
    all_real_word_preds: List[List[str]]
        all_real_word_preds stores 4 lists of real words sorted by
        -log_probability in ascending order.
    all_errorous_word_preds: Dict[int, List[str]]
        all_errorous_word_preds[i] stores a list of errorous words
        sorted by -log_probability in ascending order if all_real_word_preds[i]
        has less than 4 words. Otherwise, all_errorous_word_preds does not
        have the key i.
    """

    all_real_word_preds = []
    all_errorous_word_preds = {}

    for i, pred in enumerate(preds):
        real_word_preds = []
        errorous_word_preds = []
        for _, word in pred:
            if word in vocab_set:
                real_word_preds.append(word)
                if len(real_word_preds) == 4:
                    break
            else:
                errorous_word_preds.append(word)
        
        all_real_word_preds.append(real_word_preds)
        if len(real_word_preds) < 4:
            all_errorous_word_preds[i] = errorous_word_preds

    return all_real_word_preds, all_errorous_word_preds

In [49]:
def separate_invalid_preds_greedy(preds: List[List[str]],
                                vocab_set: Set[str]) -> Tuple[List[List[str]], Dict[int, List[str]]]:

    all_real_word_preds = []
    all_errorous_word_preds = {}

    for i, pred in enumerate(preds):
        real_word_preds = []
        errorous_word_preds = []
        for word in pred:
            if word in vocab_set:
                real_word_preds.append(word)
                if len(real_word_preds) == 4:
                    break
            else:
                errorous_word_preds.append(word)
        
        all_real_word_preds.append(real_word_preds)
        if len(real_word_preds) < 4:
            all_errorous_word_preds[i] = errorous_word_preds

    return all_real_word_preds, all_errorous_word_preds

In [88]:
def augment_predictions(preds, augment_list):
    augmented_preds = copy.deepcopy(preds)
    for pred, aug in zip(augmented_preds, augment_list):
        for aug_el in aug:
            if len(pred) >= 4:
                break
            pred.append(aug_el)
    return augmented_preds

In [52]:
def create_submission(preds_list, out_path) -> None:
    if os.path.exists(out_path):
        raise ValueError(f"File {out_path} already exists")
    
    with open(out_path, "w", encoding="utf-8") as f:
        for preds in preds_list:
            pred_str = ",".join(preds)
            f.write(pred_str + "\n")

In [68]:
def get_vocab_set(vocab_path: str):
    with open(vocab_path, 'r', encoding = "utf-8") as f:
        return set(f.read().splitlines())

In [64]:
grid_name = "default"
model_getter = get_m1_model
weights_path = os.path.join(MODELS_ROOT, "m1_v2/m1_v2__2023_11_09__10_36_02__0.14229_default_switch_0.pt")
model = model_getter(device, weights_path)
grid_name_to_greedy_generator = {grid_name: GreedyGenerator(model, word_char_tokenizer, device)}
default_test_predictions = predict_greedy_raw_multiproc(test_default_dataset,
                                                        grid_name_to_greedy_generator,
                                                        num_workers=4)

100%|██████████| 9373/9373 [05:08<00:00, 30.40it/s]


In [65]:
grid_name = "extra"
model_getter = get_m1_model
weights_path = os.path.join(MODELS_ROOT, "m1_v2/m1_v2__2023_11_09__17_47_40__0.14301_extra_l2_1e-05_switch_0.pt")
model = model_getter(device, weights_path)
grid_name_to_greedy_generator = {grid_name: GreedyGenerator(model, word_char_tokenizer, device)}
extra_test_predictions = predict_greedy_raw_multiproc(test_extra_dataset,
                                                 grid_name_to_greedy_generator,
                                                 num_workers=4)

100%|██████████| 627/627 [00:22<00:00, 27.47it/s]


In [66]:
all_test_preds = merge_preds(default_test_predictions, extra_test_predictions, test_default_dataset.grid_name_idxs, test_extra_dataset.grid_name_idxs)

In [69]:
vocab_set = get_vocab_set(os.path.join(DATA_ROOT, "voc.txt"))

In [83]:
clean_test_preds, invalid_test_preds = separate_invalid_preds_greedy(all_test_preds, vocab_set)

In [84]:
clean_test_preds[:10]

[['на'],
 ['что'],
 ['опоздания'],
 ['сколько'],
 [],
 ['не'],
 ['как'],
 ['садовод'],
 ['заметил'],
 ['ваги']]

In [85]:
augment_list = None
with open(r"..\data\submissions\sample_submission.csv", 'r', encoding = 'utf-8') as f:
    augment_lines = f.read().splitlines()
augment_list = [line.split(",") for line in augment_lines]

In [86]:
augmented_test_preds = augment_predictions(clean_test_preds, augment_list)

In [87]:
augmented_test_preds[:10]

[['на'],
 ['что'],
 ['опоздания'],
 ['сколько'],
 [],
 ['не'],
 ['как'],
 ['садовод'],
 ['заметил'],
 ['ваги']]

In [None]:
submission_name = "m1_v2__0.14229_deault__0.14301_extra__greedy.csv"
out_path = rf"..\data\submissions\{submission_name}"
create_submission(augmented_test_preds, out_path)