In [2]:
from spikegadget2nwb.read_spikegadget import get_ephys_folder
import spikeinterface.extractors as se
import os
import numpy as np
import spikeinterface.preprocessing as spre
import time
import spikeinterface as si
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
subject_id = "CnL14"
exp_date = "20241004"
exp_time = "153555"
session_description = subject_id + '_' + exp_date + '_' + exp_time + '.rec'
ephys_folder = Path(r"D:\cl\rf_reconstruction\head_fixed")
# ephys_folder = Path(r"D:\cl\rf_reconstruction\freelymoving")
# ephys_folder = get_ephys_folder()
folder = ephys_folder / session_description

nwb_file = folder / (session_description + '.nwb')
rec = se.NwbRecordingExtractor(nwb_file)
rec

In [4]:
def get_bad_ch_id(rec, folder, load_if_exists=True):
    if load_if_exists and os.path.exists(folder / 'bad_ch_id.npy'):
        bad_ch_id = np.load(folder / 'bad_ch_id.npy')
    else:
        bad_ch_id, _ = spre.detect_bad_channels(
            rec, num_random_chunks=400, n_neighbors=5, dead_channel_threshold=-0.2)

        np.save(folder / 'bad_ch_id.npy', bad_ch_id)

    print('Bad channel IDs:', bad_ch_id)
    return bad_ch_id

In [5]:
rec_filtered = spre.bandpass_filter(rec, freq_min=300, freq_max=6000)
bad_ch_id = get_bad_ch_id(rec_filtered, folder)
remaining_ch = np.array([ch for ch in rec.get_channel_ids() if ch not in bad_ch_id])

# export remaining channel ids to a npy file
np.save(os.path.join(folder, 'remaining_ch.npy'), remaining_ch)
remaining_ch

Bad channel IDs: [  0   2   3   8  11  16  20  24  27  36  39  40  43  47  48  59  63  67
  83  87 103 111 115 118 119 121 122 125 126 127]


array([  1,   4,   5,   6,   7,   9,  10,  12,  13,  14,  15,  17,  18,
        19,  21,  22,  23,  25,  26,  28,  29,  30,  31,  32,  33,  34,
        35,  37,  38,  41,  42,  44,  45,  46,  49,  50,  51,  52,  53,
        54,  55,  56,  57,  58,  60,  61,  62,  64,  65,  66,  68,  69,
        70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,
        84,  85,  86,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
        98,  99, 100, 101, 102, 104, 105, 106, 107, 108, 109, 110, 112,
       113, 114, 116, 117, 120, 123, 124])

In [6]:
threshold = 7
chunk_size = 900
n_timepoints = rec_filtered.get_num_frames()
n_channels = rec_filtered.get_num_channels()
num_chunks = int(np.ceil(n_timepoints / chunk_size))

# load artifact indices if exists
if os.path.exists(folder / 'artifact_indices.npy'):
    artifact_indices = np.load(folder / 'artifact_indices.npy')
else:
# mask artifacts
    norms = np.zeros((num_chunks, n_channels))
    for i in range(num_chunks):
        start = int(i * chunk_size)
        end = int(np.minimum((i + 1) * chunk_size, n_timepoints))
        chunk = rec_filtered.get_traces(start_frame=start, end_frame=end, return_scaled=True)

        norms[i] = np.linalg.norm(chunk, axis=0)

    
    use_it = np.ones(num_chunks, dtype=bool)
# if detect artifacts in a chunk, don't use it and the two neighboring chunks

    for m in range(n_channels):
        if m in bad_ch_id:
            continue
        vals = norms[:, m]

        sigma0 = np.std(vals)
        mean0 = np.mean(vals)

        artifact_indices = np.where(vals > mean0 + threshold * sigma0)[0]

        # check if the first chunk is above threshold, ensure that we don't use negative indices later
        negIndBool = np.where(artifact_indices > 0)[0]

        # check if the last chunk is above threshold to avoid a IndexError
        maxIndBool = np.where(artifact_indices < num_chunks - 1)[0]

        use_it[artifact_indices] = 0
        use_it[artifact_indices[negIndBool] - 1] = 0  # don't use the neighbor chunks either
        use_it[artifact_indices[maxIndBool] + 1] = 0  # don't use the neighbor chunks either

        print("For channel %d: mean=%.2f, stdev=%.2f, chunk size = %d, n_artifacts = %d" % (m, mean0, sigma0, chunk_size, len(artifact_indices)))


    artifact_indices = np.where(use_it == 0)[0]
    artifact_indices = artifact_indices * chunk_size
    # save artifact indices
    np.save(folder / 'artifact_indices.npy', artifact_indices)


In [7]:
chunk_time = chunk_size / rec.get_sampling_frequency()*1000

if artifact_indices.size > 0:
    rec_rm_artifacts = spre.remove_artifacts(rec_filtered, list_triggers=artifact_indices, ms_before=0, ms_after=chunk_time)

else:
    rec_rm_artifacts = rec_filtered


rec_clean = rec_rm_artifacts.channel_slice(remaining_ch)
rec_ref = spre.common_reference(rec_clean, reference='global', operator='average')

In [8]:
idx_shank0 = np.where(rec_ref.get_channel_groups()=='shank0')
sh0_ch = rec_ref.get_channel_ids()[idx_shank0]
rec_sh0 = rec_ref.channel_slice(sh0_ch)

In [9]:
rec_sh0

In [10]:
rec_sh0.get_traces(start_frame=0, end_frame=100000).shape

(100000, 23)

In [None]:
import mountainsort5 as ms5
import json
from tempfile import TemporaryDirectory
from mountainsort5.util import create_cached_recording

experiment_length = rec_sh0.get_duration() / 60  # in minutes
recording_whitened = spre.whiten(rec_sh0, dtype='float32')

threshold = 5.5
phase1_detect_time_radius_msec = .4

with TemporaryDirectory() as tmpdir:
    # recording_cached = create_cached_recording(recording_whitened, folder=tmpdir)
    recording_cached = recording_whitened


    if experiment_length < 25:
        sorting_params = ms5.Scheme1SortingParameters(
            detect_time_radius_msec=phase1_detect_time_radius_msec, detect_threshold=threshold, detect_channel_radius=80,
            )
        sorting = ms5.sorting_scheme1(
            recording_cached, sorting_parameters=sorting_params)
        
        assert isinstance(sorting, si.BaseSorting)
    else:
        sorting_params = ms5.Scheme2SortingParameters(
            phase1_detect_threshold=threshold, detect_threshold=threshold,
            phase1_detect_channel_radius=100, detect_channel_radius=100, phase1_detect_time_radius_msec=phase1_detect_time_radius_msec, training_duration_sec=5*60,
            training_recording_sampling_mode='uniform')
        sorting = ms5.sorting_scheme2(
            recording=recording_whitened, sorting_parameters=sorting_params)

Number of channels: 23
Number of timepoints: 30473680
Sampling frequency: 30000.0 Hz
Channel 0: [  0. 250.]
Channel 1: [0. 0.]
Channel 2: [ 0. 25.]
Channel 3: [ 25. 225.]
Channel 4: [  0. 200.]
Channel 5: [  0. 175.]
Channel 6: [ 25. 375.]
Channel 7: [  0. 375.]
Channel 8: [  0. 150.]
Channel 9: [ 25. 350.]
Channel 10: [  0. 350.]
Channel 11: [25. 75.]
Channel 12: [ 25. 325.]
Channel 13: [25. 50.]
Channel 14: [  0. 325.]
Channel 15: [25. 25.]
Channel 16: [ 25. 300.]
Channel 17: [25.  0.]
Channel 18: [  0. 300.]
Channel 19: [  0. 125.]
Channel 20: [  0. 100.]
Channel 21: [  0. 275.]
Channel 22: [ 0. 75.]
Loading traces


In [8]:
current_time = time.strftime("%Y%m%d_%H%M", time.localtime())
folder_name = 'sorting_results_' + current_time
sort_out_folder = folder / folder_name
if not os.path.exists(sort_out_folder):
    os.makedirs(sort_out_folder)

# write a into json file: sorting_params.json
with open(sort_out_folder / 'sorting_params.json', 'w') as f:
    json.dump(sorting_params.__dict__, f)

In [9]:
print(f'unit number:{len(sorting.get_unit_ids())}')

unit number:1


In [10]:
sorting.register_recording(rec_ref)
sorting.save(folder = os.path.join(sort_out_folder, 'sorting'))



## export sorting result to phy

In [23]:
from spikeinterface import create_sorting_analyzer
from spikeinterface.exporters import export_to_phy


sorting_analyzer_folder = sort_out_folder / 'sorting_analyzer' 

if not os.path.exists(sorting_analyzer_folder):
    sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=rec_ref, format='memory',)
    
print(sorting_analyzer)
sorting_analyzer.compute("random_spikes")
# sorting_analyzer.compute("waveforms", ms_before=2.0, ms_after=2.0)
sorting_analyzer.compute(["templates"])

phy_folder = sort_out_folder / 'phy'
if not phy_folder.exists():
    phy_folder.mkdir()
export_to_phy(
    sorting_analyzer,
    phy_folder,
    verbose=True,
    remove_if_exists=True,
)

estimate_sparsity:   0%|          | 0/1016 [00:00<?, ?it/s]

estimate_sparsity: 100%|##########| 1016/1016 [00:02<00:00, 359.14it/s]


SortingAnalyzer: 98 channels - 1 units - 1 segments - memory - sparse - has recording
Loaded 0 extensions: 


estimate_templates_with_accumulator: 100%|##########| 1016/1016 [00:36<00:00, 27.84it/s]
write_binary_recording: 100%|##########| 1016/1016 [17:19<00:00,  1.02s/it]
spike_amplitudes: 100%|##########| 1016/1016 [17:25<00:00,  1.03s/it]


AssertionError: Extension principal_components requires waveforms to be computed first

write_binary_recording: 100%|##########| 1016/1016 [17:02<00:00,  1.01s/it]
extract PCs: 100%|##########| 1016/1016 [43:08<00:00,  2.55s/it]

Run:
phy template-gui  D:\cl\rf_reconstruction\head_fixed\CnL14_20241004_153555.rec\sorting_results_20241021_0816\phy\params.py





In [24]:
# convert channel position to 2d
file_path = r"D:\cl\rf_reconstruction\head_fixed\CnL14_20241004_153555.rec\sorting_results_20241021_0816\phy\channel_positions.npy"
channel_positions = np.load(file_path)
channel_positions = channel_positions[:, :2]
np.save(file_path, channel_positions)