In [185]:
%load_ext autoreload
%autoreload 2

import os
import glob
import numpy as np
import pandas as pd
import seaborn as sns
from collections import defaultdict
from scipy.signal import savgol_filter
import matplotlib.pyplot as plt

from utils import ioutils
from utils import figutils
from utils import styleutils
from utils import phy_utils
from utils import psth_utils
from utils import sort_utils
import configs

sns.reset_defaults()
styleutils.update_mpl_params(fontsize=14, linewidth=1)
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [186]:
mouse = configs.NPIX1
mouse_path = os.path.join(configs.BASE_PATH, mouse.path)
dates = list(mouse.exp_params.keys())
trial_duration = mouse.duration

# constants
celltypes_names = [
    'ss',
    '_ss', # cliques
    'cs',
    'mli',
    '>40']
fs = 30000
ccg_win_size = 80 # ms
ccg_bin_size = 1 # ms
ccg_smooth_win = 3 # ms
default_refractory_period = 1 # ms

# psth constants
psthb_cs = 1000
psthb = 250
psthw = np.array([0, trial_duration * 1000])
xticks = np.array([5, 7, 9, 10])
xticklabels = ['O', '+2', '', 'W']
times_in_bins = np.array(xticks) / (psthb * 1e-3)
xlabel='Time (s)'
ylabel='FR (spk/s)'
stim_colors = ['darkgreen', 'green', 'red', 'magenta', 'turquoise']

base_folders = []
catgt_folders = []
phy_folders = []
for date in dates:
    try:
        base_folders.append(os.path.join(mouse_path, date))
        catgt_folder = glob.glob(os.path.join(mouse_path, date, 'catgt*'))[0]
        catgt_folders.append(catgt_folder)

        catgt_imec_folder = glob.glob(os.path.join(catgt_folder, '*imec0'))[0]
        data_folder = os.path.join(catgt_imec_folder, 'imec0_ks2')
        phy_folders.append(data_folder)
    except:
        pass

catgt_folders

['D:\\NPIX\\NPIX1\\2023.11.10\\catgt_2023_11_10_all_g0']

In [187]:
catgt_folder = catgt_folders[0]
base_folder = base_folders[0]
phy_folder = phy_folders[0]
behavior_folder = os.path.join(base_folder, 'behavior')
processed_folder = os.path.join(base_folder, 'processed')
fig_folder = os.path.join(base_folder, 'FIGURES')

In [188]:
package = ioutils.pload(os.path.join(processed_folder, 'package'))

channel_pos = np.load(os.path.join(phy_folder, 'channel_positions.npy'))
channel_map = np.load(os.path.join(phy_folder, 'channel_map.npy')).flatten()
assert np.array_equal(channel_map, np.arange(len(channel_map)))
metrics = pd.read_csv(os.path.join(phy_folder, 'metrics.csv'))
cluster_ids = metrics['cluster_id'].to_numpy()
cluster_depth = channel_pos[metrics['peak_channel'].to_numpy(), 1]

In [189]:
''' load data'''
# get timing of trials per stim condition
onset_files = glob.glob(os.path.join(catgt_folder, r'*xa_1*'))
offset_files = glob.glob(os.path.join(catgt_folder, r'*xia_1*'))
assert len(onset_files) == 1 and len(offset_files) == 1

onsets = np.loadtxt(onset_files[0])
offsets = np.loadtxt(offset_files[0])
# onsets = np.delete(onsets, 25) #npix1, 11/14
# offsets = np.delete(offsets, 25) #npix1, 11/14
# onsets = onsets[np.r_[True, np.diff(onsets) > 10]] #npix1, 11/15
assert len(onsets) == len(offsets)
n_trials_nidaq = len(onsets)

stimuli = list(mouse.stimuli_params.keys())
npy_files = sorted(glob.glob(os.path.join(behavior_folder, f'*.npy')))
n_trials_npy = len(npy_files)
assert n_trials_npy == n_trials_nidaq, print(n_trials_npy, n_trials_nidaq)

stim_epochs_dict = {}
for stim in stimuli:
    trial_ixs = np.where([stim in x for x in npy_files])[0]
    if len(trial_ixs):
        stim_onsets = onsets[trial_ixs]
        stim_offsets = offsets[trial_ixs]
        stim_epochs = np.array([stim_onsets, stim_offsets])
        stim_epochs_dict[stim] = stim_epochs.T[:, 0]
stim_names = [x for x in stim_epochs_dict.keys()]
stim_trial_times = [x for x in stim_epochs_dict.values()]

for k, v in stim_epochs_dict.items():
    print(f'{k}: {v.shape[0]}')

package = ioutils.pload(os.path.join(processed_folder, 'package'))
fs = package['fs']
matches = package['matches']
spike_trains_dict = {k: v/fs for k, v in package['spike_trains_dict_after_rp'].items()}

package.keys()

ISO: 20
PIN: 20
EUY: 20
HEP: 20
US: 20


dict_keys(['fs', 'cluster_ids', 'celltypes_parcellated_dict', 'matches', 'spike_trains_dict_after_rp', 'spike_amp_dict_after_rp', 'spike_trains_dict_before_rp', 'spike_amp_dict_before_rp'])

In [None]:
'''
TODOS:

- raster plot SS rate and CS occurrences

- analyze cross-correlation between SS and CS, compare vs shuffled
'''


In [379]:
raster_b = 25
filter_length = 0.5 # in seconds
behavior_package = ioutils.pload(os.path.join(processed_folder, 'behavior_package'))
behavior_stim = behavior_package['stim']
licks_per_stim = behavior_package['licks']


for m, cs_ss in enumerate(matches):
    cs, ss = cs_ss

    # SS / CS plot
    sf = os.path.join(fig_folder, 'raster_cs_ss')
    f, axs = figutils.pretty_fig(figsize=(20, 12), rows=1, cols=len(stim_names))
    for i, (k, event) in enumerate(stim_epochs_dict.items()):
        ax = axs[i]
        plt.sca(ax)

        dat_per_cell = []
        x_cs, y_cs, _, _ = phy_utils.get_processed_ifr(spike_trains_dict[cs], event, raster_b, psthw)
        x_ss, y_ss, _, _ = phy_utils.get_processed_ifr(spike_trains_dict[ss], event, raster_b, psthw)
        y_cs = y_cs * raster_b / 1e3 # convert form hertz to raw
        temp = []
        win = int(filter_length / (raster_b * 1e-3))
        for sig in y_ss:
            temp.append(savgol_filter(sig, polyorder=0, window_length=win))
        y_ss = np.array(temp)

        max = np.round(y_ss.max(), -1)
        figutils.imagePanel(
            ax=ax,
            dffs=y_ss,
            height_per_cell=0.5,
            vmin=0,
            vcenter=max//2,
            vmax=max,
            show=False,
            cmap='gray_r'
        )
        ax.imshow(
            np.ma.masked_where(y_cs == 0, y_cs),
            # y_cs,
            vmin=-1,
            vmax=1,
            cmap='bwr',
            interpolation='none',
            origin='upper',
        )
        plt.axis('tight')
        sns.despine(bottom=True, left=True, fig=f, ax=ax)
        if i == 0:
            plt.yticks(np.arange(len(y_cs)));
        else:
            plt.yticks([])
        # if i == len(stim_names)-1:
        #     plt.colorbar()
        plt.tight_layout()
    figutils.save_fig(sf, f'{m}_{cs}_{ss}', close=True, show=False)

    # SS / lick plot
    sf = os.path.join(fig_folder, 'raster_ss_lick')
    f, axs = figutils.pretty_fig(figsize=(20, 12), rows=1, cols=len(stim_names))
    for i, (k, event) in enumerate(stim_epochs_dict.items()):
        ax = axs[i]
        plt.sca(ax)

        dat_per_cell = []
        x_cs, y_cs, _, _ = phy_utils.get_processed_ifr(spike_trains_dict[cs], event, raster_b, psthw)
        x_ss, y_ss, _, _ = phy_utils.get_processed_ifr(spike_trains_dict[ss], event, raster_b, psthw)
        y_cs = y_cs * raster_b / 1e3 # convert form hertz to raw
        temp = []
        win = int(filter_length / (raster_b * 1e-3))
        for sig in y_ss:
            temp.append(savgol_filter(sig, polyorder=0, window_length=win))
        y_ss = np.array(temp)

        max = np.round(y_ss.max(), -1)
        figutils.imagePanel(
            ax=ax,
            dffs=y_ss,
            height_per_cell=0.5,
            vmin=0,
            vcenter=max//2,
            vmax=max,
            show=False,
            cmap='gray_r'
        )

        assert behavior_stim[i] == k
        data = licks_per_stim[i]
        ax.imshow(
            np.ma.masked_where(data < 0.9, data),
            vmin=-1,
            vmax=1,
            cmap='bwr',
            interpolation='none',
            origin='upper',
            alpha=0.5
        )
        plt.axis('tight')
        sns.despine(bottom=True, left=True, fig=f, ax=ax)
        if i == 0:
            plt.yticks(np.arange(len(y_cs)));
        else:
            plt.yticks([])
        plt.tight_layout()
    figutils.save_fig(sf, f'{m}_{cs}_{ss}', close=True, show=False)


    # CS / lick plot
    sf = os.path.join(fig_folder, 'raster_cs_lick')
    f, axs = figutils.pretty_fig(figsize=(20, 12), rows=1, cols=len(stim_names))
    for i, (k, event) in enumerate(stim_epochs_dict.items()):
        ax = axs[i]
        plt.sca(ax)

        dat_per_cell = []
        x_cs, y_cs, _, _ = phy_utils.get_processed_ifr(spike_trains_dict[cs], event, raster_b, psthw)
        x_ss, y_ss, _, _ = phy_utils.get_processed_ifr(spike_trains_dict[ss], event, raster_b, psthw)
        y_cs = y_cs * raster_b / 1e3 # convert form hertz to raw
        temp = []
        win = int(filter_length / (raster_b * 1e-3))
        for sig in y_ss:
            temp.append(savgol_filter(sig, polyorder=0, window_length=win))
        y_ss = np.array(temp)

        max = np.round(y_ss.max(), -1)

        figutils.imagePanel(
            ax=ax,
            dffs=y_cs,
            height_per_cell=0.5,
            vmin=0,
            vcenter=0.01,
            vmax=1,
            show=False,
            cmap='gray_r'
        )

        data = licks_per_stim[i]
        ax.imshow(
            np.ma.masked_where(data < 0.9, data),
            vmin=-1,
            vmax=1,
            cmap='bwr',
            interpolation='none',
            origin='upper',
            alpha=0.5
        )
        plt.axis('tight')
        sns.despine(bottom=True, left=True, fig=f, ax=ax)
        if i == 0:
            plt.yticks(np.arange(len(y_cs)));
        else:
            plt.yticks([])
        plt.tight_layout()
    figutils.save_fig(sf, f'{m}_{cs}_{ss}', close=True, show=False)

Figure saved at: D:\NPIX\NPIX1\2023.11.10\FIGURES\raster_cs_ss\0_437_259
Figure saved at: D:\NPIX\NPIX1\2023.11.10\FIGURES\raster_ss_lick\0_437_259
Figure saved at: D:\NPIX\NPIX1\2023.11.10\FIGURES\raster_cs_lick\0_437_259
Figure saved at: D:\NPIX\NPIX1\2023.11.10\FIGURES\raster_cs_ss\1_246_260
Figure saved at: D:\NPIX\NPIX1\2023.11.10\FIGURES\raster_ss_lick\1_246_260
Figure saved at: D:\NPIX\NPIX1\2023.11.10\FIGURES\raster_cs_lick\1_246_260
Figure saved at: D:\NPIX\NPIX1\2023.11.10\FIGURES\raster_cs_ss\2_445_444
Figure saved at: D:\NPIX\NPIX1\2023.11.10\FIGURES\raster_ss_lick\2_445_444
Figure saved at: D:\NPIX\NPIX1\2023.11.10\FIGURES\raster_cs_lick\2_445_444
Figure saved at: D:\NPIX\NPIX1\2023.11.10\FIGURES\raster_cs_ss\3_522_443
Figure saved at: D:\NPIX\NPIX1\2023.11.10\FIGURES\raster_ss_lick\3_522_443
Figure saved at: D:\NPIX\NPIX1\2023.11.10\FIGURES\raster_cs_lick\3_522_443


In [184]:
# make heatmaps
filter_window = int(0.5 / (psthb * 1e-3))
intervals = {'ss': 1, 'rest': 10}

sf = os.path.join(fig_folder, 'psth_heatmaps')
for celltype in ['ss', 'rest']:
    dat_per_stim_per_cell = []
    unit_ids = np.array(list(package['celltypes_parcellated_dict'][celltype].keys()))
    trains = [spike_trains_dict[unit_id] for unit_id in unit_ids]
    for k, event in stim_epochs_dict.items():
        dat_per_cell = []
        for train in trains:
            x, ys, y_p, y_p_var = phy_utils.get_processed_ifr(
                train, event, psthb, psthw,
                bsl_subtract=False,
                zscore=True,
                zscoretype='across')
            y_p = savgol_filter(y_p, window_length=filter_window, polyorder=0)
            dat_per_cell.append(y_p)
        dat_per_stim_per_cell.append(np.array(dat_per_cell))
    dat_per_stim_per_cell = np.array(dat_per_stim_per_cell)
    print(dat_per_stim_per_cell.shape)

    ''' Sort by depth + plot'''
    depth_dict = {k: v for k, v in zip(cluster_ids, cluster_depth)}
    depths = np.array([depth_dict[k] for k in unit_ids])
    sort_ixs = np.argsort(depths)[::-1] # from dorsal to ventral
    depths = depths[sort_ixs]
    unit_ids_sorted = unit_ids[sort_ixs]

    interval = intervals[celltype]
    yticks = np.arange(len(depths))[::interval]
    yticklabels = [f'{d}, {u}' for d, u in zip(depths[::interval].astype(int), unit_ids_sorted[::interval])]
    figutils.imagePanes(
        figsize=(8, 6),
        mats=dat_per_stim_per_cell[:, sort_ixs, :],
        titles=stim_names,
        vmin=-2.5,
        vmax=2.5,
        vcenter=0,
        axargs={'xticks': times_in_bins, 'xticklabels': xticklabels,
                'yticks': yticks, 'yticklabels': yticklabels
                }
    )
    plt.suptitle('Responses by depth')
    name = f'{celltype}_sort_by_depth'
    figutils.save_fig(sf, name, show=False)

    ''' Sort by responses + plot'''
    sort_times = [5, 10]
    sort_times_in_bins = np.array(sort_times) / (psthb * 1e-3)
    sort_times_in_bins = sort_times_in_bins.astype(int)
    sort_ixs = sort_utils.sortByOnset(
        dat_per_stim_per_cell[:2],
        on=sort_times_in_bins[0],
        off=sort_times_in_bins[1],
        thres=1,
        n_thres=-1
    )
    figutils.imagePanes(
        figsize=(8, 6),
        mats=dat_per_stim_per_cell[:, sort_ixs, :],
        titles=stim_names,
        vmin=-2.5,
        vmax=2.5,
        vcenter=0,
        axargs={'xticks': times_in_bins, 'xticklabels': xticklabels}
    )
    plt.suptitle('Responses by onset')
    name = f'{celltype}_sort_by_onset'
    figutils.save_fig(sf, name, show=False)

(5, 15, 60)
Figure saved at: D:\NPIX\NPIX1\2023.11.08\FIGURES\psth_heatmaps\ss_sort_by_depth
Figure saved at: D:\NPIX\NPIX1\2023.11.08\FIGURES\psth_heatmaps\ss_sort_by_onset
(5, 115, 60)
Figure saved at: D:\NPIX\NPIX1\2023.11.08\FIGURES\psth_heatmaps\rest_sort_by_depth
Figure saved at: D:\NPIX\NPIX1\2023.11.08\FIGURES\psth_heatmaps\rest_sort_by_onset


In [169]:
# specified celltypes
celltypes = ['ss', 'cs', 'mli']
psthbs = [psthb, psthb_cs, psthb]
for b, celltype in zip(psthbs, celltypes):
    save_folder = os.path.join(fig_folder, 'psth', celltype)
    os.makedirs(save_folder, exist_ok=True)

    unit_ids = list(package['celltypes_parcellated_dict'][celltype].keys())
    unit_names = list(package['celltypes_parcellated_dict'][celltype].values())
    for unit_id, unit_name in zip(unit_ids, unit_names):
        train = [spike_trains_dict[unit_id]]
        psth_dict = {}
        for e, event in enumerate(stim_trial_times):
            x, ys, y_p, y_p_var = phy_utils.get_processed_ifr(train, event, b, psthw)
            psth_dict[0, e] = [x, y_p, y_p_var]

        psth_utils.make_psth(psth_dict,
                             psthw,
                             [unit_id],
                             stim_names,
                             xticks=xticks,
                             xticklabels=xticklabels,
                             ylabel=ylabel,
                             xlabel=xlabel,
                             colors=stim_colors)
        figutils.save_fig(save_folder, figname=f'{unit_name}__{unit_id}', close=True, show=False)

Figure saved at: D:\NPIX\NPIX1\2023.11.15\FIGURES\psth\ss\1_ss__63
Figure saved at: D:\NPIX\NPIX1\2023.11.15\FIGURES\psth\ss\1_ss__69
Figure saved at: D:\NPIX\NPIX1\2023.11.15\FIGURES\psth\ss\1_ss__114
Figure saved at: D:\NPIX\NPIX1\2023.11.15\FIGURES\psth\ss\1_ss__115
Figure saved at: D:\NPIX\NPIX1\2023.11.15\FIGURES\psth\ss\2_ss__274
Figure saved at: D:\NPIX\NPIX1\2023.11.15\FIGURES\psth\ss\2_ss__288
Figure saved at: D:\NPIX\NPIX1\2023.11.15\FIGURES\psth\ss\2_ss__290
Figure saved at: D:\NPIX\NPIX1\2023.11.15\FIGURES\psth\cs\cs__257
Figure saved at: D:\NPIX\NPIX1\2023.11.15\FIGURES\psth\cs\cs__285
Figure saved at: D:\NPIX\NPIX1\2023.11.15\FIGURES\psth\mli\1_mli__119
Figure saved at: D:\NPIX\NPIX1\2023.11.15\FIGURES\psth\mli\1_mli__124


In [170]:
# SS cliques
psth_folder = os.path.join(fig_folder, 'psth', 'ss_cliques')
os.makedirs(psth_folder, exist_ok=True)

clique_dict = defaultdict(list)
for unit_id, clique in package['celltypes_parcellated_dict']['_ss'].items():
    clique_dict[clique].append(unit_id)

for clique_name, unit_ids in clique_dict.items():
    trains = [spike_trains_dict[x] for x in unit_ids]
    psth_dict = {}
    for t, train in enumerate(trains):
        for e, event in enumerate(stim_trial_times):
            x, ys, y_p, y_p_var = phy_utils.get_processed_ifr(train, event, psthb, psthw)
            psth_dict[t, e] = [x, y_p, y_p_var]

    psth_utils.make_psth(psth_dict,
                         psthw,
                         unit_ids,
                         stim_names,
                         xticks=xticks,
                         xticklabels=xticklabels,
                         ylabel=ylabel,
                         xlabel=xlabel,
                         colors=stim_colors)

    name = f'{clique_name}_{len(unit_ids)}'
    figutils.save_fig(psth_folder, figname=name, close=True, show=False)

Figure saved at: D:\NPIX\NPIX1\2023.11.15\FIGURES\psth\ss_cliques\1_ss_4
Figure saved at: D:\NPIX\NPIX1\2023.11.15\FIGURES\psth\ss_cliques\2_ss_3


In [171]:
# matches
psth_folder = os.path.join(fig_folder, 'psth', 'matches')
os.makedirs(psth_folder, exist_ok=True)

for match_ix in range(matches.shape[0]):
    unit_names = matches[match_ix]
    trains = [spike_trains_dict[x] for x in unit_names]

    psth_dict = {}
    for t, train in enumerate(trains):
        for e, event in enumerate(stim_trial_times):
            b = psthb_cs if t == 0 else psthb
            x, ys, y_p, y_p_var = phy_utils.get_processed_ifr(train, event, b, psthw)
            psth_dict[t, e] = [x, y_p, y_p_var]

    psth_utils.make_psth(psth_dict,
                         psthw,
                         unit_names,
                         stim_names,
                         xticks=xticks,
                         xticklabels=xticklabels,
                         ylabel=ylabel,
                         xlabel=xlabel,
                         colors=stim_colors)

    name = f'match_{match_ix}__cs_{unit_names[0]}_ss_{unit_names[1]}'
    figutils.save_fig(psth_folder, figname=name, close=True, show=False)

Figure saved at: D:\NPIX\NPIX1\2023.11.15\FIGURES\psth\matches\match_0__cs_257_ss_274
Figure saved at: D:\NPIX\NPIX1\2023.11.15\FIGURES\psth\matches\match_1__cs_285_ss_288


'\nTODOS:\n\n- raster plot SS rate and CS occurrences\n\n- do for all days\n'