# Spike sorting on RTX 5090

10 min recording |

author: laquitainesteeve@gmail.com

Tested on an Ubuntu 24 with a 32GB VRAM Nvidia RTX 5090

Execution time: 16 min

## Setup 

2. Activate `spikesort_rtx5090` environment and select kernel

    ```bash
    python -m ipykernel install --user --name spikesort_rtx5090 --display-name "spikesort_rtx5090"
    ```

3. Run notebook.



In [2]:
%%time 
%load_ext autoreload
%autoreload 2

# import python packages
import os
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface as si
print("spikeinterface", si.__version__)

# project path
PROJ_PATH = "/home/steeve/steeve/epfl/code/spikebias/"
os.chdir(PROJ_PATH)

# import spikebias package
from src.nodes.sorting import sort_and_postprocess_10m

# recording parameters
REC_SECS = 600 
RECORDING_PATH = "./dataset/00_raw/recording_npx_spont/"

# setup sorting parameters
SORTER = "kilosort4"

# these are the default parameters
# for spikeinterface 0.100.5
# note that there are no minFR and minFR_channels in ks4
# - we set batch_size to 10,000 instead of 60,0000 due to memory constrains
# - we set dminx to 25.6 um instead of None
SORTER_PARAMS = {
    "batch_size": 10000,
    "nblocks": 1,
    "Th_universal": 9,
    "Th_learned": 8,
    "do_CAR": True,
    "invert_sign": False,
    "nt": 61,
    "artifact_threshold": None,
    "nskip": 25,
    "whitening_range": 32,
    "binning_depth": 5,
    "sig_interp": 20,
    "nt0min": None,
    "dmin": None,
    "dminx": 25.6,
    "min_template_size": 10,
    "template_sizes": 5,
    "nearest_chans": 10,
    "nearest_templates": 100,
    "templates_from_data": True,
    "n_templates": 6,
    "n_pcs": 6,
    "Th_single_ch": 6,
    "acg_threshold": 0.2,
    "ccg_threshold": 0.25,
    "cluster_downsampling": 20,
    "cluster_pcs": 64,
    "duplicate_spike_bins": 15,
    "do_correction": True,
    "keep_good_only": False,
    "save_extra_kwargs": False,
    "skip_kilosort_preprocessing": False,
    "scaleproc": None,
}

# sorting output 
KS4_OUTPUT_PATH = "./temp/SortingKS4"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
spikeinterface 0.100.5
CPU times: user 341 µs, sys: 11 µs, total: 352 µs
Wall time: 313 µs


## Minimal spike sorting

kilosort==4.0

In [None]:
%%time

# load 10 min of recording
Recording = si.load_extractor(RECORDING_PATH)
#Recording = Recording.frame_slice(start_frame=0, end_frame=Recording.sampling_frequency*REC_SECS)
Recording = Recording.frame_slice(start_frame=0, end_frame=Recording.sampling_frequency*60)


# spike sort
Sorting = ss.run_sorter(sorter_name = SORTER,
                        recording = Recording,
                        output_folder = KS4_OUTPUT_PATH,
                        remove_existing_folder = True,
                        verbose = True,
                        **SORTER_PARAMS)

print('\nRecording', Recording)
print('\nSorted units:', Sorting.unit_ids)



2025-07-12 18:13:35,853 - faiss.loader - loader.py - <module> - INFO - Loading faiss with AVX2 support.
2025-07-12 18:13:35,863 - faiss.loader - loader.py - <module> - INFO - Successfully loaded faiss with AVX2 support.
2025-07-12 18:13:35,865 - faiss - __init__.py - <module> - INFO - Failed to load GPU Faiss: name 'GpuIndexIVFFlat' is not defined. Will not load constructor refs for GPU indexes. This is only an error if you're trying to use GPU Faiss.


  warn("There is no Probe attached to this recording. Creating a dummy one with contact positions")
NVIDIA GeForce RTX 5090 with CUDA capability sm_120 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_50 sm_60 sm_61 sm_70 sm_75 sm_80 sm_86 sm_89 sm_90 compute_90.
If you want to use the NVIDIA GeForce RTX 5090 GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/

  X[:, self.nt : self.nt+nsamp] = torch.from_numpy(data).to(self.device).float()


Loading recording with SpikeInterface...
number of samples: 2400000
number of channels: 384
numbef of segments: 1
sampling rate: 40000.0
dtype: float32
Preprocessing filters computed in  0.68s; total  0.68s

computing drift
Re-computing universal templates from data.


## Full pipeline