In [11]:
#%% 0) Imports, deterministic settings, user config

import os, random, time, shutil
from pathlib import Path
from datetime import datetime
import numpy as np


TEST_SECONDS = 60
# ---- user config (edit these two) ----
ROOT_DIR = Path(r"C:/Users/ryoi/Documents/SpikeSorting/recordings")
SESSION_SUBPATH = Path(r"2025-10-01_15-53-19/Record Node 125/experiment1/recording1")
# Optional: force a specific stream name, else auto-pick
STREAM_NAME = None  # e.g. "Record Node 125#Acquisition_Board-100.Rhythm Data"

# ---- determinism-ish (helps reduce run-to-run jitter) ----
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
random.seed(0)
np.random.seed(0)


# Outputs live here
BASE_OUT = Path(r"C:/Users/ryoi/Documents/SpikeSorting")
KS4_OUT = BASE_OUT / "ks4_outputs"
SC2_OUT = BASE_OUT / "sc2_outputs"
for d in (KS4_OUT, SC2_OUT): 
    d.mkdir(parents=True, exist_ok=True)

# Quick test slice (seconds) â€” set to None for full session

# Known bad channels (use the same ID format as the recording)
BAD = []  # e.g. ["CH12","CH59"]
# Geometry settings for nicer Phy plots (tetrode 2Ã—2 layout on a grid)
ATTACH_GEOMETRY = True
TETRODES_PER_ROW = 4  # 64 ch = 16 tetrodes â†’ 4Ã—4 layout looks tidy
# Sorter toggles
RUN_KS4 = True
RUN_SC2 = True
# KS4 device ('cpu' now; switch to 'cuda' when you have GPU PyTorch)
KS4_TORCH_DEVICE = "cpu"
print("Recording root:", ROOT_DIR)
print("KS4 out:", KS4_OUT)
print("SC2 out:", SC2_OUT)


Recording root: C:\Users\ryoi\Documents\SpikeSorting\recordings
KS4 out: C:\Users\ryoi\Documents\SpikeSorting\ks4_outputs
SC2 out: C:\Users\ryoi\Documents\SpikeSorting\sc2_outputs


In [12]:
TEST_SECONDS

60

In [13]:
#%% 1) Load Open Ephys & select stream

#%% 1) Load Open Ephys & select stream (robust to multi-stream error)

from pathlib import Path
import re, ast
import spikeinterface.extractors as se

DATA_PATH = (ROOT_DIR / SESSION_SUBPATH)
print("Using OE folder:", DATA_PATH)

def discover_oe_stream_names(folder: Path) -> list[str]:
    """
    Return list of OE stream names for this folder.
    Works even when SI raises on multi-stream by parsing the error message.
    """
    try:
        rec = se.read_openephys(folder)  # may raise if multiple streams
        # If we got here, there was only one stream or SI allowed opening without specifying.
        # Try public accessor; fall back to _annotations.
        get_ann = getattr(rec, "get_annotation", None)
        if callable(get_ann):
            names = get_ann("streams_names") or get_ann("stream_names") or []
        else:
            ann = getattr(rec, "_annotations", {}) or {}
            names = ann.get("streams_names") or ann.get("stream_names") or []
        # Some SI versions return empty here for single-stream; then get from rec.
        if not names:
            # As a last resort, print repr and try to infer; otherwise assume single.
            names = [getattr(rec, "stream_name", None)] if hasattr(rec, "stream_name") else []
            names = [n for n in names if n]
        return names
    except ValueError as e:
        msg = str(e)
        # Typical message contains a Python list right after `stream_names`:
        # `stream_names`: ['Record Node ... Rhythm Data', 'Record Node ... Rhythm Data_ADC']
        m = re.search(r"`stream_names`:\s*(\[[^\]]+\])", msg)
        if m:
            try:
                names = ast.literal_eval(m.group(1))
                return names
            except Exception:
                pass
        # If we can't parse, re-raise with context.
        raise

# --- Discover stream names safely ---
stream_names = []
try:
    stream_names = discover_oe_stream_names(DATA_PATH)
except Exception as e:
    raise RuntimeError(f"Could not discover Open Ephys streams. Set STREAM_NAME manually. Original error: {e}")

if STREAM_NAME is None:
    if not stream_names:
        raise RuntimeError("Could not discover OE stream names; please set STREAM_NAME explicitly.")
    print("Available streams:")
    for i, s in enumerate(stream_names):
        print(f"  [{i}] {s}")
    # Heuristic: pick neural data (avoid *_ADC, SYNC)
    candidates = [s for s in stream_names
                  if ("Rhythm" in s or "Data" in s) and ("ADC" not in s) and ("SYNC" not in s)]
    STREAM_NAME = candidates[0] if candidates else stream_names[0]

print("Using stream:", STREAM_NAME)

# --- Open the chosen stream ---
recording = se.read_openephys(DATA_PATH, stream_name=STREAM_NAME)
print(recording)
print("Segments:", recording.get_num_segments(),
      "| Fs:", recording.get_sampling_frequency(),
      "| Ch:", recording.get_num_channels())

# KS4 requires single segment; keep segment 0 if multiple
if recording.get_num_segments() > 1:
    recording = recording.select_segments([0])
    print("Selected segment 0 (single-segment).")

# Optional quick slice
if TEST_SECONDS is not None:
    fs = recording.get_sampling_frequency()
    # Cap end_frame to available frames in segment 0 (handle SI API differences)
    n_frames = (recording.get_num_samples(0)
                if hasattr(recording, "get_num_samples")
                else recording.get_num_frames(0))
    end_frame = min(int(fs * TEST_SECONDS), n_frames)
    recording = recording.frame_slice(0, end_frame)
    print(f"Recording sliced to first {TEST_SECONDS}s (end_frame={end_frame}).")


Using OE folder: C:\Users\ryoi\Documents\SpikeSorting\recordings\2025-10-01_15-53-19\Record Node 125\experiment1\recording1
Available streams:
  [0] Record Node 125#Acquisition_Board-100.Rhythm Data
  [1] Record Node 125#Acquisition_Board-100.Rhythm Data_ADC
Using stream: Record Node 125#Acquisition_Board-100.Rhythm Data
OpenEphysBinaryRecordingExtractor: 64 channels - 30.0kHz - 1 segments - 105,877,760 samples 
                                   3,529.26s (58.82 minutes) - int16 dtype - 12.62 GiB
Segments: 1 | Fs: 30000.0 | Ch: 64
Recording sliced to first 60s (end_frame=1800000).


In [14]:
#%% 2) Build tetrode groups (4 channels each) & remove BAD

DEV_IDS = list(recording.channel_ids)
print("Total channels:", len(DEV_IDS))
if len(DEV_IDS) % 4 != 0:
    raise ValueError("Channel count not divisible by 4; adjust grouping to match your wiring.")

groups = [DEV_IDS[i:i+4] for i in range(0, len(DEV_IDS), 4)]

if BAD:
    groups = [[ch for ch in g if ch not in BAD] for g in groups]
    groups = [g for g in groups if len(g) > 0]
    keep = [ch for g in groups for ch in g]
    recording = recording.channel_slice(keep_channel_ids=keep)
    DEV_IDS = list(recording.channel_ids)
    print("Removed BAD channels; new channel count:", len(DEV_IDS))

print(f"Tetrodes: {len(groups)}; first group: {groups[0]}")


Total channels: 64
Tetrodes: 16; first group: [np.str_('CH40'), np.str_('CH38'), np.str_('CH36'), np.str_('CH34')]


In [15]:
#%% 3) (Optional) Attach 2Ã—2 tetrode geometry

if ATTACH_GEOMETRY:
    from probeinterface import Probe
    import numpy as np

    idx_map = {ch: i for i, ch in enumerate(recording.channel_ids)}
    pos = np.zeros((len(idx_map), 2), dtype=float)

    pitch = 20.0     # Âµm inside a tetrode
    dx, dy = 150.0, 150.0  # Âµm between tetrode blocks

    for t, g in enumerate(groups):
        base_xy = np.array([[0,0],[pitch,0],[0,pitch],[pitch,pitch]], dtype=float)[:len(g)]
        row, col = divmod(t, TETRODES_PER_ROW)
        offset = np.array([col*dx, row*dy], dtype=float)
        for j, ch in enumerate(g):
            pos[idx_map[ch]] = base_xy[j] + offset

    pr = Probe(ndim=2)
    pr.set_contacts(
        positions=pos,
        shapes='circle',
        shape_params={'radius': 7}
    )
    # VERY IMPORTANT for SI: map probe contacts to device channel indices
    device_inds = np.array([idx_map[ch] for ch in recording.channel_ids], dtype=int)
    pr.set_device_channel_indices(device_inds)

    recording = recording.set_probe(pr)
    print("Geometry attached. Locations:", recording.get_channel_locations().shape)
else:
    print("Geometry attachment skipped (ATTACH_GEOMETRY=False).")


Geometry attached. Locations: (64, 2)


In [16]:
# %% 4) Preprocess (bandpass + notch); per-tetrode CAR for KS4; no CAR for SC2

import spikeinterface.preprocessing as spre

# --- Bandpass first ---
rec_bp = spre.bandpass_filter(recording, freq_min=300, freq_max=6000)

# --- Optional: notch out 50 Hz and harmonics (for mains interference) ---
for f0 in (50, 100, 150):  # adjust to your mains frequency
    rec_bp = spre.notch_filter(rec_bp, freq=f0, q=30)

# --- KS4 branch: per-tetrode CAR (grouped median reference) ---
rec_ks4 = spre.common_reference(
    rec_bp,
    reference="global",     # group-wise CAR across tetrodes
    operator="median",
    groups=groups
)

# --- SC2 branch: no CAR (does internal whitening) ---
rec_sc2 = rec_bp

print("Preprocessing prepared: bandpass + notch; KS4 with CAR, SC2 without.")


Preprocessing prepared: bandpass + notch; KS4 with CAR, SC2 without.


In [17]:
print(rec_bp)
print("KS4 chans:", rec_ks4.get_num_channels(), "| SC2 chans:", rec_sc2.get_num_channels())
import numpy as np
import numpy as _np
for name, rec in [("rec_bp", rec_bp), ("rec_ks4", rec_ks4), ("rec_sc2", rec_sc2)]:
    s0 = rec.get_traces(start_frame=0, end_frame=min(int(rec.get_sampling_frequency()), rec.get_num_samples()))
    print(name, "NaNs?", _np.isnan(s0).any(), "dtype:", s0.dtype)


NotchFilterRecording: 64 channels - 30.0kHz - 1 segments - 1,800,000 samples 
                      60.00s (1.00 minutes) - int16 dtype - 219.73 MiB
KS4 chans: 64 | SC2 chans: 64
rec_bp NaNs? False dtype: int16
rec_ks4 NaNs? False dtype: int16
rec_sc2 NaNs? False dtype: int16


In [18]:


def safe_rmtree(path: Path, retries: int = 6, wait_s: float = 0.5):
    for _ in range(retries):
        try:
            if path.exists():
                shutil.rmtree(path)
            return
        except PermissionError:
            time.sleep(wait_s)
    # last try â€” raise if still present/locked
    if path.exists():
        shutil.rmtree(path)

ts = datetime.now().strftime("%Y%m%d_%H%M%S")
ks4_cache = KS4_OUT / f"cached_ks4_{ts}"
sc2_cache = SC2_OUT / f"cached_sc2_{ts}"

# clean then save
for cache_dir in (ks4_cache, sc2_cache):
    safe_rmtree(cache_dir)

rec_ks4_cached = rec_ks4.save(folder=ks4_cache, format="binary", dtype="float32",
                              chunk_duration="1s", overwrite=True)
print("Cached KS4 â†’", ks4_cache)

rec_sc2_cached = rec_sc2.save(folder=sc2_cache, format="binary", dtype="float32",
                              chunk_duration="1s", overwrite=True)
print("Cached SC2 â†’", sc2_cache)


write_binary_recording 
engine=process - n_jobs=1 - samples_per_chunk=30,000 - chunk_memory=3.66 MiB - total_memory=3.66 MiB - chunk_duration=1.00s


write_binary_recording (no parallelization):   0%|          | 0/60 [00:00<?, ?it/s]

Cached KS4 â†’ C:\Users\ryoi\Documents\SpikeSorting\ks4_outputs\cached_ks4_20251025_141602
write_binary_recording 
engine=process - n_jobs=1 - samples_per_chunk=30,000 - chunk_memory=3.66 MiB - total_memory=3.66 MiB - chunk_duration=1.00s


write_binary_recording (no parallelization):   0%|          | 0/60 [00:00<?, ?it/s]

Cached SC2 â†’ C:\Users\ryoi\Documents\SpikeSorting\sc2_outputs\cached_sc2_20251025_141602


In [19]:
#%% 6) Run Kilosort4

import spikeinterface.sorters as ss

ks4_params = ss.Kilosort4Sorter.default_params()
ks4_params.update({
    "torch_device": KS4_TORCH_DEVICE,   # "cpu" now; switch to "cuda" on GPU
    "do_CAR": False,                    # already did per-tetrode CAR
    "progress_bar": True,
    "bad_channels": (BAD or None),
    # Optional knobs to pin sensitivity if you like:
    # "Th_universal": 6,
})

tag = datetime.now().strftime("%Y%m%d_%H%M%S")
ks4_run = KS4_OUT / f"ks4_run_{tag}"
safe_rmtree(ks4_run)

sorting_ks4 = ss.run_sorter(
    "kilosort4",
    rec_ks4_cached,
    folder=ks4_run,                  # SI 0.103 uses folder=
    verbose=True,
    remove_existing_folder=True,
    **ks4_params,
)
print("KS4 units:", sorting_ks4.get_num_units(), "| out â†’", ks4_run)


kilosort.run_kilosort:  
kilosort.run_kilosort: Computing preprocessing variables.
kilosort.run_kilosort: ----------------------------------------
kilosort.run_kilosort: N samples: 1800000
kilosort.run_kilosort: N seconds: 60.0
kilosort.run_kilosort: N batches: 30


Skipping common average reference.


kilosort.run_kilosort: Preprocessing filters computed in 0.60s; total 0.60s
kilosort.run_kilosort:  
kilosort.run_kilosort: Resource usage after preprocessing
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort: CPU usage:    22.40 %
kilosort.run_kilosort: Mem used:     87.70 %     |      13.56 GB
kilosort.run_kilosort: Mem avail:     1.91 / 15.47 GB
kilosort.run_kilosort: ------------------------------------------------------
kilosort.run_kilosort: GPU usage:    N/A
kilosort.run_kilosort: GPU memory:   N/A
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort:  
kilosort.run_kilosort: Computing drift correction.
kilosort.run_kilosort: ----------------------------------------
kilosort.spikedetect: Re-computing universal templates from data.
kilosort.spikedetect: Number of universal templates: 1410
kilosort.spikedetect: Detecting spikes...
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ

kilosort4 run time 1339.31s
KS4 units: 153 | out â†’ C:\Users\ryoi\Documents\SpikeSorting\ks4_outputs\ks4_run_20251025_141724


In [20]:
#%% 7) Run SpyKING Circus 2

# Make sure hdbscan is installed in this env (you already did)
# pip install hdbscan scikit-learn

sc2_params = ss.Spykingcircus2Sorter.default_params()
# defaults are generally good; SC2 does CAR/whitening internally

sc2_run = SC2_OUT / f"sc2_run_{tag}"
safe_rmtree(sc2_run)

sorting_sc2 = ss.run_sorter(
    "spykingcircus2",
    rec_sc2_cached,
    folder=sc2_run,
    verbose=True,
    remove_existing_folder=True,
    **sc2_params,
)
print("SC2 units:", sorting_sc2.get_num_units(), "| out â†’", sc2_run)


Preprocessing the recording (bandpass filtering + CMR + whitening)
Geometry of the probe does not allow 1D drift correction


noise_level (no parallelization):   0%|          | 0/20 [00:00<?, ?it/s]

write_memory_recording (no parallelization):   0%|          | 0/60 [00:00<?, ?it/s]

detect peaks using locally_exclusive + 1 node (no parallelization):   0%|          | 0/60 [00:00<?, ?it/s]

detect peaks using matched_filtering (no parallelization):   0%|          | 0/60 [00:00<?, ?it/s]

Kept 47354 peaks for clustering


Transform peaks svd (no parallelization):   0%|          | 0/60 [00:00<?, ?it/s]

split_clusters with local_feature_clustering:   0%|          | 0/64 [00:00<?, ?it/s]

Kept 100 raw clusters
Kept 97 clean clusters


find spikes (circus-omp-svd) (no parallelization):   0%|          | 0/60 [00:00<?, ?it/s]

Found 85269 spikes




Kept 94 units after final merging
spykingcircus2 run time 157.57s
SC2 units: 94 | out â†’ C:\Users\ryoi\Documents\SpikeSorting\sc2_outputs\sc2_run_20251025_141724


In [21]:
#%% 8) Create analyzers (SI 0.103 requires this path)

from spikeinterface.core import create_sorting_analyzer

an_ks4 = create_sorting_analyzer(
    sorting=sorting_ks4,
    recording=rec_ks4_cached,
    folder=KS4_OUT / f"analyzer_ks4_{tag}",
    overwrite=True,
)

an_sc2 = create_sorting_analyzer(
    sorting=sorting_sc2,
    recording=rec_sc2_cached,
    folder=SC2_OUT / f"analyzer_sc2_{tag}",
    overwrite=True,
)

print("Analyzers created.")


estimate_sparsity (no parallelization):   0%|          | 0/60 [00:00<?, ?it/s]

estimate_sparsity (no parallelization):   0%|          | 0/60 [00:00<?, ?it/s]

Analyzers created.


In [24]:
#%% 9) Compute (random_spikes â†’ waveforms â†’ templates â†’ amplitudes â†’ PCs)

# KS4 analyzer
an_ks4.compute({"random_spikes": {"max_spikes_per_unit": 500, "seed": 42}}, verbose=True)
an_ks4.compute("waveforms", ms_before=1.5, ms_after=2.5, dtype="float32",
               verbose=True, n_jobs=4, chunk_duration="2s")
an_ks4.compute("templates")
#an_ks4.compute("amplitudes")
an_ks4.compute("principal_components")
print("KS4 analyzer computed.")

# SC2 analyzer (same sequence)
an_sc2.compute({"random_spikes": {"max_spikes_per_unit": 500, "seed": 42}}, verbose=True)
an_sc2.compute("waveforms", ms_before=1.5, ms_after=2.5, dtype="float32",
               verbose=True, n_jobs=4, chunk_duration="2s")
an_sc2.compute("templates")
#an_sc2.compute("amplitudes")
an_sc2.compute("principal_components")
print("SC2 analyzer computed.")


compute_waveforms 
engine=process - n_jobs=4 - samples_per_chunk=60,000 - chunk_memory=14.65 MiB - total_memory=58.59 MiB - chunk_duration=2.00s


compute_waveforms (workers: 4 processes):   0%|          | 0/30 [00:00<?, ?it/s]

Fitting PCA:   0%|          | 0/153 [00:00<?, ?it/s]

Projecting waveforms:   0%|          | 0/153 [00:00<?, ?it/s]

KS4 analyzer computed.
compute_waveforms 
engine=process - n_jobs=4 - samples_per_chunk=60,000 - chunk_memory=14.65 MiB - total_memory=58.59 MiB - chunk_duration=2.00s


compute_waveforms (workers: 4 processes):   0%|          | 0/30 [00:00<?, ?it/s]

Fitting PCA:   0%|          | 0/94 [00:00<?, ?it/s]

Projecting waveforms:   0%|          | 0/94 [00:00<?, ?it/s]

SC2 analyzer computed.


In [25]:
#%% 10) Export to Phy

from spikeinterface.exporters import export_to_phy

phy_ks4 = KS4_OUT / f"phy_ks4_{tag}"
export_to_phy(an_ks4, output_folder=phy_ks4, remove_if_exists=True)
print("âœ… Exported Phy (KS4) â†’", phy_ks4)

phy_sc2 = SC2_OUT / f"phy_sc2_{tag}"
export_to_phy(an_sc2, output_folder=phy_sc2, remove_if_exists=True)
print("âœ… Exported Phy (SC2) â†’", phy_sc2)

print("ðŸŽ‰ All done â€” ready for Phy curation.")


write_binary_recording (no parallelization):   0%|          | 0/60 [00:00<?, ?it/s]

spike_amplitudes (no parallelization):   0%|          | 0/60 [00:00<?, ?it/s]

extract PCs (no parallelization):   0%|          | 0/60 [00:00<?, ?it/s]

Run:
phy template-gui  C:\Users\ryoi\Documents\SpikeSorting\ks4_outputs\phy_ks4_20251025_141724\params.py
âœ… Exported Phy (KS4) â†’ C:\Users\ryoi\Documents\SpikeSorting\ks4_outputs\phy_ks4_20251025_141724


write_binary_recording (no parallelization):   0%|          | 0/60 [00:00<?, ?it/s]

spike_amplitudes (no parallelization):   0%|          | 0/60 [00:00<?, ?it/s]

extract PCs (no parallelization):   0%|          | 0/60 [00:00<?, ?it/s]

Run:
phy template-gui  C:\Users\ryoi\Documents\SpikeSorting\sc2_outputs\phy_sc2_20251025_141724\params.py
âœ… Exported Phy (SC2) â†’ C:\Users\ryoi\Documents\SpikeSorting\sc2_outputs\phy_sc2_20251025_141724
ðŸŽ‰ All done â€” ready for Phy curation.


In [26]:
phy_sc2

WindowsPath('C:/Users/ryoi/Documents/SpikeSorting/sc2_outputs/phy_sc2_20251025_141724')