In [12]:
import numpy as np
from scipy.optimize import curve_fit

wpm = 90
bg_min_samples = 50

In [13]:
# Helper function IQR average for time processing later
def get_iqr_avg(data):
    Q1 = np.percentile(data, 25)
    Q3 = np.percentile(data, 75)
    IQR = Q3-Q1

    lower_bound = Q1 - 1.5 * IQR
    upper_bound = Q3 + 1.5 * IQR

    new_data = [x for x in data if x >= lower_bound and x <= upper_bound]
        
    return sum(new_data)/len(new_data)

# because eval is so freaking slow
def str_to_tuple(s):
    return tuple(map(int,s.strip("()").split(", ")))

with open("bistrokes.tsv") as f2:
    bistroke_data = [(eval(a), b, *[s for x in c  if ((s := str_to_tuple(x))[0] >= wpm)]) for (a, b, *c) in (l.strip().split("\t") for l in f2) if (not any([c in "QWERTYUIOPASDFGHJKL:ZXCVBNM<>? " for c in b]))] # and not any([char in "" for char in b])
    bistroke_data = [bd for bd in bistroke_data if (len(bd)-2 >= bg_min_samples)]

bigram_to_freq = {}

with open("bigrams.txt") as f:
    for k, v in (l.split("\t") for l in f):
        bigram_to_freq[k] = int(v)

In [14]:
times = np.zeros(len(bistroke_data))
sfb = np.zeros(len(bistroke_data))
freqs = np.zeros(len(bistroke_data))
col = ["green" for _ in range(len(bistroke_data))]

for i, bs in enumerate(bistroke_data):
    ((ax, ay), (bx, by)), bigram, *bistroke_times = bs

    times[i] = get_iqr_avg([t[1] for t in bistroke_times])

    shb = ((ax//abs(ax)) == (bx//abs(bx)))
    scb = (ax==bx)
    sfb[i] = (scb or (shb and (abs(ax) in (1,2) and abs(bx) in (1,2))))

    if sfb[i]:
        col[i] = "red"

    freqs[i] = bigram_to_freq[bigram]

In [15]:
def get_time(features, p0, p1, p2, p3):
    freq, sfb = features

    freq_pen = (p0*np.log(freq+p1)+p2)

    return freq_pen*(sfb+p3)

bg_popt, bg_pcov = curve_fit(get_time, [freqs, sfb], times, method="trf", maxfev=750000) # "trf" p0=initial_guess

sum_of_squares = np.sum((times - np.mean(times))**2)

new_y = get_time([freqs, sfb], *bg_popt)
residuals = times-new_y
r2 = 1 - np.sum((residuals)**2)/sum_of_squares

print("R^2:", r2)
print("MAE:", np.mean(np.abs(residuals)))


R^2: 0.5645246941688662
MAE: 16.495250637296326


In [21]:
%matplotlib qt

import matplotlib.pyplot as plt

plt.figure()

#xx, yy, ll, fit_y, c = zip(*sorted([r for r in zip(freqs, times, bg_labels, new_y) if r[-1] != "blue"], key=lambda x: x[0], reverse=True))
xx, yy, fit_y, cc = zip(*sorted([r for r in zip(freqs, times, new_y, col)], key = lambda x: -x[0]))
# xx = list((range(len(xx))))
scatter = plt.scatter(xx, yy, s=50, c=cc)

plt.plot(xx, fit_y, c="black")
plt.xlabel("Number of Occurences")
plt.ylabel("Average Typing Time (Milliseconds)")
plt.xscale("log")

plt.show()