In [1]:
%load_ext autoreload
%autoreload 2

In [20]:
import os
import json

import torch
from tqdm import tqdm

from model import get_m1_model
from word_generators import GreedyGenerator
from tokenizers import CharLevelTokenizerv1, KeyboardTokenizerv1
from dataset import NeuroSwipeDatasetv1

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
IN_KAGGLE = False

if IN_KAGGLE:
    DATA_ROOT = "/kaggle/input/yandex-cup-playground"
    MODELS_DIR = ""
else:
    DATA_ROOT = "../data/data_separated_grid"
    MODELS_DIR = "../data/trained_models/m1"

In [57]:
default_model_fname = "best_model__2023_11_03__16_34_37__0.02697_default_switch_1.pt"
extra_model_fname = "best_model__2023_11_05__07_55_13__0.02516_extra_switch_2__with_pad_cutting.pt"

grid_name_to_model = {
    "default": get_m1_model(os.path.join(MODELS_DIR, default_model_fname), device),
    "extra": get_m1_model(os.path.join(MODELS_DIR, extra_model_fname), device)
}

In [58]:
word_char_tokenizer = CharLevelTokenizerv1(os.path.join(DATA_ROOT, "voc.txt"))

In [59]:
grid_name_to_greedy_generator = {
    grid_name: GreedyGenerator(grid_name_to_model[grid_name], word_char_tokenizer, device)
    for grid_name in ("default", "extra")
}

In [60]:
grid_name_to_grid_path = os.path.join(DATA_ROOT, "gridname_to_grid.json")
with open(grid_name_to_grid_path, "r", encoding="utf-8") as f:
    grid_name_to_grid = json.load(f)

val_path = os.path.join(DATA_ROOT, f"valid__in_train_format.jsonl")

kb_tokenizer = KeyboardTokenizerv1()
word_tokenizer = CharLevelTokenizerv1(os.path.join(DATA_ROOT, "voc.txt"))


val_dataset = NeuroSwipeDatasetv1(
    data_path = val_path,
    gridname_to_grid = grid_name_to_grid,
    kb_tokenizer = kb_tokenizer,
    max_traj_len = 299,
    word_tokenizer = word_tokenizer,
    include_time = False,
    include_velocities = True,
    include_accelerations = True,
    has_target=True,
    has_one_grid_only=False,
    include_grid_name=True,
    total = 10_000
)



100%|██████████| 10000/10000 [00:02<00:00, 4418.88it/s]


In [61]:
def create_pred_list_greedy_witout_vocab_check(dataset,
                                               grid_name_to_greedy_generator):
    """
    Creates submission file generating words greedily.

    If prediction is not in the vocabulary 
    """

    preds = []

    for i, data in tqdm(enumerate(dataset), total=len(dataset)):
        try:
            (xyt, kb_tokens, _, traj_pad_mask, word_mask), target, grid_name = data
            pred = grid_name_to_greedy_generator[grid_name](xyt, kb_tokens, traj_pad_mask)
            pred = pred.removeprefix("<sos>") 
            target_len = torch.sum(~word_mask)
            target = word_char_tokenizer.decode(target[:target_len]).removesuffix('<pad>').removesuffix('<eos>')
            print(pred, target)
            preds.append(pred)
        except KeyboardInterrupt:
            print('Досрочно остановлено пользователем')
            break
    return preds

In [62]:
create_pred_list_greedy_witout_vocab_check(val_dataset, grid_name_to_greedy_generator)

  0%|          | 1/10000 [00:00<32:49,  5.08it/s]

на на


  0%|          | 2/10000 [00:00<1:06:56,  2.49it/s]

все все


  0%|          | 3/10000 [00:01<1:01:05,  2.73it/s]

этом этом


  0%|          | 4/10000 [00:01<1:03:28,  2.62it/s]

добрый добрый


  0%|          | 5/10000 [00:01<1:04:54,  2.57it/s]

девочка девочка


  0%|          | 6/10000 [00:02<1:06:44,  2.50it/s]

сказала сказала


  0%|          | 7/10000 [00:02<1:04:30,  2.58it/s]

скинь скинь


  0%|          | 8/10000 [00:02<1:00:24,  2.76it/s]

геев геев


  0%|          | 9/10000 [00:03<56:40,  2.94it/s]  

тобой тобой


  0%|          | 10/10000 [00:03<1:06:12,  2.51it/s]

быстра была


  0%|          | 12/10000 [00:04<48:38,  3.42it/s]  

есть есть
да да


  0%|          | 13/10000 [00:04<1:01:07,  2.72it/s]

муж муж
щас щас


  0%|          | 15/10000 [00:05<46:19,  3.59it/s]  

она она


  0%|          | 16/10000 [00:05<53:09,  3.13it/s]

проблема проблема


  0%|          | 17/10000 [00:05<55:10,  3.02it/s]

билайн билайн


  0%|          | 18/10000 [00:06<49:53,  3.33it/s]

уже уже


  0%|          | 20/10000 [00:06<45:00,  3.70it/s]

раньше раньше
рам рам


  0%|          | 21/10000 [00:06<40:46,  4.08it/s]

щас щас


  0%|          | 23/10000 [00:07<37:40,  4.41it/s]

купил купил
ты ты


  0%|          | 24/10000 [00:07<39:53,  4.17it/s]

зовут зовут


  0%|          | 25/10000 [00:07<43:52,  3.79it/s]

короче короче


  0%|          | 26/10000 [00:08<44:13,  3.76it/s]

лучше лучше


  0%|          | 27/10000 [00:08<49:03,  3.39it/s]

приедем приедем


  0%|          | 28/10000 [00:08<57:10,  2.91it/s]

размыто размыто


  0%|          | 30/10000 [00:09<52:56,  3.14it/s]  

давай давай
ты ты


  0%|          | 31/10000 [00:09<54:47,  3.03it/s]

отдать отдать


  0%|          | 33/10000 [00:10<45:56,  3.62it/s]

привет привет
не не


  0%|          | 34/10000 [00:10<39:37,  4.19it/s]

да да


  0%|          | 35/10000 [00:10<43:46,  3.79it/s]

будете будете


  0%|          | 36/10000 [00:11<45:38,  3.64it/s]

связи связи


  0%|          | 37/10000 [00:11<53:57,  3.08it/s]

колывань колывань


  0%|          | 38/10000 [00:11<50:37,  3.28it/s]

меня меня


  0%|          | 39/10000 [00:12<54:06,  3.07it/s]

напиши напиши


  0%|          | 40/10000 [00:12<50:07,  3.31it/s]

знаю знаю


  0%|          | 42/10000 [00:13<43:48,  3.79it/s]

мамой мамой
не не


  0%|          | 43/10000 [00:13<38:28,  4.31it/s]

ты ты


  0%|          | 45/10000 [00:13<50:22,  3.29it/s]

только только
они они


  0%|          | 47/10000 [00:14<44:00,  3.77it/s]

саминг свинг
спи спи


  0%|          | 48/10000 [00:15<58:45,  2.82it/s]

соскучилась соскучилась


  0%|          | 49/10000 [00:15<1:01:19,  2.70it/s]

целую целую


  0%|          | 50/10000 [00:15<56:57,  2.91it/s]  

что что


  1%|          | 51/10000 [00:16<1:00:00,  2.76it/s]

почему почему


  1%|          | 52/10000 [00:16<1:02:11,  2.67it/s]

шакалы шакалов
мне мне


  1%|          | 53/10000 [00:17<53:14,  3.11it/s]  

Досрочно остановлено пользователем





['на',
 'все',
 'этом',
 'добрый',
 'девочка',
 'сказала',
 'скинь',
 'геев',
 'тобой',
 'быстра',
 'есть',
 'да',
 'муж',
 'щас',
 'она',
 'проблема',
 'билайн',
 'уже',
 'раньше',
 'рам',
 'щас',
 'купил',
 'ты',
 'зовут',
 'короче',
 'лучше',
 'приедем',
 'размыто',
 'давай',
 'ты',
 'отдать',
 'привет',
 'не',
 'да',
 'будете',
 'связи',
 'колывань',
 'меня',
 'напиши',
 'знаю',
 'мамой',
 'не',
 'ты',
 'только',
 'они',
 'саминг',
 'спи',
 'соскучилась',
 'целую',
 'что',
 'почему',
 'шакалы',
 'мне']