In [1]:
#%% 0) Environment check

from pathlib import Path
import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import numpy as np
import torch

print("SpikeInterface:", si.__version__)
print("Installed sorters:", ss.installed_sorters())
print("CUDA available:", torch.cuda.is_available())


  from .autonotebook import tqdm as notebook_tqdm


SpikeInterface: 0.103.0
Installed sorters: ['kilosort4', 'simple', 'spykingcircus2', 'tridesclous2']
CUDA available: False


In [2]:
#%% 1) Paths (EDIT data_path to your OE binary session folder)

# Folder that contains structure.oebin (Open Ephys "binary" format)
data_path = r"C:\Users\ryoi\Documents\SpikeSorting\recordings\2025-10-01_15-53-19\Record Node 125\experiment1\recording1"  # <-- EDIT

# Where outputs will be written
OUT_DIR = Path(r"C:\Users\ryoi\Documents\SpikeSorting\ks4_outputs")
OUT_DIR.mkdir(parents=True, exist_ok=True)
print("Outputs →", OUT_DIR)


Outputs → C:\Users\ryoi\Documents\SpikeSorting\ks4_outputs


In [3]:
#%% 2) Load Open Ephys (choose neural stream, not ADC)

# If you’re unsure of the exact stream name, try calling read_openephys(...) once
# without stream_name to see the listed streams in the error, then set it here.
recording = se.read_openephys(
    data_path,
    stream_name="Record Node 125#Acquisition_Board-100.Rhythm Data"  # <-- EDIT if your stream differs
)
print(recording)
print("Segments:", recording.get_num_segments(), "| Fs:", recording.get_sampling_frequency(), "| Ch:", recording.get_num_channels())


OpenEphysBinaryRecordingExtractor: 64 channels - 30.0kHz - 1 segments - 105,877,760 samples 
                                   3,529.26s (58.82 minutes) - int16 dtype - 12.62 GiB
Segments: 1 | Fs: 30000.0 | Ch: 64


In [4]:
#%% 3) (Optional) 60 s smoke test

# Uncomment to sort only the first 60 seconds (recommended for a sanity check)
fs = recording.get_sampling_frequency()
recording = recording.frame_slice(0, int(fs * 60))
print("Sliced to first 60 s.")


Sliced to first 60 s.


In [5]:
#%% 4) Tetrode groups from REAL channel IDs (strings like 'CH40')

dev_ids = list(recording.channel_ids)
print("First 16 channel IDs:", dev_ids[:16])
print("Total channels:", len(dev_ids))

if len(dev_ids) % 4 != 0:
    raise ValueError("Channel count isn’t divisible by 4; define custom tetrode groups to match your wiring.")

# Auto-groups: device order in chunks of 4 (adjust if your wiring differs)
groups = [dev_ids[i:i+4] for i in range(0, len(dev_ids), 4)]

# Known bad channels — use SAME string format as dev_ids (e.g., 'CH06'); leave empty if none
BAD = []  # e.g., ['CH06', 'CH14']

# If dropping BAD, optionally slice the recording
if BAD:
    groups = [[ch for ch in g if ch not in BAD] for g in groups]
    groups = [g for g in groups if len(g) > 0]
    keep = [ch for g in groups for ch in g]
    recording = recording.channel_slice(keep_channel_ids=keep)
    dev_ids = list(recording.channel_ids)

print(f"Tetrodes: {len(groups)} | First group: {groups[0]}")


First 16 channel IDs: [np.str_('CH40'), np.str_('CH38'), np.str_('CH36'), np.str_('CH34'), np.str_('CH48'), np.str_('CH46'), np.str_('CH44'), np.str_('CH42'), np.str_('CH56'), np.str_('CH54'), np.str_('CH52'), np.str_('CH50'), np.str_('CH58'), np.str_('CH64'), np.str_('CH62'), np.str_('CH60')]
Total channels: 64
Tetrodes: 16 | First group: [np.str_('CH40'), np.str_('CH38'), np.str_('CH36'), np.str_('CH34')]


In [6]:
#%% 5) (Optional) Attach simple 2×2 geometry per tetrode (for nicer Phy plots)

from probeinterface import Probe

# Layout parameters (µm)
pitch = 20.0           # within tetrode (2x2 square)
dx, dy = 150.0, 150.0  # spacing between tetrodes (grid)

idx_map = {ch: i for i, ch in enumerate(dev_ids)}
pos = np.zeros((len(dev_ids), 2), dtype=float)

for t, g in enumerate(groups):
    base_xy = np.array([[0,0], [pitch,0], [0,pitch], [pitch,pitch]], dtype=float)[:len(g)]
    row, col = divmod(t, 8)  # 8 tetrodes per row; change to taste
    base_xy += np.array([col*dx, row*dy], dtype=float)
    for j, ch in enumerate(g):
        pos[idx_map[ch]] = base_xy[j]

# Ensure unique coordinates (Phy doesn’t like duplicates)
assert np.unique(pos, axis=0).shape[0] == pos.shape[0], "Duplicate positions detected—check groups."

pr = Probe(ndim=2)
pr.set_contacts(positions=pos, shapes='circle', shape_params={'radius': 7})
# Map contacts to current device channel order (0..n-1)
pr.set_device_channel_indices(np.arange(len(dev_ids), dtype=int))

recording = recording.set_probe(pr)
print("Geometry attached. Locations shape:", recording.get_channel_locations().shape)


Geometry attached. Locations shape: (64, 2)


In [7]:
#%% 6) Preprocessing strategy

# Build one common front-end WITHOUT CAR (bandpass + notch).
#   → SC2 will do its own CMR + whitening.
rec_bp = spre.bandpass_filter(recording, freq_min=300, freq_max=6000)
for f0 in (50, 100, 150):  # add 200 if needed
    rec_bp = spre.notch_filter(rec_bp, freq=f0, q=30)

# KS4 branch: add per-tetrode CAR (grouped "global" median)
rec_ks4 = spre.common_reference(rec_bp, reference="global", operator="median", groups=groups)

print("Preprocessing built. KS4 uses CAR branch; SC2 uses no-CAR base.")


Preprocessing built. KS4 uses CAR branch; SC2 uses no-CAR base.


In [8]:
#%% 7) Cache preprocessed recordings for KS4 & SC2 (timestamped version — safer on Windows)

from datetime import datetime
import spikeinterface as si

tag = datetime.now().strftime('%Y%m%d_%H%M%S')

# Timestamped cache folders prevent Windows file-lock errors
cached_sc2 = OUT_DIR / f"cached_sc2_{tag}"  # no-CAR for SC2
cached_ks4 = OUT_DIR / f"cached_ks4_{tag}"  # CAR for KS4

rec_sc2_cached = rec_bp.save(
    folder=cached_sc2,
    format="binary",
    dtype="float32",
    chunk_duration="1s",
    overwrite=True,
)

rec_ks4_cached = rec_ks4.save(
    folder=cached_ks4,
    format="binary",
    dtype="float32",
    chunk_duration="1s",
    overwrite=True,
)

print("SC2 cache →", cached_sc2)
print("KS4 cache →", cached_ks4)


write_binary_recording 
engine=process - n_jobs=1 - samples_per_chunk=30,000 - chunk_memory=3.66 MiB - total_memory=3.66 MiB - chunk_duration=1.00s


write_binary_recording (no parallelization):   0%|          | 0/60 [00:00<?, ?it/s]

write_binary_recording 
engine=process - n_jobs=1 - samples_per_chunk=30,000 - chunk_memory=3.66 MiB - total_memory=3.66 MiB - chunk_duration=1.00s


write_binary_recording (no parallelization):   0%|          | 0/60 [00:00<?, ?it/s]

SC2 cache → C:\Users\ryoi\Documents\SpikeSorting\ks4_outputs\cached_sc2_20251023_151925
KS4 cache → C:\Users\ryoi\Documents\SpikeSorting\ks4_outputs\cached_ks4_20251023_151925


In [9]:
#%% 8a) Run Kilosort4 (separate cell; CPU)

from datetime import datetime
ks4_params = ss.Kilosort4Sorter.default_params()
ks4_params.update({
    "torch_device": "cpu",   # no CUDA on this machine
    "do_CAR": False,         # CAR done upstream
    "progress_bar": True,
    "bad_channels": (BAD or None),
    # Optional sensitivity tweak:
    # "Th_universal": 6,
})

tag = datetime.now().strftime('%Y%m%d_%H%M%S')
ks4_out = OUT_DIR / f"ks4_run_{tag}"

sorting_ks4 = ss.run_sorter(
    "kilosort4",
    rec_ks4_cached,          # CAR’d input
    folder=ks4_out,          # SI 0.103: argument is folder=
    verbose=True,
    remove_existing_folder=True,
    **ks4_params,
)
print("KS4 units:", sorting_ks4.get_num_units(), "| out →", ks4_out)


kilosort.run_kilosort:  
kilosort.run_kilosort: Computing preprocessing variables.
kilosort.run_kilosort: ----------------------------------------
kilosort.run_kilosort: N samples: 1800000
kilosort.run_kilosort: N seconds: 60.0
kilosort.run_kilosort: N batches: 30
kilosort.run_kilosort: Preprocessing filters computed in 0.15s; total 0.16s
kilosort.run_kilosort:  


Skipping common average reference.


kilosort.run_kilosort: Resource usage after preprocessing
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort: CPU usage:    40.00 %
kilosort.run_kilosort: Mem used:     70.50 %     |      10.91 GB
kilosort.run_kilosort: Mem avail:     4.56 / 15.47 GB
kilosort.run_kilosort: ------------------------------------------------------
kilosort.run_kilosort: GPU usage:    N/A
kilosort.run_kilosort: GPU memory:   N/A
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort:  
kilosort.run_kilosort: Computing drift correction.
kilosort.run_kilosort: ----------------------------------------
kilosort.spikedetect: Re-computing universal templates from data.
kilosort.spikedetect: Number of universal templates: 1224
kilosort.spikedetect: Detecting spikes...
100%|██████████████████████████████████████████████████████████████████████████████████| 30/30 [02:00<00:00,  4.02s/it]
kilosort.run_kilosort: drift com

kilosort4 run time 330.07s
KS4 units: 172 | out → C:\Users\ryoi\Documents\SpikeSorting\ks4_outputs\ks4_run_20251023_152027


In [None]:
#%% 8b) Run SpyKING CIRCUS 2 (separate cell)



In [14]:
# KS4: load latest cache → random_spikes → waveforms → templates → Phy
from pathlib import Path
import spikeinterface as si
from spikeinterface import create_sorting_analyzer
from spikeinterface.exporters import export_to_phy
from datetime import datetime

# 1) Use the newest cached_ks4 folder
cached_candidates = sorted(OUT_DIR.glob("cached_ks4*"), key=lambda p: p.stat().st_mtime, reverse=True)
cached_ks4_path = cached_candidates[0] if cached_candidates else (OUT_DIR / "cached_ks4")
print("Using cached KS4 recording from:", cached_ks4_path)

# 2) Load folder-backed recording (new API)
rec_ks4_cached = si.load(cached_ks4_path)

# 3) Fresh analyzer folder
tag = datetime.now().strftime("%Y%m%d_%H%M%S")
an_ks4_folder = OUT_DIR / f"analyzer_ks4_{tag}"

analyzer_ks4 = create_sorting_analyzer(
    sorting=sorting_ks4,
    recording=rec_ks4_cached,
    folder=an_ks4_folder,
    overwrite=True,
)

# 4) REQUIRED first: choose spikes per unit
analyzer_ks4.compute({"random_spikes": {"max_spikes_per_unit": 500, "seed": 42}}, verbose=True)

# 5) Waveforms (SI 0.103: no sparsity/max_spikes kw here)
analyzer_ks4.compute(
    {"waveforms": {"ms_before": 1.5, "ms_after": 2.5, "dtype": "float32"}},
    verbose=True, n_jobs=4, chunk_duration="2s"
)

# 6) templates are REQUIRED for export_to_phy in SI 0.103
analyzer_ks4.compute({"templates": {}}, verbose=True)

# 7) Export to Phy (note: output_folder=..., not folder=)
phy_ks4_folder = OUT_DIR / f"phy_ks4_{tag}"
export_to_phy(analyzer_ks4, output_folder=phy_ks4_folder, remove_if_exists=True)
print("✅ Exported Phy (KS4) →", phy_ks4_folder)


Using cached KS4 recording from: C:\Users\ryoi\Documents\SpikeSorting\ks4_outputs\cached_ks4_20251023_151925


estimate_sparsity (no parallelization):   0%|          | 0/60 [00:00<?, ?it/s]

compute_waveforms 
engine=process - n_jobs=4 - samples_per_chunk=60,000 - chunk_memory=14.65 MiB - total_memory=58.59 MiB - chunk_duration=2.00s


compute_waveforms (workers: 4 processes):   0%|          | 0/30 [00:00<?, ?it/s]

write_binary_recording (no parallelization):   0%|          | 0/60 [00:00<?, ?it/s]

spike_amplitudes (no parallelization):   0%|          | 0/60 [00:00<?, ?it/s]

Fitting PCA:   0%|          | 0/172 [00:00<?, ?it/s]

Projecting waveforms:   0%|          | 0/172 [00:00<?, ?it/s]

extract PCs (no parallelization):   0%|          | 0/60 [00:00<?, ?it/s]

Run:
phy template-gui  C:\Users\ryoi\Documents\SpikeSorting\ks4_outputs\phy_ks4_20251023_162519\params.py
✅ Exported Phy (KS4) → C:\Users\ryoi\Documents\SpikeSorting\ks4_outputs\phy_ks4_20251023_162519


In [None]:
#%% 9b) Export SC2 to Phy (SortingAnalyzer: sparsity → waveforms → export)



In [None]:
#%% 10) (Optional) Quick quality metrics + sorter agreement

import spikeinterface.qualitymetrics as sqm
from spikeinterface.comparison import compare_multiple_sorters

# Fast metrics (skip PC metrics for speed)
qm_ks4 = sqm.compute_quality_metrics(analyzer_ks4, skip_pc_metrics=True)
qm_sc2 = sqm.compute_quality_metrics(analyzer_sc2, skip_pc_metrics=True)
qm_ks4.to_csv(OUT_DIR / f"quality_metrics_ks4_{tag}.csv", index=True)
qm_sc2.to_csv(OUT_DIR / f"quality_metrics_sc2_{tag}.csv", index=True)

comp = compare_multiple_sorters(
    sorting_list=[sorting_ks4, sorting_sc2],
    name_list=["KS4", "SC2"],
    delta_time=0.001,  # 1 ms tolerance
    min_accuracy=0.1
)
print("Agreement:\n", comp.agreement_scores)
print("Metrics saved to:", OUT_DIR)
