In [1]:
import joblib
import numpy as np
import pandas as pd
import yaml
from brand import initializeRedisFromYAML
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import train_test_split

from utils import decode_field, get_lagged_features, smooth_data

In [2]:
# Connect to Redis
r = initializeRedisFromYAML('replayTest.yaml')

connecting to Redis using: replayTest.yaml
Redis Socket Path /var/run/redis.sock
Initialized Redis


In [3]:
with open('stream_spec.yaml', 'r') as f:
    stream_spec = yaml.safe_load(f)

In [4]:
# Load data from Redis
# taskInput
all_entries = r.xread({b'taskInput': 0})
stream_entries = all_entries[0][1]
stream_data = [entry[1] for entry in stream_entries]

task_input = pd.DataFrame(stream_data)
for field in ['timestamps', 'samples']:
    task_input[field] = task_input[field.encode()].apply(
        decode_field, stream='taskInput', field=field, stream_spec=stream_spec)
task_input = task_input.set_index('timestamps')

# thresholdCrossings
all_entries = r.xread({b'thresholdCrossings': 0})
stream_entries = all_entries[0][1]
stream_data = [entry[1] for entry in stream_entries]

threshold_crossings = pd.DataFrame(stream_data)
for field in ['timestamps', 'crossings']:
    threshold_crossings[field] = threshold_crossings[field.encode()].apply(
        decode_field,
        stream='thresholdCrossings',
        field=field,
        stream_spec=stream_spec)

# Separate channels into their own columns
tc_timestamps = threshold_crossings['timestamps'].values
crossings = np.stack(threshold_crossings['crossings'])
n_chans = crossings.shape[1]

channel_labels = [f'ch{i :03d}' for i in range(n_chans)]
tc_df = pd.DataFrame(crossings,
                     index=tc_timestamps + 13,
                     columns=channel_labels)

In [5]:
joined_df = task_input.join(tc_df, how='inner')
joined_df.index = pd.to_timedelta(joined_df.index / 30, unit='ms')

In [6]:
samples = np.stack(joined_df['samples'])
joined_df['touch'] = samples[:, 0]
joined_df['x'] = samples[:, 1]
joined_df['y'] = samples[:, 2]

In [7]:
bin_size_ms = 5  # ms
gauss_width_ms = None # or 20 ms, for smoothing
binned_data = joined_df.resample(f'{bin_size_ms :d}ms').sum()
binned_data

Unnamed: 0,ch000,ch001,ch002,ch003,ch004,ch005,ch006,ch007,ch008,ch009,...,ch089,ch090,ch091,ch092,ch093,ch094,ch095,touch,x,y
00:45:15.803000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1330.0,75.0,-39.0
00:45:15.808000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1329.0,70.0,-46.0
00:45:15.813000,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1323.0,62.0,-39.0
00:45:15.818000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1320.0,61.0,-37.0
00:45:15.823000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1324.0,63.0,-45.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
00:46:43.623000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1325.0,157.0,6351.0
00:46:43.628000,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1320.0,157.0,6529.0
00:46:43.633000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1319.0,172.0,6845.0
00:46:43.638000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1320.0,157.0,7403.0


In [8]:
if gauss_width_ms is None:
    neural_data = binned_data[channel_labels].values
else:
    neural_data = smooth_data(binned_data[channel_labels].values,
                            bin_size=bin_size_ms,
                            gauss_width=gauss_width_ms)

X = get_lagged_features(neural_data, n_history=50)
y = binned_data[['x', 'y']].values

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

with joblib.parallel_backend('loky'):
    mdl = RidgeCV(alphas=np.logspace(-1, 3, 20))
mdl.fit(X_train, y_train)

print(f'Best L2: {mdl.alpha_}')
print(f'Train R^2: {mdl.score(X_train, y_train)}')
print(f'Test R^2: {mdl.score(X_test, y_test)}')

Best L2: 615.8482110660261
Train R^2: 0.5902117270226008
Test R^2: 0.4305570490634626
