Concatenating ICMS files

In [None]:
from pathlib import Path
import re
import numpy as np
import spikeinterface as si
import spikeinterface.extractors as se

# =========================
# CONFIG
# =========================
BASE_DIR     = Path(r"C:/storage/Reza/rat data/08-08-2025")  # the date folder
SESSION_GLOB = "ICMSstim_icmsstim_*_g0"                      # pattern for runs
TARGET_DIR   = Path(r"C:/storage/Reza/rat data/08-08-2025/ICMSstim_icmsstim_All_Concat_g0")

# IO speedups for saving; adjust as needed
JOB_KW = dict(n_jobs=4, progress_bar=True, chunk_duration="5s")


# =========================
# HELPERS
# =========================
def list_runs(base_dir: Path, pattern: str) -> list[Path]:
    """Find and numerically sort run folders like ..._2_g0, ..._3_g0, ..._14_g0. Skip the All_Concat target."""
    runs = []
    for d in base_dir.glob(pattern):
        if "All_Concat_g0" in d.name:
            continue
        m = re.search(r"_(\d+)_g\d+$", d.name)  # capture the run index before _g#
        idx = int(m.group(1)) if m else -1
        runs.append((idx, d))
    runs.sort(key=lambda x: x[0])
    return [d for _, d in runs]


def read_spikeglx_run(run_dir: Path):
    """Read imec0.ap and nidq streams from a single SpikeGLX run directory."""
    stem = run_dir.name  # e.g., ICMSstim_icmsstim_2_g0
    imec_dir = run_dir / f"{stem}_imec0"
    if not imec_dir.exists():
        raise FileNotFoundError(f"Missing imec folder: {imec_dir}")

    rec_ap = se.read_spikeglx(str(imec_dir), stream_id="imec0.ap")
    nidq   = se.read_spikeglx(str(run_dir),  stream_id="nidq")
    return rec_ap, nidq


def check_fs_compatible(recordings, tol=1e-9, label=""):
    fs0 = recordings[0].sampling_frequency
    for i, r in enumerate(recordings[1:], start=1):
        if abs(r.sampling_frequency - fs0) > tol:
            raise ValueError(f"[{label}] Sampling rate mismatch: {fs0} vs {r.sampling_frequency} (run idx {i})")


def subset_channels(rec, channel_ids):
    """
    Version-agnostic channel selection on a Recording:
      1) recording.select_channels(...)   (newer SI)
      2) recording.channel_slice(...)     (older SI)
      3) ChannelSliceRecording(...)       (lowest-level fallback)
    """
    if hasattr(rec, "select_channels"):
        return rec.select_channels(channel_ids=channel_ids)
    if hasattr(rec, "channel_slice"):
        return rec.channel_slice(channel_ids=channel_ids)
    try:
        from spikeinterface import ChannelSliceRecording  # fallback (older core class)
        return ChannelSliceRecording(rec, channel_ids=channel_ids)
    except Exception as e:
        raise AttributeError(
            "Could not subset channels with select_channels/channel_slice/ChannelSliceRecording. "
            "Please update SpikeInterface."
        ) from e


def align_to_common_channels(recordings, label=""):
    """Keep the intersection across runs (ordered as in the first run), then subset every run to that list."""
    check_fs_compatible(recordings, label=label)

    ref_ids = np.asarray(recordings[0].get_channel_ids())
    common = set(ref_ids.tolist())
    for r in recordings[1:]:
        common &= set(np.asarray(r.get_channel_ids()).tolist())

    if not common:
        raise ValueError(f"[{label}] No common channels across runs. Check per-run channel maps.")

    common_ids = [cid for cid in ref_ids if cid in common]
    if len(common_ids) < len(ref_ids):
        print(f"[warn/{label}] Reducing to {len(common_ids)} common channels (from {len(ref_ids)}).")

    aligned = [subset_channels(r, common_ids) for r in recordings]
    return aligned, common_ids


def concatenate_with_alignment(recordings, label=""):
    aligned, common_ids = align_to_common_channels(recordings, label=label)
    rec_concat = si.concatenate_recordings(aligned)  # single long segment (time-concatenated)
    return rec_concat, common_ids


# =========================
# MAIN
# =========================
def main():
    runs = list_runs(BASE_DIR, SESSION_GLOB)
    if not runs:
        raise RuntimeError(f"No runs matching '{SESSION_GLOB}' in {BASE_DIR}")

    print("Found runs (in order):")
    for d in runs:
        print("  -", d)

    # Read all runs
    ap_list, nidq_list = [], []
    for rd in runs:
        rec_ap, nidq = read_spikeglx_run(rd)
        ap_list.append(rec_ap)
        nidq_list.append(nidq)

    # Align channels and concatenate in time
    rec_ap_concat, ap_common   = concatenate_with_alignment(ap_list,  label="AP")
    nidq_concat, nidq_common   = concatenate_with_alignment(nidq_list, label="NiDq")

    # # Save to Zarr via the instance method .save(...)
    # TARGET_DIR.mkdir(parents=True, exist_ok=True)
    # ap_out   = TARGET_DIR / "ap_zarr"
    # nidq_out = TARGET_DIR / "nidq_zarr"

    # print("\nSaving AP (zarr)...")
    # rec_ap_saved = rec_ap_concat.save(folder=str(ap_out), format="zarr", **JOB_KW)

    # print("Saving NiDq (zarr)...")
    # nidq_saved   = nidq_concat.save(folder=str(nidq_out), format="zarr", **JOB_KW)

    # Save to binary via the instance method .save(...)
    TARGET_DIR.mkdir(parents=True, exist_ok=True)
    ap_out   = TARGET_DIR / "ap_binary"
    nidq_out = TARGET_DIR / "nidq_binary"

    print("\nSaving AP (binary)...")
    # pick a dtype; int16 keeps files small and matches SpikeGLX raw
    rec_ap_saved = rec_ap_concat.save(folder=str(ap_out), format="binary",
                                    dtype="int16", **JOB_KW)

    print("Saving NiDq (binary)...")
    nidq_saved   = nidq_concat.save(folder=str(nidq_out), format="binary",
                                    dtype="int16", **JOB_KW)

    # Info
    print("\n=== SUMMARY ===")
    print(f"AP  -> {ap_out}")
    print(f"   fs={rec_ap_saved.sampling_frequency} Hz | nch={rec_ap_saved.get_num_channels()} | "
          f"duration={rec_ap_saved.get_total_duration():.2f} s | common_channels={len(ap_common)}")
    print(f"NiDq-> {nidq_out}")
    print(f"   fs={nidq_saved.sampling_frequency} Hz | nch={nidq_saved.get_num_channels()} | "
          f"duration={nidq_saved.get_total_duration():.2f} s | common_channels={len(nidq_common)}")

    # print("\nRe-open later with:")
    # print(f"  rec_ap = si.read_zarr(r'{ap_out}')")
    # print(f"  nidq   = si.read_zarr(r'{nidq_out}')")

    print("\nRe-open later with:")
    print(f"  rec_ap = spikeinterface.extractors.read_binary_folder(r'{ap_out}')")
    print(f"  nidq   = spikeinterface.extractors.read_binary_folder(r'{nidq_out}')")


if __name__ == "__main__":
    main()


Found runs (in order):
  - C:\storage\Reza\rat data\08-08-2025\ICMSstim_icmsstim_2_g0
  - C:\storage\Reza\rat data\08-08-2025\ICMSstim_icmsstim_3_g0
  - C:\storage\Reza\rat data\08-08-2025\ICMSstim_icmsstim_4_g0
  - C:\storage\Reza\rat data\08-08-2025\ICMSstim_icmsstim_5_g0
  - C:\storage\Reza\rat data\08-08-2025\ICMSstim_icmsstim_7_g0
  - C:\storage\Reza\rat data\08-08-2025\ICMSstim_icmsstim_9_g0
  - C:\storage\Reza\rat data\08-08-2025\ICMSstim_icmsstim_10_g0
  - C:\storage\Reza\rat data\08-08-2025\ICMSstim_icmsstim_12_g0
  - C:\storage\Reza\rat data\08-08-2025\ICMSstim_icmsstim_13_g0
  - C:\storage\Reza\rat data\08-08-2025\ICMSstim_icmsstim_14_g0

Saving AP (binary)...
write_binary_recording 
engine=process - n_jobs=4 - samples_per_chunk=150,000 - chunk_memory=109.86 MiB - total_memory=439.45 MiB - chunk_duration=5.00s


write_binary_recording (workers: 4 processes):   0%|          | 0/256 [00:00<?, ?it/s]

Saving NiDq (binary)...
write_binary_recording 
engine=process - n_jobs=4 - samples_per_chunk=52,966 - chunk_memory=931.04 KiB - total_memory=3.64 MiB - chunk_duration=5.00s


write_binary_recording (workers: 4 processes):   0%|          | 0/256 [00:00<?, ?it/s]


=== SUMMARY ===
AP  -> C:\storage\Reza\rat data\08-08-2025\ICMSstim_icmsstim_All_Concat_g0\ap_binary
   fs=30000.0 Hz | nch=384 | duration=1278.11 s | common_channels=384
NiDq-> C:\storage\Reza\rat data\08-08-2025\ICMSstim_icmsstim_All_Concat_g0\nidq_binary
   fs=10593.220339 Hz | nch=9 | duration=1278.11 s | common_channels=9

Re-open later with:
  rec_ap = spikeinterface.extractors.read_binary_folder(r'C:\storage\Reza\rat data\08-08-2025\ICMSstim_icmsstim_All_Concat_g0\ap_binary')
  nidq   = spikeinterface.extractors.read_binary_folder(r'C:\storage\Reza\rat data\08-08-2025\ICMSstim_icmsstim_All_Concat_g0\nidq_binary')
