In [1]:
%load_ext autoreload
%autoreload 2

import os
import glob
import sys
import scipy.stats
import matplotlib.pyplot as plt
import numpy as np
import phylib.io.model
import seaborn as sns

import configs
from utils import styleutils, ioutils, figutils

sns.reset_defaults()
styleutils.update_mpl_params(fontsize=14, linewidth=1)
%matplotlib inline

In [2]:
mouse = configs.NPIX1
mouse_path = os.path.join(configs.BASE_PATH, mouse.path)
dates = list(mouse.exp_params.keys())
trial_duration = mouse.duration

base_folders = []
catgt_folders = []
phy_folders = []
for date in dates:
    try:
        base_folders.append(os.path.join(mouse_path, date))
        catgt_folder = glob.glob(os.path.join(mouse_path, date, 'catgt*'))[0]
        catgt_folders.append(catgt_folder)

        catgt_imec_folder = glob.glob(os.path.join(catgt_folder, '*imec0'))[0]
        data_folder = os.path.join(catgt_imec_folder, 'imec0_ks2')
        base_folders.append(os.path.join(mouse_path, date))
        phy_folders.append(data_folder)
    except:
        pass

print(catgt_folders)

['D:\\NPIX\\NPIX1\\2023.11.12\\catgt_2023_11_12_all_g1']


In [3]:
# load data
catgt_folder = catgt_folders[0]
base_folder = base_folders[0]
phy_folder = phy_folders[0]
processed_folder = os.path.join(base_folder, 'processed')
fig_folder = os.path.join(base_folder, 'FIGURES')
print(phy_folder)

package = ioutils.pload(os.path.join(processed_folder, 'package'))
fs = package['fs']
spike_trains_dict = {k: v for k, v in package['spike_trains_dict_after_rp'].items()}
channel_pos = np.load(os.path.join(phy_folder, 'channel_positions.npy'))
channel_map = np.load(os.path.join(phy_folder, 'channel_map.npy')).flatten()
assert np.array_equal(channel_map, np.arange(len(channel_map)))

model = phylib.io.model.load_model(os.path.join(phy_folder, 'params.py'))

D:\NPIX\NPIX1\2023.11.12\catgt_2023_11_12_all_g1\2023_11_12_all_g1_imec0\imec0_ks2


In [4]:
'''Manual confirmation that refractory period requirement and other modifications makes sense

- plot CS waveforms before and after spike filtering
'''

n_waveforms_to_plot = 20
xscale = .3
yscale = .04
xg = 2

sf = os.path.join(fig_folder, 'controls', f'cs_waveforms')
os.makedirs(sf, exist_ok=True)

cs_units = package['celltypes_parcellated_dict']['cs']
for m, cluster_id in enumerate(cs_units):

    waveforms_before = model.get_cluster_spike_waveforms(cluster_id)
    test_mask = model.get_cluster_spikes(cluster_id)

    spikes_before = model.spike_samples[cluster_id == model.spike_clusters]
    spikes_after = spike_trains_dict[cluster_id]
    filter_mask = np.isin(spikes_before, spikes_after)
    waveforms_after = waveforms_before[filter_mask]

    channel_ids = model.get_cluster_channels(cluster_id)
    xy_locs = channel_pos[channel_ids]
    list_of_waveforms = [waveforms_before, waveforms_after]

    f, axs = figutils.pretty_fig(figsize=(16, 10), rows=2, cols=2, sharex=True, sharey=True)
    for ix, waveforms in enumerate(list_of_waveforms):
        for r, individual in enumerate([False, True]):
            n_spikes, n_samples, n_channels_loc = waveforms.shape
            plt.sca(axs[ix, r])
            for ch in range(n_channels_loc):
                pos = xy_locs[ch]

                if individual:
                    waveform_ixs = np.linspace(0, waveforms.shape[0] - 1, n_waveforms_to_plot).astype(int)
                    y = waveforms[waveform_ixs, :, ch].T
                    y = y * yscale + pos[1]
                    x = np.arange(len(y))
                    x = np.tile(x.reshape(-1, 1), reps=[1, y.shape[1]])
                    x = x * xscale + pos[0]
                    plt.plot(x, y, color='r', alpha=4/n_waveforms_to_plot)
                else:
                    y = waveforms[:, :, ch]
                    y = y * yscale + pos[1]
                    ym = np.mean(y, axis=0)
                    yse = np.std(y, axis=0)
                    x = np.arange(len(ym))
                    x = x * xscale + pos[0]

                    plt.plot(x, ym, color='r', alpha=1)
                    plt.fill_between(x, ym-yse, ym+yse, color='r', alpha=0.5)
                plt.text(pos[0]-5, pos[1], f'{channel_ids[ch]}')
            xlim = plt.xlim()
            plt.xlim([xlim[0]-xg, xlim[1]+xg])
            plt.axis('off')
            sns.despine(bottom=True, left=True)
    plt.suptitle(f'CS: {cluster_id}, Fraction filtered: {1 - filter_mask.sum()/len(filter_mask):0.3f}')
    plt.tight_layout()
    figutils.save_fig(sf, f'CS_{cluster_id}', show=False, close=True)

Figure saved at: D:\NPIX\NPIX1\2023.11.12\FIGURES\controls\cs_waveforms\CS_231
Figure saved at: D:\NPIX\NPIX1\2023.11.12\FIGURES\controls\cs_waveforms\CS_242
Figure saved at: D:\NPIX\NPIX1\2023.11.12\FIGURES\controls\cs_waveforms\CS_243
Figure saved at: D:\NPIX\NPIX1\2023.11.12\FIGURES\controls\cs_waveforms\CS_244
Figure saved at: D:\NPIX\NPIX1\2023.11.12\FIGURES\controls\cs_waveforms\CS_415
Figure saved at: D:\NPIX\NPIX1\2023.11.12\FIGURES\controls\cs_waveforms\CS_423
Figure saved at: D:\NPIX\NPIX1\2023.11.12\FIGURES\controls\cs_waveforms\CS_473
Figure saved at: D:\NPIX\NPIX1\2023.11.12\FIGURES\controls\cs_waveforms\CS_476
Figure saved at: D:\NPIX\NPIX1\2023.11.12\FIGURES\controls\cs_waveforms\CS_477
Figure saved at: D:\NPIX\NPIX1\2023.11.12\FIGURES\controls\cs_waveforms\CS_478


In [5]:
'''
Confirm that high correlations between CSs and SSs are not due to similarities between their waveforms

- plot waveforms nicely
- TODO: come up with some metric to assess similarity of waveforms vs similarity of CS / SS responses
'''

n_waveforms_to_plot = 20
xscale = .3
yscale = .04
xg = 2

sf = os.path.join(fig_folder, 'controls', f'match_waveforms')
os.makedirs(sf, exist_ok=True)

matches = package['matches']
for m, match in enumerate(matches):
    f, axs = figutils.pretty_fig(figsize=(16, 10), rows=2, cols=2, sharex=True, sharey=True)
    for ix, cluster_id in enumerate(match):
        waveforms = model.get_cluster_spike_waveforms(cluster_id)
        test_mask = model.get_cluster_spikes(cluster_id)

        spikes_before = model.spike_samples[cluster_id == model.spike_clusters]
        spikes_after = spike_trains_dict[cluster_id]
        filter_mask = np.isin(spikes_before, spikes_after)
        waveforms = waveforms[filter_mask]

        channel_ids = model.get_cluster_channels(cluster_id)
        n_spikes, n_samples, n_channels_loc = waveforms.shape
        xy_locs = channel_pos[channel_ids]
        color = 'r' if ix == 0 else 'k'

        for r, individual in enumerate([False, True]):
            plt.sca(axs[r, ix])
            for ch in range(n_channels_loc):
                pos = xy_locs[ch]

                if individual:
                    waveform_ixs = np.linspace(0, waveforms.shape[0] - 1, n_waveforms_to_plot).astype(int)
                    y = waveforms[waveform_ixs, :, ch].T
                    y = y * yscale + pos[1]
                    x = np.arange(len(y))
                    x = np.tile(x.reshape(-1, 1), reps=[1, y.shape[1]])
                    x = x * xscale + pos[0]
                    plt.plot(x, y, color=color, alpha=4/n_waveforms_to_plot)
                else:
                    y = waveforms[:, :, ch]
                    y = y * yscale + pos[1]
                    ym = np.mean(y, axis=0)
                    # yse = scipy.stats.sem(y, axis=0)
                    yse = np.std(y, axis=0)

                    x = np.arange(len(ym))
                    x = x * xscale + pos[0]

                    plt.plot(x, ym, color=color, alpha=1)
                    plt.fill_between(x, ym-yse, ym+yse, color=color, alpha=0.5)
                plt.text(pos[0]-5, pos[1], f'{channel_ids[ch]}')
            xlim = plt.xlim()
            plt.xlim([xlim[0]-xg, xlim[1]+xg])
            plt.axis('off')
            sns.despine(bottom=True, left=True)
    plt.suptitle(f'Match: {m}, CS: {match[0]}, SS: {match[1]}')
    plt.tight_layout()
    figutils.save_fig(sf, f'match_{m}_cs_{match[0]}_ss_{match[1]}', show=False, close=True)

Figure saved at: D:\NPIX\NPIX1\2023.11.12\FIGURES\controls\match_waveforms\match_0_cs_423_ss_250
Figure saved at: D:\NPIX\NPIX1\2023.11.12\FIGURES\controls\match_waveforms\match_1_cs_473_ss_426


In [6]:
# _find_best_channels(self, template, amplitude_threshold=None)
# spike_ids = self.get_cluster_spikes(cluster_id)
# return self._get_template_from_spikes(spike_ids).channel_ids
