In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import json

from aggregate_predictions import (aggregate_preds_raw_weighted,
                                   aggregate_preds_raw_appendage,
                                   get_default_and_extra_idxs, 
                                   load_preds_to_aggregate,
                                   get_vocab_set,
                                   merge_default_and_extra_preds,
                                   load_baseline_preds,
                                   create_submission,
                                   append_preds,
                                   delete_duplicates_stable)                    

In [3]:
%cd ..

c:\Users\proshian\Documents\yandex_cup_2023_ml_neuroswipe


In [4]:
DATA_ROOT = "data/data_separated_grid/"

grid_name_to_ranged_preds_names = {
    'default': [
        "m1_bigger__m1_bigger_v2__2023_11_12__14_51_49__0.13115__greed_acc_0.86034__default_l2_0_ls0_switch_2.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_12__12_30_29__0.13121__greed_acc_0.86098__default_l2_0_ls0_switch_2.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_11__22_18_35__0.13542_default_l2_0_ls0_switch_1.pt.pkl",
        "m1_v2__m1_v2__2023_11_09__10_36_02__0.14229_default_switch_0.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_12__00_39_33__0.13297_default_l2_0_ls0_switch_1.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_11__14_29_37__0.13679_default_l2_0_ls0_switch_0.pt.pkl",
        
    ],
    'extra': [
        "m1_v2__m1_v2__2023_11_09__17_47_40__0.14301_extra_l2_1e-05_switch_0.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_12__02_27_14__0.13413_extra_l2_0_ls0_switch_1.pt.pkl"
    ]
}

vocab_set = get_vocab_set(os.path.join(DATA_ROOT, "voc.txt"))

default_idxs, extra_idxs = get_default_and_extra_idxs(
    os.path.join(DATA_ROOT, "test.jsonl"))

grid_name_to_aggregated_preds = {}

for grid_name in ('default', 'extra'):
    f_names = grid_name_to_ranged_preds_names[grid_name]
    f_paths = [os.path.join("data/saved_beamsearch_results/", f_name)
                for f_name in f_names]
    
    preds_to_aggregate = load_preds_to_aggregate(f_paths)
    
    aggregated_preds = aggregate_preds_raw_appendage(
        preds_to_aggregate,
        vocab_set,
        limit = 4)

    grid_name_to_aggregated_preds[grid_name] = aggregated_preds
    

full_preds = merge_default_and_extra_preds(
    grid_name_to_aggregated_preds['default'],
    grid_name_to_aggregated_preds['extra'],
    default_idxs,
    extra_idxs)


baseline_preds = load_baseline_preds(r"data\submissions\baseline.csv")
full_preds = append_preds(full_preds, baseline_preds, limit = 4)

create_submission(full_preds,
    f"data/submissions/id3_with_baseline_without_old_preds_check.csv")


In [5]:
def patch_wrong_prediction_shape(prediciton):
    return [pred_el[0] for pred_el in prediciton]

In [6]:
f_names = [
    "m1_bigger__m1_bigger_v2__2023_11_12__14_51_49__0.13115__greed_acc_0.86034__default_l2_0_ls0_switch_2.pt.pkl",
    "m1_bigger__m1_bigger_v2__2023_11_12__12_30_29__0.13121__greed_acc_0.86098__default_l2_0_ls0_switch_2.pt.pkl",
    "m1_bigger__m1_bigger_v2__2023_11_11__22_18_35__0.13542_default_l2_0_ls0_switch_1.pt.pkl",
    "m1_v2__m1_v2__2023_11_09__10_36_02__0.14229_default_switch_0.pt.pkl",
    "m1_bigger__m1_bigger_v2__2023_11_12__00_39_33__0.13297_default_l2_0_ls0_switch_1.pt.pkl",
    # "m1_bigger__m1_bigger_v2__2023_11_11__14_29_37__0.13679_default_l2_0_ls0_switch_0.pt.pkl",
]

f_paths = [os.path.join("data/saved_beamsearch_validation_results/", f_name)
            for f_name in f_names]

preds_to_aggregate = load_preds_to_aggregate(f_paths)
preds_to_aggregate = [patch_wrong_prediction_shape(pred) for pred in preds_to_aggregate]
    
aggregated_preds = aggregate_preds_raw_appendage(
    preds_to_aggregate,
    vocab_set,
    limit = 4)

In [7]:
def get_mmr(preds_list, ref):
    # Works properly if has duplicates or n_line_preds < 4

    MMR = 0
    
    for preds, target in zip(preds_list, ref):
        preds = delete_duplicates_stable(preds)

        weights = [1, 0.1, 0.09, 0.08]

        line_MRR = sum(weights[i] * (pred == target) for i, pred in enumerate(preds))

        MMR += line_MRR
    
    MMR /= len(preds_list)

    return MMR

In [8]:
def get_targets(f_path,
                condition = lambda x: True):
    with open(f_path, 'r', encoding='utf-8') as f:
        targets = [json.loads(line)['word'] for line in f.readlines() if condition(line)]
    
    return targets

In [9]:
val_default_targets = get_targets(os.path.join(DATA_ROOT, "valid__in_train_format__default_only.jsonl"))

In [10]:
print(get_mmr(aggregated_preds, val_default_targets))

0.8958517417162353


In [16]:
preds_to_aggregate = load_preds_to_aggregate(f_paths)
preds_to_aggregate = [patch_wrong_prediction_shape(pred) for pred in preds_to_aggregate]
    
aggregated_preds = aggregate_preds_raw_weighted(
    preds_to_aggregate,
    [1, 0.1, 0.09, 0.08],
    vocab_set,
    limit = 4)

In [17]:
print(get_mmr(aggregated_preds, val_default_targets))

0.8982625318606707


In [28]:
preds_to_aggregate = load_preds_to_aggregate(f_paths)
preds_to_aggregate = [patch_wrong_prediction_shape(pred) for pred in preds_to_aggregate]
    
aggregated_preds = aggregate_preds_raw_weighted(
    preds_to_aggregate[0:1],
    [1],
    vocab_set,
    limit = 4)

In [29]:
print(get_mmr(aggregated_preds, val_default_targets))

0.8929800339847141


In [72]:
preds_to_aggregate = load_preds_to_aggregate(f_paths)
preds_to_aggregate = [patch_wrong_prediction_shape(pred) for pred in preds_to_aggregate]
    
aggregated_preds = aggregate_preds_raw_weighted(
    preds_to_aggregate[0:2],
    [100, 95],
    vocab_set,
    limit = 4)

In [73]:
print(get_mmr(aggregated_preds, val_default_targets))

0.8962956669498799


In [None]:
grid_name_to_f_names = {
    'default': [
        "m1_bigger__m1_bigger_v2__2023_11_12__14_51_49__0.13115__greed_acc_0.86034__default_l2_0_ls0_switch_2.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_12__12_30_29__0.13121__greed_acc_0.86098__default_l2_0_ls0_switch_2.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_11__22_18_35__0.13542_default_l2_0_ls0_switch_1.pt.pkl",
        "m1_v2__m1_v2__2023_11_09__10_36_02__0.14229_default_switch_0.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_12__00_39_33__0.13297_default_l2_0_ls0_switch_1.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_11__14_29_37__0.13679_default_l2_0_ls0_switch_0.pt.pkl",
    ],
    'extra': [
        "m1_v2__m1_v2__2023_11_09__17_47_40__0.14301_extra_l2_1e-05_switch_0.pt.pkl",
        "m1_bigger__m1_bigger_v2__2023_11_12__02_27_14__0.13413_extra_l2_0_ls0_switch_1.pt.pkl"
    ]
}

test_preds_root = "data/saved_beamsearch_validation_results/"

grid_name_to_test_f_paths = {
    grid_name: [os.path.join(test_preds_root, f_name)
                for f_name in grid_name_to_f_names[grid_name]]
    for grid_name in grid_name_to_f_names.keys()
}


In [None]:
preds_to_aggregate = load_preds_to_aggregate(grid_name_to_test_f_paths['default'])
    
aggregated_default_preds = aggregate_preds_raw_weighted(
    preds_to_aggregate,
    [1, 0.1, 0.09, 0.08],
    vocab_set,
    limit = 4)