In [1]:
import torch
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from KeyboardModel import KeyboardModel

In [2]:
sentence_h5_path = '../../Data/r01/recordings_01_typing_enrollment01_text_typing01.h5'

df = pd.read_hdf(sentence_h5_path)
df = df[['thumb3', 'thumb4', 'dev_pos', 'keys']]

In [66]:
def get_restored_sentence(df_):
    df_tmp = df_.copy()
    # restore typed sentence
    df_tmp['keys_prev'] = df_tmp['keys'].shift(1).bfill()
    pressed_keys_vectors = df_tmp.apply(lambda x: x['keys_prev'] if not np.array_equal(x['keys_prev'], x['keys']) else np.nan, axis=1).dropna() # get rows where pressing state was changed
    pressed_keys_indices = pressed_keys_vectors.apply(lambda x: np.argmax(x) if x.sum() > 0 else -1) # extract indices from one-hot vectors (pressing states)
    pressed_keys_asciis = pressed_keys_indices.values[np.where(pressed_keys_indices.values >= 0)[0]] + 65 # extract ascii chars
    pressed_keys_asciis[np.where(pressed_keys_asciis == 91)[0]] = 32 # fix space ascii char
    restored_sentence = ''.join(chr(i) for i in pressed_keys_asciis) # convert array of ascii chars to visible text
    restored_sentence = restored_sentence.lower()
    
    return restored_sentence

In [5]:
def get_minmax_values(df_):
    df_tmp = df_.copy()

    minmax_values = {'thumb3_rel': {'min': [np.inf, np.inf, np.inf], 'max': [-np.inf, -np.inf, -np.inf]},
                     'thumb4_rel': {'min': [np.inf, np.inf, np.inf], 'max': [-np.inf, -np.inf, -np.inf]}}

    # Get the relative positions
    df_tmp['thumb3_rel'] = df_tmp['thumb3'] - df_tmp['dev_pos']
    df_tmp['thumb4_rel'] = df_tmp['thumb4'] - df_tmp['dev_pos']

    # Create separation between x, y and z values
    for pos in range(3):
        df_tmp[f'thumb3_rel_{pos}'] = df_tmp['thumb3_rel'].apply(lambda x: x[pos])
        df_tmp[f'thumb4_rel_{pos}'] = df_tmp['thumb4_rel'].apply(lambda x: x[pos])

    # Calculate the min and max values
    for pos in range(3):
        minmax_values['thumb3_rel']['min'][pos] = min(minmax_values['thumb3_rel']['min'][pos], min(df_tmp[f'thumb3_rel_{pos}']))
        minmax_values['thumb3_rel']['max'][pos] = max(minmax_values['thumb3_rel']['max'][pos], max(df_tmp[f'thumb3_rel_{pos}']))

        minmax_values['thumb4_rel']['min'][pos] = min(minmax_values['thumb4_rel']['min'][pos], min(df_tmp[f'thumb4_rel_{pos}']))
        minmax_values['thumb4_rel']['max'][pos] = max(minmax_values['thumb4_rel']['max'][pos], max(df_tmp[f'thumb4_rel_{pos}']))

    return minmax_values

def calculate_avg_distance_diff_for_pressed_seq(df_):
    distance_diff_dict = {'thumb3_rel': [], 'thumb4_rel': []}

    curr_row = 0
    while curr_row < len(df_):
        if df_['press'].iloc[curr_row] == 1:
            start_point_thumb3_rel = np.array(df_['thumb3_rel'].iloc[curr_row])
            start_point_thumb4_rel = np.array(df_['thumb4_rel'].iloc[curr_row])

            while curr_row < len(df_) and df_['press'].iloc[curr_row] == 1:
                curr_row += 1

            end_point_thumb3_rel = np.array(df_['thumb3_rel'].iloc[curr_row - 1])
            end_point_thumb4_rel = np.array(df_['thumb4_rel'].iloc[curr_row - 1])

            distance_diff_dict['thumb3_rel'].append(np.linalg.norm(end_point_thumb3_rel - start_point_thumb3_rel))
            distance_diff_dict['thumb4_rel'].append(np.linalg.norm(end_point_thumb4_rel - start_point_thumb4_rel))

        else:
            curr_row += 1

    distance_diff_dict['thumb3_rel'] = np.mean(distance_diff_dict['thumb3_rel'])
    distance_diff_dict['thumb4_rel'] = np.mean(distance_diff_dict['thumb4_rel'])

    return distance_diff_dict


def is_rows_close(row1, row2, distance_diff_dict, threshold=0.5):
    thumb3_diff = np.linalg.norm(np.array(row1['thumb3_rel']) - np.array(row2['thumb3_rel']))
    thumb4_diff = np.linalg.norm(np.array(row1['thumb4_rel']) - np.array(row2['thumb4_rel']))

    if thumb3_diff < distance_diff_dict['thumb3_rel'] * threshold or thumb4_diff < distance_diff_dict['thumb4_rel'] * threshold:
        return True

    return False


def create_filtered_df(df_):
    avg_diff_dist_for_pressed_seq = calculate_avg_distance_diff_for_pressed_seq(df_)

    rows = []

    curr_row = 0
    while curr_row < len(df_):
        rows.append(df_.iloc[curr_row])

        next_row = curr_row + 1
        while next_row < len(df_) and is_rows_close(df_.iloc[curr_row], df_.iloc[next_row],  avg_diff_dist_for_pressed_seq):
            next_row += 1

        curr_row = next_row

    return pd.DataFrame(rows)

def normalize_and_filter_df(df_, to_filter=False):
    minmax_values = get_minmax_values(df_)

    # Get the relative positions
    df_['thumb3_rel'] = df_['thumb3'] - df_['dev_pos']
    df_['thumb4_rel'] = df_['thumb4'] - df_['dev_pos']

    # Normalize the relative positions
    df_['thumb3_rel'] = df_['thumb3_rel'].apply(
        lambda x: [(x[i] - minmax_values['thumb3_rel']['min'][i]) / (minmax_values['thumb3_rel']['max'][i] - minmax_values['thumb3_rel']['min'][i]) for i in range(3)])

    df_['thumb4_rel'] = df_['thumb4_rel'].apply(
        lambda x: [(x[i] - minmax_values['thumb4_rel']['min'][i]) / (minmax_values['thumb4_rel']['max'][i] - minmax_values['thumb4_rel']['min'][i]) for i in range(3)])

    df_['press'] = df_['keys'].apply(lambda x: 1 if x.any() else 0)

    df_['coords'] = df_[['thumb3_rel', 'thumb4_rel']].apply(lambda row: list(row['thumb3_rel']) + list(row['thumb4_rel']), axis=1)

    df_ = df_[['coords', 'thumb3_rel', 'thumb4_rel', 'press', 'keys']]
    
    if to_filter:
        df_ = create_filtered_df(df_)

    return df_

In [6]:
normalized_df = normalize_and_filter_df(df)
get_restored_sentence(normalized_df)

'IN THE PROCESS THEY DEMONSTRATED THET THERE IS STILL A LOT OF LIFE IN HONG KONG CINEMA'

In [7]:
filtered_df = normalize_and_filter_df(df, to_filter=True)
get_restored_sentence(filtered_df)

'IN THEPOCESTHEY DEMONSTRATEDTHET THEE IS TILLALO OF LIFE IN HONGKON INEMA'

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
keyboard_model = KeyboardModel(device=device)

In [9]:
normalized_df.head()

Unnamed: 0,coords,thumb3_rel,thumb4_rel,press,keys
0,"[0.273785481510138, 0.2301122053542864, 0.0311...","[0.273785481510138, 0.2301122053542864, 0.0311...","[0.3615041616703537, 0.1579752358411023, 0.003...",0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,"[0.27412782461549884, 0.22963383239731122, 0.0...","[0.27412782461549884, 0.22963383239731122, 0.0...","[0.36138120817745994, 0.15777145608280213, 0.0...",0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,"[0.27463423783243324, 0.22896647409568494, 0.0...","[0.27463423783243324, 0.22896647409568494, 0.0...","[0.3612606622535732, 0.1573589728755944, 0.002...",0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,"[0.274367760849553, 0.22853135414597187, 0.030...","[0.274367760849553, 0.22853135414597187, 0.030...","[0.3617540682712919, 0.15708120207817772, 0.00...",0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,"[0.274645093004048, 0.22795418624399272, 0.029...","[0.274645093004048, 0.22795418624399272, 0.029...","[0.36194800535651106, 0.1561074795908664, 0.00...",0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [10]:
def predict_number_of_typing_press(keyboard_model_: KeyboardModel, df_: pd.DataFrame, device_, prefix_percentage=0.5):
    total_number_of_presses_pred = 0
    total_number_of_presses_gt = 0
    
    seq_len = keyboard_model_.typing_seq_len
    
    curr_position = int(len(df_) * prefix_percentage)
    end_position = len(df_)
    
    while curr_position < end_position - seq_len:
        # Gather the input data
        input_data = df_.iloc[curr_position - seq_len:curr_position]['coords'].tolist()
        input_data = torch.tensor(input_data)
        input_data = input_data.type(torch.float32).to(device_)
        
        # Predict the next key
        outputs = keyboard_model_.typing_model(input_data)
        next_probability = outputs.squeeze(-1)[-1].detach().cpu()
        next_prediction = (torch.sigmoid(next_probability) > 0.5).item()
        
        if next_prediction:
            total_number_of_presses_pred += 1
            
        total_number_of_presses_gt += df_.iloc[curr_position]['press']
        
        curr_position += 1
    
    return {'pred': total_number_of_presses_pred, 'gt': total_number_of_presses_gt}

In [38]:
def predict_number_of_new_chars(keyboard_model_: KeyboardModel, df_: pd.DataFrame, device_, prefix_percentage=0.5):
    results_typing_press = predict_number_of_typing_press(keyboard_model_, df_, device_, prefix_percentage)
    typing_press_pred = results_typing_press['pred']
    typing_press_gt = results_typing_press['gt']
    
    # calculate ratio of typing presses / characters
    ratio = typing_press_gt / len(get_restored_sentence(df_))
    
    return {'typing_press_pd': typing_press_pred, 'typing_press_gt': typing_press_gt, 'ratio': ratio,
            'new_chars_pred': int(typing_press_pred / ratio), 'new_chars_gt': int(typing_press_gt / ratio)}

In [40]:
def get_tokenization_dict():
    char_to_idx = {'pad':0, 'eos': 1}
    for i in range(0, 255):
        char_to_idx[chr(i)] = i + 2
        
    idx_to_char = {v: k for k, v in char_to_idx.items()}
    
    return char_to_idx, idx_to_char

In [41]:
def tokenize_sentence(sentence, char_to_idx):
    return torch.tensor([char_to_idx[char] for char in sentence])

In [70]:
def restore_sentence(keyboard_model_: KeyboardModel, df_: pd.DataFrame, device_, to_concat_gt = False, prefix_percentage=0.5):
    sentence_prefix = get_restored_sentence(df_.iloc[:int(len(df_) * prefix_percentage)])
    restored_sentence = sentence_prefix

    sentence_all = get_restored_sentence(df_)
    curr_sentence_idx = len(sentence_prefix)

    char_to_idx, idx_to_char = get_tokenization_dict()
    sentence_prefix_tokenized = tokenize_sentence(sentence_prefix, char_to_idx)

    input_ids = sentence_prefix_tokenized.view(-1, 1).to(device_)
    results_num_new_chars = predict_number_of_new_chars(keyboard_model_, df_, device_, prefix_percentage)

    for i in range(results_num_new_chars['new_chars_gt']):
        outputs = keyboard_model_.char_model(input_ids)['logits']
        
        _, predicted = torch.max(outputs, -1)
        predicted = predicted.view(-1)[-1]
        
        restored_sentence += idx_to_char[predicted.item()]
        
        if to_concat_gt and curr_sentence_idx < len(sentence_all):
            next_char = char_to_idx[sentence_all[curr_sentence_idx]]
            next_char = torch.tensor([next_char]).view(-1, 1).to(device_)
            
        else:
            next_char = predicted.view(-1, 1).to(device_)
            
        input_ids = torch.cat((input_ids, next_char), dim=0)
        
        curr_sentence_idx += 1
        
    return {
        'sentence_prefix': sentence_prefix,
        'restored_sentence': restored_sentence,
        'sentence_all': sentence_all,
    }

In [73]:
def calculate_recording_files(data_dir):
    recordings = {}
    for folder_name in os.listdir(data_dir):
        if folder_name.startswith('r'):
            for file_name in os.listdir(f'{data_dir}/{folder_name}'):
                if file_name.endswith('.h5'):
                    if folder_name not in recordings:
                        recordings[folder_name] = []
                    recordings[folder_name].append(f'{data_dir}/{folder_name}/{file_name}')

    return recordings

In [74]:
calculate_recording_files('../../Data')

{'r24': ['../../Data/r24/recordings_24_typing_enrollment03_text_typing03.h5',
  '../../Data/r24/recordings_24_typing_enrollment02_text_typing02.h5',
  '../../Data/r24/recordings_24_typing_enrollment01_text_typing01.h5'],
 'r01': ['../../Data/r01/recordings_01_typing_enrollment01_text_typing01.h5',
  '../../Data/r01/recordings_01_typing_enrollment02_text_typing02.h5',
  '../../Data/r01/recordings_01_typing_enrollment03_text_typing03.h5'],
 'r23': ['../../Data/r23/recordings_23_typing_enrollment03_text_typing03.h5',
  '../../Data/r23/recordings_23_typing_enrollment02_text_typing02.h5',
  '../../Data/r23/recordings_23_typing_enrollment01_text_typing01.h5'],
 'r04': ['../../Data/r04/recordings_04_typing_enrollment03_text_typing03.h5',
  '../../Data/r04/recordings_04_typing_enrollment02_text_typing02.h5',
  '../../Data/r04/recordings_04_typing_enrollment01_text_typing01.h5'],
 'r25': ['../../Data/r25/recordings_25_typing_enrollment03_text_typing04.h5',
  '../../Data/r25/recordings_25_typing