In [16]:
# CUDAメモリ断片化対策（OOM対策・torchより前に設定）
import os
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

from pathlib import Path
import numpy as np
import pprint
import json
import torch
import gc
import traceback
import time
from scipy.io import savemat
from datetime import datetime
from myTools.output_converter import convert_to_matlab_files
from kilosort.run_kilosort import load_sorting
from probeinterface.io import write_probeinterface

from myTools.read_spikeglx import get_exp_path, read_spikeglx_meta
from myTools.set_gain_and_offset import set_gain_and_offset
from myTools.init_run import get_recording, get_probe, get_probe_sorted, set_probe_info, get_stimtime

import spikeinterface.full as si
# import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
# import spikeinterface.postprocessing as spost
# import spikeinterface.qualitymetrics as sqm
# import spikeinterface.comparison as sc
import spikeinterface.exporters as sexp
# import spikeinterface.curation as scur
# import spikeinterface.widgets as sw

from kilosort.plots import plot_drift_amount, plot_drift_scatter, plot_diagnostics, plot_spike_positions


In [17]:

### Select experiment ###
dir_info = {
    "root_dir": r"C:\Users\tanaka-users\NeuronData",
    "name": "ge6w2",
    "ep": "005",
    "run": "002",
    "ng": "0",
    "nt": "0",
}

dict_path = get_exp_path(dir_info)

### Setting sorters ###
do_preprocess = False
do_runsort = True
do_export_phy = True
sorter = "dartsort"

# DARTsortのデフォルトパラメータを取得
sort_params = ss.get_default_sorter_params("dartsort")

name_thisparam = "setting13"

sort_params["featurization_radius_um"] = 60

# GMM refinement のメモリ削減（core_features は全スパイク分確保されるため core_radius が重要）
# sort_params["gmm_max_spikes"] = 250_000   # 300k で OOM なら 150k にさらに下げる
# sort_params["gmm_val_proportion"] = 0.1   # デフォルト: 0.25 → 0.1
# TVI/processor の Coo_invsqrt 等は近傍サイズに比例するため、OOM 時は core_radius を下げる
# sort_params["core_radius"] = 15           # デフォルト: "extract" → 12（processor update PCA の GPU OOM 対策）
# sort_params["n_refinement_iters"] = 1    # 1にするとsplitをスキップしGPU OOMを回避
# sort_params["later_steps"] = "merge"     # splitを行わずmergeのみ（メモリ削減）

# name_thisparam = "param2"
# sort_params["gmm_max_spikes"] = 500_000 # デフォルト: 2000000 → 50000に大幅削減（refine用）
# sort_params["gmm_val_proportion"] = 0.1  # デフォルト: 0.25 → 0.1に削減（メモリ削減）

In [18]:
### Fetch meta and bin files ###
meta_ap = read_spikeglx_meta(dict_path["ap"]["meta"])
meta_lf = read_spikeglx_meta(dict_path["lf"]["meta"])
meta_obx = read_spikeglx_meta(dict_path["obx"]["meta"])

recording, sync_recording = get_recording(meta_ap, dict_path["ap"]["bin"])
probe = get_probe(meta_ap)
probe = set_probe_info(probe, meta_ap)
recording = recording.set_probe(probe)
probe = get_probe_sorted(recording, probe)
recording = recording.set_probe(probe)
recording = set_gain_and_offset(meta_ap, recording)


In [19]:
### Preprocess recording ###
if do_preprocess:

    print("\n" + "="*20 + " parameter " + "="*20)
    print(f"\n{'='*5} {sorter} {'='*5}")
    pprint.pprint(sort_params)

    params_file = dict_path["exp"] / sorter / name_thisparam / "params.txt"
    params_file.parent.mkdir(parents=True, exist_ok=True)
    with open(str(params_file), "w", encoding="utf-8") as f:
        json.dump(sort_params, f, indent=4, ensure_ascii=False)

    print("="*5, sorter, "="*5)
    folder = dict_path["exp"] / sorter / name_thisparam
    pp_rec_folder = folder / "pp_rec"

    # DARTsort用前処理（NaN/Inf を避けるためmean common referenceを使用）
    print(f"  バンドパスフィルタ適用中...")
    recording_f = spre.bandpass_filter(recording, freq_min=300, freq_max=3000)
    
    # medianではなくmean common referenceを使用（NaN生成を回避）
    # medianは同じ値が続く場合や極端な値がある場合にNaNを生成する可能性がある
    print(f"  Common reference適用中... (mean)")
    try:
        recording_cmr = spre.common_reference(recording_f, reference="global", operator="mean")
    except Exception as e:
        print(f"  Common reference (mean) でエラー: {e}")
        print(f"  フォールバック: median common referenceを使用")
        recording_cmr = spre.common_reference(recording_f, reference="global", operator="median")
    
    # 前処理recordingを保存（NaN/Infチェックは後で実行）
    recording_preprocessed = recording_cmr.save(format="binary", folder=pp_rec_folder, overwrite=True)
else:
    print("skip preprocessing.")


skip preprocessing.


In [20]:
### Run sorting ###
if do_runsort:
    print(f"{'='*5} {sorter} {'='*5}")
    print(f"  {sorter}実行開始: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    try:
        start_time = time.time()
        
        sorter_output_dir = dict_path["exp"] / sorter / name_thisparam / "sorting"
        sorting = ss.run_sorter(
            sorter_name=sorter,
            folder=sorter_output_dir, 
            remove_existing_folder=True, 
            recording=recording,
            verbose=True,
            **sort_params
            )
            
        elapsed_time = time.time() - start_time
        print(f"  {sorter}実行完了: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} (経過時間: {elapsed_time/60:.1f}分)")

        print(f"  Analyzer作成開始: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        analyzer = si.create_sorting_analyzer(sorting=sorting, recording=recording, format='binary_folder', folder=dict_path["exp"] / sorter / name_thisparam / "analyzer", overwrite=True)
        print(analyzer)
        print("===== Sorting done =====")

    except Exception as e:
        print(f"Error occurred while running {sorter}: {e}")
        traceback.print_exc()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            gc.collect()
            print(f"  GPUメモリクリア完了 - 割り当て済み: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
else:
    print("skip sorting.")


===== dartsort =====
  dartsort実行開始: 2026-02-13 06:01:12
write_binary_recording 
engine=process - n_jobs=1 - samples_per_chunk=30,000 - chunk_memory=21.97 MiB - total_memory=21.97 MiB - chunk_duration=1.00s


write_binary_recording (no parallelization):   0%|          | 0/523 [00:00<?, ?it/s]

Using DARTsortUserConfig for DARTsort
Running DARTsort with parameters: {'dredge_only': False, 'matching_iterations': 1, 'n_jobs_cpu': 0, 'n_jobs_gpu': 0, 'device': None, 'executor': 'threading_unless_multigpu', 'chunk_length_samples': 30000, 'work_in_tmpdir': False, 'copy_recording_to_tmpdir': False, 'workdir_copier': 'shutil', 'workdir_follow_symlinks': False, 'tmpdir_parent': None, 'save_intermediates': False, 'save_final_features': True, 'ms_before': 1.4, 'ms_after': 2.6333333333333333, 'alignment_ms': 1.5, 'peak_sign': 'both', 'voltage_threshold': 4.0, 'matching_threshold': 10.0, 'initial_threshold': 12.0, 'temporal_pca_rank': 8, 'feature_ms_before': 0.75, 'feature_ms_after': 1.25, 'subtraction_radius_um': 200.0, 'deduplication_radius_um': 100.0, 'featurization_radius_um': 60, 'fit_radius_um': 75.0, 'localization_radius_um': 100.0, 'amplitude_scaling_stddev': 0.01, 'amplitude_scaling_boundary': 0.333, 'temporal_upsamples': 8, 'do_motion_estimation': True, 'rigid': False, 'probe_bo

Load examples for denoiser fitting:cuda 1.0s/it [spk/it=%%%]:   0%|          | 0/100 [00:00<?, ?it/s]

Got 53893 spikes, enough to stop early.


Load examples for feature fitting:cuda 1.0s/it [spk/it=%%%]:   0%|          | 0/100 [00:00<?, ?it/s]

  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i

Got 50239 spikes, enough to stop early.


Train localizer:   0%|          | 0/100 [00:00<?, ?epoch/s]

Subtraction:cuda 1.0s/it [spk/it=%%%]:   0%|          | 0/523 [00:00<?, ?it/s]

  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i0:i1])
  torch.mm(waveforms_in_probe, proj, out=out[i

Cross correlation:   0%|          | 0/1 [00:00<?, ?it/s]

Interpolating /collisioncleaned_tpca_features:   0%|          | 0/267 [00:00<?, ?it/s]

KDTdens[5]:   0%|          | 0/134 [00:00<?, ?it/s]

Interpolating /collisioncleaned_tpca_features:   0%|          | 0/267 [00:00<?, ?it/s]

  0%|          | 0/5746 [00:00<?, ?it/s]

Interpolating /collisioncleaned_tpca_features:   0%|          | 0/267 [00:00<?, ?it/s]

Interpolating /collisioncleaned_tpca_features:   0%|          | 0/267 [00:00<?, ?it/s]

  _log_warn_or_raise_coverage(uncovered_adj, neighborhood_ids, n_steps, adj)


EM:   0%|          | 0/250 [00:00<?, ?it/s]

Split:   0%|          | 0/756 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
### export to phy ###
if do_export_phy:
    analyzer = si.load_sorting_analyzer(folder= dict_path["exp"] / sorter / name_thisparam / "analyzer")
    extensions = {
        'random_spikes': {},
        'waveforms': {},
        'templates': {},
        'spike_locations': {},
        'principal_components': {},
    }
    analyzer.compute(extensions)
    # PC特徴の計算が時間かかってそうなので、Falseにしておく
    sexp.export_to_phy(
        sorting_analyzer=analyzer, 
        output_folder=dict_path["exp"] / sorter / name_thisparam / "phy", 
        remove_if_exists=True,
        # compute_pc_features=False
        )
else:
    print("skip export to phy.")


AssertionError: This folder does not exists C:\Users\tanaka-users\NeuronData\ge6w2\ge6w2_ep005_002\dartsort\setting13\analyzer

In [None]:
### convert to matlab files ###
convert_to_matlab_files(res_dir=dict_path["exp"] / sorter / name_thisparam, recording=recording)

### export to probe.json ###
write_probeinterface(dict_path["exp"] / sorter / name_thisparam / "probe.json", probe_or_probegroup=probe)

stim_times_rise, stim_times_fall, fs_obx = get_stimtime(meta_obx, dict_path["obx"]["bin"])
savemat(dict_path["exp"] / sorter / name_thisparam / 'matlab' / "stim_times.mat", {'stim_times_rise': stim_times_rise, 'stim_times_fall': stim_times_fall, 'fs': fs_obx})


Error loading C:\Users\tanaka-users\NeuronData\ge6w2\ge6w2_ep005_002\dartsort\setting12\phy\cluster_label.tsv: [Errno 2] No such file or directory: 'C:\\Users\\tanaka-users\\NeuronData\\ge6w2\\ge6w2_ep005_002\\dartsort\\setting12\\phy\\cluster_label.tsv'
Error loading C:\Users\tanaka-users\NeuronData\ge6w2\ge6w2_ep005_002\dartsort\setting12\phy\cluster_contamPct.tsv: [Errno 2] No such file or directory: 'C:\\Users\\tanaka-users\\NeuronData\\ge6w2\\ge6w2_ep005_002\\dartsort\\setting12\\phy\\cluster_contamPct.tsv'
Error loading C:\Users\tanaka-users\NeuronData\ge6w2\ge6w2_ep005_002\dartsort\setting12\phy\cluster_amplitude.tsv: [Errno 2] No such file or directory: 'C:\\Users\\tanaka-users\\NeuronData\\ge6w2\\ge6w2_ep005_002\\dartsort\\setting12\\phy\\cluster_amplitude.tsv'
Error loading C:\Users\tanaka-users\NeuronData\ge6w2\ge6w2_ep005_002\dartsort\setting12\phy\kept_spikes.npy: [Errno 2] No such file or directory: 'C:\\Users\\tanaka-users\\NeuronData\\ge6w2\\ge6w2_ep005_002\\dartsort\\s

In [None]:
# from dartsort.util.data_util import DARTsortSorting

# # .npzファイルからロード（HDF5も自動的に読み込まれる）
# sorting = DARTsortSorting.load(
#     r"C:\Users\tanaka-users\NeuronData\ge6w2\ge6w2_ep005_002\DARTsort\setting3\sorting\sorter_output\dartsort_sorting.npz"
# )

# # データにアクセス
# print(f"スパイク数: {sorting.n_spikes}")
# print(f"ユニット数: {sorting.n_units}")
# print(f"サンプリング周波数: {sorting.sampling_frequency}")

# # 基本データ
# times_samples = sorting.times_samples
# channels = sorting.channels
# labels = sorting.labels

# # 追加特徴量（HDF5から読み込まれたもの）
# if sorting.extra_features:
#     print(f"利用可能な特徴量: {list(sorting.extra_features.keys())}")
#     print(sorting.geom)

In [None]:
# ### export to matlab files for ks4 ###
# conv_ks4_mat(res_dir=dict_path["exp"] / sorter / name_thisparam / "sorting" / "sorter_output", recording=recording)

# ### export to probe.json ###
# write_probeinterface(dict_path["exp"] / sorter / name_thisparam / "probe.json", probe_or_probegroup=probe)

# ### plot drift amount, scatter, diagnostics, spike positions ###
# ks_dir = dict_path["exp"] / sorter / name_thisparam / "sorting" / "sorter_output"
# ops, st, clu, similar_templates, \
#     is_ref, est_contam_rate, kept_spikes, \
#         tF, Wall, full_st, full_clu, full_amp = \
#             load_sorting(ks_dir, device="cuda", load_extra_vars=True)

# plot_drift_amount(ops, ks_dir)
# plot_drift_scatter(full_st, ks_dir)
# plot_diagnostics(Wall, full_clu, ops, ks_dir)
# plot_spike_positions(clu, is_ref, ks_dir)


# stim_times_rise, stim_times_fall, fs_obx = get_stimtime(meta_obx, dict_path["obx"]["bin"])
# savemat(dict_path["exp"] / sorter / name_thisparam / "stim_times.mat", {'stim_times_rise': stim_times_rise, 'stim_times_fall': stim_times_fall, 'fs': fs_obx})
