In [None]:
# Hack to move to the parent directory
# and do it only once.

try:
    if In_PARENT_DIR:
        print("Already in parent directory.")
except NameError:
    %cd ..
    In_PARENT_DIR = True
    print("Moved to parent directory.")

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json

import torch

from feature_extraction.swipe_feature_extractors import (
    MultiFeatureExtractor,  
    TrajectoryFeatureExtractor,
    CoordinateFunctionFeatureExtractor
    )
from feature_extraction.key_weights_functions import (
    weights_function_v1,
    weights_function_v1_softmax,
    weights_function_sigmoid_normalized_v1,)
from feature_extraction.distance_getter import DistanceGetter
from feature_extraction.nearest_key_getter import NearestKeyGetter
from feature_extraction.key_weights_getter import KeyWeightsGetter
from feature_extraction.grid_lookup import GridLookup
from feature_extraction.normalizers import identity_function, MinMaxNormalizer, MeanStdNormalizer
from ns_tokenizers import KeyboardTokenizerv1, ALL_CYRILLIC_LETTERS_ALPHABET_ORD
from grid_processing_utils import get_grid

In [None]:
#### CONFIGURATION ####

GRIDS_PATH = "../data/data_preprocessed/gridname_to_grid.json"
GRID_NAME = "default"
TRAJECTORY_FEATURES_STATISTICS_PATH = "../data/data_preprocessed/trajectory_features_statistics.json"

#### Create all resources for feature extractors ####

grid = get_grid(grid_name=GRID_NAME, grids_path=GRIDS_PATH)
key_labels_of_interest = set(ALL_CYRILLIC_LETTERS_ALPHABET_ORD)
weights_function=weights_function_v1
tokenizer = KeyboardTokenizerv1()

distance_getter = DistanceGetter(grid, tokenizer, key_labels_of_interest)
nearest_key_getter = NearestKeyGetter(grid, tokenizer, key_labels_of_interest)
key_weights_getter = KeyWeightsGetter(
    grid, tokenizer, 
    weights_function=weights_function, 
    key_labels_of_interest=key_labels_of_interest)

with open(TRAJECTORY_FEATURES_STATISTICS_PATH, "r", encoding="utf-8") as f:
    trajectory_features_statistics = json.load(f)

In [None]:
def grid_lookup_maker(width, height):
    def grid_lookup_maker(value_fn):
        return GridLookup(grid_width=width, grid_height=height, value_fn=value_fn)
    return grid_lookup_maker

get_lookup_fn = grid_lookup_maker(grid["width"], grid["height"])

In [None]:
### Example data for feature extractors ###

x = torch.tensor([1, 100, 200, 305], dtype=torch.float32)
y = torch.tensor([10, 20, 30, 20], dtype=torch.float32)
t = torch.tensor([0, 46, 64, 100], dtype=torch.float32)

In [None]:
nearest_key_extractor = CoordinateFunctionFeatureExtractor(
    nearest_key_getter)

nearest_key_extractor_with_lookup = CoordinateFunctionFeatureExtractor(
    get_lookup_fn(nearest_key_getter), cast_dtype=torch.int32)

In [None]:
distances_extractor = CoordinateFunctionFeatureExtractor(
    distance_getter)

distances_extractor_with_lookup = CoordinateFunctionFeatureExtractor(
    get_lookup_fn(distance_getter), cast_dtype=torch.int32)

In [None]:
key_weights_extractor = CoordinateFunctionFeatureExtractor(
    key_weights_getter)

key_weights_extractor_with_lookup = CoordinateFunctionFeatureExtractor(
    get_lookup_fn(key_weights_getter), cast_dtype=torch.int32)

In [None]:
trajectory_feats_extractor = TrajectoryFeatureExtractor(
    include_dt=True,
    include_velocities=True,
    include_accelerations=True,
    x_normalizer=identity_function,
    y_normalizer=identity_function,
    dt_normalizer=MeanStdNormalizer(**trajectory_features_statistics["dt"]),
    velocity_x_normalizer=MeanStdNormalizer(**trajectory_features_statistics["velocity_x"]),
    velocity_y_normalizer=MeanStdNormalizer(**trajectory_features_statistics["velocity_y"]),
    acceleration_x_normalizer=MeanStdNormalizer(**trajectory_features_statistics["acceleration_x"]),
    acceleration_y_normalizer=MeanStdNormalizer(**trajectory_features_statistics["acceleration_y"]),
)

In [None]:
traj_feats_and_nearest_key_extractor = MultiFeatureExtractor(
    extractors=[trajectory_feats_extractor, nearest_key_extractor]
)

In [None]:
nearest_key_result = nearest_key_extractor(x, y, t)[0]

nearest_key_with_lookup_result = nearest_key_extractor_with_lookup(x, y, t)[0]

expected_nearest_key_result = torch.tensor(
    [
        tokenizer.get_token('й'),
        tokenizer.get_token('ц'),
        tokenizer.get_token('у'),
        tokenizer.get_token('к'),
    ],
    dtype=torch.int32
)

(
    torch.equal(
        nearest_key_result, expected_nearest_key_result
    ),
    torch.equal(
        nearest_key_with_lookup_result, expected_nearest_key_result
    )
)

In [None]:
distances_result = distances_extractor(x, y, t)[0]
distances_with_lookup_result = distances_extractor_with_lookup(x, y, t)[0]

(
    distances_result.shape == (x.shape[0], len(key_labels_of_interest)),
    distances_with_lookup_result.shape == (x.shape[0], len(key_labels_of_interest))
)

In [None]:
distances_result

In [None]:
distances_with_lookup_result

In [None]:
trajectory_feats_extractor_result = trajectory_feats_extractor(x, y, t)[0]
trajectory_feats_extractor_result

In [None]:
key_weights_result = key_weights_extractor(x, y, t)[0]
key_weights_with_lookup_result = key_weights_extractor_with_lookup(x, y, t)[0]

(
    key_weights_result.shape == (x.shape[0], len(key_labels_of_interest)),
    key_weights_with_lookup_result.shape == (x.shape[0], len(key_labels_of_interest))
)

In [None]:
key_weights_result

In [None]:
key_weights_with_lookup_result

In [None]:
traj_feats_and_nearest_key_extractor(x, y, t)

# Check batch is formed correctly with feature extractors

In [None]:
from torch.utils.data import DataLoader
from dataset import SwipeDataset, CollateFn

from ns_tokenizers import CharLevelTokenizerv2

In [None]:
word_tokenizer = CharLevelTokenizerv2('../data/data_preprocessed/voc.txt')

In [None]:
val_dataset = SwipeDataset(
    data_path='../data/data_preprocessed/valid.jsonl',
    word_tokenizer=word_tokenizer,
    grid_name_to_swipe_feature_extractor={
        'default': traj_feats_and_nearest_key_extractor, 'extra': traj_feats_and_nearest_key_extractor},
    total=10000
)

In [None]:
(encoder_in, decoder_in), decoder_out = val_dataset[0]

In [None]:
for el in encoder_in:
    print(el)

In [None]:
decoder_in

In [None]:
decoder_out

In [None]:
WORD_PAD_IDX = word_tokenizer.char_to_idx['<pad>']

collate_fn = CollateFn(
    word_pad_idx = WORD_PAD_IDX, batch_first = False)

In [None]:
NUM_WORKERS = 0
PERSISTENT_WORKERS = False
VAL_BATCH_SIZE = 4

val_loader = DataLoader(
    val_dataset, batch_size=VAL_BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS,
    collate_fn=collate_fn)

In [None]:
for ((encoder_in, decoder_in, swipe_mask, word_mask), decoder_out) in val_loader:
    tensors = [decoder_in, swipe_mask, word_mask, decoder_out]
    tensor_names = ["decoder_in", "swipe_mask", "word_mask", "decoder_out"]

    for i, swipe_feature in enumerate(encoder_in):
        tensors.append(swipe_feature)
        tensor_names.append(f"swipe_feature_{i}")

    for name, tensor in zip(tensor_names, tensors):
        print(f"{name}:")
        print(tensor)
        print("Shape:", tensor.shape)
        print("Type:", tensor.dtype)
        print()
    
    break


In [None]:
encoder_in[0][:, :, 2].T

In [None]:
encoder_in[0][:, :, 6].T