In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import json
import copy
from multiprocessing import cpu_count
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, IterableDataset
from tqdm import tqdm
import numpy as np

from model import SwipeCurveTransformer, get_m1_model, get_m1_bigger_model, get_m1_smaller_model
from tokenizers import CharLevelTokenizerv2, KeyboardTokenizerv1
from dataset import NeuroSwipeDatasetv2, NeuroSwipeGridSubset
from word_generators import GreedyGenerator, BeamGenerator
from word_generation_v2 import predict_raw_mp
from metrics import get_mmr
from get_individual_models_predictions import weights_to_raw_predictions
from aggregate_predictions import (separate_out_vocab_all_crvs,
                                   append_preds,
                                   create_submission,
                                   merge_default_and_extra_preds)

In [4]:
IN_KAGGLE = False

if IN_KAGGLE:
    DATA_ROOT = "/kaggle/input/yandex-cup-playground"
    MODELS_DIR = ""
else:
    DATA_ROOT = "../data/data_separated_grid"
    MODELS_ROOT = "../data/trained_models"

In [5]:
def get_grid(grid_name: str, grids_path: str) -> dict:
    with open(grids_path, "r", encoding="utf-8") as f:
        return json.load(f)[grid_name]

In [6]:
MAX_TRAJ_LEN = 299

grid_name_to_grid_path = os.path.join(DATA_ROOT, "gridname_to_grid.json")
grid_name_to_grid = {grid_name: get_grid(grid_name, grid_name_to_grid_path) for grid_name in ("default", "extra")}


kb_tokenizer = KeyboardTokenizerv1()
word_char_tokenizer = CharLevelTokenizerv2(os.path.join(DATA_ROOT, "voc.txt"))
keyboard_selection_set = set(kb_tokenizer.i2t)


val_path = os.path.join(DATA_ROOT, "valid__in_train_format.jsonl")


val_dataset = NeuroSwipeDatasetv2(
    data_path = val_path,
    gridname_to_grid = grid_name_to_grid,
    kb_tokenizer = kb_tokenizer,
    max_traj_len = MAX_TRAJ_LEN,
    word_tokenizer = word_char_tokenizer,
    include_time = False,
    include_velocities = True,
    include_accelerations = True,
    has_target=True,
    has_one_grid_only=False,
    include_grid_name=True,
    keyboard_selection_set=keyboard_selection_set,
    total = 10_000
)

test_path = os.path.join(DATA_ROOT, "test.jsonl")


test_dataset = NeuroSwipeDatasetv2(
    data_path = test_path,
    gridname_to_grid = grid_name_to_grid,
    kb_tokenizer = kb_tokenizer,
    max_traj_len = MAX_TRAJ_LEN,
    word_tokenizer = word_char_tokenizer,
    include_time = False,
    include_velocities = True,
    include_accelerations = True,
    has_target=False,
    has_one_grid_only=False,
    include_grid_name=True,
    keyboard_selection_set=keyboard_selection_set,
    total = 10_000
)

100%|██████████| 10000/10000 [00:02<00:00, 4816.94it/s]
100%|██████████| 10000/10000 [00:01<00:00, 5583.61it/s]


In [7]:
val_default_dataset = NeuroSwipeGridSubset(val_dataset, "default")
val_extra_dataset = NeuroSwipeGridSubset(val_dataset, "extra")

test_default_dataset = NeuroSwipeGridSubset(test_dataset, "default")
test_extra_dataset = NeuroSwipeGridSubset(test_dataset, "extra")

In [13]:
def get_targets(dataset: NeuroSwipeDatasetv2) -> List[str]:
    targets = []
    for (_, _, _, _, word_pad_mask), target_tokens, _ in dataset:
        target_len = int(torch.sum(~word_pad_mask)) - 1
        target = word_char_tokenizer.decode(target_tokens[:target_len])
        targets.append(target)
    return targets

In [None]:
def get_vocab_set(vocab_path: str):
    with open(vocab_path, 'r', encoding = "utf-8") as f:
        return set(f.read().splitlines())

In [None]:
vocab_set = get_vocab_set(os.path.join(DATA_ROOT, "voc.txt"))

In [8]:
device = torch.device('cpu')
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
grid_name = "default"
model_getter = get_m1_smaller_model
weights_path = os.path.join(MODELS_ROOT, "m1_smaller/m1_smaller_v2_2023_11_11_17_43_35_0_33179_default_l2_1e_05_ls0_02.pt")
model = model_getter(device, weights_path)
grid_name_to_greedy_generator = {grid_name: GreedyGenerator(model, word_char_tokenizer, device)}

In [10]:
greedy_generator = GreedyGenerator(model, word_char_tokenizer, device)


print("{:<20} {:<20} {:<20}".format("target", "prediction", "prob"))
# print("-"*31)
print("-"*49)

n_examples = 40

for i, data in enumerate(val_default_dataset):

    (xyt, kb_tokens, dec_in_char_seq, traj_pad_mask, word_pad_mask), target, grid_name = data

    score, pred = greedy_generator(xyt, kb_tokens, traj_pad_mask)[0]

    # strip работвет только потому что в настоящих словах нет этих символов
    pred = pred
    target_len = int(torch.sum(~word_pad_mask)) - 1
    target = word_char_tokenizer.decode(target[:target_len])
    print("{:<20} {:<20} {:<20.5f}".format(target, pred, np.exp(score)))

    if i >= n_examples:
        break

target               prediction           prob                
-------------------------------------------------
на                   на                   0.95232             
все                  все                  0.94362             
добрый               добрый               0.90555             
девочка              девочка              0.87215             
сказала              сказала              0.90086             
скинь                скинь                0.92264             
геев                 гееев                0.25355             
тобой                тобой                0.91465             
была                 быса                 0.66717             
да                   да                   0.95149             
муж                  мад                  0.18687             
щас                  щас                  0.94379             
она                  она                  0.93583             
проблема             проблема             0.85841             
билай

In [14]:
val_default_targets = get_targets(val_default_dataset)
val_extra_targets = get_targets(val_extra_dataset)

# Evaluate models separately and as a pair

In [15]:
print(cpu_count())

4


In [16]:
val_default_targets = get_targets(val_default_dataset)
val_extra_targets = get_targets(val_extra_dataset)

In [None]:
default_predictions = weights_to_raw_predictions(
    grid_name = "default",
    model_getter=get_m1_bigger_model,
    weights_path = os.path.join(MODELS_ROOT, "m1_bigger/m1_bigger_v2__2023_11_12__12_30_29__0.13121__greed_acc_0.86098__default_l2_0_ls0_switch_2.pt"),
    word_char_tokenizer=word_char_tokenizer,
    dataset=val_default_dataset,
    generator_ctor=GreedyGenerator,
    n_workers=4
)

In [135]:
default_MMR =  get_mmr(default_predictions, val_default_targets)
default_MMR

0.8531223449447749

In [136]:
default_predictions_best_bigger = default_predictions

In [137]:
default_predictions_best_bigger_clean, _ = separate_out_vocab_all_crvs(default_predictions_best_bigger, vocab_set)

In [138]:
sum(bool(el) for el in default_predictions_best_bigger_clean)

8784

In [182]:
default_predictions = weights_to_raw_predictions(
    grid_name = "default",
    model_getter=get_m1_smaller_model,
    weights_path = os.path.join(MODELS_ROOT, "m1_smaller/m1_smaller_v2_2023_11_11_17_43_35_0_33179_default_l2_1e_05_ls0_02.pt"),
    word_char_tokenizer=word_char_tokenizer,
    dataset=val_default_dataset,
    generator_ctor=GreedyGenerator,
    n_workers=4
)

100%|██████████| 9416/9416 [04:08<00:00, 37.90it/s]


In [183]:
default_predictions_clean, _ = separate_out_vocab_all_crvs(default_predictions, vocab_set)

In [184]:
sum(bool(el) for el in default_predictions_clean)

8449

In [185]:
default_MMR =  get_mmr(default_predictions, val_default_targets)
default_MMR

0.8221112999150383

In [110]:
# grid_name = "default"
# model_getter = get_m1_bigger_model
# weights_path = os.path.join(MODELS_ROOT, "m1_bigger/m1_bigger_v2__2023_11_11__13_17_50__0.13845_default_l2_0_ls0_switch_0.pt")
# model = model_getter(device, weights_path)
# grid_name_to_greedy_generator = {grid_name: GreedyGenerator(model, word_char_tokenizer, device)}

In [111]:
# default_predictions = predict_greedy_raw_multiproc(val_default_dataset,
#                                                     grid_name_to_greedy_generator,
#                                                     num_workers=4)

100%|██████████| 9416/9416 [05:02<00:00, 31.10it/s]


In [29]:
grid_name = "extra"
model_getter = get_m1_model
weights_path = os.path.join(MODELS_ROOT, "m1_v2/m1_v2__2023_11_09__17_47_40__0.14301_extra_l2_1e-05_switch_0.pt")
model = model_getter(device, weights_path)
grid_name_to_greedy_generator = {grid_name: GreedyGenerator(model, word_char_tokenizer, device)}

In [31]:
extra_predictions = predict_raw_mp(val_extra_dataset,
                                   grid_name_to_greedy_generator,
                                   num_workers=4)

100%|██████████| 584/584 [00:18<00:00, 31.60it/s]


In [33]:
extra_MMR = get_mmr(extra_predictions, val_extra_targets)
extra_MMR

0.851027397260274

In [41]:
all_preds = merge_default_and_extra_preds(default_predictions, extra_predictions, val_default_dataset.grid_name_idxs, val_extra_dataset.grid_name_idxs)

In [40]:
all_targets = None
with open(os.path.join(DATA_ROOT, "valid.ref"), 'r', encoding='utf-8') as f:
    all_targets = f.read().splitlines() 

In [46]:
full_MMR = get_mmr(all_preds, all_targets)
full_MMR

0.8512

In [365]:
max(len(el) for el in all_targets)

19

In [None]:
{"m1_v2/best_model__2023_11_09__10_36_02__0.14229_default_switch_2_try_2.pt": 0.8512107051826678,
 "m1_v2/m1_v2__2023_11_09__17_47_40__0.14301_extra_l2_1e-05_switch_0.pt": 0.851027397260274,
 
 "m1_bigger/m1_bigger_v2__2023_11_10__13_38_32__0.50552_default_l2_5e-05_ls0.045_switch_0.pt": 0.810429056924384,
 "m1_bigger/m1_bigger_v2__2023_11_10__16_36_38__0.49848_default_l2_5e-05_ls0.045_switch_0.pt": 0.818500424808836,
 "m1_bigger/m1_bigger_v2__2023_11_10__21_51_01__0.49382_default_l2_5e-05_ls0.045_switch_0.pt": 0.8210492778249787,
 
 "m1_bigger/m1_bigger_v2__2023_11_11__13_17_50__0.13845_default_l2_0_ls0_switch_0.pt": 0.8512107051826678,
 "m1_bigger/m1_bigger_v2__2023_11_11__14_29_37__0.13679_default_l2_0_ls0_switch_0.pt": 0.8531223449447749,
 
 "m1_smaller/m1_smaller_v2_2023_11_11_17_43_35_0_33179_default_l2_1e_05_ls0_02.pt": 0.8221112999150383}

# Let's create a greedy submission

In [73]:
from typing import Set, List, Tuple

In [None]:
grid_name = "default"
model_getter = get_m1_model
weights_path = os.path.join(MODELS_ROOT, "m1_v2/m1_v2__2023_11_09__10_36_02__0.14229_default_switch_0.pt")
model = model_getter(device, weights_path)
grid_name_to_greedy_generator = {grid_name: GreedyGenerator(model, word_char_tokenizer, device)}
default_test_predictions = predict_raw_mp(test_default_dataset,
                                          grid_name_to_greedy_generator,
                                          num_workers=3)

In [15]:
grid_name = "extra"
model_getter = get_m1_model
weights_path = os.path.join(MODELS_ROOT, "m1_v2/m1_v2__2023_11_09__17_47_40__0.14301_extra_l2_1e-05_switch_0.pt")
model = model_getter(device, weights_path)
grid_name_to_greedy_generator = {grid_name: GreedyGenerator(model, word_char_tokenizer, device)}
extra_test_predictions = predict_raw_mp(test_extra_dataset,
                                        grid_name_to_greedy_generator,
                                        num_workers=3)

100%|██████████| 627/627 [00:20<00:00, 30.18it/s]


In [66]:
all_test_preds = merge_default_and_extra_preds(default_test_predictions, extra_test_predictions, test_default_dataset.grid_name_idxs, test_extra_dataset.grid_name_idxs)

In [69]:
vocab_set = get_vocab_set(os.path.join(DATA_ROOT, "voc.txt"))

In [91]:
clean_test_preds, invalid_test_preds = separate_out_vocab_all_crvs(all_test_preds, vocab_set)

In [92]:
clean_test_preds[:10]

[['на'],
 ['что'],
 ['опоздания'],
 ['сколько'],
 [],
 ['не'],
 ['как'],
 ['садовод'],
 ['заметил'],
 ['ваги']]

In [81]:
augment_list = None
with open(r"..\data\submissions\sample_submission.csv", 'r', encoding = 'utf-8') as f:
    augment_lines = f.read().splitlines()
augment_list = [line.split(",") for line in augment_lines]

In [96]:
augmented_test_preds = append_preds(clean_test_preds, augment_list, limit = 4)

In [97]:
augmented_test_preds[:10]

[['на', 'неа', 'на', 'ненка'],
 ['что', 'часто', 'частого', 'чисто'],
 ['опоздания', 'опоздания', 'опозданиям', 'оприходования'],
 ['сколько', 'сколько', 'сокольского', 'свердловского'],
 ['дремать', 'дописать', 'донимать', 'дюрренматт'],
 ['не', 'неук', 'нк', 'ненка'],
 ['как', 'как', 'капак', 'капе'],
 ['садовод', 'спародировал', 'садовод', 'сурдоперевод'],
 ['заметил', 'знаменито', 'знаменитого', 'замерил'],
 ['ваги', 'ваенги', 'венгрии', 'ванги']]

In [98]:
submission_name = "m1_v2__0.14229_deault__0.14301_extra__greedy.csv"
out_path = rf"..\data\submissions\{submission_name}"
create_submission(augmented_test_preds, out_path)

# BeamSearch

In [None]:
from typing import Tuple, List
def remove_beamsearch_probs(preds: List[List[Tuple[float, str]]]) -> List[List[str]]:
    new_preds = []
    for pred_line in preds:
        new_preds_line = []
        for _, word in pred_line:
            new_preds_line.append(word)
        new_preds.append(new_preds_line)
    return new_preds

In [None]:
def patch_wrong_prediction_shape(prediciton):
    return [pred_el[0] for pred_el in prediciton]

## Beamsearch Evaluation

In [372]:
MAX_VAL_WORD_LEN = max(len(el) for el in all_targets)

generator_kwargs = {
    'max_steps_n': MAX_VAL_WORD_LEN+1,
    'return_hypotheses_n': 7,
    'beamsize': 6,
    'normalization_factor': 0.5,
    'verbose': False
}

In [373]:
grid_name_to_val_dataset = {
    'default': val_default_dataset,
    'extra': val_extra_dataset
}

In [374]:
import pickle

bs_params = [
    ("default", get_m1_bigger_model, "m1_bigger/m1_bigger_v2__2023_11_12__14_51_49__0.13115__greed_acc_0.86034__default_l2_0_ls0_switch_2.pt"),
]


for grid_name, model_getter, weights_f_name in bs_params:

    bs_preds_path = os.path.join("../data/saved_beamsearch_validation_results/",
                                f"{weights_f_name.replace('/', '__')}.pkl")
    
    if os.path.exists(bs_preds_path):
        print(f"Path {bs_preds_path} exists. Skipping.")
        continue

    bs_predictions = weights_to_raw_predictions(
        grid_name = grid_name,
        model_getter=model_getter,
        weights_path = os.path.join(MODELS_ROOT, weights_f_name),
        word_char_tokenizer=word_char_tokenizer,
        dataset=grid_name_to_val_dataset[grid_name],
        generator_ctor=BeamGenerator,
        n_workers=4,
        generator_kwargs=generator_kwargs
    )

    with open(bs_preds_path, 'wb') as f:
        pickle.dump(bs_predictions, f, protocol=pickle.HIGHEST_PROTOCOL)

100%|██████████| 9416/9416 [41:23<00:00,  3.79it/s] 


In [375]:
preds_name = "m1_bigger__m1_bigger_v2__2023_11_12__14_51_49__0.13115__greed_acc_0.86034__default_l2_0_ls0_switch_2.pt.pkl"
bs_preds_path = os.path.join("../data/saved_beamsearch_validation_results/",
                                preds_name)
with open(bs_preds_path, 'rb') as f:
    default_valid_preds_bs = pickle.load(f)

default_valid_preds_bs = patch_wrong_prediction_shape(default_valid_preds_bs)
default_valid_preds_bs = remove_beamsearch_probs(default_valid_preds_bs)
default_valid_preds_bs, _ = separate_out_vocab_all_crvs(default_valid_preds_bs, vocab_set)
get_mmr(default_valid_preds_bs, val_default_targets)

0.8929800339847141

In [340]:
preds_name = "m1_v2__m1_v2__2023_11_09__17_47_40__0.14301_extra_l2_1e-05_switch_0.pt.pkl"
bs_preds_path = os.path.join("../data/saved_beamsearch_validation_results/",
                                preds_name)
with open(bs_preds_path, 'rb') as f:
    extra_valid_preds_bs = pickle.load(f)

extra_valid_preds_bs = patch_wrong_prediction_shape(extra_valid_preds_bs)
extra_valid_preds_bs = remove_beamsearch_probs(extra_valid_preds_bs)
extra_valid_preds_bs, _ = separate_out_vocab_all_crvs(extra_valid_preds_bs, vocab_set)
get_mmr(extra_valid_preds_bs, val_extra_targets)

0.8876027397260277

In [None]:
# {
#     "m1_smaller__m1_smaller_v2_2023_11_12_01_21_45_0_31891_default_l2_1e_05_ls0_02.pt.pkl": 0.8811533559898118,
#     "m1_smaller__m1_smaller_v2_2023_11_12_08_17_33_0_31223_default_l2_1e_05_ls0_02.pt.pkl": 0.8835960067969484,
#     "m1_bigger__m1_bigger_v2__2023_11_11__22_18_35__0.13542_default_l2_0_ls0_switch_1.pt.pkl": 0.8900881478334822,
#     "m1_bigger__m1_bigger_v2__2023_11_11__15_53_07__0.13636_default_l2_0_ls0_switch_0.pt.pkl": 0.8871590909090984,
#     "m1_bigger__m1_bigger_v2__2023_11_12__00_39_33__0.13297_default_l2_0_ls0_switch_1.pt.pkl": 0.887674171622777,
#     "m1_v2__m1_v2__2023_11_09__10_36_02__0.14229_default_switch_0.pt.pkl": 0.8877740016992428,
    
    
#     "m1_bigger__m1_bigger_v2__2023_11_11__16_45_33__0.13721_extra_l2_0_ls0_switch_0.pt.pkl": 0.8864383561643838,
#     "m1_v2__m1_v2__2023_11_09__17_47_40__0.14301_extra_l2_1e-05_switch_0.pt.pkl": 0.8876027397260277
    
# }

In [37]:
grid_name_to_ranged_bs_model_preds_paths = {
    'default': [
        "m1_bigger__m1_bigger_v2__2023_11_12__14_51_49__0.13115__greed_acc_0.86034__default_l2_0_ls0_switch_2.pt.pkl", #: 0.8929800339847141,
        "m1_bigger__m1_bigger_v2__2023_11_12__12_30_29__0.13121__greed_acc_0.86098__default_l2_0_ls0_switch_2.pt.pkl", #: 0.8914698385726496,
        "m1_bigger__m1_bigger_v2__2023_11_11__22_18_35__0.13542_default_l2_0_ls0_switch_1.pt.pkl",  #: 0.8900881478334822,
        "m1_v2__m1_v2__2023_11_09__10_36_02__0.14229_default_switch_0.pt.pkl",  #: E 0.8877740016992428,
        "m1_bigger__m1_bigger_v2__2023_11_12__00_39_33__0.13297_default_l2_0_ls0_switch_1.pt.pkl",  #: 0.887674171622777,
        # "m1_bigger__m1_bigger_v2__2023_11_11__15_53_07__0.13636_default_l2_0_ls0_switch_0.pt.pkl",  #: 0.8871590909090984,
        "m1_smaller__m1_smaller_v2__2023_11_12__17_40_42__0.30909_default_l2_1e-05_ls0.02_switch_0.pt.pkl",  # 0.8849384027187835
        # "m1_smaller__m1_smaller_v2_2023_11_12_08_17_33_0_31223_default_l2_1e_05_ls0_02.pt.pkl",  #: 0.8835960067969484,
        # "m1_smaller__m1_smaller_v2_2023_11_12_01_21_45_0_31891_default_l2_1e_05_ls0_02.pt.pkl",  #: 0.8811533559898118,
        ],
    'extra': [
        "m1_v2__m1_v2__2023_11_09__17_47_40__0.14301_extra_l2_1e-05_switch_0.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_11__16_45_33__0.13721_extra_l2_0_ls0_switch_0.pt.pkl"
        ]
}

# Отранжированы по качесту beamsearch на валидации

In [None]:
import pickle

default_idxs = val_default_dataset.grid_name_idxs
extra_idxs = val_extra_dataset.grid_name_idxs 

grid_name_to_augmented_preds = {}

for grid_name in ('default', 'extra'):
    bs_pred_list = []

    for f_name in grid_name_to_ranged_bs_model_preds_paths[grid_name]:
        f_path = os.path.join("../data/saved_beamsearch_validation_results/", f_name)
        with open(f_path, 'rb') as f:
            bs_pred_list.append(pickle.load(f))
        
    bs_pred_list = [patch_wrong_prediction_shape(bs_preds) for bs_preds in bs_pred_list] 
    bs_pred_list = [remove_beamsearch_probs(bs_preds) for bs_preds in bs_pred_list]
    bs_pred_list = [separate_out_vocab_all_crvs(bs_preds, vocab_set)[0] for bs_preds in bs_pred_list]


    augmented_preds = bs_pred_list.pop(0)

    while bs_pred_list:
        augmented_preds = append_preds(augmented_preds, bs_pred_list.pop(0))

    grid_name_to_augmented_preds[grid_name] = augmented_preds


full_preds = merge_default_and_extra_preds(
    grid_name_to_augmented_preds['default'],
    grid_name_to_augmented_preds['extra'],
    default_idxs,
    extra_idxs)

In [None]:
full_preds[:10]

In [None]:
from collections import defaultdict

n_preds_in_line_dict = defaultdict(int)

for line in full_preds:
    n_preds_in_line_dict[len(line)] += 1

print(n_preds_in_line_dict)

In [None]:
all_targets = None
with open(os.path.join(DATA_ROOT, "valid.ref"), 'r', encoding='utf-8') as f:
    all_targets = f.read().splitlines() 

In [None]:
get_mmr(full_preds, all_targets)

```python
grid_name_to_ranged_bs_model_preds_paths = {
    'default': [
        "m1_bigger__m1_bigger_v2__2023_11_11__22_18_35__0.13542_default_l2_0_ls0_switch_1.pt.pkl",#: 0.8900881478334822,
        "m1_v2__m1_v2__2023_11_09__10_36_02__0.14229_default_switch_0.pt.pkl",#: 0.8877740016992428,
        "m1_bigger__m1_bigger_v2__2023_11_12__00_39_33__0.13297_default_l2_0_ls0_switch_1.pt.pkl",#: 0.887674171622777,
        "m1_bigger__m1_bigger_v2__2023_11_11__15_53_07__0.13636_default_l2_0_ls0_switch_0.pt.pkl",#: 0.8871590909090984,
        "m1_smaller__m1_smaller_v2_2023_11_12_08_17_33_0_31223_default_l2_1e_05_ls0_02.pt.pkl",#: 0.8835960067969484,
        "m1_smaller__m1_smaller_v2_2023_11_12_01_21_45_0_31891_default_l2_1e_05_ls0_02.pt.pkl",#: 0.8811533559898118,
        ],
    'extra': [
        "m1_v2__m1_v2__2023_11_09__17_47_40__0.14301_extra_l2_1e-05_switch_0.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_11__16_45_33__0.13721_extra_l2_0_ls0_switch_0.pt.pkl"
        ]
}
```

0.8936010000000082


``` python
grid_name_to_ranged_bs_model_preds_paths = {
    'default': [
        "m1_v2__m1_v2__2023_11_09__10_36_02__0.14229_default_switch_0.pt.pkl",#: 0.8877740016992428,
        "m1_bigger__m1_bigger_v2__2023_11_11__22_18_35__0.13542_default_l2_0_ls0_switch_1.pt.pkl",#: 0.8900881478334822,
        "m1_bigger__m1_bigger_v2__2023_11_12__00_39_33__0.13297_default_l2_0_ls0_switch_1.pt.pkl",#: 0.887674171622777,
        "m1_bigger__m1_bigger_v2__2023_11_11__15_53_07__0.13636_default_l2_0_ls0_switch_0.pt.pkl",#: 0.8871590909090984,
        "m1_smaller__m1_smaller_v2_2023_11_12_08_17_33_0_31223_default_l2_1e_05_ls0_02.pt.pkl",#: 0.8835960067969484,
        "m1_smaller__m1_smaller_v2_2023_11_12_01_21_45_0_31891_default_l2_1e_05_ls0_02.pt.pkl",#: 0.8811533559898118,
        ],
    'extra': [
        "m1_v2__m1_v2__2023_11_09__17_47_40__0.14301_extra_l2_1e-05_switch_0.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_11__16_45_33__0.13721_extra_l2_0_ls0_switch_0.pt.pkl"
        ]
}
```

0.892801000000009


``` python
grid_name_to_ranged_bs_model_preds_paths = {
    'default': [
        "m1_bigger__m1_bigger_v2__2023_11_11__22_18_35__0.13542_default_l2_0_ls0_switch_1.pt.pkl",#: 0.8900881478334822,
        "m1_v2__m1_v2__2023_11_09__10_36_02__0.14229_default_switch_0.pt.pkl",#: 0.8877740016992428,
        "m1_bigger__m1_bigger_v2__2023_11_12__00_39_33__0.13297_default_l2_0_ls0_switch_1.pt.pkl",#: 0.887674171622777,
        "m1_smaller__m1_smaller_v2_2023_11_12_08_17_33_0_31223_default_l2_1e_05_ls0_02.pt.pkl",#: 0.8835960067969484,
        "m1_smaller__m1_smaller_v2_2023_11_12_01_21_45_0_31891_default_l2_1e_05_ls0_02.pt.pkl",#: 0.8811533559898118,
        "m1_bigger__m1_bigger_v2__2023_11_11__15_53_07__0.13636_default_l2_0_ls0_switch_0.pt.pkl",#: 0.8871590909090984,
        ],
    'extra': [
        "m1_v2__m1_v2__2023_11_09__17_47_40__0.14301_extra_l2_1e-05_switch_0.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_11__16_45_33__0.13721_extra_l2_0_ls0_switch_0.pt.pkl"
        ]
}

# должны ранжироваться по качесту beamsearch на валидации
```

0.8936000000000084

## Beamsearch test

In [343]:
MAX_WORD_LEN = 36

generator_kwargs = {
    'max_steps_n': MAX_WORD_LEN - 1,
    'return_hypotheses_n': 7,
    'beamsize': 6,
    'normalization_factor': 0.5,
    'verbose': False
}

In [344]:
grid_name_to_test_dataset = {
    'default': test_default_dataset,
    'extra': test_extra_dataset
}

In [356]:
import pickle

bs_params = [
    ("extra", get_m1_bigger_model, "m1_bigger/m1_bigger_v2__2023_11_12__02_27_14__0.13413_extra_l2_0_ls0_switch_1.pt"),

    # "m1_smaller__m1_smaller_v2_2023_11_12_08_17_33_0_31223_default_l2_1e_05_ls0_02.pt.pkl": 0.8835960067969484,
    # "m1_bigger__m1_bigger_v2__2023_11_11__15_53_07__0.13636_default_l2_0_ls0_switch_0.pt.pkl": 0.8871590909090984,
    # "m1_bigger__m1_bigger_v2__2023_11_11__16_45_33__0.13721_extra_l2_0_ls0_switch_0.pt.pkl": 0.8864383561643838,
]


for grid_name, model_getter, weights_f_name in bs_params:

    bs_preds_path = os.path.join("../data/saved_beamsearch_results/",
                                f"{weights_f_name.replace('/', '__')}.pkl")
    
    if os.path.exists(bs_preds_path):
        print(f"Path {bs_preds_path} exists. Skipping.")
        continue

    bs_predictions = weights_to_raw_predictions(
        grid_name = grid_name,
        model_getter=model_getter,
        weights_path = os.path.join(MODELS_ROOT, weights_f_name),
        word_char_tokenizer=word_char_tokenizer,
        dataset=grid_name_to_test_dataset[grid_name],
        generator_ctor=BeamGenerator,
        n_workers=4,
        generator_kwargs=generator_kwargs
    )

    with open(bs_preds_path, 'wb') as f:
        pickle.dump(bs_predictions, f, protocol=pickle.HIGHEST_PROTOCOL)

100%|██████████| 627/627 [06:02<00:00,  1.73it/s]


In [161]:
# len(clean_default_test_predictions), sum(bool(el) for el in clean_default_test_predictions)

(9373, 9193)

In [78]:
vocab_set = get_vocab_set(os.path.join(DATA_ROOT, "voc.txt"))

In [169]:
# from collections import defaultdict

# n_preds_in_line_dict = defaultdict(int)

# for line in clean_test_predictions:
#     n_preds_in_line_dict[len(line)] += 1

# print(n_preds_in_line_dict)

defaultdict(<class 'int'>, {4: 7404, 3: 1032, 2: 840, 1: 577, 0: 147})


In [170]:
old_preds_path = os.path.join(DATA_ROOT, "test_raw_pred___best_model__2023_11_04__18_31_37__0.02530_default_switch_2.pt__best_model__2023_11_05__07_55_13__0.02516_extra_switch_2__with_pad_cutting.pt.pkl")
with open(old_preds_path, 'rb') as f:
    old_preds_list = pickle.load(f)

In [171]:
old_preds_list = remove_beamsearch_probs(old_preds_list)

In [172]:
old_preds_list_valid, old_preds_list_invalid = separate_out_vocab_all_crvs(old_preds_list, vocab_set)

In [176]:
# submission_name = "default__m1_bigger_13679__m1_v2__14229___extra__14301___with_baseline__beam.csv"
# out_path = rf"..\data\submissions\{submission_name}"
# create_submission(clean_test_baseline_augmented, out_path)

# Submission creation

id 1

```
grid_name_to_ranged_bs_model_preds_paths = {
    'default': [
        "m1_bigger__m1_bigger_v2__2023_11_12__12_30_29__0.13121__greed_acc_0.86098__default_l2_0_ls0_switch_2.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_11__22_18_35__0.13542_default_l2_0_ls0_switch_1.pt.pkl",
        "m1_v2__m1_v2__2023_11_09__10_36_02__0.14229_default_switch_0.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_12__00_39_33__0.13297_default_l2_0_ls0_switch_1.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_11__14_29_37__0.13679_default_l2_0_ls0_switch_0.pt.pkl",
        
    ],
    'extra': ["m1_v2__m1_v2__2023_11_09__17_47_40__0.14301_extra_l2_1e-05_switch_0.pt.pkl"]
}
```

id 2

```
grid_name_to_ranged_bs_model_preds_paths = {
    'default': [
        "m1_bigger__m1_bigger_v2__2023_11_12__12_30_29__0.13121__greed_acc_0.86098__default_l2_0_ls0_switch_2.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_11__22_18_35__0.13542_default_l2_0_ls0_switch_1.pt.pkl",
        "m1_v2__m1_v2__2023_11_09__10_36_02__0.14229_default_switch_0.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_12__00_39_33__0.13297_default_l2_0_ls0_switch_1.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_11__14_29_37__0.13679_default_l2_0_ls0_switch_0.pt.pkl",
        
    ],
    'extra': [
        "m1_v2__m1_v2__2023_11_09__17_47_40__0.14301_extra_l2_1e-05_switch_0.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_12__02_27_14__0.13413_extra_l2_0_ls0_switch_1.pt.pkl"
    ]
}

# должны ранжироваться по качесту beamsearch на валидации
```

id 3

```python
grid_name_to_ranged_bs_model_preds_paths = {
    'default': [
        "m1_bigger_m1_bigger_v2_2023_11_12_14_51_49_0_13115_greed_acc_0_86034.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_12__12_30_29__0.13121__greed_acc_0.86098__default_l2_0_ls0_switch_2.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_11__22_18_35__0.13542_default_l2_0_ls0_switch_1.pt.pkl",
        "m1_v2__m1_v2__2023_11_09__10_36_02__0.14229_default_switch_0.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_12__00_39_33__0.13297_default_l2_0_ls0_switch_1.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_11__14_29_37__0.13679_default_l2_0_ls0_switch_0.pt.pkl",
        
    ],
    'extra': [
        "m1_v2__m1_v2__2023_11_09__17_47_40__0.14301_extra_l2_1e-05_switch_0.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_12__02_27_14__0.13413_extra_l2_0_ls0_switch_1.pt.pkl"
    ]
}

# должны ранжироваться по качесту beamsearch на валидации
```

In [385]:
grid_name_to_ranged_bs_model_preds_paths = {
    'default': [
        "m1_bigger_m1_bigger_v2_2023_11_12_14_51_49_0_13115_greed_acc_0_86034.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_12__12_30_29__0.13121__greed_acc_0.86098__default_l2_0_ls0_switch_2.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_11__22_18_35__0.13542_default_l2_0_ls0_switch_1.pt.pkl",
        "m1_v2__m1_v2__2023_11_09__10_36_02__0.14229_default_switch_0.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_12__00_39_33__0.13297_default_l2_0_ls0_switch_1.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_11__14_29_37__0.13679_default_l2_0_ls0_switch_0.pt.pkl",
        
    ],
    'extra': [
        "m1_v2__m1_v2__2023_11_09__17_47_40__0.14301_extra_l2_1e-05_switch_0.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_12__02_27_14__0.13413_extra_l2_0_ls0_switch_1.pt.pkl"
    ]
}

# должны ранжироваться по качесту beamsearch на валидации

In [386]:
default_idxs = test_default_dataset.grid_name_idxs
extra_idxs = test_extra_dataset.grid_name_idxs 

grid_name_to_augmented_preds = {}

for grid_name in ('default', 'extra'):
    bs_pred_list = []

    for f_name in grid_name_to_ranged_bs_model_preds_paths[grid_name]:
        f_path = os.path.join("../data/saved_beamsearch_results/", f_name)
        with open(f_path, 'rb') as f:
            bs_pred_list.append(pickle.load(f))
        
    bs_pred_list = [patch_wrong_prediction_shape(bs_preds) for bs_preds in bs_pred_list] 
    bs_pred_list = [remove_beamsearch_probs(bs_preds) for bs_preds in bs_pred_list]
    bs_pred_list = [separate_out_vocab_all_crvs(bs_preds, vocab_set)[0] for bs_preds in bs_pred_list]


    augmented_preds = bs_pred_list.pop(0)

    while bs_pred_list:
        augmented_preds = append_preds(augmented_preds, bs_pred_list.pop(0))

    grid_name_to_augmented_preds[grid_name] = augmented_preds


full_preds = merge_default_and_extra_preds(
    grid_name_to_augmented_preds['default'],
    grid_name_to_augmented_preds['extra'],
    default_idxs,
    extra_idxs)

In [387]:
from collections import defaultdict

n_preds_in_line_dict = defaultdict(int)

for line in full_preds:
    n_preds_in_line_dict[len(line)] += 1

print(n_preds_in_line_dict)

defaultdict(<class 'int'>, {4: 8188, 3: 706, 2: 606, 1: 407, 0: 93})


In [388]:
baseline_preds = None
with open(r"..\data\submissions\sample_submission.csv", 'r', encoding = 'utf-8') as f:
    baseline_preds = f.read().splitlines()
baseline_preds = [line.split(",") for line in augment_lines]

In [389]:
full_preds = append_preds(full_preds, baseline_preds)

In [390]:
create_submission(full_preds,
                  f"../data/submissions/id3_with_baseline_without_old_preds")

In [363]:
full_preds_augmentations = [
    
]