In [None]:
def decode_and_score(bst, tc, pos):
    # access decoding accuracy on behavioral time scale 
    posteriors, lengths, mode_pth, mean_pth = nel.decoding.decode1D(bst,
                                                                    tc,
                                                                    xmin=0,
                                                                    xmax=np.nanmax(pos.data))
    actual_pos = pos(bst.bin_centers)
    slope, intercept, rvalue, pvalue, stderr = stats.linregress(actual_pos, mode_pth)
    mean_error = np.mean(np.abs(actual_pos - mode_pth))
    
    return rvalue,mean_error

session = 'LS19_S20170523165204'
session = 'LEM3116_S20180802100324'

maze_size_cm,pos,st_all = replay_run.get_base_data(data_path,spike_path,session)

# to make everything more simple, lets restrict to just the linear track
pos = pos[0]
st_all = st_all[0]
maze_size_cm = maze_size_cm[0]

# compute and smooth speed
speed1 = nel.utils.ddt_asa(pos, smooth=True, sigma=0.1, norm=True)

# find epochs where the animal ran > 4cm/sec
run_epochs = nel.utils.get_run_epochs(speed1, v1=4, v2=4)

# set up results
results = {}

# loop through each area seperately
areas = df_cell_class.area[df_cell_class.session == session] 
for current_area in pd.unique(areas):

    # subset units to current area
    st = st_all._unit_subset(np.where(areas==current_area)[0]+1)
    # reset unit ids like the other units never existed
    st.series_ids = np.arange(0,len(st.series_ids))+1

    # restrict spike trains to those epochs during which the animal was running
    st_run = st[run_epochs] 
    ds_run = 0.5 
    ds_50ms = 0.05
    # smooth and re-bin:
    #     sigma = 0.3 # 300 ms spike smoothing
    bst_run = st_run.bin(ds=ds_50ms).smooth(sigma=0.3 , inplace=True).rebin(w=ds_run/ds_50ms)

    sigma = 3 #0.2 # smoothing std dev in cm
    tc = nel.TuningCurve1D(bst=bst_run,
                           extern=pos,
                           n_extern=40,
                           extmin=0,
                           extmax=maze_size_cm,
                           sigma=sigma,
                           min_duration=0)

    # locate pyr cells that have at least 100 spikes and a peak rate at least 1 Hz
    peak_firing_rates = tc.max(axis=1)
    mean_firing_rates = tc.mean(axis=1)
    ratio = peak_firing_rates/mean_firing_rates

    temp_df = df_cell_class[(df_cell_class.session == session) & (df_cell_class.area == current_area)]
    unit_ids_to_keep = (np.where((temp_df.cell_type == "pyr") &
                                 (temp_df.n_spikes >=100) &
                                 (tc.ratemap.max(axis=1) >=1) &
                                 (ratio>=1.5))[0]+1).squeeze().tolist()

    
    if isinstance(unit_ids_to_keep, int):
        print('warning: only 1 unit')
        results[current_area] = {}
        continue
    elif len(unit_ids_to_keep) == 0:
        print('warning: no units')
        results[current_area] = {}
        continue

    sta_placecells = st._unit_subset(unit_ids_to_keep)
    tc = tc._unit_subset(unit_ids_to_keep)
    total_units = sta_placecells.n_active
    
    posteriors, lengths, mode_pth, mean_pth = nel.decoding.decode1D(bst_run.loc[:,unit_ids_to_keep],
                                                                tc,
                                                                xmin=0,
                                                                xmax=np.nanmax(pos.data))
    actual_pos = pos(bst_run.bin_centers)

    median_error = np.nanmedian(np.abs(actual_pos - mode_pth))
    print(median_error)

In [None]:
# tc = tc.reorder_units()
npl.set_palette('tab20b',sta_placecells.n_active)
# npl.set_palette(npl.colors.rainbow)
with npl.FigureManager(show=True, figsize=(8,8)) as (fig, ax):
    npl.utils.skip_if_no_output(fig)
    npl.plot_tuning_curves1D(tc.smooth(sigma=0), normalize=True, pad=.5)

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib

def decode_example_fig(bst, spiketrainarray, tc_placecells, pos, idx=None, title_str=None,vmax=.1):
    
    actual_pos = pos(bst.bin_centers)

    posteriors, lengths, mode_pth, mean_pth = nel.decoding.decode1D(bst,
                                                                    tc,
                                                                    xmin=0,
                                                                    xmax=np.nanmax(pos.data))
    
    with npl.FigureManager(show=True, figsize=(10,2.7)) as (fig, ax):

        
        st = copy.deepcopy(spiketrainarray)

        no = tc_placecells.get_peak_firing_order_ids()
        st.reorder_units_by_ids(no, inplace=True)

        st_cut = st[bst.support]
        st_cut._support = bst.support # hacky fix so that we can plot events out of order
        st_cut = nel.utils.collapse_time(st_cut)


        npl.imagesc(x=np.arange(posteriors.shape[1]), y=np.arange(121), data=posteriors, cmap=plt.cm.bone_r, ax=ax,vmax=vmax)
        ax.plot(actual_pos,color='teal',linewidth=1)#linestyle=':'

        divider = make_axes_locatable(ax)
        axRaster = divider.append_axes("top", size=1, pad=0)
        
        
#         colors = npl.set_palette(npl.colors.rainbow,sta_placecells.n_active)
        cmap = matplotlib.cm.get_cmap('tab20b',sta_placecells.n_active)
    
        for i,ids in enumerate(sta_placecells.series_ids):
            npl.rasterplot(st_cut[:,ids], vertstack=True, ax=axRaster, lh=.5,color=cmap(i))



#         npl.rasterplot(st_cut, vertstack=True, ax=axRaster, lh=.5,cmap=cmap)
        axRaster.set_xlim(st_cut.support.time.squeeze())

        npl.utils.no_xticks(axRaster)
        npl.utils.no_xticklabels(axRaster)
        npl.utils.no_yticklabels(axRaster)
        npl.utils.no_yticks(axRaster)
        ax.set_ylabel('position [cm]')
        ax.set_xlabel('time bins (20 ms)')

        npl.utils.clear_left_right(axRaster)
        npl.utils.clear_top_bottom(axRaster)
        return fig,ax,axRaster
    
# epochs = nel.EpochArray(np.array([100,200]))
    
fig,ax,axRaster = decode_example_fig(bst_run.loc[:,unit_ids_to_keep], sta_placecells, tc, pos, idx=None, title_str=None,vmax=.1)        

fig.savefig(os.path.join(fig_save_path,'decoding_example.svg'),dpi=300,bbox_inches='tight')

In [None]:
plt.figure(figsize=(18,3))
plt.imshow(posteriors,aspect='auto',origin='lower',vmax=.1,cmap=plt.cm.bone_r)

plt.colorbar()

plt.figure(figsize=(15,3))
# actual_pos = bst_run.loc[:,unit_ids_to_keep]
actual_pos = pos(bst_run.bin_centers)

plt.plot(np.round(actual_pos),color='grey')
plt.plot(np.round(mode_pth),color='black')
plt.axis('tight')

In [None]:
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(np.round(actual_pos),np.round(mode_pth),labels=np.arange(120),normalize='true')
plt.imshow(cm,cmap=plt.cm.bone_r,vmax=.8)
plt.colorbar()
plt.xlabel('True position')
plt.ylabel('Estimated position')

# plt.scatter(np.round(actual_pos),np.round(mode_pth))