In [1]:
%load_ext autoreload
%autoreload 2

import sys; sys.path.append('..')

In [None]:
import torch

from feature_extraction.swipe_feature_extractors import (
    MultiFeatureExtractor, 
    KeyWeightsFeatureExtractor, 
    NearestKeyFeatureExtractor, 
    TrajectoryFeatureExtractor, 
    KeyDistancesFeatureExtractor)
from feature_extraction.key_weights_functions import (
    weights_function_v1,
    weights_function_v1_softmax,
    weights_function_sigmoid_normalized_v1,)
from ns_tokenizers import KeyboardTokenizerv1, ALL_CYRILLIC_LETTERS_ALPHABET_ORD
from feature_extraction.nearest_key_lookup import NearestKeyLookup
from feature_extraction.distances_lookup import DistancesLookup
from grid_processing_utils import get_grid
from feature_extraction_utils import get_avg_half_key_diag

In [3]:
#### CONFIGURATION ####

GRIDS_PATH = "../../data/data_preprocessed/gridname_to_grid.json"
GRID_NAME = "default"


#### Create all resources for feature extractors ####

grid = get_grid(grid_name=GRID_NAME, grids_path=GRIDS_PATH)
nearest_key_candidates = kb_key_list = ALL_CYRILLIC_LETTERS_ALPHABET_ORD
weights_function=weights_function_v1
distances_lookup = DistancesLookup(grid=grid, kb_key_list=kb_key_list)
nearest_key_lookup = NearestKeyLookup(grid=grid, nearest_key_candidates=nearest_key_candidates)
keyboard_tokenizer = KeyboardTokenizerv1()
half_key_diag = get_avg_half_key_diag(grid=grid, allowed_keys=kb_key_list)

In [4]:
### Example data for feature extractors ###

x = torch.tensor([1, 100, 200, 305], dtype=torch.float32)
y = torch.tensor([10, 20, 30, 20], dtype=torch.float32)
t = torch.tensor([0, 46, 64, 100], dtype=torch.float32)

In [5]:
nearest_key_extractor = NearestKeyFeatureExtractor(
    nearest_key_lookup=nearest_key_lookup, 
    keyboard_tokenizer=keyboard_tokenizer)

distances_extractor = KeyDistancesFeatureExtractor(
    distances_lookup=distances_lookup)

trajectory_feats_extractor = TrajectoryFeatureExtractor(
    include_dt=True,
    include_velocities=True,
    include_accelerations=True,
    coordinate_normalizer=lambda x, y: (x, y),
    dt_normalizer=lambda x: x,
    velocities_normalizer=lambda x, y: (x, y),
    accelerations_normalizer=lambda x, y: (x, y),
)

key_weights_extractor = KeyWeightsFeatureExtractor(
    distances_lookup=distances_lookup,
    half_key_diag=half_key_diag,
    weights_function=weights_function)

traj_feats_and_nearest_key_extractor = MultiFeatureExtractor(
    extractors=[trajectory_feats_extractor, nearest_key_extractor]
)

traj_feats_and_nearest_key_extractor = MultiFeatureExtractor(
    extractors=[trajectory_feats_extractor, nearest_key_extractor]
)

In [17]:
nearest_key_result = nearest_key_extractor(x, y, t)[0]
expected_nearest_key_result = torch.tensor(
    [
        keyboard_tokenizer.get_token('й'),
        keyboard_tokenizer.get_token('ц'),
        keyboard_tokenizer.get_token('у'),
        keyboard_tokenizer.get_token('к'),
    ],
    dtype=torch.int32
)

torch.equal(
    nearest_key_result, expected_nearest_key_result
)



True

In [19]:
distances_result = distances_extractor(x, y, t)[0]
distances_result.shape == (x.shape[0], len(kb_key_list))



True

In [20]:
distances_result

tensor([[ 4.1635e+02,  9.0712e+02,  3.4018e+02,  6.4226e+02,  8.6675e+02,
          4.4856e+02, -1.0000e+00,  9.6142e+02,  9.3560e+02,  6.6530e+02,
          9.5016e+01,  3.5267e+02,  7.7196e+02,  5.9171e+02,  5.4520e+02,
          6.7931e+02,  5.0018e+02,  5.8840e+02,  5.2536e+02,  7.4265e+02,
          2.5836e+02,  2.4083e+02,  1.0333e+03,  1.6745e+02,  4.6762e+02,
          7.3956e+02,  8.3802e+02, -1.0000e+00,  2.7751e+02,  8.2412e+02,
          1.0567e+03,  9.9281e+02,  4.2347e+02],
        [ 3.3258e+02,  8.1413e+02,  2.6906e+02,  5.4280e+02,  7.6896e+02,
          3.4950e+02, -1.0000e+00,  8.6311e+02,  8.3611e+02,  5.8138e+02,
          8.8233e+01,  2.5440e+02,  6.7496e+02,  5.1392e+02,  4.4585e+02,
          5.8354e+02,  4.0993e+02,  4.9465e+02,  4.5652e+02,  6.5459e+02,
          1.6279e+02,  2.3168e+02,  9.3378e+02,  8.5983e+01,  4.1192e+02,
          6.4006e+02,  7.3852e+02, -1.0000e+00,  2.3084e+02,  7.3316e+02,
          9.5804e+02,  8.9833e+02,  3.8569e+02],
        [ 2.59

In [21]:
trajectory_feats_extractor_result = trajectory_feats_extractor(x, y, t)[0]
trajectory_feats_extractor_result

tensor([[ 1.0000e+00,  1.0000e+01,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 1.0000e+02,  2.0000e+01,  4.6000e+01,  3.1094e+00,  3.1250e-01,
          5.9317e-02,  0.0000e+00],
        [ 2.0000e+02,  3.0000e+01,  1.8000e+01,  3.7963e+00,  0.0000e+00,
         -5.7581e-02, -5.7870e-03],
        [ 3.0500e+02,  2.0000e+01,  3.6000e+01,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00]])

In [22]:
key_weights_result = key_weights_extractor(x, y, t)[0]
key_weights_result.shape == (x.shape[0], len(kb_key_list))



True

In [23]:
key_weights_result

tensor([[1.4540e-02, 9.1794e-07, 6.2205e-02, 1.7089e-04, 2.0361e-06, 7.7529e-03,
         0.0000e+00, 3.1438e-07, 5.2324e-07, 1.0846e-04, 8.9330e-01, 4.9288e-02,
         1.3218e-05, 4.6318e-04, 1.1590e-03, 8.2254e-05, 2.8134e-03, 4.9444e-04,
         1.7135e-03, 2.3571e-05, 2.5003e-01, 3.2025e-01, 7.6160e-08, 6.6719e-01,
         5.3358e-03, 2.5052e-05, 3.5891e-06, 0.0000e+00, 1.8597e-01, 4.7219e-06,
         4.7963e-08, 1.6920e-07, 1.2658e-02],
        [7.1542e-02, 5.7516e-06, 2.1255e-01, 1.2152e-03, 1.4024e-05, 5.2302e-02,
         0.0000e+00, 2.1875e-06, 3.7274e-06, 5.6792e-04, 9.0541e-01, 2.6495e-01,
         8.9629e-05, 2.1465e-03, 8.1748e-03, 5.4420e-04, 1.6471e-02, 3.1368e-03,
         6.6334e-03, 1.3397e-04, 6.8730e-01, 3.6076e-01, 5.4238e-07, 9.0914e-01,
         1.5845e-02, 1.7844e-04, 2.5572e-05, 0.0000e+00, 3.6463e-01, 2.8422e-05,
         3.3605e-07, 1.0918e-06, 2.6308e-02],
        [2.4546e-01, 3.5419e-05, 4.1141e-01, 8.7518e-03, 9.7429e-05, 2.8295e-01,
         0.0000e+

In [24]:
traj_feats_and_nearest_key_extractor(x, y, t)



[tensor([[ 1.0000e+00,  1.0000e+01,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 1.0000e+02,  2.0000e+01,  4.6000e+01,  3.1094e+00,  3.1250e-01,
           5.9317e-02,  0.0000e+00],
         [ 2.0000e+02,  3.0000e+01,  1.8000e+01,  3.7963e+00,  0.0000e+00,
          -5.7581e-02, -5.7870e-03],
         [ 3.0500e+02,  2.0000e+01,  3.6000e+01,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00]]),
 tensor([10, 23, 20, 11], dtype=torch.int32)]