In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from typing import List, Tuple, Callable

import torch
import numpy as np

# from model import get_m1_model, get_m1_bigger_model, get_m1_smaller_model
from model import MODEL_GETTERS_DICT
from ns_tokenizers import CharLevelTokenizerv2, KeyboardTokenizerv1
from dataset import CurveDataset, CurveDatasetSubset
from word_generators_v2 import GreedyGenerator, BeamGenerator, WordGenerator
from metrics import get_mmr
from feature_extraction.feature_extractors import weights_function_v1
from feature_extraction.feature_extractors import get_val_transform

In [4]:
IN_KAGGLE = False

if IN_KAGGLE:
    DATA_ROOT = "/kaggle/input/yandex-cup-playground"
    MODELS_DIR = ""
else:
    DATA_ROOT = "../data/data_separated_grid"
    MODELS_ROOT = "../data/trained_models_for_final_submit"

In [5]:
VAL_PATH = os.path.join(DATA_ROOT, "valid__in_train_format.jsonl")
TEST_PATH = os.path.join(DATA_ROOT, "test.jsonl")

VOCAB_PATH = os.path.join(DATA_ROOT, "voc.txt")
GRID_NAME_TO_GRID_PATH = os.path.join(DATA_ROOT, "gridname_to_grid.json")

In [6]:
kb_tokenizer = KeyboardTokenizerv1()
char_tokenizer = CharLevelTokenizerv2(VOCAB_PATH)

data_paths = [VAL_PATH, TEST_PATH]

# takes almost no time
dist_transform_v1 = get_val_transform(
    gridname_to_grid_path=GRID_NAME_TO_GRID_PATH,
    grid_names=('default', 'extra'),
    transform_name="traj_feats_and_distance_weights",
    char_tokenizer=char_tokenizer,
    dist_weights_func=weights_function_v1,
    include_time=False,
    include_velocities=True,
    include_accelerations=True
)

# takes a lot of time
kb_transform = get_val_transform(
    gridname_to_grid_path=GRID_NAME_TO_GRID_PATH,
    grid_names=('default', 'extra'),
    transform_name="traj_feats_and_nearest_key",
    char_tokenizer=char_tokenizer,
    uniform_noise_range=0,
    ds_paths_list=data_paths,
    totals = [10_000, 10_000],
    include_time=False,
    include_velocities=True,
    include_accelerations=True
)

Accumulating out-of-bounds coordinates...


  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

augmenting gname_to_out_of_bounds


In [7]:
val_dataset = CurveDataset(
    data_path=VAL_PATH,
    store_gnames=True,
    init_transform=None,
    # get_item_transform=kb_transform,
    get_item_transform=dist_transform_v1,
)

test_dataset = CurveDataset(
    data_path=TEST_PATH,
    store_gnames=True,
    init_transform=None,
    # get_item_transform=kb_transform,
    get_item_transform=dist_transform_v1,
)

10000it [00:00, 19441.20it/s]
10000it [00:00, 19562.92it/s]


In [8]:
val_default_dataset = CurveDatasetSubset(val_dataset, "default")
val_extra_dataset = CurveDatasetSubset(val_dataset, "extra")

test_default_dataset = CurveDatasetSubset(test_dataset, "default")
test_extra_dataset = CurveDatasetSubset(test_dataset, "extra")

In [9]:
def get_targets(dataset: CurveDataset, 
                char_tokenizer: CharLevelTokenizerv2) -> List[str]:
    targets = []
    for _, target_tokens in dataset:
        target = char_tokenizer.decode(target_tokens[:-1])
        targets.append(target)
    return targets

In [10]:
def get_vocab_set(vocab_path: str):
    with open(vocab_path, 'r', encoding = "utf-8") as f:
        return set(f.read().splitlines())

In [11]:
vocab_set = get_vocab_set(os.path.join(DATA_ROOT, "voc.txt"))

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [13]:
MODEL_GETTERS_DICT.keys()

dict_keys(['v3_weighted_and_traj_transformer_bigger', 'v3_nearest_and_traj_transformer_bigger', 'v3_nearest_only_transformer_bigger', 'v3_trainable_gaussian_weights_and_traj_transformer_bigger', 'm1', 'm1_bigger', 'm1_smaller'])

In [14]:
val_default_targets = get_targets(val_default_dataset, char_tokenizer)

In [15]:
grid_name = "default"
model_getter = MODEL_GETTERS_DICT['v3_weighted_and_traj_transformer_bigger']
weights_path = r"../results/models_for_debug/weighted_transformer_bigger-default--epoch=60-val_loss=0.442-val_word_level_accuracy=0.875.pt"

# model_getter = MODEL_GETTERS_DICT['v3_nearest_and_traj_transformer_bigger']
# weights_path = r"..\results\models_for_debug\my_features_1\v3_nearest_and_traj_transformer_bigger-default--epoch=32-val_loss=0.441-val_word_level_accuracy=0.864.pt"

model = model_getter(device, weights_path).eval()
grid_name_to_greedy_generator = {grid_name: GreedyGenerator(model, char_tokenizer, device)}

In [16]:
def remove_suffix(s: str, suffix: str) -> str:
    if s.endswith(suffix):
        return s[:-len(suffix)]
    return s

In [17]:
n_classes = 35

greedy_generator__no_vocab = GreedyGenerator(model, char_tokenizer, device)
beam_generator__no_vocab = BeamGenerator(model, char_tokenizer, device)

greedy_generator_with_vocab = GreedyGenerator(model, char_tokenizer, device, vocab_set, max_token_id=n_classes-1)
beam_generator__with_vocab = BeamGenerator(model, char_tokenizer, device, vocab=vocab_set, max_token_id=n_classes-1)

In [18]:
def denormalize_beamsearch(score, pred_len, normalization_factor):
    return score * (pred_len + 1)**normalization_factor

def beamsearch_score_to_prob(score, pred_len, normalization_factor):   
    return np.exp(denormalize_beamsearch(-score, pred_len, normalization_factor))

In [19]:
MAX_WORD_LEN = 35

def predict(word_generator: WordGenerator, dataset, n_hypotheses, scores_to_prob: Callable, 
            max_steps_n: int = MAX_WORD_LEN, verbose = True, 
            n_examples = None, generator_call_kwargs = None, ) -> List[List[Tuple[str, float]]]:
    
    curve_id_to_hypotheses = []

    n_examples = n_examples or len(dataset)
    generator_call_kwargs = {} if generator_call_kwargs is None else generator_call_kwargs
    

    if verbose:
        print(("{:<22}" + "{:<29}" * n_hypotheses).format("target", *[f"pred{i}" for i in range(1, n_hypotheses+1)]))
        print("-"*(15+30*n_hypotheses))

    for i, data in enumerate(val_default_dataset):
        if i >= n_examples:
            return curve_id_to_hypotheses

        (encoder_in, dec_in), target = data

        scores_and_preds_full = word_generator(encoder_in, max_steps_n = max_steps_n, 
            **generator_call_kwargs)

        preds_and_probs_full = [(pred, scores_to_prob(pred, score)) 
                                for score, pred in scores_and_preds_full]

        curve_id_to_hypotheses.append(preds_and_probs_full)

        true_label = char_tokenizer.decode(target[:-1])

        flat_preds_and_scores = (item for pair in preds_and_probs_full[:n_hypotheses] for item in pair)
        if verbose:
            print(("{:<15}   |   " + "{:<16}{:.4f}   |   " *n_hypotheses ).format(true_label, *flat_preds_and_scores))
    
    return curve_id_to_hypotheses

In [20]:
def remove_scores_from_results(swipe_id_to_hypotheses_lst: list) -> list:
    return [[pred for pred, score in item_hypothses_lst] for item_hypothses_lst in swipe_id_to_hypotheses_lst]

In [21]:
N_EXAMPLES = 40

In [22]:
curve_id_to_hypotheses = predict(
    greedy_generator_with_vocab, val_dataset, n_hypotheses=1, 
    scores_to_prob=lambda pred, score: np.exp(-score), n_examples=N_EXAMPLES)

target                pred1                        
---------------------------------------------
на                |   на              0.8950   |   
все               |   все             0.8705   |   
добрый            |   добрый          0.8608   |   
девочка           |   девочка         0.8550   |   
сказала           |   сказала         0.8633   |   
скинь             |   скинь           0.8786   |   
геев              |   геев            0.8842   |   
тобой             |   тобой           0.8880   |   
была              |   баса            0.4085   |   
да                |   да              0.8957   |   
муж               |   муж             0.8350   |   
щас               |   щас             0.9516   |   
она               |   она             0.9011   |   
проблема          |   проблема        0.8468   |   
билайн            |   билайн          0.8468   |   
уже               |   уже             0.9064   |   
раньше            |   раньше          0.8657   |   
рам               

In [23]:
get_mmr(
    remove_scores_from_results(curve_id_to_hypotheses), 
    val_default_targets[:N_EXAMPLES])

0.925

In [24]:
curve_id_to_hypotheses = predict(
    greedy_generator__no_vocab, val_dataset, n_hypotheses=1, 
    scores_to_prob=lambda pred, score: np.exp(-score), n_examples=N_EXAMPLES)

target                pred1                        
---------------------------------------------
на                |   на              0.8729   |   
все               |   все             0.8297   |   
добрый            |   добрый          0.7269   |   
девочка           |   девочка         0.6857   |   
сказала           |   сказала         0.7000   |   
скинь             |   скинь           0.7596   |   
геев              |   геев            0.7808   |   
тобой             |   тобой           0.7647   |   
была              |   баса            0.3738   |   
да                |   да              0.8705   |   
муж               |   муж             0.7908   |   
щас               |   щас             0.8447   |   
она               |   она             0.8367   |   
проблема          |   проблема        0.6630   |   
билайн            |   билайн          0.7044   |   
уже               |   уже             0.8313   |   
раньше            |   раньше          0.7320   |   
рам               

In [25]:
get_mmr(
    remove_scores_from_results(curve_id_to_hypotheses), 
    val_default_targets[:N_EXAMPLES])

0.925

In [26]:
n_examples = N_EXAMPLES
normalization_factor = 0.5
beamsize = 6

curve_id_to_hypotheses = predict(
    beam_generator__with_vocab, val_dataset, n_hypotheses=4, n_examples=n_examples,
    generator_call_kwargs={'normalization_factor': normalization_factor, 'beamsize': beamsize}, 
    scores_to_prob =lambda pred, score: beamsearch_score_to_prob(score, len(pred) + 1, normalization_factor))

target                pred1                        pred2                        pred3                        pred4                        
---------------------------------------------------------------------------------------------------------------------------------------
на                |   на              0.8950   |   нан             0.0014   |   нам             0.0013   |   нас             0.0013   |   
все               |   все             0.8705   |   всенародная     0.0002   |   всенародно      0.0001   |   всенародную     0.0001   |   
добрый            |   добрый          0.8608   |   добрый-добрый   0.0011   |   добрым          0.0039   |   добрые          0.0019   |   
девочка           |   девочка         0.8550   |   девочка-волшебница0.0002   |   девочки         0.0029   |   девочку         0.0025   |   
сказала           |   сказала         0.8633   |   сказал-сделал   0.0008   |   сказал          0.0037   |   сказали         0.0026   |   
скинь             |   скинь 

In [27]:
get_mmr(
    remove_scores_from_results(curve_id_to_hypotheses), 
    val_default_targets[:N_EXAMPLES])

0.9295

In [28]:
n_examples = N_EXAMPLES
normalization_factor = 0.5

curve_id_to_hypotheses = predict(
    beam_generator__no_vocab, val_dataset, n_hypotheses=4, n_examples=n_examples,
    generator_call_kwargs={'normalization_factor': normalization_factor, 'beamsize': beamsize}, 
    scores_to_prob =lambda pred, score: beamsearch_score_to_prob(score, len(pred) + 1, normalization_factor))

target                pred1                        pred2                        pred3                        pred4                        
---------------------------------------------------------------------------------------------------------------------------------------
на                |   на              0.8729   |   нан             0.0014   |   нак             0.0013   |   наз             0.0012   |   
все               |   все             0.8297   |   всем            0.0015   |   вссе            0.0012   |   всен            0.0012   |   
добрый            |   добрый          0.7269   |   добрым          0.0032   |   добрыйо         0.0011   |   добрыйи         0.0010   |   
девочка           |   девочка         0.6857   |   девочки         0.0023   |   девочку         0.0020   |   девочкам        0.0013   |   
сказала           |   сказала         0.7000   |   сказал          0.0031   |   сказали         0.0021   |   сказалась       0.0009   |   
скинь             |   скинь   

In [29]:
get_mmr(
    remove_scores_from_results(curve_id_to_hypotheses), 
    val_default_targets[:N_EXAMPLES])

0.9272500000000001

In [30]:
print(curve_id_to_hypotheses[0][:6])
print()
print(curve_id_to_hypotheses[1][:6])

[('на', 0.8729091854152667), ('нан', 0.0013687703829175547), ('нак', 0.0012917311480830328), ('наз', 0.001240692764716203), ('нас', 0.0012353012520421001), ('нам', 0.0012299431799445055)]

[('все', 0.8296965372726712), ('всем', 0.0014837087743218484), ('вссе', 0.001230527051960393), ('всен', 0.001221845670644851), ('всел', 0.0011628419218424532), ('всев', 0.0011594028348149632)]


---------------
Как и ожидается, после денормализации бимсерча вероятности не такие же как в greedy

Нужно отметить, что в общем случае должно быть так:
* greedy_search_no_vocab__results == beamsearch__no_vocab__results
* greedy_search_with_voc__results == beamsearch__with_voc__results
* greedy_search_with_voc__results != beamsearch__no_vocab__results
* greedy_search__no_vocab__results != beamsearch__with_voc__results



Некоторые заметки:
1. Иногда вероятность оказывается немонотонной. Это именно из-за того, что мы убрали нормализацию бимсерча. То есть убрали штраф за краткость, а при ранжировании он был
2. Даже если beamsize = n_classes, beamsearch не оценивает явно вероятность всех слов, потому что из-за over confidence вероятности некоторых возможных токенов оказываются нулевыми и эта ветка обрывается
---------------

# Extra: Batched Greedy Search example


Batched greedy search Pros:
* produces the same results as unbatched variant (make sure to compare with no_vocab version)
* works faster
Batched greedy search Cons:
* doesn't support vocab masking yet
* has a different interface

In [31]:
from torch.utils.data import DataLoader

from dataset import CollateFnV2
from word_generators_v2 import GreedyGeneratorBatched

In [32]:
collate_fn = CollateFnV2(batch_first=False, word_pad_idx=char_tokenizer.char_to_idx['<pad>'])
val_default_dataloader = DataLoader(val_default_dataset, batch_size=N_EXAMPLES, shuffle=False, collate_fn=collate_fn)

In [33]:
greedy_generator_batched__no_vocab = GreedyGeneratorBatched(model, char_tokenizer, device)

In [34]:
MAX_WORD_LEN = 35


def remove_tokens(tensor, tokens_to_remove):
    return tensor[~torch.isin(tensor, torch.tensor(tokens_to_remove))]

def predict_batched(word_generator: WordGenerator, dataloader, max_steps: int = MAX_WORD_LEN, 
                    verbose: bool = True, n_batches: int = None, generator_call_kwargs = None, 
                    ):
    n_hypotheses = 1

    generator_call_kwargs = {} if generator_call_kwargs is None else generator_call_kwargs

    pad_token_id = word_generator.tokenizer.char_to_idx['<pad>']
    eos_token_id = word_generator.tokenizer.char_to_idx['<eos>']
    sos_token_id = word_generator.tokenizer.char_to_idx['<sos>']



    if verbose:
        print(("{:<22}" + "{:<29}" * n_hypotheses).format("target", *[f"pred{i}" for i in range(1, n_hypotheses+1)]))
        print("-"*(15+30*n_hypotheses))

    for i, ((encoder_in, _, encoder_pad_mask, _), target) in enumerate(dataloader):
        target = target.T
        if n_batches is not None and i >= n_batches:
            break

        with torch.no_grad():
            char_token_ids, log_probs = word_generator(encoder_in, encoder_pad_mask, max_steps, **generator_call_kwargs)
            char_token_ids = char_token_ids.T

            pred_words = [char_tokenizer.decode(remove_tokens(char_token_ids[i], [pad_token_id, eos_token_id, sos_token_id])) 
                          for i in range(char_token_ids.size(0))]
            target_words = [char_tokenizer.decode(remove_tokens(target[i], [pad_token_id, eos_token_id, sos_token_id]))
                            for i in range(target.size(0))]
            log_probs = log_probs.cpu().numpy()
            probs = np.exp(log_probs)

            if verbose:
                for true_label, pred_word, log_prob in zip(target_words, pred_words, probs):
                    print(("{:<15}   |   " + "{:<16}{:.4f}   |   " *n_hypotheses ).format(true_label, pred_word, log_prob))
        


In [35]:
predict_batched(greedy_generator_batched__no_vocab, val_default_dataloader, n_batches=1)

target                pred1                        
---------------------------------------------
на                |   на              0.8729   |   
все               |   все             0.8297   |   
добрый            |   добрый          0.7269   |   
девочка           |   девочка         0.6857   |   
сказала           |   сказала         0.7000   |   
скинь             |   скинь           0.7596   |   
геев              |   геев            0.7808   |   
тобой             |   тобой           0.7647   |   
была              |   баса            0.3738   |   
да                |   да              0.8705   |   
муж               |   муж             0.7908   |   
щас               |   щас             0.8447   |   
она               |   она             0.8367   |   
проблема          |   проблема        0.6630   |   
билайн            |   билайн          0.7044   |   
уже               |   уже             0.8313   |   
раньше            |   раньше          0.7320   |   
рам               