# Pre-compute all quality metrics

* author: steeve.laquitaine@epfl.ch

In [1]:
%load_ext autoreload
%autoreload 2
import os
import spikeinterface as si
import spikeinterface.core.template_tools as ttools
from spikeinterface import comparison
from spikeinterface.qualitymetrics import compute_quality_metrics as qm
from spikeinterface import qualitymetrics
import pandas as pd
from cebra import CEBRA
import cebra
import torch
import numpy as np
from matplotlib import pyplot as plt
import sklearn
import seaborn as sns
from sklearn import metrics
import cebra.models
from spikeinterface.postprocessing import compute_principal_components


# set project path
proj_path = "/gpfs/bbp.cscs.ch/project/proj85/home/laquitai/spikebias/"
os.chdir(proj_path)

from src.nodes.utils import get_config
from src.nodes import utils
from src.nodes.utils import euclidean_distance

# npx spont. biophy.
cfg_ns, _ = get_config("silico_neuropixels", "concatenated").values()
KS4_ns_10m = cfg_ns["sorting"]["sorters"]["kilosort4"]["10m"][
    "output"
]  # sorting with KS4
GT_ns_10m = cfg_ns["sorting"]["simulation"]["ground_truth"]["10m"]["output"] # KS4 sorting
STUDY_ns = cfg_ns["postprocessing"]["waveform"]["sorted"]["study"]["kilosort4"][
    "10m"
]  # WaveformExtractor
REC_ns = cfg_ns["probe_wiring"]["full"]["output"]  # Wired

# npx evoked biophy.
cfg_ne, _ = get_config("silico_neuropixels", "stimulus").values()
KS4_ne_10m = cfg_ne["sorting"]["sorters"]["kilosort4"]["10m"]["output"]
GT_ne_10m = cfg_ne["sorting"]["simulation"]["ground_truth"]["10m"]["output"]
STUDY_ne = cfg_ne["postprocessing"]["waveform"]["sorted"]["study"]["kilosort4"][
    "10m"
]  # WaveformExtractor
STUDY_ne_su = '/gpfs/bbp.cscs.ch/project/proj85/scratch/laquitai/4_preprint_2023/0_silico/4_spikesorting_stimulus_test_neuropixels_8-1-24__8slc_80f_360r_50t_200ms_1_smallest_fiber_gids/0fcb7709-b1e9-4d84-b056-5801f20d55af/postpro/realism/spike/sorted/study_ks4_10m_single_units'
REC_ne = cfg_ne["probe_wiring"]["full"]["output"]  # Wired


# job parameters
job_kwargs = dict(n_jobs=-1, chunk_duration="1s", progress_bar=True, verbose=True)

# PATHS

# pre-computed sorted unit quality
quality_path = "/gpfs/bbp.cscs.ch/project/proj85/scratch/laquitai/4_preprint_2023/analysis/sorting_quality/sorting_quality.csv"

# model save path
MODEL_PATH_40Khz_s2s_pooled_pm1 = "/gpfs/bbp.cscs.ch/project/proj85/scratch/laquitai/4_preprint_2023/analysis/sorting_quality/models/cebra/sf_40Khz/s2s_pooled_pm1"
MODEL_PATH_20Khz = "/gpfs/bbp.cscs.ch/project/proj85/scratch/laquitai/4_preprint_2023/analysis/sorting_quality/models/cebra/sf_20Khz/"
MODEL_PATH_20Khz_e2e_pm1 = "/gpfs/bbp.cscs.ch/project/proj85/scratch/laquitai/4_preprint_2023/analysis/sorting_quality/models/cebra/sf_20Khz/e2e_pm1/"
MODEL_PATH_20Khz_s2e_pm1 = "/gpfs/bbp.cscs.ch/project/proj85/scratch/laquitai/4_preprint_2023/analysis/sorting_quality/models/cebra/sf_20Khz/s2e_pm1/"
MODEL_PATH_20Khz_e2s_pm1 = "/gpfs/bbp.cscs.ch/project/proj85/scratch/laquitai/4_preprint_2023/analysis/sorting_quality/models/cebra/sf_20Khz/e2s_pm1/"
MODEL_PATH_20Khz_e2s_pm2 = "/gpfs/bbp.cscs.ch/project/proj85/scratch/laquitai/4_preprint_2023/analysis/sorting_quality/models/cebra/sf_20Khz/e2s_pm2/"
MODEL_PATH_20Khz_e2s_pm3 = "/gpfs/bbp.cscs.ch/project/proj85/scratch/laquitai/4_preprint_2023/analysis/sorting_quality/models/cebra/sf_20Khz/e2s_pm3/"
MODEL_PATH_20Khz_e2s_pm4 = "/gpfs/bbp.cscs.ch/project/proj85/scratch/laquitai/4_preprint_2023/analysis/sorting_quality/models/cebra/sf_20Khz/e2s_pm4/"
MODEL_PATH_20Khz_e2s_pm5 = "/gpfs/bbp.cscs.ch/project/proj85/scratch/laquitai/4_preprint_2023/analysis/sorting_quality/models/cebra/sf_20Khz/e2s_pm5/"
MODEL_PATH_20Khz_e2s_pm4_bal = "/gpfs/bbp.cscs.ch/project/proj85/scratch/laquitai/4_preprint_2023/analysis/sorting_quality/models/cebra/sf_20Khz/e2s_pm4_bal/"
MODEL_PATH_20Khz_e2s_pm4_mixed_dataset = "/gpfs/bbp.cscs.ch/project/proj85/scratch/laquitai/4_preprint_2023/analysis/sorting_quality/models/cebra/sf_20Khz/e2s_pm4_mixed_dataset/"
MODEL_PATH_20Khz_e2s_pm4_mixed_dataset_2dwave = "/gpfs/bbp.cscs.ch/project/proj85/scratch/laquitai/4_preprint_2023/analysis/sorting_quality/models/cebra/sf_20Khz/e2s_pm4_mixed_dataset_2dwave/"

tight_layout_cfg = {"pad": 0.001}

2024-09-12 09:17:21,533 - root - utils.py - get_config - INFO - Reading experiment config.
2024-09-12 09:17:21,601 - root - utils.py - get_config - INFO - Reading experiment config. - done
2024-09-12 09:17:21,603 - root - utils.py - get_config - INFO - Reading experiment config.
2024-09-12 09:17:21,674 - root - utils.py - get_config - INFO - Reading experiment config. - done


In [2]:
def get_waveformExtractor_for_single_units(
    KS4_ne_10m, STUDY_ne, save_path: str, n_sites=384
):

    # get single units
    SortingNe = si.load_extractor(KS4_ne_10m)
    su_ix = np.where(SortingNe.get_property("KSLabel") == "good")[0]
    su_unit_ids = SortingNe.unit_ids[su_ix]

    # load WaveformExtractor
    WeNe = si.WaveformExtractor.load_from_folder(STUDY_ne)

    # create waveformExtractor for single units
    # which we will keep for all downstream analyses
    # this should speed up computations
    WeSuNe = WeNe.select_units(unit_ids=su_unit_ids, new_folder=save_path)

    # setup two properties required to calculate some quality metrics
    WeSuNe.recording.set_property("gain_to_uV", np.ones((n_sites,)))
    WeSuNe.recording.set_property("offset_to_uV", np.zeros((n_sites,)))
    return WeSuNe

In [6]:
# (20s)get (and save) single units waveform extractor
WeSuNe = get_waveformExtractor_for_single_units(
    KS4_ne_10m, STUDY_ne, save_path=STUDY_ne_su, n_sites=384
)

In [7]:
# () compute pca
pca = compute_principal_components(
    waveform_extractor=WeSuNe,
    n_components=5,
    mode="by_channel_local",
    **job_kwargs,
)

Fitting PCA:   0%|          | 0/408 [00:00<?, ?it/s]