# Neural decoding

In [1]:
import pynwb

import numpy as np

from pynwb import NWBHDF5IO, NWBFile, TimeSeries

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import math
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

import bisect
import ssm

import autograd.numpy as np
import autograd.numpy.random as npr


from data_loaders import *

## 1. Target Position (Cartesian Coordinate)

In [None]:
# Classifier

from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score

from sklearn.neighbors import KNeighborsClassifier
from sklearn import svm
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB

def clf(method, feature, cluster_assignments):
    '''
    Do cross validation using the given data and classifer
    :param method: the abbreviation of a classifier
        should be one of knn, svm, rf, gbm, lr and nb
    :param feature: feature(input of the classifier)
    :param cluster_assignments: target id(expected output of the classifier)
    :return: scores
    '''
    if method == 'knn':
        clf = KNeighborsClassifier(n_neighbors=1)
    elif method == 'svm':
        clf = svm.SVC(kernel='linear')
    elif method == 'rf':
        clf = RandomForestClassifier(n_estimators=100, random_state=42)
    elif method == 'gbm':
        clf = GradientBoostingClassifier(n_estimators=100, learning_rate=0.1, random_state=42)
    elif method == 'lr':
        clf = LogisticRegression(random_state=42)
    elif method == 'nb':
        clf = GaussianNB()
    else:
        print('There is no such method')

    scores = cross_val_score(clf, feature, cluster_assignments, cv=15)
    return scores

In [None]:
def CLFPerFile(filename, num_clusters, acqRate, newRate):
    '''
    Perform classification on data extracted from an NWB file.
    :param filename: name of the NWB file
    :param num_clusters: number of clusters for k-means clustering
    :param acqRate: acquisition rate of data
    :param newRate: new rate of data
    :return: rObs: Mean classification accuracies for raw observations.
             ls: Mean classification accuracies for latent states.
             sObs: Mean classification accuracies for smoothed observations.
             ls2d: Mean classification accuracies for 2D projections of latent states.

    '''
    #Extract data from the file
    with NWBHDF5IO(filename , "r") as io:
        read_nwbfile = io.read()    
        trial_num = len(read_nwbfile.trials["id"].data)
        lastTrialId = getLastTrialId(read_nwbfile, acqRate=acqRate, newRate=newRate)
        cursor_position_list=CursorPositionMatrix_list(read_nwbfile, 
                                                       lastTrialId, acqRate=acqRate, newRate = newRate)
    last_cursor =[]
    for cursor_position in cursor_position_list:
        last_cursor.append(cursor_position[-1,:])
    stop_Indices=[]
    accumulated_idx=0
    for cursor_position in cursor_position_list:
        accumulated_idx += len(cursor_position)
        stop_Indices.append(accumulated_idx-1)       
    colors = ['blue', 'black','red','green','m','orange', 'gray','c']
    cluster_assignments, center_xs, center_ys = Cluster(last_cursor, 
                                                        num_clusters=num_clusters, acqRate=acqRate, vis = False)
    plt.figure()
    for i, center in enumerate(center_xs):
        plt.plot(center, center_ys[i], 'o', color=colors[i], label = f'center {i+1}')
    plt.title('Target Position', fontsize = 20)
    plt.legend()
    
    # Fit LDS
    num_permutation = 5
    original_input, original_spike, p_input_matrices, p_spike_matrices = getData_for_LDS(filename, 
                                                                                       num_permutation = num_permutation,
                                                                                       acqRate=acqRate, newRate=newRate, num_clusters = num_clusters)
    obs=original_spike
    inputs = original_input
    time_bins = obs.shape[0]
    obs_dim = obs.shape[1]
    input_dim = 4
    state_dim = 4
    
    lds_inp = ssm.LDS(obs_dim, state_dim, M=input_dim, emissions="poisson")
    elbos, q = lds_inp.fit((obs).astype(int), inputs=inputs, num_iters=20)
    state_means = q.mean_continuous_states[0]

    # fixed point
    A = lds_inp.dynamics.A
    b = lds_inp.dynamics.b
    B = lds_inp.dynamics.Vs
    FPs=[]
    for i in range(len(center_xs)):
        center = np.array([center_xs[i], center_ys[i]])
        inputs = np.concatenate([center, [0, 0]])
        inputs = inputs.reshape(-1, 1)
        x = np.linalg.solve(np.eye(A.shape[0])-A, (B @ inputs).reshape(state_dim,1) + b.reshape(state_dim,1))
        FPs.append(x.reshape(1,-1))

    FPs = np.array(FPs).reshape(len(center_xs), state_dim)
    pca = PCA(n_components=2)
    pca.fit(FPs)
    FP_2d = pca.transform(FPs)
    plt.figure()
    for i in range(len(FP_2d)):
        plt.scatter(FP_2d[i,0], FP_2d[i,1], s=200, marker = 'X', color = colors[i])
    plt.title('Fixed Points', fontsize=15)
    plt.xlabel('PC1', fontsize=15)
    plt.ylabel('PC2', fontsize=15)
    
    ls_2d = []
    for trialid, stop_idx in enumerate(stop_Indices):
        cps = state_means[stop_idx:stop_idx+1,:]
        points_2d = pca.transform(cps)
        plt.plot(points_2d[:,0], points_2d[:,1], 'o', color = colors[cluster_assignments[trialid]])
        ls_2d.append(points_2d[0])
    ls_2d = np.array(ls_2d)
    
    # Get all 4 features
    # Raw observations
    last_original_spike = np.array([original_spike[idx] for idx in stop_Indices])
    # Latent State
    last_latent_state=np.array([state_means[idx] for idx in stop_Indices])
    # Smoothed observations(inferred firing rate)
    smoothed_obs = lds_inp.smooth(state_means, (obs).astype(int))
    last_smooth_obs=np.array([smoothed_obs[idx] for idx in stop_Indices])
    # 2d proj of last latent states
    ls_2d

    # For each feature, use all the 6 methods
    methods = ['knn','svm', 'rf', 'gbm','lr','nb']
    rObs = np.zeros(len(methods))
    ls = np.zeros(len(methods))
    sObs = np.zeros(len(methods))
    ls2d = np.zeros(len(methods))
    for idx, method in enumerate(methods):
        rObs[idx] = np.mean(clf(method, last_original_spike, cluster_assignments))
        ls[idx] = np.mean(clf(method, last_latent_state, cluster_assignments))
        sObs[idx] = np.mean(clf(method, last_smooth_obs, cluster_assignments))
        ls2d[idx] = np.mean(clf(method, ls_2d, cluster_assignments))

    return rObs, ls, sObs, ls2d


def plotClfs(matrix, dataname):
    '''
    Plot cross validation scores across clssifiers
    :param matrix: array containing scores of each classifier
    :param dataname: the type of input data you use to perform classification
    '''
    fontsize=15
    means = np.mean(matrix, axis=0)
    std = np.std(matrix, axis=0)
    methods = ['knn','svm', 'rf', 'gbm','lr','nb']
    width = 0.6
    fig, ax = plt.subplots()
    for idx in range(matrix.shape[1]):
        ax.bar(methods[idx].upper(), means[idx], yerr=std[idx], width=width, edgecolor = 'black', capsize=15, color = 'gray')
        # ax.scatter([methods[idx].upper()]*len(matrix), matrix[:,idx], color='none', edgecolor='gray', marker='o', s=50, label='_nolegend_')
    plt.axhline(y=0.125, color='r', linestyle='--')
    ax.set_ylabel('Loss', fontsize=fontsize)
    ax.set_title('Cross Validation Scores Across Classifiers', fontsize=fontsize + 5)
    ax.tick_params(axis='x', labelsize=fontsize)
    ax.tick_params(axis='y', labelsize=fontsize)
    fig.savefig(f'CV Scores Across Classifiers using {dataname}.svg')

In [None]:
filenames = ["sub-monk-g_ses-session0.nwb",
             "sub-monk-g_ses-session1.nwb",
             "sub-monk-g_ses-session2.nwb",
             "sub-monk-g_ses-session3.nwb",
             "sub-monk-g_ses-session4.nwb",
             "sub-monk-j_ses-session0.nwb",
             "sub-monk-j_ses-session1.nwb",
             "sub-monk-j_ses-session2.nwb"]
robsMatrix=[]
lsMatrix = []
sobsMatrix = []
ls2dMatrix=[]
for filename in filenames:
    if 'g' in filename:
        acqRate = 60
        newRate = 20
        print('g file')
    else:
        acqRate = 200
        newRate = 40
        print('j file')
    if filename == "sub-monk-j_ses-session1.nwb":
        num_cluster = 6
        print('The one with 6 targets')
    else:
        num_cluster = 8
    rObs, ls, sObs, ls2d = CLFPerFile(filename, num_cluster, acqRate, newRate)
    robsMatrix.append(rObs)
    lsMatrix.append(ls)
    sobsMatrix.append(sObs)
    ls2dMatrix.append(ls2d)
    
robsMatrix = np.array(robsMatrix)
lsMatrix = np.array(lsMatrix)
sobsMatrix = np.array(sobsMatrix)
ls2dMatrix = np.array(ls2dMatrix)

In [None]:
plotClfs(robsMatrix, 'robs')
plotClfs(lsMatrix, 'ls')
plotClfs(sobsMatrix, 'sobs')
plotClfs(ls2dMatrix, 'ls2d')

In [None]:
def plotFeatures(method, robsMatrix, lsMatrix, sobsMatrix, ls2dMatrix):
    '''
    Plot cross validation scores across features
    :param method: classifer
    :param robsMatrix: mean classification accuracies for raw observations.
    :param lsMatrix: mean classification accuracies for latent states.
    :param sobsMatrix: mean classification accuracies for smoothed observations.
    :param ls2dMatrix: mean classification accuracies for 2D projections of latent states.
    '''
    method = method.lower()
    methods = ['knn','svm', 'rf', 'gbm','lr','nb']
    if method in methods:
        idx = methods.index(method)
        matrix = np.vstack((robsMatrix[:,idx], lsMatrix[:,idx], sobsMatrix[:,idx], ls2dMatrix[:,idx]))
        matrix = matrix.T
        fontsize=15
        means = np.mean(matrix, axis=0)
        std = np.std(matrix, axis=0)
        features = ['ROBS', 'LS', 'SOBS', 'LS2D']
        width = 0.6
        fig, ax = plt.subplots()
        for idx in range(matrix.shape[1]):
            ax.bar(features[idx].upper(), means[idx], yerr=std[idx], width=width, edgecolor = 'black', capsize=15, color = 'gray')
            # ax.scatter([features[idx].upper()]*len(matrix), matrix[:,idx], color='none', edgecolor='gray', marker='o', s=50, label='_nolegend_')
        # Draw a horizontal line at y = 0.125
        plt.axhline(y=0.125, color='r', linestyle='--')
        # plt.xticks(rotation=45)
        ax.set_ylabel('Loss', fontsize=fontsize)
        ax.set_title(f'Cross Validation Scores Across Features({method.upper()})', fontsize=fontsize + 5)
        ax.tick_params(axis='x', labelsize=fontsize)
        ax.tick_params(axis='y', labelsize=fontsize)
        plt.savefig(f'Cross Validation Scores Across Features({method.upper()}).svg')
    else:
        print('This method is not included')

In [None]:
methods = ['knn','svm', 'rf', 'gbm','lr','nb']
for method in methods:
    plotFeatures(method, robsMatrix, lsMatrix, sobsMatrix, ls2dMatrix)

## 2. Velocity ($v_x, v_y$)

In [None]:
from sklearn.model_selection import cross_val_score, GridSearchCV
from sklearn.linear_model import LinearRegression
from sklearn.kernel_ridge import KernelRidge
from sklearn.svm import SVR
from sklearn.datasets import make_regression
# filename = "sub-monk-g_ses-session0.nwb"
filename = "sub-monk-g_ses-session2.nwb"
# filename = "sub-monk-g_ses-session3.nwb"
# filename = "sub-monk-j_ses-session1.nwb"
num_permutation = 0
acqRate=60
newRate=20
num_clusters = 8
def LinearR(data, targets):
    '''
    Perform cross validation using Linear Regression and return averaged r2 score
    :param data: inputs
    :param targets: labels
    :return: r2_avg
    '''
    model = LinearRegression()
    r2s = cross_val_score(model, data, targets, cv=8, scoring='r2')
    print('Linear score', r2s)
    r2_avg = np.mean(r2s)
    return r2_avg

def kernelR(data, targets):
    '''
    Perform cross validation using Kernal Regression with optimal hyperparameters
    and return averaged r2 score
    :param data: inputs
    :param targets: labels
    :return: r2_avg
    '''
    param_grid = {
        'alpha': [0.1, 1, 10],
        'kernel': ['linear', 'poly', 'rbf']
    }
    grid_search = GridSearchCV(KernelRidge(), param_grid, cv=5, scoring='r2')
    grid_search.fit(data, targets)
    best_params = grid_search.best_params_
    
    kr = KernelRidge(**best_params)
    r2s = cross_val_score(kr, data, targets, cv=8, scoring='r2')
    print("Cross-Validation R^2 Scores:", r2s)

    print("Best Hyperparameters:", best_params)
    r2_avg = np.mean(r2s)
    return r2_avg

def svr(data, targets):
    '''
    Perform cross validation using Support Vector Machine with optimal hyperparameters
    and return averaged r2 score
    :param data: inputs
    :param targets: labels
    :return: r2_avg
    '''
    # Define hyperparameter grid
    param_grid = {
        'kernel': ['poly', 'rbf'],
        'C': [0.1, 1, 10],
        'gamma': [0.1, 1, 10]
    }
    grid_search = GridSearchCV(SVR(), param_grid, cv=5, scoring='r2')
    grid_search.fit(data, targets)
    best_params = grid_search.best_params_

    svr = SVR(**best_params)

    r2s = cross_val_score(svr, data, targets, cv=8, scoring='r2')
    print("Cross-Validation R^2 Scores:", r2s)
    print("Mean Cross-Validation R^2:", np.mean(r2s))

    print("Best Hyperparameters:", best_params)
    print("Best R^2 Score:", grid_search.best_score_)
    r2_avg = np.mean(r2s)
    return r2_avg

with NWBHDF5IO(filename , "r") as io:
    read_nwbfile = io.read()    
    trial_num = len(read_nwbfile.trials["id"].data)
    lastTrialId = getLastTrialId(read_nwbfile, acqRate=acqRate, newRate=newRate)
    cursor_position_list=CursorPositionMatrix_list(read_nwbfile, 
                                                   lastTrialId, acqRate=acqRate, newRate = newRate)
last_cursor =[]
for cursor_position in cursor_position_list:
    last_cursor.append(cursor_position[-1,:])
stop_Indices=[]

# Get velocity at the end of each trial
vs = []
for cursor_position in cursor_position_list:
    last_cp = np.array([cursor_position[-1,:], cursor_position[-4,:]])
    v = np.diff(last_cp, axis=0) * newRate / 3
    vs.append(v)
vs=np.array(vs).reshape(-1,2)

accumulated_idx=0
for cursor_position in cursor_position_list:
    accumulated_idx += len(cursor_position)
    stop_Indices.append(accumulated_idx-1)       
colors = ['blue', 'black','red','green','m','orange', 'gray','c']
cluster_assignments, center_xs, center_ys = Cluster(last_cursor, 
                                                    num_clusters=num_clusters, acqRate=acqRate, vis = False)
plt.figure()
for i, center in enumerate(center_xs):
    plt.plot(center, center_ys[i], 'o', color=colors[i], label = f'center {i+1}')
plt.title('Target Position', fontsize = 20)
plt.legend()

# Fit LDS
num_permutation = 5
original_input, original_spike, p_input_matrices, p_spike_matrices = getData_for_LDS(filename, 
                                                                                   num_permutation = num_permutation,
                                                                                   acqRate=acqRate, newRate=newRate, 
                                                                                     num_clusters = num_clusters)
obs=original_spike
inputs = original_input
time_bins = obs.shape[0]
obs_dim = obs.shape[1]
input_dim = 4
state_dim = 4

lds_inp = ssm.LDS(obs_dim, state_dim, M=input_dim, emissions="poisson")
elbos, q = lds_inp.fit((obs).astype(int), inputs=inputs, num_iters=20)
state_means = q.mean_continuous_states[0]

# fixed point
A = lds_inp.dynamics.A
b = lds_inp.dynamics.b
B = lds_inp.dynamics.Vs
FPs=[]

for i in range(len(center_xs)):
    center = np.array([center_xs[i], center_ys[i]])
    inputs = np.concatenate([center, [0, 0]])
    inputs = inputs.reshape(-1, 1)
    x = np.linalg.solve(np.eye(A.shape[0])-A, (B @ inputs).reshape(state_dim,1) + b.reshape(state_dim,1))
    FPs.append(x.reshape(1,-1))

FPs = np.array(FPs).reshape(len(center_xs), state_dim)
pca = PCA(n_components=2)
pca.fit(FPs)
FP_2d = pca.transform(FPs)
plt.figure()
for i in range(len(FP_2d)):
    plt.scatter(FP_2d[i,0], FP_2d[i,1], s=200, marker = 'X', color = colors[i])
plt.title('Fixed Points', fontsize=15)
plt.xlabel('PC1', fontsize=15)
plt.ylabel('PC2', fontsize=15)

ls_2d = []
for trialid, stop_idx in enumerate(stop_Indices):
    cps = state_means[stop_idx:stop_idx+1,:]
    points_2d = pca.transform(cps)
    plt.plot(points_2d[:,0], points_2d[:,1], 'o', color = colors[cluster_assignments[trialid]])
    ls_2d.append(points_2d[0])
ls_2d = np.array(ls_2d)

# Get all 4 features
# Raw observations
last_original_spike = np.array([original_spike[idx] for idx in stop_Indices])
# Latent State
last_latent_state=np.array([state_means[idx] for idx in stop_Indices])
# Smoothed observations(inferred firing rate)
smoothed_obs = lds_inp.smooth(state_means, (obs).astype(int))
last_smooth_obs=np.array([smoothed_obs[idx] for idx in stop_Indices])
# 2d proj of last latent states
ls_2d

x_robs_lr = LinearR(last_original_spike, vs[:,0])
x_ls_lr = LinearR(last_latent_state, vs[:,0])
x_sobs_lr = LinearR(last_smooth_obs, vs[:,0])
x_ls2d_lr = LinearR(ls_2d, vs[:,0])

y_robs_lr = LinearR(last_original_spike, vs[:,1])
y_ls_lr = LinearR(last_latent_state, vs[:,1])
y_sobs_lr = LinearR(last_smooth_obs, vs[:,1])
y_ls2d_lr = LinearR(ls_2d, vs[:,1])

robs_kr = kernelR(last_original_spike, vs)
ls_kr = kernelR(last_latent_state, vs)
sobs_kr = kernelR(last_smooth_obs, vs)
ls2d_kr = kernelR(ls_2d, vs)

In [None]:
labels = ['linearRx', 'linearRy', 'kernerlR']
plt.figure()
plt.bar(labels, [x_robs_lr,y_robs_lr, robs_kr])
plt.ylabel('R2')
print('R2 of robs_kl', robs_kr)

plt.figure()
plt.bar(labels, [x_ls_lr,y_ls_lr, ls_kr])
plt.ylabel('R2')

plt.figure()
plt.bar(labels, [x_sobs_lr,y_sobs_lr, sobs_kr])
plt.ylabel('R2')
print('R2 of sobs_kr', sobs_kr)

plt.figure()
plt.bar(labels, [x_ls2d_lr,y_ls2d_lr, ls2d_kr])
plt.ylabel('R2')