In [None]:
from nlb_tools.nwb_interface import NWBDataset
import numpy as np
import matplotlib.pyplot as plt
from Area2_analysis.lr_funcs import nans
from pyglmnet import GLMCV, GLM
from scipy import stats
from Area2_analysis.glm_funcs import resample, mp_glm_pr2
import multiprocessing as mp

In [None]:
foldername = "~/area2_population_analysis/s1-kinematics/actpas_NWB/"
monkey = "Han_20171207"
filename = foldername + monkey + "_COactpas_TD.nwb"

dataset = NWBDataset(filename, split_heldout=False)
xy_vel = dataset.data['hand_vel'].to_numpy()
xy_acc = np.diff(xy_vel, axis = 0, prepend=[xy_vel[0]])
dataset.add_continuous_data(xy_acc,'hand_acc',chan_names = ['x','y'])

bin_width = dataset.bin_width
print(bin_width)

In [None]:
active_mask = (~dataset.trial_info.ctr_hold_bump) & (dataset.trial_info.split != 'none')
passive_mask = (dataset.trial_info.ctr_hold_bump) & (dataset.trial_info.split != 'none')
all_mask = (dataset.trial_info.split != 'none')

n_neurons = dataset.data.spikes.shape[1]
print(n_neurons,'neurons')

In [None]:
#make dictionary for trial condition (reaching directions) for Stratified CV
trial_mask = active_mask
active_trials_idx = np.array(dataset.trial_info.loc[trial_mask]['trial_id'])
active_n_trials = dataset.trial_info.loc[trial_mask].shape[0]
print(active_n_trials,'active trials')

trial_mask = passive_mask
passive_trials_idx = np.array(dataset.trial_info.loc[trial_mask]['trial_id'])
passive_n_trials = dataset.trial_info.loc[trial_mask].shape[0]
print(passive_n_trials,'passive trials')

active_cond_dir_idx = []
passive_cond_dir_idx = []
for direction in [0,45,90,135,180,225,270,315]:
    active_cond_dir_idx.append(np.where((dataset.trial_info['cond_dir'] == direction) & (dataset.trial_info['ctr_hold_bump'] == False) & \
           (dataset.trial_info['split'] != 'none'))[0])
    passive_cond_dir_idx.append(np.where((dataset.trial_info['bump_dir'] == direction) & (dataset.trial_info['ctr_hold_bump'] == True) & \
           (dataset.trial_info['split'] != 'none'))[0])


active_cond_dict = nans([active_n_trials])
i = 0
for idx in active_trials_idx:
    for cond in range(0,len(active_cond_dir_idx)):
        if idx in active_cond_dir_idx[cond]:
            active_cond_dict[i] = cond
            break
    i+=1
print(active_cond_dict)
print(len(active_cond_dict))

passive_cond_dict = nans([passive_n_trials])
i = 0
for idx in passive_trials_idx:
    for cond in range(0,len(passive_cond_dir_idx)):
        if idx in passive_cond_dir_idx[cond]:
            passive_cond_dict[i] = cond
            break
    i+=1
print(passive_cond_dict)
print(len(passive_cond_dict))

In [None]:
#GLM params
distr = 'poisson'
random_state = 0
score_metric = 'pseudo_R2'

#Data selection params
align_range = (0,120)
lag_range = np.arange(-200,201,20)
encoding_bin_size = int(20)

active_behav_df = dataset.make_trial_data(align_field='move_onset_time', align_range=align_range, ignored_trials=~active_mask)
active_spikes_resampled = resample(active_behav_df.spikes.to_numpy(),encoding_bin_size)*1000 
active_behav = np.concatenate((active_behav_df['hand_vel'].to_numpy(),active_behav_df['hand_acc'].to_numpy()),axis=1)
active_behav_resampled = resample(active_behav,encoding_bin_size)

passive_behav_df = dataset.make_trial_data(align_field='move_onset_time', align_range=align_range, ignored_trials=~passive_mask)
passive_spikes_resampled = resample(passive_behav_df.spikes.to_numpy(),encoding_bin_size)*1000
passive_behav = np.concatenate((passive_behav_df['hand_vel'].to_numpy(),passive_behav_df['hand_acc'].to_numpy()),axis=1)
passive_behav_resampled = resample(passive_behav,encoding_bin_size)

fr_thresh = 1 #in Hz
neuron_filter = np.logical_and(np.mean(active_spikes_resampled,axis = 0) > fr_thresh, np.mean(passive_spikes_resampled,axis = 0) > fr_thresh)
n_high_neurons = np.sum(neuron_filter)
print(n_high_neurons,'high fr neurons')
n_timepoints = int((align_range[1] - align_range[0])/encoding_bin_size)
print(n_timepoints,'timepoints')

#Active
X_reshaped = active_behav_resampled.reshape(active_n_trials, n_timepoints, -1)
pool = mp.Pool(mp.cpu_count())
active_pR2 = [pool.starmap(mp_glm_pr2, [(dataset, X_reshaped, active_mask, active_cond_dict, neuron_filter, encoding_bin_size, align_range, lag) for lag in lag_range])][0]
pool.close()

#Passive
X_reshaped = passive_behav_resampled.reshape(passive_n_trials, n_timepoints, -1)
pool = np.Pool(mp.cpu_count())
passive_pR2 = [pool.starmap(mp_glm_pr2, [(dataset, X_reshaped, passive_mask, passive_cond_dict, neuron_filter, encoding_bin_size, align_range, lag) for lag in lag_range])][0]
pool.close()

In [None]:
# np.savez(monkey+'_encoding_VA20_pR2', active_pR2 = active_pR2, passive_pR2 = passive_pR2) 
# np.savez(monkey+'_encoding_VA20_result', active_true = active_true, passive_true = passive_true, active_pred = active_pred, passive_pred = passive_pred)

In [None]:
plt.hist(np.mean(active_spikes_resampled,axis = 0),alpha  =0.5, label = 'Active')
plt.hist(np.mean(passive_spikes_resampled,axis = 0),alpha = 0.5,label = 'Passive')
plt.xlabel('Firing rates (Hz)')
plt.ylabel('Neuron count')
plt.legend()
plt.title('Distribution of trial-averaged firing rates')

### Plotting

In [None]:
# foldername = "~/area2_population_analysis/s1-kinematics/actpas_NWB/"
# monkey = "Han_20171207"
# with np.load(monkey+'_encoding_VA20_pR2.npz') as data:
#     active_pR2 = data['active_pR2']
#     passive_pR2 = data['passive_pR2']
# with np.load(monkey+'_encoding_VA20_result.npz') as data:
#     active_true = data['active_true']
#     passive_true = data['passive_true']
#     active_pred = data['active_pred']
#     passive_pred = data['passive_pred']

In [None]:
lag_range = np.arange(-200, 201, 20)
n_high_neurons = active_pR2.shape[1]
print(n_high_neurons,'neurons')

In [None]:
active_r2s = np.nanmax(active_pR2,axis = 0) #highest pr2 in active condition that encode all neurons
passive_r2s = np.nanmax(passive_pR2,axis = 0) #pasisve condition
plt.hist(active_r2s,alpha = 0.5,label = 'Active')
plt.hist(passive_r2s,alpha = 0.5,label = 'Passive')
plt.xlabel("Best pseudo-R2")
plt.ylabel("Neuron count")
plt.title('Distribution of best pR2')
plt.legend()

In [None]:
active_lags = lag_range[np.argmax(active_pR2,axis = 0)] #best time lags in active condition that encode all neurons
passive_lags = lag_range[np.argmax(passive_pR2,axis = 0)] #pasisve condition
plt.hist(active_lags,alpha = 0.5,label = 'Active')
plt.hist(passive_lags,alpha = 0.5,label = 'Passive')
plt.xlabel("Best time lag (ms)")
plt.ylabel("Neuron count")
plt.title('Distribution of best time lags')
plt.legend()
print([stats.mode(active_lags), stats.mode(passive_lags)])

In [None]:
# Add jitter to best time lags for visualization
active_lags_rand = np.add(active_lags, np.random.uniform(low=-5, high=5, size=(len(active_lags),)))
passive_lags_rand = np.add(passive_lags, np.random.uniform(low=-5, high=5, size=(len(passive_lags),)))

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
ax.set_xlim(-225,225)
ax.set_ylim(-225,225)
ax.set_ylabel("Passive best time lag (ms)")
ax.set_xlabel("Active best time lag (ms)")
ax.scatter(active_lags_rand, passive_lags_rand,color = 'k')
ax.scatter(np.mean(active_lags), np.mean(passive_lags), marker = 'X', color = 'blue')
ax.plot(ax.get_xlim(), ax.get_ylim(), ls="--", c="k")
print('Mean:',[np.mean(active_lags),np.mean(passive_lags)])
print('Median:',[np.median(active_lags),np.median(passive_lags)])

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
ax.set_xlim(0,0.8)
ax.set_ylim(0,0.8)
ax.set_ylabel("Passive best pR2")
ax.set_xlabel("Active best pR2")
ax.scatter(active_r2s, passive_r2s,color = 'k')
ax.scatter(np.mean(active_r2s), np.mean(passive_r2s), marker = 'X', color = 'blue')
ax.plot(ax.get_xlim(), ax.get_ylim(), ls="--", c="k")
print([np.mean(active_r2s),np.mean(passive_r2s)])

In [None]:
time_threshold_1 = 20
# time_threshold_2 = 200
pR2_threshold = 0.05
thresh_active_lags = []
thresh_passive_lags = []
thresh_active_r2s = []
thresh_passive_r2s = []
for i in range(n_high_neurons):
    time_thresh_idx1 = np.argwhere(lag_range==time_threshold_1)[0,0]+1
    # time_thresh_idx2 = np.argwhere(lag_range==time_threshold_2)[0,0]+1
    act = active_pR2[:,i]
    pas = passive_pR2[:,i]
    if np.argmax(act) >= time_thresh_idx1 and np.argmax(pas) >= time_thresh_idx1:
        if np.nanmax(act) >= pR2_threshold and np.nanmax(pas) >= pR2_threshold:
        # if np.nanmax(act)/np.nanmax(pas) > 1.4 and np.nanmax(act) >= pR2_threshold:
        # if np.nanmax(act)/np.nanmax(pas) < 0.7 and np.nanmax(pas) >= pR2_threshold:    
        # if 0.7 < np.nanmax(act)/np.nanmax(pas) < 1.4 and np.nanmax(act) >= pR2_threshold and np.nanmax(pas) >= pR2_threshold:
            thresh_active_r2s.append(np.nanmax(act))
            thresh_passive_r2s.append(np.nanmax(pas))        
            thresh_active_lags.append(lag_range[np.argmax(act)])
            thresh_passive_lags.append(lag_range[np.argmax(pas)])
print(len(thresh_active_lags),'neurons after thresholding')

In [None]:
thresh_active_lags_rand = np.add(thresh_active_lags, np.random.uniform(low=-5, high=5, size=(len(thresh_active_lags),)))
thresh_passive_lags_rand = np.add(thresh_passive_lags, np.random.uniform(low=-5, high=5, size=(len(thresh_passive_lags),)))

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
ax.set_xlim(15,225)
ax.set_ylim(15,225)
ax.set_ylabel("Passive best time lag (ms)")
ax.set_xlabel("Active best time lag (ms)")
ax.scatter(thresh_active_lags_rand, thresh_passive_lags_rand,color = 'k')
ax.scatter(np.mean(thresh_active_lags), np.mean(thresh_passive_lags), marker = 'X', color = 'blue')
ax.plot(ax.get_xlim(), ax.get_ylim(), ls="--", c="k")
print([np.mean(thresh_active_lags),np.mean(thresh_passive_lags)])
print([np.median(thresh_active_lags),np.median(thresh_passive_lags)])

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
ax.set_xlim(0,0.8)
ax.set_ylim(0,0.8)
ax.set_ylabel("Passive best pR2")
ax.set_xlabel("Active best pR2")
ax.scatter(thresh_active_r2s, thresh_passive_r2s,color = 'k')
ax.scatter(np.mean(thresh_active_r2s), np.mean(thresh_passive_r2s), marker = 'X', color = 'blue')
ax.plot(ax.get_xlim(), ax.get_ylim(), ls="--", c="k")
print([np.mean(thresh_active_r2s),np.mean(thresh_passive_r2s)])