# Spike sorting w/o drift corr

10 min recording | +/- drift correction | high-pass filt.| - bad channels

author: laquitainesteeve@gmail.com

Tested on an Ubuntu 24 with a 32GB VRAM Nvidia RTX 5090, takes 25GB of VRAM.

Execution time: 16 min

## Setup 

2. Activate `spikesort_rtx5090` environment and select kernel

    ```bash
    python -m ipykernel install --user --name spikesort_rtx5090 --display-name "spikesort_rtx5090"
    ```

3. Run notebook or pipeline:
    ```bash
    # ks4 - npx spontaneous
    nohup python -m src.pipes.sorting.test_params.driftcorr.npx_spont.10m.ks4 \
        --recording-path dataset/00_raw/recording_npx_spont \
            --preprocess-path dataset/01_intermediate/preprocessing/recording_npx_spont \
                --sorting-path-corrected ./temp/npx_spont/SortingKS3_10m_RTX5090_DriftCorr \
                    --sorting-output-path-corrected ./temp/npx_spont/KS3_output_10m_RTX5090_DriftCorr/ \
                        --study-path-corrected ./temp/npx_spont/study_ks3_10m_RTX5090_DriftCorr/ \
                            --sorting-path-not-corrected ./temp/npx_spont/SortingKS3_10m_RTX5090_NoDriftCorr \
                                --sorting-output-path-not-corrected ./temp/npx_spont/KS3_output_10m_RTX5090_NoDriftCorr/ \
                                    --study-path-not-corrected ./temp/npx_spont/study_ks3_10m_RTX5090_NoDriftCorr/
    ```

In [1]:
%%time 
%load_ext autoreload
%autoreload 2

# import python packages
import os
import torch
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface as si
print("spikeinterface", si.__version__)
torch.cuda.empty_cache()

# project path
PROJ_PATH = "/home/steeve/steeve/epfl/code/spikebias/"
os.chdir(PROJ_PATH)

# import spikebias package
from src.nodes.sorting import sort_and_postprocess_10m

# recording parameters
REC_SECS = 600 
RECORDING_PATH = "./dataset/00_raw/recording_npx_spont/"

# setup sorting parameters
SORTER = "kilosort4"

# these are the default parameters
# for spikeinterface 0.100.5
# note that there are no minFR and minFR_channels in ks4
# - we set batch_size to 10,000 instead of 60,0000 due to memory constrains
# - we set dminx to 25.6 um instead of None
SORTER_PARAMS = {
    "batch_size": 10000,
    "nblocks": 1,
    "Th_universal": 9,
    "Th_learned": 8,
    "do_CAR": True,
    "invert_sign": False,
    "nt": 61,
    "artifact_threshold": None,
    "nskip": 25,
    "whitening_range": 32,
    "binning_depth": 5,
    "sig_interp": 20,
    "nt0min": None,
    "dmin": None,
    "dminx": 25.6,
    "min_template_size": 10,
    "template_sizes": 5,
    "nearest_chans": 10,
    "nearest_templates": 100,
    "templates_from_data": True,
    "n_templates": 6,
    "n_pcs": 6,
    "Th_single_ch": 6,
    "acg_threshold": 0.2,
    "ccg_threshold": 0.25,
    "cluster_downsampling": 20,
    "cluster_pcs": 64,
    "duplicate_spike_bins": 15,
    "do_correction": True,
    "keep_good_only": False,
    "save_extra_kwargs": False,
    "skip_kilosort_preprocessing": False,
    "scaleproc": None,
}

spikeinterface 0.100.5
CPU times: user 3.52 s, sys: 220 ms, total: 3.74 s
Wall time: 1.5 s


## npx spont w/ drift corr.

In [2]:
# setup configuration

# WITH CORR.
CFG_CORR = {
    'probe_wiring': {
        'full': {
            'output': 'dataset/00_raw/recording_npx_spont'
        }
    },
    'preprocessing': {
        'full': {
            'output': {
                'trace_file_path': 'dataset/01_intermediate/preprocessing/recording_npx_spont'
            }
        }
    },
    'sorting': {
        'sorters': {
            f"{SORTER}": {
                '10m': {
                    'output': './temp/npx_spont/SortingKS4_10m_RTX5090_DriftCorr', #'path/to/sorting/output',
                    'sort_output':'./temp/npx_spont/KS4_output/KS4_output_10m_RTX5090_DriftCorr' #'path/to/sorting/sort_output'
                }
            }
        }
    },
    'postprocessing': {
        'waveform': {
            'sorted': {
                'study': {
                    f"{SORTER}": {  # sorter name
                        '10m': './temp/npx_spont/study_ks4_10m_RTX5090_DriftCorr' #'path/to/postprocessing/study'
                    }
                }
            }
        }
    }
}

# WITHOUT CORR.

CFG_NO_CORR = {
    'probe_wiring': {
        'full': {
            'output': 'dataset/00_raw/recording_npx_spont'
        }
    },
    'preprocessing': {
        'full': {
            'output': {
                'trace_file_path': 'dataset/01_intermediate/preprocessing/recording_npx_spont'
            }
        }
    },
    'sorting': {
        'sorters': {
            f"{SORTER}": {
                '10m': {
                    'output': './temp/npx_spont/SortingKS4_10m_RTX5090_NoDriftCorr', #'path/to/sorting/output',
                    'sort_output':'./temp/npx_spont/KS4_output/KS4_output_10m_RTX5090_NoDriftCorr' #'path/to/sorting/sort_output'
                }
            }
        }
    },
    'postprocessing': {
        'waveform': {
            'sorted': {
                'study': {
                    f"{SORTER}": {  # sorter name
                        '10m': './temp/npx_spont/study_ks4_10m_RTX5090_NoDriftCorr' #'path/to/postprocessing/study'
                    }
                }
            }
        }
    }
}

# + DRIFT CORRECTION
# spike sort
sort_and_postprocess_10m(CFG_CORR, SORTER, SORTER_PARAMS, duration_sec=REC_SECS, 
                         is_sort=True, is_postpro=False, extract_wvf=False, copy_binary_recording=True,
                         remove_bad_channels=True)
# # post-process
# sort_and_postprocess_10m(CFG_CORR, SORTER, SORTER_PARAMS, duration_sec=REC_SECS,
#                          is_sort=False, is_postpro=True, extract_wvf=True, copy_binary_recording=True,
#                          remove_bad_channels=False)

# - DRIFT CORRECTION

# SORTER_PARAMS['do_correction'] = False

# # spike sort
# sort_and_postprocess_10m(CFG_NO_CORR, SORTER, SORTER_PARAMS, duration_sec=REC_SECS,
#                          is_sort=True, is_postpro=False, extract_wvf=False, copy_binary_recording=True,
#                          remove_bad_channels=True)
# # post-process
# sort_and_postprocess_10m(CFG_NO_CORR, SORTER, SORTER_PARAMS, duration_sec=REC_SECS,
#                          is_sort=False, is_postpro=True, extract_wvf=True, copy_binary_recording=True,
#                          remove_bad_channels=False)

2025-07-12 18:50:20,595 - root - sorting.py - sort_and_postprocess_10m - INFO - Started sorting 10 minutes recording.
2025-07-12 18:50:20,597 - root - sorting.py - sort - INFO - Removing bad channels...
2025-07-12 18:50:20,600 - root - sorting.py - sort - INFO - Done removing bad channels in: 0.0
2025-07-12 18:50:20,600 - root - sorting.py - sort - INFO - Selected first 10.0 minutes in: 0.0
2025-07-12 18:50:20,600 - root - sorting.py - sort - INFO - Done converting recording as int16 in: 0.0
2025-07-12 18:50:20,601 - root - sorting.py - sort - INFO - Saving int16 binary recording...




write_binary_recording with n_jobs = 32 and chunk_size = 400000


write_binary_recording:   0%|          | 0/60 [00:00<?, ?it/s]

2025-07-12 18:50:34,735 - root - sorting.py - sort - INFO - Done copying int16 binary recording in: 14.1
2025-07-12 18:50:34,737 - root - sorting.py - sort - INFO - Start sorting...
2025-07-12 18:50:35,033 - faiss.loader - loader.py - <module> - INFO - Loading faiss with AVX2 support.
2025-07-12 18:50:35,042 - faiss.loader - loader.py - <module> - INFO - Successfully loaded faiss with AVX2 support.
2025-07-12 18:50:35,044 - faiss - __init__.py - <module> - INFO - Failed to load GPU Faiss: name 'GpuIndexIVFFlat' is not defined. Will not load constructor refs for GPU indexes. This is only an error if you're trying to use GPU Faiss.


  warn("There is no Probe attached to this recording. Creating a dummy one with contact positions")
NVIDIA GeForce RTX 5090 with CUDA capability sm_120 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_50 sm_60 sm_61 sm_70 sm_75 sm_80 sm_86 sm_89 sm_90 compute_90.
If you want to use the NVIDIA GeForce RTX 5090 GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/

  X[:, self.nt : self.nt+nsamp] = torch.from_numpy(data).to(self.device).float()


Loading recording with SpikeInterface...
number of samples: 24000000
number of channels: 202
numbef of segments: 1
sampling rate: 40000.0
dtype: int16
Preprocessing filters computed in  0.57s; total  0.57s

computing drift
Re-computing universal templates from data.


100%|██████████| 2400/2400 [35:24<00:00,  1.13it/s]


drift computed in  2245.97s; total  2246.54s

Extracting spikes using templates
Re-computing universal templates from data.


100%|██████████| 2400/2400 [35:24<00:00,  1.13it/s]


393204 spikes extracted in  2246.36s; total  4492.91s

First clustering


100%|██████████| 93/93 [00:02<00:00, 35.47it/s]  


95 clusters found, in  2.65s; total  4495.56s

Extracting spikes using cluster waveforms


100%|██████████| 2400/2400 [00:12<00:00, 192.21it/s]


109968 spikes extracted in  12.51s; total  4508.07s

Final clustering


100%|██████████| 93/93 [00:03<00:00, 29.55it/s]  

112 clusters found, in  3.15s; total  4511.22s

Merging clusters
100 units found, in  0.05s; total  4511.27s

Saving to phy and computing refractory periods
27 units found with good refractory periods

Total runtime: 4511.32s = 01:75:11 h:m:s
kilosort4 run time 4511.39s
2025-07-12 20:05:47,028 - root - sorting.py - sort - INFO - Removing empty units...
2025-07-12 20:05:47,033 - root - sorting.py - sort - INFO - Done removing empty units.
2025-07-12 20:05:47,033 - root - sorting.py - sort - INFO - Done sorting: took 4512.3
2025-07-12 20:05:47,033 - root - sorting.py - sort - INFO - Done running kilosort4 in: 4512.3
2025-07-12 20:05:47,034 - root - sorting.py - sort - INFO - Saved sorting metadata.
2025-07-12 20:05:47,041 - root - sorting.py - sort - INFO - Done saving kilosort4 in: 0.0
2025-07-12 20:05:47,041 - root - sorting.py - sort_and_postprocess_10m - INFO - Done sorting with kilosort4 in: 4526.4
2025-07-12 20:05:47,041 - root - sorting.py - sort_and_postprocess_10m - INFO - Skipp




In [7]:
# Compare sorting results
SortingRef = si.load_extractor("./dataset/01_intermediate/sorting/npx_spont/SortingKS4_10m")
SortingCorr = si.load_extractor("./temp/npx_spont/5min/SortingKS4_10m_RTX5090_DriftCorr")
SortingNoCorr = si.load_extractor("./temp/npx_spont/5min/SortingKS4_10m_RTX5090_NoDriftCorr")

# total units
print("total units:")
print(len(SortingRef.unit_ids))
print(len(SortingCorr.unit_ids))
print(len(SortingNoCorr.unit_ids))

# total units
print("\nsingle units:")
print(sum(SortingRef.get_property('KSLabel')=='good'))
print(sum(SortingCorr.get_property('KSLabel')=='good'))
print(sum(SortingNoCorr.get_property('KSLabel')=='good'))

total units:
520
70
62

single units:
184
17
20


In [8]:
torch.cuda.empty_cache()