In [None]:
from nlb_tools.nwb_interface import NWBDataset
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rc('font', size=18)

In [None]:
foldername = "~/area2_population_analysis/s1-kinematics/actpas_NWB/"
monkey = "Han_20171207"
filename = foldername + monkey + "_COactpas_TD.nwb"
dataset = NWBDataset(filename, split_heldout=False)

In [None]:
# force - forces and torques applied to the manipulandum. x, y, and z are the forces applied in their respective directions, while xmo, ymo, and zmo are the torques in those directions
# hand_pos - x and y position of the hand, in cm
# hand_vel - x and y velocity of the hand, in cm/s
# joint_ang - angle of various monkey arm joints, in degrees
# joint_vel - velocity of various monkey arm joints, in degrees/s
# muscle_len - length of various monkey arm muscles, in m
# muscle_vel - velocity of various monkey arm muscles, in m/s
# spikes - spike trains binned at 1 ms

In [None]:
dataset.data.keys().unique(0)

In [None]:
dataset.data.spikes

In [None]:
# trial_id - a number assigned to each trial during loading
# start_time - time when the trial begins
# end_time - time when the trial ends
# result - whether a trial was rewarded (R), aborted (A), incomplete (I), or failed (F)
# ctr_hold - the amount of time the monkey had to hold in the center before the reach
# ctr_hold_bump - whether there was a bump during the center hold period
# bump_dir - direction of the bump, in degrees. If there was no bump, bump_dir is NaN
# target_on_time - time of target presentation
# target_dir - direction to the target, in degrees
# go_cue_time - time of go cue
# bump_time - time of center hold bump, if there was one
# move_time - time of movement onset, either in response to the go cue for active trials or the center hold bump for passive trials
# cond_dir - bump_dir for passive trials and target_dir for active trials. Though it is redundant with information already in other fields, it is provided for convenience when filtering trials

In [None]:
dataset.trial_info.keys().unique(0)

In [None]:
active_mask = (~dataset.trial_info.ctr_hold_bump) & (dataset.trial_info.split != 'none')
passive_mask = (dataset.trial_info.ctr_hold_bump) & (dataset.trial_info.split != 'none')


trial_mask = active_mask
n_trials = dataset.trial_info.loc[trial_mask].shape[0]
print(n_trials,'active trials')

trial_mask = passive_mask
n_trials = dataset.trial_info.loc[trial_mask].shape[0]
print(n_trials,'passive trials')

In [None]:
dataset.trial_info.cond_dir.unique()

In [None]:
dataset.trial_info.split.unique()

In [None]:
dataset.descriptions

In [None]:
dataset.bin_width

In [None]:
dataset.resample(5)
print(dataset.bin_width)

In [None]:
# All 16 conditions, in the format (ctr_hold_bump, cond_dir)
unique_conditions = [(False, 0.0), (False, 45.0), (False, 90.0), (False, 135.0),
                     (False, 180.0), (False, 225.0), (False, 270.0), (False, 315.0),
                     (True, 0.0), (True, 45.0), (True, 90.0), (True, 135.0),
                     (True, 180.0), (True, 225.0), (True, 270.0), (True, 315.0)]

# unique_conditions = [(False, 0.0), (False, 90.0), 
#                      (False, 180.0), (False, 270.0), 
#                      (True, 0.0),(True, 90.0), 
#                      (True, 180.0), (True, 270.0)]

# Initialize figure
fig = plt.figure(figsize=(8, 4))
ax_whole = fig.add_subplot(1, 2, 1)
ax_after = fig.add_subplot(1, 2, 2)
# ax_after = fig.add_subplot(1, 3, 3)

for cond in unique_conditions:
    # Filter out invalid trials (labeled 'none') and trials in other conditions
    cond_mask = (np.all(dataset.trial_info[['ctr_hold_bump', 'cond_dir']] == cond, axis=1)) & \
                (dataset.trial_info.split != 'none')
    # Extract relevant portion of selected trials
    cond_data = dataset.make_trial_data(align_field='move_onset_time', align_range=(0, 500), ignored_trials=~cond_mask)
    # Plot reaches on appropriate subplot
    for idx, trial in cond_data.groupby('trial_id'):
        if cond[0]==False: 
            ax_whole.plot(trial.hand_pos.x, trial.hand_pos.y, color=plt.cm.hsv(cond[1] / 360), linewidth=0.5)

    cond_data = dataset.make_trial_data(align_field='move_onset_time', align_range=(0, 120), ignored_trials=~cond_mask)
    # Plot reaches on appropriate subplot
    for idx, trial in cond_data.groupby('trial_id'):
        if cond[0]==False: 
            ax_after.plot(trial.hand_pos.x, trial.hand_pos.y, color=plt.cm.hsv(cond[1] / 360), linewidth=0.5)
            
# Add labels
ax_whole.set_title('0 to 500')
ax_after.set_title('0 to 120')
ax_whole.axis("off")
ax_after.axis("off")
figDir = '/Users/sherryan/area2_population_analysis/figures_plus/'
plt.tight_layout()
plt.savefig(figDir + monkey + '_active_traj.pdf',dpi = 'figure')
# plt.suptitle('Active Reach Trajectories')
plt.show()

In [None]:
# All 16 conditions, in the format (ctr_hold_bump, cond_dir)
unique_conditions = [(False, 0.0), (False, 45.0), (False, 90.0), (False, 135.0),
                     (False, 180.0), (False, 225.0), (False, 270.0), (False, 315.0),
                     (True, 0.0), (True, 45.0), (True, 90.0), (True, 135.0),
                     (True, 180.0), (True, 225.0), (True, 270.0), (True, 315.0)]

# unique_conditions = [(False, 0.0), (False, 90.0), 
#                      (False, 180.0), (False, 270.0), 
#                      (True, 0.0),(True, 90.0), 
#                      (True, 180.0), (True, 270.0)]

# Initialize figure
fig = plt.figure(figsize=(8, 4))
ax_whole = fig.add_subplot(1, 2, 1)
ax_after = fig.add_subplot(1, 2, 2)

for cond in unique_conditions:
    # Filter out invalid trials (labeled 'none') and trials in other conditions
    cond_mask = (np.all(dataset.trial_info[['ctr_hold_bump', 'cond_dir']] == cond, axis=1)) & \
                (dataset.trial_info.split != 'none')
    # Extract relevant portion of selected trials
    cond_data = dataset.make_trial_data(align_field='move_onset_time', align_range=(-100, 500), ignored_trials=~cond_mask)
    # Plot reaches on appropriate subplot
    for idx, trial in cond_data.groupby('trial_id'):
        if cond[0]==True: 
            ax_whole.plot(trial.hand_pos.x, trial.hand_pos.y, color=plt.cm.hsv(cond[1] / 360), linewidth=0.5)
 
    cond_data = dataset.make_trial_data(align_field='move_onset_time', align_range=(0, 120), ignored_trials=~cond_mask)
    # Plot reaches on appropriate subplot
    for idx, trial in cond_data.groupby('trial_id'):
        if cond[0]==True: 
            ax_after.plot(trial.hand_pos.x, trial.hand_pos.y, color=plt.cm.hsv(cond[1] / 360), linewidth=0.5)
            
            
# Add labels
ax_whole.set_title('0 to 500')
ax_after.set_title('0 to 120')
ax_whole.axis("off")
ax_after.axis("off")

figDir = '/Users/sherryan/area2_population_analysis/figures_plus/'
plt.tight_layout()
plt.savefig(figDir + monkey + '_passive_traj.pdf',dpi = 'figure')
# plt.suptitle('Passive Reach Trajectories')
plt.show()

## Smoothing effect

In [None]:
#number of neurons
n_neurons = dataset.data.spikes.shape[1]
print(n_neurons)

In [None]:
# number of useable active trials
# Filter out invalid trials (labeled 'none') and trials in other conditions
active_mask = (np.all(dataset.trial_info[['ctr_hold_bump']] == False, axis=1)) & \
    (dataset.trial_info.split != 'none')
active_data = dataset.make_trial_data(align_field='move_onset_time', align_range=(-100, 500), ignored_trials=~active_mask)
n_trials = active_data['trial_id'].nunique();
print(n_trials)

In [None]:
#number of time bins per trial = align_range / bin_size
for idx, trial in active_data.groupby('trial_id'):
    n_timepoints = trial.shape[0]
    break
print(n_timepoints)

In [None]:
dataset.smooth_spk(40, name='smth_40')
dataset.smooth_spk(25, name='smth_25')
dataset.smooth_spk(10, name='smth_10')

active_trials_idx_array = np.empty((n_trials,1))
raw_active_trials_array = np.empty((n_trials,n_timepoints,n_neurons))
active_trials_smth_10_array = np.empty((n_trials,n_timepoints,n_neurons))
active_trials_smth_25_array = np.empty((n_trials,n_timepoints,n_neurons))
active_trials_smth_40_array = np.empty((n_trials,n_timepoints,n_neurons))
i = 0
for idx, trial in active_data.groupby('trial_id'):
    active_trials_idx_array[i,:]=idx
    raw_active_trials_array[i,:,:]=trial.spikes.to_numpy()
    active_trials_smth_10_array[i,:,:]=trial.spikes_smth_10.to_numpy()
    active_trials_smth_25_array[i,:,:]=trial.spikes_smth_25.to_numpy()
    active_trials_smth_40_array[i,:,:]=trial.spikes_smth_40.to_numpy()
    i+=1

In [None]:
print(active_trials_smth_40_array.shape)
print(active_trials_idx_array.shape)

In [None]:
x_axis = np.arange(-500,700,5)
plt.plot(x_axis,raw_active_trials_array[0,:,10]/dataset.bin_width*1000,label = 'raw')
plt.plot(x_axis,active_trials_smth_10_array[0,:,10]/dataset.bin_width*1000 ,label='smooth_10')
plt.plot(x_axis,active_trials_smth_25_array[0,:,10]/dataset.bin_width*1000 ,label='smooth_25')
plt.plot(x_axis,active_trials_smth_40_array[0,:,10]/dataset.bin_width*1000 ,label='smooth_40')
plt.title('Spike smoothing (5ms bin)')
plt.ylabel('Firing rate (spks/s)')
plt.xlabel('time (ms)')
plt.legend()