In [1]:
%load_ext autoreload
%autoreload 2

import sys; sys.path.append('..')

In [2]:
import torch

from feature_extraction.swipe_feature_extractors import (
    MultiFeatureExtractor,  
    TrajectoryFeatureExtractor,
    CoordinateFunctionFeatureExtractor
    )
from feature_extraction.key_weights_functions import (
    weights_function_v1,
    weights_function_v1_softmax,
    weights_function_sigmoid_normalized_v1,)
from feature_extraction.distance_getter import DistanceGetter
from feature_extraction.nearest_key_getter import NearestKeyGetter
from feature_extraction.key_weights_getter import KeyWeightsGetter
from feature_extraction.grid_lookup import GridLookup


from ns_tokenizers import KeyboardTokenizerv1, ALL_CYRILLIC_LETTERS_ALPHABET_ORD
from grid_processing_utils import get_grid

In [3]:
#### CONFIGURATION ####

GRIDS_PATH = "../../data/data_preprocessed/gridname_to_grid.json"
GRID_NAME = "default"


#### Create all resources for feature extractors ####

grid = get_grid(grid_name=GRID_NAME, grids_path=GRIDS_PATH)
key_labels_of_interest = set(ALL_CYRILLIC_LETTERS_ALPHABET_ORD)
weights_function=weights_function_v1
tokenizer = KeyboardTokenizerv1()

distance_getter = DistanceGetter(grid, tokenizer, key_labels_of_interest)
nearest_key_getter = NearestKeyGetter(grid, tokenizer, key_labels_of_interest)
key_weights_getter = KeyWeightsGetter(
    grid, tokenizer, 
    weights_function=weights_function, 
    key_labels_of_interest=key_labels_of_interest)


In [4]:
def grid_lookup_maker(width, height):
    def grid_lookup_maker(value_fn):
        return GridLookup(grid_width=width, grid_height=height, value_fn=value_fn)
    return grid_lookup_maker

get_lookup_fn = grid_lookup_maker(grid["width"], grid["height"])

In [5]:
### Example data for feature extractors ###

x = torch.tensor([1, 100, 200, 305], dtype=torch.float32)
y = torch.tensor([10, 20, 30, 20], dtype=torch.float32)
t = torch.tensor([0, 46, 64, 100], dtype=torch.float32)

In [6]:
nearest_key_extractor = CoordinateFunctionFeatureExtractor(
    nearest_key_getter)

nearest_key_extractor_with_lookup = CoordinateFunctionFeatureExtractor(
    get_lookup_fn(nearest_key_getter), cast_dtype=torch.int32)

In [7]:
distances_extractor = CoordinateFunctionFeatureExtractor(
    distance_getter)

distances_extractor_with_lookup = CoordinateFunctionFeatureExtractor(
    get_lookup_fn(distance_getter), cast_dtype=torch.int32)

In [8]:
key_weights_extractor = CoordinateFunctionFeatureExtractor(
    key_weights_getter)

key_weights_extractor_with_lookup = CoordinateFunctionFeatureExtractor(
    get_lookup_fn(key_weights_getter), cast_dtype=torch.int32)

tensor([False, False, False, False, False, False,  True, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False,  True, False, False,
        False, False, False])
tensor([[ 423.3158,  912.8090,  347.8965,  645.0940,  870.4832,  451.9627,
          965.3695,  938.0225,  672.0119,  104.4713,  356.5729,  776.4974,
          599.4683,  548.2739,  684.2501,  506.2828,  593.8487,  533.4876,
          749.2144,  262.6404,  250.9308, 1035.0966,  173.8397,  476.8021,
          742.2239,  840.0530,  286.8314,  829.7765, 1059.9420,  998.1534,
          433.0774]])
tensor([[ 4.6410, 10.0075,  3.8141,  7.0724,  9.5434,  4.9550, 10.5837, 10.2839,
          7.3675,  1.1454,  3.9092,  8.5130,  6.5722,  6.0109,  7.5017,  5.5506,
          6.5106,  5.8488,  8.2139,  2.8794,  2.7511, 11.3482,  1.9059,  5.2274,
          8.1373,  9.2098,  3.1446,  9.0972, 11.6206, 10.9431,  4.7480]])
91.21268442851873
tens

In [9]:
trajectory_feats_extractor = TrajectoryFeatureExtractor(
    include_dt=True,
    include_velocities=True,
    include_accelerations=True,
    coordinate_normalizer=lambda x, y: (x, y),
    dt_normalizer=lambda x: x,
    velocities_normalizer=lambda x, y: (x, y),
    accelerations_normalizer=lambda x, y: (x, y),
)

In [10]:
traj_feats_and_nearest_key_extractor = MultiFeatureExtractor(
    extractors=[trajectory_feats_extractor, nearest_key_extractor]
)

In [11]:
nearest_key_result = nearest_key_extractor(x, y, t)[0]

nearest_key_with_lookup_result = nearest_key_extractor_with_lookup(x, y, t)[0]

expected_nearest_key_result = torch.tensor(
    [
        [tokenizer.get_token('й')],
        [tokenizer.get_token('ц')],
        [tokenizer.get_token('у')],
        [tokenizer.get_token('к')],
    ],
    dtype=torch.int32
)

(
    torch.equal(
        nearest_key_result, expected_nearest_key_result
    ),
    torch.equal(
        nearest_key_with_lookup_result, expected_nearest_key_result
    )
)

(True, True)

In [12]:
distances_result = distances_extractor(x, y, t)[0]
distances_with_lookup_result = distances_extractor_with_lookup(x, y, t)[0]

(
    distances_result.shape == (x.shape[0], len(key_labels_of_interest)),
    distances_with_lookup_result.shape == (x.shape[0], len(key_labels_of_interest))
)

(True, True)

In [13]:
distances_result

tensor([[ 416.7592,  907.5683,  340.1779,  642.7521,  866.7479,  449.0504,
               inf,  961.9003,  936.0984,  665.2977,   95.2694,  353.1519,
          772.4353,  592.0897,  545.6961,  679.7810,  500.6179,  588.8601,
          525.3608,  743.0721,  258.3583,  240.9320, 1033.2589,  167.8876,
          467.8913,  740.0569,  838.0215,       inf,  277.7737,  824.1213,
         1056.6910,  993.2685,  423.4678],
        [ 332.9508,  814.5675,  269.0576,  543.2921,  768.9610,  349.9861,
               inf,  863.5961,  836.6040,  581.3777,   87.9446,  254.8809,
          675.4319,  514.2589,  446.3454,  584.0020,  410.3441,  495.0922,
          456.5183,  654.9979,  162.7882,  231.5734,  933.7800,   86.2569,
          412.1168,  640.5593,  738.5181,       inf,  230.9378,  733.1637,
          958.0381,  898.7827,  385.6890],
        [ 259.8774,  722.4405,  220.8438,  442.8614,  670.7317,  250.3003,
               inf,  764.6426,  736.1157,  502.4938,  162.7705,  157.2395,
          578.

In [14]:
distances_with_lookup_result

tensor([[ 416.7592,  907.5683,  340.1779,  642.7521,  866.7479,  449.0504,
               inf,  961.9003,  936.0984,  665.2977,   95.2694,  353.1519,
          772.4353,  592.0897,  545.6961,  679.7810,  500.6179,  588.8601,
          525.3608,  743.0721,  258.3583,  240.9320, 1033.2589,  167.8876,
          467.8913,  740.0569,  838.0215,       inf,  277.7737,  824.1213,
         1056.6910,  993.2685,  423.4678],
        [ 332.9508,  814.5675,  269.0576,  543.2921,  768.9610,  349.9861,
               inf,  863.5961,  836.6040,  581.3777,   87.9446,  254.8809,
          675.4319,  514.2589,  446.3454,  584.0020,  410.3441,  495.0922,
          456.5183,  654.9979,  162.7882,  231.5734,  933.7800,   86.2569,
          412.1168,  640.5593,  738.5181,       inf,  230.9378,  733.1637,
          958.0381,  898.7827,  385.6890],
        [ 259.8774,  722.4405,  220.8438,  442.8614,  670.7317,  250.3003,
               inf,  764.6426,  736.1157,  502.4938,  162.7705,  157.2395,
          578.

In [15]:
trajectory_feats_extractor_result = trajectory_feats_extractor(x, y, t)[0]
trajectory_feats_extractor_result

tensor([[ 1.0000e+00,  1.0000e+01,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 1.0000e+02,  2.0000e+01,  4.6000e+01,  3.1094e+00,  3.1250e-01,
          5.9317e-02,  0.0000e+00],
        [ 2.0000e+02,  3.0000e+01,  1.8000e+01,  3.7963e+00,  0.0000e+00,
         -5.7581e-02, -5.7870e-03],
        [ 3.0500e+02,  2.0000e+01,  3.6000e+01,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00]])

In [16]:
key_weights_result = key_weights_extractor(x, y, t)[0]
key_weights_with_lookup_result = key_weights_extractor_with_lookup(x, y, t)[0]

(
    key_weights_result.shape == (x.shape[0], len(key_labels_of_interest)),
    key_weights_with_lookup_result.shape == (x.shape[0], len(key_labels_of_interest))
)

tensor([False, False, False, False, False, False,  True, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False,  True, False, False,
        False, False, False])
tensor([[ 416.7592,  907.5683,  340.1779,  642.7521,  866.7479,  449.0504,
          961.9003,  936.0984,  665.2977,   95.2694,  353.1519,  772.4353,
          592.0897,  545.6961,  679.7810,  500.6179,  588.8601,  525.3608,
          743.0721,  258.3583,  240.9320, 1033.2589,  167.8876,  467.8913,
          740.0569,  838.0215,  277.7737,  824.1213, 1056.6910,  993.2685,
          423.4678],
        [ 332.9508,  814.5675,  269.0576,  543.2921,  768.9610,  349.9861,
          863.5961,  836.6040,  581.3777,   87.9446,  254.8809,  675.4319,
          514.2589,  446.3454,  584.0020,  410.3441,  495.0922,  456.5183,
          654.9979,  162.7882,  231.5734,  933.7800,   86.2569,  412.1168,
          640.5593,  738.5181,  230.9378

(True, True)

In [17]:
key_weights_result

tensor([[1.4424e-02, 9.0980e-07, 6.2205e-02, 1.6922e-04, 2.0361e-06, 7.6786e-03,
         0.0000e+00, 3.1139e-07, 5.1812e-07, 1.0846e-04, 8.9283e-01, 4.8840e-02,
         1.3094e-05, 4.5976e-04, 1.1477e-03, 8.1497e-05, 2.7891e-03, 4.9000e-04,
         1.7135e-03, 2.3374e-05, 2.5003e-01, 3.1982e-01, 7.6160e-08, 6.6527e-01,
         5.3070e-03, 2.4807e-05, 3.5891e-06, 0.0000e+00, 1.8518e-01, 4.7219e-06,
         4.7963e-08, 1.6767e-07, 1.2658e-02],
        [7.1062e-02, 5.7016e-06, 2.1255e-01, 1.2034e-03, 1.4024e-05, 5.1825e-02,
         0.0000e+00, 2.1667e-06, 3.6909e-06, 5.6792e-04, 9.0589e-01, 2.6311e-01,
         8.8800e-05, 2.1323e-03, 8.0963e-03, 5.3927e-04, 1.6338e-02, 3.1094e-03,
         6.6334e-03, 1.3290e-04, 6.8730e-01, 3.6126e-01, 5.4238e-07, 9.0869e-01,
         1.5785e-02, 1.7670e-04, 2.5572e-05, 0.0000e+00, 3.6416e-01, 2.8422e-05,
         3.3605e-07, 1.0820e-06, 2.6308e-02],
        [2.4445e-01, 3.5120e-05, 4.1141e-01, 8.6675e-03, 9.7429e-05, 2.8101e-01,
         0.0000e+

In [18]:
key_weights_with_lookup_result

tensor([[1.4424e-02, 9.0980e-07, 6.2205e-02, 1.6922e-04, 2.0361e-06, 7.6786e-03,
         0.0000e+00, 3.1139e-07, 5.1812e-07, 1.0846e-04, 8.9283e-01, 4.8840e-02,
         1.3094e-05, 4.5976e-04, 1.1477e-03, 8.1497e-05, 2.7891e-03, 4.9000e-04,
         1.7135e-03, 2.3374e-05, 2.5003e-01, 3.1982e-01, 7.6160e-08, 6.6527e-01,
         5.3070e-03, 2.4807e-05, 3.5891e-06, 0.0000e+00, 1.8518e-01, 4.7219e-06,
         4.7963e-08, 1.6767e-07, 1.2658e-02],
        [7.1062e-02, 5.7016e-06, 2.1255e-01, 1.2034e-03, 1.4024e-05, 5.1825e-02,
         0.0000e+00, 2.1667e-06, 3.6909e-06, 5.6792e-04, 9.0589e-01, 2.6311e-01,
         8.8800e-05, 2.1323e-03, 8.0963e-03, 5.3927e-04, 1.6338e-02, 3.1094e-03,
         6.6334e-03, 1.3290e-04, 6.8730e-01, 3.6126e-01, 5.4238e-07, 9.0869e-01,
         1.5785e-02, 1.7670e-04, 2.5572e-05, 0.0000e+00, 3.6416e-01, 2.8422e-05,
         3.3605e-07, 1.0820e-06, 2.6308e-02],
        [2.4445e-01, 3.5120e-05, 4.1141e-01, 8.6675e-03, 9.7429e-05, 2.8101e-01,
         0.0000e+

In [19]:
traj_feats_and_nearest_key_extractor(x, y, t)

[tensor([[ 1.0000e+00,  1.0000e+01,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 1.0000e+02,  2.0000e+01,  4.6000e+01,  3.1094e+00,  3.1250e-01,
           5.9317e-02,  0.0000e+00],
         [ 2.0000e+02,  3.0000e+01,  1.8000e+01,  3.7963e+00,  0.0000e+00,
          -5.7581e-02, -5.7870e-03],
         [ 3.0500e+02,  2.0000e+01,  3.6000e+01,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00]]),
 tensor([[10],
         [23],
         [20],
         [11]])]