In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import json
from typing import List, Tuple, Callable

import torch
import numpy as np

# from model import get_m1_model, get_m1_bigger_model, get_m1_smaller_model
from model import _get_transformer__vn1
from ns_tokenizers import CharLevelTokenizerv2, KeyboardTokenizer
from dataset import SwipeDataset, SwipeDatasetSubset
from word_generators import GreedyGenerator, BeamGenerator, WordGenerator
from metrics import get_mmr
from feature_extraction.swipe_feature_extractor_factory import swipe_feature_extractor_factory
from logit_processors import VocabularyLogitProcessor
from model import get_transformer__from_spe_config__vn1


In [3]:
DATA_ROOT = "../data/data_preprocessed"
MODELS_DIR = ""

In [4]:
VAL_PATH = os.path.join(DATA_ROOT, "valid.jsonl")
TEST_PATH = os.path.join(DATA_ROOT, "test.jsonl")

VOCAB_PATH = os.path.join(DATA_ROOT, "voc.txt")
TOKENIZER_PATH = "../tokenizers/keyboard/ru.json"
GRID_NAME_TO_GRID_PATH = os.path.join(DATA_ROOT, "gridname_to_grid.json")
TRAJECTORY_FEATURES_STATISTICS_PATH = os.path.join(DATA_ROOT, "trajectory_features_statistics.json")
BOUNDING_BOXES_PATH = os.path.join(DATA_ROOT, "key_bounding_boxes.json")
WEIGHTED_SWIPE_FEATURE_EXTRACTOR_CONFIG_PATH = "../configs/feature_extractor/traj_and_weights_v1.json"
TRAJ_AND_NEAREST_SWIPE_FEATURE_EXTRACTOR_CONFIG_PATH = "../configs/feature_extractor/traj_and_nearest.json"
TRAJ_AND_NEAREST_SWIPE_POINT_EMBEDDER_CONFIG_PATH = "../configs/swipe_point_embedder/separate_traj_and_nearest__6_coord.json"

In [5]:
def read_json(path: str):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

In [6]:
kb_tokenizer = KeyboardTokenizer(TOKENIZER_PATH)
subword_tokenizer = CharLevelTokenizerv2(VOCAB_PATH)

data_paths = [VAL_PATH, TEST_PATH]

gridname_to_grid = read_json(GRID_NAME_TO_GRID_PATH)
trajectory_features_statistics = read_json(TRAJECTORY_FEATURES_STATISTICS_PATH)
bounding_boxes = read_json(BOUNDING_BOXES_PATH)

# grid_name_to_weighted_swipe_feature_extractor = {
#     grid_name: swipe_feature_extractor_factory(
#         grid=grid,
#         keyboard_tokenizer=kb_tokenizer,
#         trajectory_features_statistics=trajectory_features_statistics,
#         bounding_boxes=bounding_boxes,
#         grid_name=grid_name,
#         component_configs=read_json(WEIGHTED_SWIPE_FEATURE_EXTRACTOR_CONFIG_PATH)
#     )
#     for grid_name, grid in gridname_to_grid.items()
# }

grid_name_to_traj_and_nearest_swipe_feature_extractor = {
    grid_name: swipe_feature_extractor_factory(
        grid=grid,
        keyboard_tokenizer=kb_tokenizer,
        trajectory_features_statistics=trajectory_features_statistics,
        bounding_boxes=bounding_boxes,
        grid_name=grid_name,
        component_configs=read_json(TRAJ_AND_NEAREST_SWIPE_FEATURE_EXTRACTOR_CONFIG_PATH)
    )
    for grid_name, grid in gridname_to_grid.items()
}

In [7]:
val_dataset = SwipeDataset(
    data_path=VAL_PATH,
    word_tokenizer=subword_tokenizer,
    grid_name_to_swipe_feature_extractor=grid_name_to_traj_and_nearest_swipe_feature_extractor,
)

test_dataset = SwipeDataset(
    data_path=TEST_PATH,
    word_tokenizer=subword_tokenizer,
    grid_name_to_swipe_feature_extractor=grid_name_to_traj_and_nearest_swipe_feature_extractor,
)

0it [00:00, ?it/s]

10000it [00:00, 55067.57it/s]
10000it [00:00, 41164.82it/s]
10000it [00:00, 70957.72it/s]
10000it [00:00, 71013.59it/s]


In [8]:
val_default_dataset = SwipeDatasetSubset(val_dataset, "default")
val_extra_dataset = SwipeDatasetSubset(val_dataset, "extra")

test_default_dataset = SwipeDatasetSubset(test_dataset, "default")
test_extra_dataset = SwipeDatasetSubset(test_dataset, "extra")

In [9]:
def get_targets(dataset: SwipeDataset, 
                subword_tokenizer: CharLevelTokenizerv2) -> List[str]:
    targets = []
    for _, target_tokens in dataset:
        target = subword_tokenizer.decode(target_tokens[:-1])
        targets.append(target)
    return targets

In [10]:
def get_vocab_set(vocab_path: str):
    with open(vocab_path, 'r', encoding = "utf-8") as f:
        return set(f.read().splitlines())

In [11]:
vocab_set = get_vocab_set(os.path.join(DATA_ROOT, "voc.txt"))

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [13]:
val_default_targets = get_targets(val_default_dataset, subword_tokenizer)

In [14]:
from utils.ckpt_to_pt import convert_and_save_file

In [15]:
ckpt_path = "../checkpoints/traj_and_nearest/epoch_end/v3_nearest_and_traj_transformer_bigger-default--epoch=36-val_loss=0.441-val_word_level_accuracy=0.872.ckpt"
weights_path = ckpt_path.replace('checkpoints', 'state_dicts').replace('.ckpt', '.pt')
convert_and_save_file(ckpt_path, out_path=weights_path, device=device, make_dir_if_absent=True)

In [16]:
GRID_NAME = "default"
NUM_CLASSES = 35
MAX_OUT_SEQ_LEN = 35


spe_config = read_json(TRAJ_AND_NEAREST_SWIPE_POINT_EMBEDDER_CONFIG_PATH)



model = get_transformer__from_spe_config__vn1(
    spe_config=spe_config,
    n_classes= NUM_CLASSES,
    n_word_tokens=len(subword_tokenizer.char_to_idx),
    max_out_seq_len=MAX_OUT_SEQ_LEN,
    device=device,
    weights_path=weights_path
)



In [17]:
def remove_suffix(s: str, suffix: str) -> str:
    if s.endswith(suffix):
        return s[:-len(suffix)]
    return s

In [18]:
n_classes = 35
vocab_logits_processor = VocabularyLogitProcessor(subword_tokenizer, vocab_set, max_token_id=n_classes - 1)

In [19]:
BEAMSEARCH_NORMALIZATION_FACTOR = 0.5
BEAMSEARCH_BEAM_SIZE = 6


greedy_generator__no_vocab = GreedyGenerator(model, subword_tokenizer, device, max_steps_n=MAX_OUT_SEQ_LEN)
beam_generator__no_vocab = BeamGenerator(model, subword_tokenizer, device, max_steps_n=MAX_OUT_SEQ_LEN, beamsize=BEAMSEARCH_BEAM_SIZE, normalization_factor=BEAMSEARCH_NORMALIZATION_FACTOR)

greedy_generator_with_vocab = GreedyGenerator(model, subword_tokenizer, device, logit_processor=vocab_logits_processor, max_steps_n=MAX_OUT_SEQ_LEN)
beam_generator__with_vocab = BeamGenerator(model, subword_tokenizer, device, logit_processor=vocab_logits_processor, max_steps_n=MAX_OUT_SEQ_LEN, beamsize=BEAMSEARCH_BEAM_SIZE, normalization_factor=BEAMSEARCH_NORMALIZATION_FACTOR)

In [20]:
def denormalize_beamsearch(score, pred_len, normalization_factor):
    return score * (pred_len + 1)**normalization_factor

def beamsearch_score_to_prob(score, pred_len, normalization_factor):   
    return np.exp(denormalize_beamsearch(-score, pred_len, normalization_factor))

In [21]:
MAX_WORD_LEN = 35

def predict(word_generator: WordGenerator, dataset, n_hypotheses, scores_to_prob: Callable, 
            verbose = True, n_examples = None,) -> List[List[Tuple[str, float]]]:
    curve_id_to_hypotheses = []
    
    n_examples = n_examples or len(dataset)

    if verbose:
        print(("{:<22}" + "{:<29}" * n_hypotheses).format("target", *[f"pred{i}" for i in range(1, n_hypotheses+1)]))
        print("-"*(15+30*n_hypotheses))

    for i, data in enumerate(val_default_dataset):
        if i >= n_examples:
            return curve_id_to_hypotheses

        (encoder_in, dec_in), target = data

        scores_and_preds_full = word_generator(encoder_in)

        preds_and_probs_full = [(pred, scores_to_prob(pred, score)) 
                                for score, pred in scores_and_preds_full]

        curve_id_to_hypotheses.append(preds_and_probs_full)

        true_label = subword_tokenizer.decode(target[:-1])

        flat_preds_and_scores = (item for pair in preds_and_probs_full[:n_hypotheses] for item in pair)
        if verbose:
            print(("{:<15}   |   " + "{:<16}{:.4f}   |   " *n_hypotheses ).format(true_label, *flat_preds_and_scores))
    
    return curve_id_to_hypotheses

In [22]:
def remove_scores_from_results(swipe_id_to_hypotheses_lst: list) -> list:
    return [[pred for pred, score in item_hypothses_lst] for item_hypothses_lst in swipe_id_to_hypotheses_lst]

In [23]:
N_EXAMPLES = 40

In [24]:
curve_id_to_hypotheses = predict(
    greedy_generator_with_vocab, val_dataset, n_hypotheses=1,
    scores_to_prob=lambda pred, score: np.exp(-score), n_examples=N_EXAMPLES)

target                pred1                        
---------------------------------------------
на                |   на              0.8943   |   
все               |   все             0.8726   |   
добрый            |   добрый          0.8490   |   
девочка           |   девочка         0.8544   |   
сказала           |   сказала         0.8611   |   
скинь             |   скинь           0.8763   |   
геев              |   геев            0.6401   |   
тобой             |   тобой           0.8852   |   
была              |   быстра          0.5474   |   
да                |   да              0.8956   |   
муж               |   муж             0.8123   |   
щас               |   щас             0.9444   |   
она               |   она             0.8963   |   
проблема          |   проблема        0.8444   |   
билайн            |   билайн          0.7736   |   
уже               |   уже             0.9006   |   
раньше            |   раньше          0.8695   |   
рам               

In [25]:
get_mmr(
    remove_scores_from_results(curve_id_to_hypotheses), 
    val_default_targets[:N_EXAMPLES])

0.925

In [26]:
curve_id_to_hypotheses = predict(
    greedy_generator__no_vocab, val_dataset, n_hypotheses=1, 
    scores_to_prob=lambda pred, score: np.exp(-score), n_examples=N_EXAMPLES)

target                pred1                        
---------------------------------------------
на                |   на              0.8725   |   
все               |   все             0.8322   |   
добрый            |   добрый          0.7148   |   
девочка           |   девочка         0.6902   |   
сказала           |   сказала         0.6949   |   
скинь             |   скинь           0.7548   |   
геев              |   геев            0.5662   |   
тобой             |   тобой           0.7656   |   
была              |   быса            0.7354   |   
да                |   да              0.8700   |   
муж               |   муж             0.7696   |   
щас               |   щас             0.8363   |   
она               |   она             0.8290   |   
проблема          |   проблема        0.6674   |   
билайн            |   билайн          0.6373   |   
уже               |   уже             0.8241   |   
раньше            |   раньше          0.7333   |   
рам               

In [27]:
get_mmr(
    remove_scores_from_results(curve_id_to_hypotheses), 
    val_default_targets[:N_EXAMPLES])

0.925

In [28]:
n_examples = N_EXAMPLES

curve_id_to_hypotheses = predict(
    beam_generator__with_vocab, val_dataset, n_hypotheses=4, n_examples=n_examples, 
    scores_to_prob =lambda pred, score: beamsearch_score_to_prob(score, len(pred) + 1, BEAMSEARCH_NORMALIZATION_FACTOR))

target                pred1                        pred2                        pred3                        pred4                        
---------------------------------------------------------------------------------------------------------------------------------------
на                |   на              0.8943   |   наама           0.0013   |   наособицу       0.0002   |   наощупь         0.0002   |   
все               |   все             0.8726   |   всенародная     0.0002   |   всепожирающего  0.0001   |   все-все         0.0008   |   
добрый            |   добрый          0.8490   |   доброй          0.0119   |   добрый-добрый   0.0013   |   добрые          0.0028   |   
девочка           |   девочка         0.8544   |   девочки         0.0056   |   девочка-волшебница0.0003   |   девочку         0.0042   |   
сказала           |   сказала         0.8611   |   сказал-мужик    0.0008   |   сказали         0.0027   |   сказал          0.0033   |   
скинь             |   скинь 

In [29]:
get_mmr(
    remove_scores_from_results(curve_id_to_hypotheses), 
    val_default_targets[:N_EXAMPLES])

0.9525

In [30]:
n_examples = N_EXAMPLES

curve_id_to_hypotheses = predict(
    beam_generator__no_vocab, val_dataset, n_hypotheses=4, n_examples=n_examples,
    scores_to_prob =lambda pred, score: beamsearch_score_to_prob(score, len(pred) + 1, BEAMSEARCH_NORMALIZATION_FACTOR))

target                pred1                        pred2                        pred3                        pred4                        
---------------------------------------------------------------------------------------------------------------------------------------
на                |   на              0.8725   |   нас             0.0014   |   нам             0.0013   |   наа             0.0012   |   
все               |   все             0.8322   |   всем            0.0012   |   все-            0.0012   |   всео            0.0012   |   
добрый            |   добрый          0.7148   |   доброй          0.0102   |   добрые          0.0024   |   добрый-         0.0010   |   
девочка           |   девочка         0.6902   |   девочки         0.0045   |   девочку         0.0033   |   девочкам        0.0015   |   
сказала           |   сказала         0.6949   |   сказали         0.0022   |   сказал          0.0028   |   сказалаща       0.0007   |   
скинь             |   скинь   

In [31]:
get_mmr(
    remove_scores_from_results(curve_id_to_hypotheses), 
    val_default_targets[:N_EXAMPLES])

0.9275

In [32]:
print(curve_id_to_hypotheses[0][:6])
print()
print(curve_id_to_hypotheses[1][:6])

[('на', np.float64(0.8725337144269064)), ('нас', np.float64(0.0013806192692899187)), ('нам', np.float64(0.0012561376235014612)), ('наа', np.float64(0.0012083352994650847)), ('нан', np.float64(0.0012065850960813914)), ('нао', np.float64(0.0011999743315939924))]

[('все', np.float64(0.8322106148932569)), ('всем', np.float64(0.0011985692692927101)), ('все-', np.float64(0.0011786678999622794)), ('всео', np.float64(0.0011669255543384498)), ('всеп', np.float64(0.0011558147330768588)), ('всен', np.float64(0.0011325914587499237))]


---------------
Как и ожидается, после денормализации бимсерча вероятности не такие же как в greedy

Нужно отметить, что в общем случае должно быть так:
* greedy_search_no_vocab__results == beamsearch__no_vocab__results
* greedy_search_with_voc__results == beamsearch__with_voc__results
* greedy_search_with_voc__results != beamsearch__no_vocab__results
* greedy_search__no_vocab__results != beamsearch__with_voc__results



Некоторые заметки:
1. Иногда вероятность оказывается немонотонной. Это именно из-за того, что мы убрали нормализацию бимсерча. То есть убрали штраф за краткость, а при ранжировании он был
2. Даже если beamsize = n_classes, beamsearch не оценивает явно вероятность всех слов, потому что из-за over confidence вероятности некоторых возможных токенов оказываются нулевыми и эта ветка обрывается
---------------

# Extra: Batched Greedy Search example


Batched greedy search Pros:
* produces the same results as unbatched variant (make sure to compare with no_vocab version)
* works faster
Batched greedy search Cons:
* doesn't support vocab masking yet
* has a different interface

In [33]:
from torch.utils.data import DataLoader

from dataset import CollateFn
from word_generators import GreedyGeneratorBatched

In [34]:
collate_fn = CollateFn(batch_first=False, word_pad_idx=subword_tokenizer.char_to_idx['<pad>'])
val_default_dataloader = DataLoader(val_default_dataset, batch_size=N_EXAMPLES, shuffle=False, collate_fn=collate_fn)

In [35]:
greedy_generator_batched__no_vocab = GreedyGeneratorBatched(model, subword_tokenizer, device, max_steps_n=MAX_OUT_SEQ_LEN)

In [36]:
MAX_WORD_LEN = 35


def remove_tokens(tensor, tokens_to_remove):
    device = tensor.device
    return tensor[~torch.isin(tensor, torch.tensor(tokens_to_remove, device=device))]

def predict_batched(word_generator: WordGenerator, dataloader, max_steps: int = MAX_WORD_LEN, 
                    verbose: bool = True, n_batches: int = None 
                    ):
    n_hypotheses = 1

    pad_token_id = word_generator.tokenizer.char_to_idx['<pad>']
    eos_token_id = word_generator.tokenizer.char_to_idx['<eos>']
    sos_token_id = word_generator.tokenizer.char_to_idx['<sos>']



    if verbose:
        print(("{:<22}" + "{:<29}" * n_hypotheses).format("target", *[f"pred{i}" for i in range(1, n_hypotheses+1)]))
        print("-"*(15+30*n_hypotheses))

    for i, ((encoder_in, _, encoder_pad_mask, _), target) in enumerate(dataloader):
        target = target.T
        if n_batches is not None and i >= n_batches:
            break

        with torch.no_grad():
            char_token_ids, log_probs = word_generator(encoder_in, encoder_pad_mask)
            char_token_ids = char_token_ids.T

            pred_words = [subword_tokenizer.decode(remove_tokens(char_token_ids[i], [pad_token_id, eos_token_id, sos_token_id])) 
                          for i in range(char_token_ids.size(0))]
            target_words = [subword_tokenizer.decode(remove_tokens(target[i], [pad_token_id, eos_token_id, sos_token_id]))
                            for i in range(target.size(0))]
            log_probs = log_probs.cpu().numpy()
            probs = np.exp(log_probs)

            if verbose:
                for true_label, pred_word, log_prob in zip(target_words, pred_words, probs):
                    print(("{:<15}   |   " + "{:<16}{:.4f}   |   " *n_hypotheses ).format(true_label, pred_word, log_prob))
        


In [37]:
predict_batched(greedy_generator_batched__no_vocab, val_default_dataloader, n_batches=1)

target                pred1                        
---------------------------------------------
на                |   на              0.8725   |   
все               |   все             0.8322   |   
добрый            |   добрый          0.7148   |   
девочка           |   девочка         0.6902   |   
сказала           |   сказала         0.6949   |   
скинь             |   скинь           0.7548   |   
геев              |   геев            0.5662   |   
тобой             |   тобой           0.7656   |   
была              |   быса            0.7354   |   
да                |   да              0.8700   |   
муж               |   муж             0.7696   |   
щас               |   щас             0.8363   |   
она               |   она             0.8290   |   
проблема          |   проблема        0.6674   |   
билайн            |   билайн          0.6373   |   
уже               |   уже             0.8241   |   
раньше            |   раньше          0.7333   |   
рам               

