# START OF TEST THAT PICKLED FULLY TRANSFORMED IS SAME AS SEPARATE INIT AND GETITEM TRANSFORM DATASET VERSION

In [None]:
from typing import List, Dict, Tuple, Optional, Set
import pickle
import json 

from tqdm import tqdm

from predict import get_grid_name_to_grid
from nearest_key_lookup import ExtendedNearestKeyLookup
from tokenizers import CharLevelTokenizerv2, KeyboardTokenizerv1, ALL_CYRILLIC_LETTERS_ALPHABET_ORD
from transforms import InitTransform, GetItemTransform
from dataset import CurveDataset

In [None]:
with open('../data/fully_transformed_datasets/valid_ft.pkl', 'rb') as f:
    valid_ds_pickled = pickle.load(f)

In [None]:
with open('../data/fully_transformed_datasets/valid_ft_4_workers.pkl', 'rb') as f:
    valid_ds_4_workers_pickled = pickle.load(f)

In [None]:
VAL_DS_PATH = '../data/data_separated_grid/valid__in_train_format.jsonl'


def get_gridname_to_out_of_bounds_coords_dict(
        data_paths: List[str], gridname_to_wh: dict,
        total: Optional[int] = None
        ) -> Dict[str, Set[Tuple[int, int]]]:
    """
    Returns a dictionary with grid names as keys and lists of out of bounds coordinates as values.
    """
    gname_to_out_of_bounds = {gname: set() for gname in gridname_to_wh.keys()}

    for data_path in data_paths:
        with open(data_path, "r", encoding="utf-8") as json_file:
            for line in tqdm(json_file, total=total):
                json_data = json.loads(line)
                curve = json_data['curve']
                grid_name = curve['grid_name']
                w, h = gridname_to_wh[grid_name]
                X, Y = curve['x'], curve['y']
                out_of_bounds = set((x, y) for x, y in zip(X, Y) 
                                    if x < 0 or x >= w or y < 0 or y >= h)
                gname_to_out_of_bounds[grid_name].update(out_of_bounds)
    return gname_to_out_of_bounds


kb_tokenizer = KeyboardTokenizerv1()
word_char_tokenizer = CharLevelTokenizerv2('../data/data_separated_grid/voc.txt')


gridname_to_grid  = get_grid_name_to_grid('../data/data_separated_grid/gridname_to_grid.json')

gname_to_wh = {
    gname: (grid['width'], grid['height']) 
    for gname, grid in gridname_to_grid.items()
}

print("Accumulating out-of-bounds coordinates...")
gname_to_out_of_bounds = get_gridname_to_out_of_bounds_coords_dict(
    [VAL_DS_PATH], gname_to_wh, total=6_000_000
)

print("Creating ExtendedNearestKeyLookups...")
nearest_key_candidates = ALL_CYRILLIC_LETTERS_ALPHABET_ORD
gridname_to_nkl = {
    gname: ExtendedNearestKeyLookup(grid, nearest_key_candidates, gname_to_out_of_bounds[gname])
    for gname, grid in gridname_to_grid.items()
}


init_transform = InitTransform(
    grid_name_to_nk_lookup=gridname_to_nkl,
    kb_tokenizer=kb_tokenizer,
)

get_item_transform = GetItemTransform(
    grid_name_to_wh=gname_to_wh,
    word_tokenizer=word_char_tokenizer,
    include_time=False,
    include_velocities=True,
    include_accelerations=True,
)


print("Creating datasets...")
val_ds = CurveDataset(
    data_path=VAL_DS_PATH,
    store_gnames = False,
    init_transform=init_transform,
    get_item_transform=get_item_transform,
    total = 9_416,
)

Accumulating out-of-bounds coordinates...


  0%|          | 10000/6000000 [00:00<07:05, 14068.35it/s]


Creating ExtendedNearestKeyLookups...
Creating datasets...


10000it [00:02, 4811.73it/s]                         


In [None]:
def test_datasets_equal(ds1: CurveDataset, ds2: CurveDataset) -> None:
    assert len(ds1) == len(ds2)
    for data_sample1, data_sample2 in tqdm(zip(ds1, ds2), total = len(ds1)):
        model_inputs1, target1 = data_sample1
        model_inputs2, target2 = data_sample2
        for mi1_tensor, mi2_tensor in zip(model_inputs1, model_inputs2):
            assert mi1_tensor.equal(mi2_tensor)
        assert target1.equal(target2)

In [None]:
test_datasets_equal(valid_ds_pickled, val_ds)
print("Test passed! All samples are equal!")

100%|██████████| 10000/10000 [00:17<00:00, 560.17it/s]

Test passed! All samples are equal!





In [None]:
test_datasets_equal(valid_ds_pickled, valid_ds_4_workers_pickled)
print("Test passed! All samples are equal!")

100%|██████████| 10000/10000 [00:00<00:00, 38911.34it/s]

Test passed! All samples are equal!





python ./src/create_and_save_fully_transformed_ds.py --jsonl_path ./data/data_separated_grid/train__default_only_no_errors__2023_10_31__03_26_16.jsonl --output_path ./train_default_grid_no_errors__2023_10_31_ft.pkl --vocab_path ./data/data_separated_grid/voc.txt --gridname_to_grid_path ./data/data_separated_grid/gridname_to_grid.json --n_workers 4 

# END OF TEST THAT PICKLED FULLY TRANSFORMED IS SAME AS SEPARATE INIT AND GETITEM TRANSFORM DATASET VERSION