In [1]:
%load_ext autoreload
%autoreload 2

In [179]:
import os
import json
import copy
from multiprocessing import cpu_count

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, IterableDataset
from tqdm import tqdm
import numpy as np

from model import SwipeCurveTransformer, get_m1_model, get_m1_bigger_model, get_m1_smaller_model
from tokenizers import CharLevelTokenizerv2, KeyboardTokenizerv1
from dataset import NeuroSwipeDatasetv2
from word_generators import GreedyGenerator, BeamGenerator
from word_generation_v2 import predict_greedy_raw, predict_raw_mp  #  , predict_greedy_raw_multiproc

In [3]:
IN_KAGGLE = False

if IN_KAGGLE:
    DATA_ROOT = "/kaggle/input/yandex-cup-playground"
    MODELS_DIR = ""
else:
    DATA_ROOT = "../data/data_separated_grid"
    MODELS_ROOT = "../data/trained_models"

In [4]:
def get_grid(grid_name: str, grids_path: str) -> dict:
    with open(grids_path, "r", encoding="utf-8") as f:
        return json.load(f)[grid_name]

In [5]:
MAX_TRAJ_LEN = 299

grid_name_to_grid_path = os.path.join(DATA_ROOT, "gridname_to_grid.json")
grid_name_to_grid = {grid_name: get_grid(grid_name, grid_name_to_grid_path) for grid_name in ("default", "extra")}


kb_tokenizer = KeyboardTokenizerv1()
word_char_tokenizer = CharLevelTokenizerv2(os.path.join(DATA_ROOT, "voc.txt"))
keyboard_selection_set = set(kb_tokenizer.i2t)


val_path = os.path.join(DATA_ROOT, "valid__in_train_format.jsonl")


val_dataset = NeuroSwipeDatasetv2(
    data_path = val_path,
    gridname_to_grid = grid_name_to_grid,
    kb_tokenizer = kb_tokenizer,
    max_traj_len = MAX_TRAJ_LEN,
    word_tokenizer = word_char_tokenizer,
    include_time = False,
    include_velocities = True,
    include_accelerations = True,
    has_target=True,
    has_one_grid_only=False,
    include_grid_name=True,
    keyboard_selection_set=keyboard_selection_set,
    total = 10_000
)

test_path = os.path.join(DATA_ROOT, "test.jsonl")


test_dataset = NeuroSwipeDatasetv2(
    data_path = test_path,
    gridname_to_grid = grid_name_to_grid,
    kb_tokenizer = kb_tokenizer,
    max_traj_len = MAX_TRAJ_LEN,
    word_tokenizer = word_char_tokenizer,
    include_time = False,
    include_velocities = True,
    include_accelerations = True,
    has_target=False,
    has_one_grid_only=False,
    include_grid_name=True,
    keyboard_selection_set=keyboard_selection_set,
    total = 10_000
)

100%|██████████| 10000/10000 [00:01<00:00, 9313.98it/s]
100%|██████████| 10000/10000 [00:00<00:00, 10859.10it/s]


In [6]:
from torch.utils.data import Dataset

class NeuroSwipeGridSubset(Dataset):
    def __init__(self, dataset: Dataset, grid_name: str):
        self.dataset = dataset
        self.grid_name = grid_name
        self.grid_name_idxs = self._get_grid_name_idxs()
        
            
    def _get_grid_name_idxs(self):
        grid_name_idxs: list[int] = []
        for i, ((_, _, _, _, _), _, grid_name) in enumerate(self.dataset):
            if grid_name == self.grid_name:
                grid_name_idxs.append(i)
        return grid_name_idxs

    
    def __len__(self):
        return len(self.grid_name_idxs)
    
    def __getitem__(self, idx):
        return self.dataset[self.grid_name_idxs[idx]]

In [7]:
val_default_dataset = NeuroSwipeGridSubset(val_dataset, "default")
val_extra_dataset = NeuroSwipeGridSubset(val_dataset, "extra")

test_default_dataset = NeuroSwipeGridSubset(test_dataset, "default")
test_extra_dataset = NeuroSwipeGridSubset(test_dataset, "extra")

In [8]:
def remove_duplicates(preds):
    new_preds = []
    met_preds = set()
    for pred in preds:
        if pred in met_preds:
            continue
        met_preds.add(pred)
        new_preds.append(pred)
    return new_preds


def get_metric(preds_list, ref):
    # Works properly if has duplicates or n_line_preds < 4

    MMR = 0
    
    for preds, target in zip(preds_list, ref):
        preds = remove_duplicates(preds)

        weights = [1, 0.1, 0.09, 0.08]

        line_MRR = sum(weights[i]* (pred == target) for i, pred in enumerate(preds))

        MMR += line_MRR
    
    MMR /= len(preds_list)

    return MMR

In [9]:
from typing import Callable, Dict, List


def get_targets(dataset: NeuroSwipeDatasetv2) -> List[str]:
    targets = []
    for (_, _, _, _, word_pad_mask), target_tokens, _ in dataset:
        target_len = int(torch.sum(~word_pad_mask)) - 1
        target = word_char_tokenizer.decode(target_tokens[:target_len])
        targets.append(target)
    return targets

def evaluate_model_greedy(val_dataset: NeuroSwipeDatasetv2,
                          model: nn.Module,
                          grid_name: str,
                          targets: List[str],
                          word_char_tokenizer: CharLevelTokenizerv2,
                          device: torch.device):
    """
    Evaluates model on validation dataset using greedy generation.
    """
    assert grid_name in ("extra", "default")
    model.eval()
    model.to(device)
    generator = GreedyGenerator(model, word_char_tokenizer, device)
    grid_name_to_greedy_generator = {grid_name:  generator}
    preds = predict_greedy_raw(val_dataset,
                                grid_name_to_greedy_generator)
    MMR = get_metric(preds, targets)
    return MMR, preds


def evaluate_weights_greedy(val_dataset: NeuroSwipeDatasetv2,
                            model_getter: Callable,
                            weights_path: str,
                            grid_name: str,
                            targets: List[str],
                            word_char_tokenizer: CharLevelTokenizerv2,
                            device: torch.device):
    
    model = model_getter(device, weights_path)
    MMR, preds = evaluate_model_greedy(val_dataset,
                                       model,
                                       grid_name,
                                       targets,
                                       word_char_tokenizer,
                                       device)
    return MMR, preds


In [10]:
# def get_i_to_grid_name(dataset: NeuroSwipeDatasetv2):
#     i_to_grid_name = []
#     for i, data in tqdm(enumerate(dataset), total=len(dataset)):
#         (_, _, _, _, _), _, grid_name = data
#         i_to_grid_name.append(grid_name)
#     return i_to_grid_name


# def combine_preds(i_to_grid_name, default_preds, extra_preds):
#     preds = []
#     default_i = 0
#     extra_i = 0
#     for i, grid_name in enumerate(i_to_grid_name):
#         if grid_name == "default":
#             preds.append(default_preds[default_i])
#             default_i += 1
#         elif grid_name == "extra":
#             preds.append(extra_preds[extra_i])
#             extra_i += 1
#         else:
#             raise ValueError(f"Unknown grid_name: {grid_name}")
        
#     return preds
        

In [291]:
def merge_preds(default_preds,
                extra_preds,
                default_idxs,
                extra_idxs):
    preds = [None] * (len(default_preds) + len(extra_preds))

    for i, val in zip(default_idxs, default_preds):
        preds[i] = copy.deepcopy(val)
    for i, val in zip(extra_idxs, extra_preds):
        preds[i] = copy.deepcopy(val)

    return preds


In [12]:
device = torch.device('cpu')
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [180]:
grid_name = "default"
model_getter = get_m1_smaller_model
weights_path = os.path.join(MODELS_ROOT, "m1_smaller/m1_smaller_v2_2023_11_11_17_43_35_0_33179_default_l2_1e_05_ls0_02.pt")
model = model_getter(device, weights_path)
grid_name_to_greedy_generator = {grid_name: GreedyGenerator(model, word_char_tokenizer, device)}

In [181]:
greedy_generator = GreedyGenerator(model, word_char_tokenizer, device)


print("{:<20} {:<20}".format("target", "prediction"))
print("-"*31)

n_examples = 40

for i, data in enumerate(val_default_dataset):

    (xyt, kb_tokens, dec_in_char_seq, traj_pad_mask, word_pad_mask), target, grid_name = data

    pred = greedy_generator(xyt, kb_tokens, traj_pad_mask)

    # strip работвет только потому что в настоящих словах нет этих символов
    pred = pred
    target_len = int(torch.sum(~word_pad_mask)) - 1
    target = word_char_tokenizer.decode(target[:target_len])
    print("{:<20} {:<20}".format(target, pred))

    if i >= n_examples:
        break

target               prediction          
-------------------------------
на                   на                  
все                  все                 
добрый               добрый              
девочка              девочка             
сказала              сказала             
скинь                скинь               
геев                 гееев               
тобой                тобой               
была                 быса                
да                   да                  
муж                  мад                 
щас                  щас                 
она                  она                 
проблема             проблема            
билайн               билайн              
уже                  уже                 
раньше               раньше              
рам                  рам                 
щас                  щас                 
купил                купил               
ты                   ты                  
зовут                зовут               
ко

In [13]:
val_default_targets = get_targets(val_default_dataset)
val_extra_targets = get_targets(val_extra_dataset)

In [None]:
# mmr, preds = evaluate_model_greedy(val_default_dataset,
#                                     model,
#                                     grid_name,
#                                     val_default_targets,
#                                     word_char_tokenizer,
#                                     device)

In [None]:
# print(mmr)

In [43]:
# print(preds[200:250])

[['скинь'], ['мазора'], ['то'], ['анюта'], ['звони'], ['лесник'], ['минут'], ['забрала'], ['на'], ['обуд'], ['завтра'], ['такими'], ['давай'], ['посади'], ['бон'], ['даже'], ['перчатка'], ['работа'], ['никого'], ['отресли'], ['не'], ['раз'], ['блин'], ['пока'], ['ну'], ['тогда'], ['башка'], ['был'], ['продал'], ['хочу'], ['хорошая'], ['кофе'], ['быть'], ['ты'], ['стиревем'], ['мойкой'], ['мы'], ['но'], ['мо'], ['нету'], ['ну'], ['так'], ['ты'], ['закрой'], ['сейчас'], ['пойми'], ['что'], ['поровну'], ['это'], ['не']]


In [None]:
# predictions = predict_greedy_raw_multiproc(val_default_dataset,
#                                            grid_name_to_greedy_generator,
#                                            num_workers=4)

In [None]:
# predictions = predict_greedy_raw(val_default_dataset,
#                                 grid_name_to_greedy_generator)

# Evaluate models separately and as a pair

In [143]:
from typing import Callable

def weights_to_raw_predictions(grid_name: str,
                                model_getter: Callable,
                                weights_path: str,
                                word_char_tokenizer: CharLevelTokenizerv2,
                                dataset: Dataset,
                                generator_ctor,
                                n_workers: int = 4,
                                generator_kwargs = None
                           ):
     DEVICE = torch.device('cpu')  # Avoid multiprocessing with GPU
     if generator_kwargs is None:
          generator_kwargs = {}

     model = model_getter(DEVICE, weights_path)
     grid_name_to_greedy_generator = {grid_name: generator_ctor(model, word_char_tokenizer, DEVICE)}
     raw_predictions = predict_raw_mp(dataset,
                                        grid_name_to_greedy_generator,
                                        num_workers=n_workers,
                                        generator_kwargs=generator_kwargs)
     return raw_predictions


In [59]:
print(cpu_count())

8


In [27]:
val_default_targets = get_targets(val_default_dataset)
val_extra_targets = get_targets(val_extra_dataset)

In [177]:
default_predictions = weights_to_raw_predictions(
    grid_name = "default",
    model_getter=get_m1_bigger_model,
    weights_path = os.path.join(MODELS_ROOT, "m1_bigger/m1_bigger_v2__2023_11_11__13_17_50__0.13845_default_l2_0_ls0_switch_0.pt"),
    word_char_tokenizer=word_char_tokenizer,
    dataset=val_default_dataset,
    generator_ctor=GreedyGenerator,
    n_workers=4
)

NameError: name 'get_m1_smaller_model' is not defined

In [135]:
default_MMR =  get_metric(default_predictions, val_default_targets)
default_MMR

0.8531223449447749

In [136]:
default_predictions_best_bigger = default_predictions

In [137]:
default_predictions_best_bigger_clean, _ = separate_invalid_preds_greedy(default_predictions_best_bigger, vocab_set)

In [138]:
sum(bool(el) for el in default_predictions_best_bigger_clean)

8784

In [182]:
default_predictions = weights_to_raw_predictions(
    grid_name = "default",
    model_getter=get_m1_smaller_model,
    weights_path = os.path.join(MODELS_ROOT, "m1_smaller/m1_smaller_v2_2023_11_11_17_43_35_0_33179_default_l2_1e_05_ls0_02.pt"),
    word_char_tokenizer=word_char_tokenizer,
    dataset=val_default_dataset,
    generator_ctor=GreedyGenerator,
    n_workers=4
)

100%|██████████| 9416/9416 [04:08<00:00, 37.90it/s]


In [183]:
default_predictions_clean, _ = separate_invalid_preds_greedy(default_predictions, vocab_set)

In [184]:
sum(bool(el) for el in default_predictions_clean)

8449

In [185]:
default_MMR =  get_metric(default_predictions, val_default_targets)
default_MMR

0.8221112999150383

In [110]:
# grid_name = "default"
# model_getter = get_m1_bigger_model
# weights_path = os.path.join(MODELS_ROOT, "m1_bigger/m1_bigger_v2__2023_11_11__13_17_50__0.13845_default_l2_0_ls0_switch_0.pt")
# model = model_getter(device, weights_path)
# grid_name_to_greedy_generator = {grid_name: GreedyGenerator(model, word_char_tokenizer, device)}

In [111]:
# default_predictions = predict_greedy_raw_multiproc(val_default_dataset,
#                                                     grid_name_to_greedy_generator,
#                                                     num_workers=4)

100%|██████████| 9416/9416 [05:02<00:00, 31.10it/s]


In [29]:
grid_name = "extra"
model_getter = get_m1_model
weights_path = os.path.join(MODELS_ROOT, "m1_v2/m1_v2__2023_11_09__17_47_40__0.14301_extra_l2_1e-05_switch_0.pt")
model = model_getter(device, weights_path)
grid_name_to_greedy_generator = {grid_name: GreedyGenerator(model, word_char_tokenizer, device)}

In [31]:
extra_predictions = predict_greedy_raw_multiproc(val_extra_dataset,
                                                 grid_name_to_greedy_generator,
                                                 num_workers=4)

100%|██████████| 584/584 [00:18<00:00, 31.60it/s]


In [33]:
extra_MMR = get_metric(extra_predictions, val_extra_targets)
extra_MMR

0.851027397260274

In [41]:
all_preds = merge_preds(default_predictions, extra_predictions, val_default_dataset.grid_name_idxs, val_extra_dataset.grid_name_idxs)

In [42]:
all_targets = None
with open(os.path.join(DATA_ROOT, "valid.ref"), 'r', encoding='utf-8') as f:
    all_targets = f.read().splitlines() 

In [46]:
full_MMR = get_metric(all_preds, all_targets)
full_MMR

0.8512

In [None]:
{"m1_v2/best_model__2023_11_09__10_36_02__0.14229_default_switch_2_try_2.pt": 0.8512107051826678,
 "m1_v2/m1_v2__2023_11_09__17_47_40__0.14301_extra_l2_1e-05_switch_0.pt": 0.851027397260274,
 
 "m1_bigger/m1_bigger_v2__2023_11_10__13_38_32__0.50552_default_l2_5e-05_ls0.045_switch_0.pt": 0.810429056924384,
 "m1_bigger/m1_bigger_v2__2023_11_10__16_36_38__0.49848_default_l2_5e-05_ls0.045_switch_0.pt": 0.818500424808836,
 "m1_bigger/m1_bigger_v2__2023_11_10__21_51_01__0.49382_default_l2_5e-05_ls0.045_switch_0.pt": 0.8210492778249787,
 
 "m1_bigger/m1_bigger_v2__2023_11_11__13_17_50__0.13845_default_l2_0_ls0_switch_0.pt": 0.8512107051826678,
 "m1_bigger/m1_bigger_v2__2023_11_11__14_29_37__0.13679_default_l2_0_ls0_switch_0.pt": 0.8531223449447749,
 
 "m1_smaller/m1_smaller_v2_2023_11_11_17_43_35_0_33179_default_l2_1e_05_ls0_02.pt": 0.8221112999150383}

# Let's create a greedy submission

In [73]:
from typing import Set, List, Tuple

In [48]:
# def separate_invalid_preds_beam(preds: List[List[Tuple[float, str]]],
#                                 vocab_set: Set[str]) -> Tuple[List[List[str]], Dict[int, List[str]]]:
#     """
#     Arguments:
#     ----------
#     preds: List[List[Tuple[float, str]]]
#         preds[i] stores raw output of a word generator corresponding
#         to the i-th curve. The raw output is a list of tuples, where
#         each tuple is (-log_probability, word). The list is sorted
#         by -log_probability in ascending order.
#     vocab_set: Set[str]
#         A set of all possible words.

#     Returns:
#     --------
#     all_real_word_preds: List[List[str]]
#         all_real_word_preds stores 4 lists of real words sorted by
#         -log_probability in ascending order.
#     all_errorous_word_preds: Dict[int, List[str]]
#         all_errorous_word_preds[i] stores a list of errorous words
#         sorted by -log_probability in ascending order if all_real_word_preds[i]
#         has less than 4 words. Otherwise, all_errorous_word_preds does not
#         have the key i.
#     """

#     all_real_word_preds = []
#     all_errorous_word_preds = {}

#     for i, pred in enumerate(preds):
#         real_word_preds = []
#         errorous_word_preds = []
#         for _, word in pred:
#             if word in vocab_set:
#                 real_word_preds.append(word)
#                 if len(real_word_preds) == 4:
#                     break
#             else:
#                 errorous_word_preds.append(word)
        
#         all_real_word_preds.append(real_word_preds)
#         if len(real_word_preds) < 4:
#             all_errorous_word_preds[i] = errorous_word_preds

#     return all_real_word_preds, all_errorous_word_preds

In [74]:
def separate_invalid_preds_greedy(preds: List[List[str]],
                                vocab_set: Set[str]) -> Tuple[List[List[str]], Dict[int, List[str]]]:

    all_real_word_preds = []
    all_errorous_word_preds = {}

    for i, pred in enumerate(preds):
        real_word_preds = []
        errorous_word_preds = []
        for word in pred:
            if word in vocab_set:
                real_word_preds.append(word)
                if len(real_word_preds) == 4:
                    break
            else:
                errorous_word_preds.append(word)
        
        all_real_word_preds.append(real_word_preds)
        if len(real_word_preds) < 4:
            all_errorous_word_preds[i] = errorous_word_preds

    return all_real_word_preds, all_errorous_word_preds

In [75]:
# def augment_predictions(preds, augment_list):
#     augmented_preds = copy.deepcopy(preds)
#     for pred, aug in zip(augmented_preds, augment_list):
#         for aug_el in aug:
#             if len(pred) >= 4:
#                 break
#             pred.append(aug_el)
#     return augmented_preds

In [145]:
def augment_predictions(preds:List[List[str]], augment_list: List[List[str]]):
    augmented_preds = copy.deepcopy(preds)
    for pred_line, aug_l_line in zip(augmented_preds, augment_list):
        for aug_el in aug_l_line:
            if len(pred_line) >= 4:
                break
            if not aug_el in pred_line:
                pred_line.append(aug_el)
    return augmented_preds

In [76]:
def create_submission(preds_list, out_path) -> None:
    if os.path.exists(out_path):
        raise ValueError(f"File {out_path} already exists")
    
    with open(out_path, "w", encoding="utf-8") as f:
        for preds in preds_list:
            pred_str = ",".join(preds)
            f.write(pred_str + "\n")

In [77]:
def get_vocab_set(vocab_path: str):
    with open(vocab_path, 'r', encoding = "utf-8") as f:
        return set(f.read().splitlines())

In [115]:
grid_name = "default"
model_getter = get_m1_model
weights_path = os.path.join(MODELS_ROOT, "m1_v2/m1_v2__2023_11_09__10_36_02__0.14229_default_switch_0.pt")
model = model_getter(device, weights_path)
grid_name_to_greedy_generator = {grid_name: GreedyGenerator(model, word_char_tokenizer, device)}
default_test_predictions = predict_raw_mp(test_default_dataset,
                                          grid_name_to_greedy_generator,
                                          num_workers=3)

  1%|          | 117/9373 [00:04<05:26, 28.34it/s]


KeyboardInterrupt: 

In [15]:
grid_name = "extra"
model_getter = get_m1_model
weights_path = os.path.join(MODELS_ROOT, "m1_v2/m1_v2__2023_11_09__17_47_40__0.14301_extra_l2_1e-05_switch_0.pt")
model = model_getter(device, weights_path)
grid_name_to_greedy_generator = {grid_name: GreedyGenerator(model, word_char_tokenizer, device)}
extra_test_predictions = predict_raw_mp(test_extra_dataset,
                                        grid_name_to_greedy_generator,
                                        num_workers=3)

100%|██████████| 627/627 [00:20<00:00, 30.18it/s]


In [66]:
all_test_preds = merge_preds(default_test_predictions, extra_test_predictions, test_default_dataset.grid_name_idxs, test_extra_dataset.grid_name_idxs)

In [69]:
vocab_set = get_vocab_set(os.path.join(DATA_ROOT, "voc.txt"))

In [91]:
clean_test_preds, invalid_test_preds = separate_invalid_preds_greedy(all_test_preds, vocab_set)

In [92]:
clean_test_preds[:10]

[['на'],
 ['что'],
 ['опоздания'],
 ['сколько'],
 [],
 ['не'],
 ['как'],
 ['садовод'],
 ['заметил'],
 ['ваги']]

In [81]:
augment_list = None
with open(r"..\data\submissions\sample_submission.csv", 'r', encoding = 'utf-8') as f:
    augment_lines = f.read().splitlines()
augment_list = [line.split(",") for line in augment_lines]

In [96]:
augmented_test_preds = augment_predictions(clean_test_preds, augment_list)

In [97]:
augmented_test_preds[:10]

[['на', 'неа', 'на', 'ненка'],
 ['что', 'часто', 'частого', 'чисто'],
 ['опоздания', 'опоздания', 'опозданиям', 'оприходования'],
 ['сколько', 'сколько', 'сокольского', 'свердловского'],
 ['дремать', 'дописать', 'донимать', 'дюрренматт'],
 ['не', 'неук', 'нк', 'ненка'],
 ['как', 'как', 'капак', 'капе'],
 ['садовод', 'спародировал', 'садовод', 'сурдоперевод'],
 ['заметил', 'знаменито', 'знаменитого', 'замерил'],
 ['ваги', 'ваенги', 'венгрии', 'ванги']]

In [98]:
submission_name = "m1_v2__0.14229_deault__0.14301_extra__greedy.csv"
out_path = rf"..\data\submissions\{submission_name}"
create_submission(augmented_test_preds, out_path)

# Evaluation via beamsearch

# BeamSearch

In [140]:
generator_kwargs = {
    'max_steps_n': 35,
    'return_hypotheses_n': 7,
    'beamsize': 6,
    'normalization_factor': 0.5,
    'verbose': False
}

In [144]:
weights_f_name = "m1_bigger/m1_bigger_v2__2023_11_11__14_29_37__0.13679_default_l2_0_ls0_switch_0.pt"

default_test_predictions = weights_to_raw_predictions(
    grid_name = "default",
    model_getter=get_m1_bigger_model,
    weights_path = os.path.join(MODELS_ROOT, weights_f_name),
    word_char_tokenizer=word_char_tokenizer,
    dataset=test_default_dataset,
    generator_ctor=BeamGenerator,
    n_workers=4,
    generator_kwargs=generator_kwargs
)

100%|██████████| 9373/9373 [1:28:45<00:00,  1.76it/s]


In [147]:
weights_f_name = "m1_bigger/m1_bigger_v2__2023_11_11__14_29_37__0.13679_default_l2_0_ls0_switch_0.pt"

In [35]:
# grid_name = "default"
# model_getter = get_m1_bigger_model
# weights_f_name = "m1_v2/m1_v2__2023_11_09__10_36_02__0.14229_default_switch_0.pt"
# weights_path = os.path.join(MODELS_ROOT, weights_f_name)
# model = model_getter(device, weights_path)
# grid_name_to_beam_generator = {grid_name: BeamGenerator(model, word_char_tokenizer, device)}
# default_test_predictions = predict_raw_mp(test_default_dataset,
#                                           grid_name_to_beam_generator,
#                                           num_workers=5,
#                                           generator_kwargs=generator_kwargs)

100%|██████████| 9373/9373 [1:12:28<00:00,  2.16it/s]


In [148]:
import pickle

default_test_preds_path = os.path.join("../data/saved_beamsearch_results/",
                                       f"{weights_f_name.replace('/', '__')}.pkl")

with open(default_test_preds_path, 'wb') as f:
    pickle.dump(default_test_predictions, f, protocol=pickle.HIGHEST_PROTOCOL)

In [98]:
with open(default_test_preds_path, 'rb') as f:
    check_default_test_predictions = pickle.load(f)

In [45]:
check_default_test_predictions == default_test_predictions, check_default_test_predictions is default_test_predictions

(True, False)

In [41]:
grid_name = "extra"
model_getter = get_m1_model
weights_f_name = "m1_v2/m1_v2__2023_11_09__17_47_40__0.14301_extra_l2_1e-05_switch_0.pt"
weights_path = os.path.join(MODELS_ROOT, weights_f_name)
model = model_getter(device, weights_path)
grid_name_to_beam_generator = {grid_name: BeamGenerator(model, word_char_tokenizer, device)}
extra_test_predictions = predict_raw_mp(test_extra_dataset,
                                        grid_name_to_beam_generator,
                                        num_workers=5,
                                        generator_kwargs=generator_kwargs)

100%|██████████| 627/627 [04:45<00:00,  2.19it/s]


In [49]:
extra_test_predictions[:2]

[[[(0.0007509778079111129, 'на'),
   (3.309941194903711, 'не-а-а-а'),
   (3.3550884596756756, 'неа'),
   (3.890410980826455, 'нас'),
   (4.074969764595153, 'не'),
   (4.373856262813206, 'ну'),
   (4.41572101401283, 'не-а-а-а-а')]],
 [[(0.4908975681421362, 'рядов'),
   (0.6450710900287799, 'рядомы'),
   (0.7488804344352544, 'рядом'),
   (0.958582528002698, 'рядовы'),
   (1.2750473936050235, 'ряды'),
   (1.2825134230126898, 'рядым'),
   (1.3310995005909219, 'рядова')]]]

In [50]:
import pickle

extra_test_preds_path = os.path.join("../data/saved_beamsearch_results/",
                                       f"{weights_f_name.replace('/', '__')}.pkl")

with open(extra_test_preds_path, 'wb') as f:
    pickle.dump(extra_test_predictions, f, protocol=pickle.HIGHEST_PROTOCOL)

In [149]:
weights_f_name = "m1_v2/m1_v2__2023_11_09__17_47_40__0.14301_extra_l2_1e-05_switch_0.pt"
extra_test_preds_path = os.path.join("../data/saved_beamsearch_results/",
                                       f"{weights_f_name.replace('/', '__')}.pkl")

with open(extra_test_preds_path, 'rb') as f:
    extra_test_predictions = pickle.load(f)

In [150]:
weights_f_name = "m1_v2/m1_v2__2023_11_09__10_36_02__0.14229_default_switch_0.pt"
default_test_preds_path_m1_v2_14229 = os.path.join("../data/saved_beamsearch_results/",
                                       f"{weights_f_name.replace('/', '__')}.pkl")

with open(default_test_preds_path_m1_v2_14229, 'rb') as f:
    default_test_preds_m1_v2_14229 = pickle.load(f)

In [151]:
# default_test_predictions = [el[0] for el in default_test_predictions]

In [152]:
# default_test_preds_m1_v2_14229 = [el[0] for el in default_test_preds_m1_v2_14229]

In [153]:
# extra_test_predictions = [el[0] for el in extra_test_predictions]

In [155]:
from typing import Tuple, List
def remove_beamsearch_probs(preds: List[List[Tuple[float, str]]]) -> List[List[str]]:
    new_preds = []
    for pred_line in preds:
        new_preds_line = []
        for _, word in pred_line:
            new_preds_line.append(word)
        new_preds.append(new_preds_line)
    return new_preds

In [156]:
default_test_predictions = remove_beamsearch_probs(default_test_predictions)

In [157]:
default_test_preds_m1_v2_14229 = remove_beamsearch_probs(default_test_preds_m1_v2_14229)

In [158]:
extra_test_predictions = remove_beamsearch_probs(extra_test_predictions)

In [160]:
clean_default_test_predictions, _ = separate_invalid_preds_greedy(default_test_predictions, vocab_set)

In [161]:
len(clean_default_test_predictions), sum(bool(el) for el in clean_default_test_predictions)

(9373, 9193)

In [162]:
clean_default_test_preds_m1_v2_14229, _ = separate_invalid_preds_greedy(default_test_preds_m1_v2_14229, vocab_set)

In [163]:
len(clean_default_test_preds_m1_v2_14229), sum(bool(el) for el in clean_default_test_preds_m1_v2_14229)

(9373, 9185)

In [164]:
all_default_predictions = augment_predictions(clean_default_test_predictions, clean_default_test_preds_m1_v2_14229)

In [165]:
len(all_default_predictions), sum(bool(el) for el in all_default_predictions)

(9373, 9244)

In [166]:
full_test_predictions = merge_preds(all_default_predictions, extra_test_predictions, test_default_dataset.grid_name_idxs, test_extra_dataset.grid_name_idxs)

In [103]:
sum(not el for el in full_test_predictions)

0

In [78]:
vocab_set = get_vocab_set(os.path.join(DATA_ROOT, "voc.txt"))

In [167]:
clean_test_predictions, invalid_test_predictions =  separate_invalid_preds_greedy(full_test_predictions, vocab_set)

In [168]:
sum(not el for el in clean_test_predictions)

147

In [169]:
from collections import defaultdict

n_preds_in_line_dict = defaultdict(int)

for line in clean_test_predictions:
    n_preds_in_line_dict[len(line)] += 1

print(n_preds_in_line_dict)

defaultdict(<class 'int'>, {4: 7404, 3: 1032, 2: 840, 1: 577, 0: 147})


In [82]:
clean_test_baseline_augmented = augment_predictions(clean_test_predictions, augment_list)

In [84]:
sum(not el for el in clean_test_baseline_augmented)

0

In [85]:
submission_name = "m1_v2__0.14229_deault__0.14301_extra__beam.csv"
out_path = rf"..\data\submissions\{submission_name}"
create_submission(clean_test_baseline_augmented, out_path)

In [170]:
old_preds_path = os.path.join(DATA_ROOT, "test_raw_pred___best_model__2023_11_04__18_31_37__0.02530_default_switch_2.pt__best_model__2023_11_05__07_55_13__0.02516_extra_switch_2__with_pad_cutting.pt.pkl")
with open(old_preds_path, 'rb') as f:
    old_preds_list = pickle.load(f)

In [171]:
old_preds_list = remove_beamsearch_probs(old_preds_list)

In [172]:
old_preds_list_valid, old_preds_list_invalid = separate_invalid_preds_greedy(old_preds_list, vocab_set)

In [173]:
clean_test_old_preds_augmented = augment_predictions(clean_test_predictions, old_preds_list_valid)

In [174]:
from collections import defaultdict

n_preds_in_line_dict = defaultdict(int)

for line in clean_test_old_preds_augmented:
    n_preds_in_line_dict[len(line)] += 1

print(n_preds_in_line_dict)

defaultdict(<class 'int'>, {4: 7981, 2: 657, 3: 776, 0: 126, 1: 460})


In [175]:
clean_test_old_preds_baseline_augmented = augment_predictions(clean_test_old_preds_augmented, augment_list)

In [176]:
submission_name = "default__m1_bigger_13679__m1_v2__14229___extra__14301___with_baseline__beam.csv"
out_path = rf"..\data\submissions\{submission_name}"
create_submission(clean_test_baseline_augmented, out_path)

In [294]:
grid_name_to_ranged_bs_model_preds_paths = {
    'default': ["m1_bigger__m1_bigger_v2__2023_11_11__14_29_37__0.13679_default_l2_0_ls0_switch_0.pt.pkl",
                "m1_v2__m1_v2__2023_11_09__10_36_02__0.14229_default_switch_0.pt.pkl"],
    'extra': ["m1_v2__m1_v2__2023_11_09__17_47_40__0.14301_extra_l2_1e-05_switch_0.pt.pkl"]
}

# должны ранжироваться по качесту beamsearch на валидации

In [295]:
def patch_wrong_prediction_shape(prediciton):
    return [pred_el[0] for pred_el in prediciton]

In [316]:
default_idxs = test_default_dataset.grid_name_idxs
extra_idxs = test_extra_dataset.grid_name_idxs 

grid_name_to_augmented_preds = {}

for grid_name in ('default', 'extra'):
    bs_pred_list = []

    for f_name in grid_name_to_ranged_bs_model_preds_paths[grid_name]:
        f_path = os.path.join("../data/saved_beamsearch_results/", f_name)
        with open(f_path, 'rb') as f:
            bs_pred_list.append(pickle.load(f))
        
    bs_pred_list = [patch_wrong_prediction_shape(bs_preds) for bs_preds in bs_pred_list] 
    bs_pred_list = [remove_beamsearch_probs(bs_preds) for bs_preds in bs_pred_list]
    bs_pred_list = [separate_invalid_preds_greedy(bs_preds, vocab_set)[0] for bs_preds in bs_pred_list]


    augmented_preds = bs_pred_list.pop(0)

    while bs_pred_list:
        augmented_preds = augment_predictions(augmented_preds, bs_pred_list.pop(0))

    grid_name_to_augmented_preds[grid_name] = augmented_preds


full_preds = merge_preds(
    grid_name_to_augmented_preds['default'],
    grid_name_to_augmented_preds['extra'],
    default_idxs,
    extra_idxs)

In [318]:
from collections import defaultdict

n_preds_in_line_dict = defaultdict(int)

for line in full_preds:
    n_preds_in_line_dict[len(line)] += 1

print(n_preds_in_line_dict)

defaultdict(<class 'int'>, {4: 7404, 3: 1032, 2: 840, 1: 577, 0: 147})


In [300]:
full_preds_augmentations = [
    
]

In [None]:
print(*full_preds, sep = '\n')