In [1]:
from spikegadget2nwb.read_spikegadget import get_ephys_folder
import spikeinterface.extractors as se
import os
import numpy as np
import spikeinterface.preprocessing as spre
import time
import spikeinterface as si
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
subject_id = "CnL14"
exp_date = "20241004"
exp_time = "162638"
session_description = subject_id + '_' + exp_date + '_' + exp_time + '.rec'
# ephys_folder = Path(r"D:\cl\rf_reconstruction\head_fixed")
ephys_folder = Path(r"D:\cl\rf_reconstruction\freelymoving")
folder = ephys_folder / session_description

nwb_file = folder / (session_description + '.nwb')
rec = se.NwbRecordingExtractor(nwb_file)
rec

In [3]:
def get_bad_ch_id(rec, folder, load_if_exists=True):
    if load_if_exists and os.path.exists(folder / 'bad_ch_id.npy'):
        bad_ch_id = np.load(folder / 'bad_ch_id.npy')
    else:
        bad_ch_id, _ = spre.detect_bad_channels(
            rec, num_random_chunks=400, n_neighbors=5, dead_channel_threshold=-0.2)

        np.save(folder / 'bad_ch_id.npy', bad_ch_id)

    print('Bad channel IDs:', bad_ch_id)
    return bad_ch_id

In [4]:
rec_filtered = spre.bandpass_filter(rec, freq_min=300, freq_max=6000)
bad_ch_id = get_bad_ch_id(rec_filtered, folder)
remaining_ch = np.array([ch for ch in rec.get_channel_ids() if ch not in bad_ch_id])

# export remaining channel ids to a npy file
np.save(os.path.join(folder, 'remaining_ch.npy'), remaining_ch)
remaining_ch

Bad channel IDs: [  0   2   3   8  11  16  20  24  27  36  39  40  43  47  48  59  63  67
  83  87 103 111 115 118 119 121 122 125 126 127]


array([  1,   4,   5,   6,   7,   9,  10,  12,  13,  14,  15,  17,  18,
        19,  21,  22,  23,  25,  26,  28,  29,  30,  31,  32,  33,  34,
        35,  37,  38,  41,  42,  44,  45,  46,  49,  50,  51,  52,  53,
        54,  55,  56,  57,  58,  60,  61,  62,  64,  65,  66,  68,  69,
        70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,
        84,  85,  86,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
        98,  99, 100, 101, 102, 104, 105, 106, 107, 108, 109, 110, 112,
       113, 114, 116, 117, 120, 123, 124])

In [5]:
threshold = 7
chunk_size = 900
n_timepoints = rec_filtered.get_num_frames()
n_channels = rec_filtered.get_num_channels()
num_chunks = int(np.ceil(n_timepoints / chunk_size))

# load artifact indices if exists
if os.path.exists(folder / 'artifact_indices.npy'):
    artifact_indices = np.load(folder / 'artifact_indices.npy')
else:
# mask artifacts
    norms = np.zeros((num_chunks, n_channels))
    for i in range(num_chunks):
        start = int(i * chunk_size)
        end = int(np.minimum((i + 1) * chunk_size, n_timepoints))
        chunk = rec_filtered.get_traces(start_frame=start, end_frame=end, return_scaled=True)

        norms[i] = np.linalg.norm(chunk, axis=0)

    
    use_it = np.ones(num_chunks, dtype=bool)
# if detect artifacts in a chunk, don't use it and the two neighboring chunks

    for m in range(n_channels):
        if m in bad_ch_id:
            continue
        vals = norms[:, m]

        sigma0 = np.std(vals)
        mean0 = np.mean(vals)

        artifact_indices = np.where(vals > mean0 + threshold * sigma0)[0]

        # check if the first chunk is above threshold, ensure that we don't use negative indices later
        negIndBool = np.where(artifact_indices > 0)[0]

        # check if the last chunk is above threshold to avoid a IndexError
        maxIndBool = np.where(artifact_indices < num_chunks - 1)[0]

        use_it[artifact_indices] = 0
        use_it[artifact_indices[negIndBool] - 1] = 0  # don't use the neighbor chunks either
        use_it[artifact_indices[maxIndBool] + 1] = 0  # don't use the neighbor chunks either

        print("For channel %d: mean=%.2f, stdev=%.2f, chunk size = %d, n_artifacts = %d" % (m, mean0, sigma0, chunk_size, len(artifact_indices)))


    artifact_indices = np.where(use_it == 0)[0]
    artifact_indices = artifact_indices * chunk_size
    # save artifact indices
    np.save(folder / 'artifact_indices.npy', artifact_indices)


In [6]:
chunk_time = chunk_size / rec.get_sampling_frequency()*1000

if artifact_indices.size > 0:
    rec_rm_artifacts = spre.remove_artifacts(rec_filtered, list_triggers=artifact_indices, ms_before=0, ms_after=chunk_time)

else:
    rec_rm_artifacts = rec_filtered


rec_clean = rec_rm_artifacts.channel_slice(remaining_ch)
rec_ref = spre.common_reference(rec_clean, reference='global', operator='average')

In [7]:
import mountainsort5 as ms5
import json
experiment_length = rec_ref.get_duration() / 60  # in minutes

recording_whitened = spre.whiten(rec_ref, dtype='float32')

threshold = 5.5
phase1_detect_time_radius_msec = .4

if experiment_length < 5:
    sorting_params = ms5.Scheme1SortingParameters(
        detect_time_radius_msec=phase1_detect_time_radius_msec, detect_threshold=threshold)
    sorting = ms5.sorting_scheme1(
        recording=recording_whitened, sorting_parameters=sorting_params)

else:
    sorting_params = ms5.Scheme2SortingParameters(
        phase1_detect_threshold=threshold, detect_threshold=threshold,
        phase1_detect_channel_radius=100, detect_channel_radius=100, phase1_detect_time_radius_msec=phase1_detect_time_radius_msec, training_duration_sec=5*60,
        training_recording_sampling_mode='uniform')
    sorting = ms5.sorting_scheme2(
        recording=recording_whitened, sorting_parameters=sorting_params)

Using training recording of duration 300 sec with the sampling mode initial
*** MS5 Elapsed time for SCHEME2 get_sampled_recording_for_training: 410.668 seconds ***
Running phase 1 sorting
Number of channels: 98
Number of timepoints: 9000000
Sampling frequency: 30000.0 Hz
Channel 0: [350. 225.]
Channel 1: [350. 150.]
Channel 2: [325. 225.]
Channel 3: [325. 275.]
Channel 4: [325. 100.]
Channel 5: [325. 175.]
Channel 6: [325. 150.]
Channel 7: [350. 100.]
Channel 8: [350. 250.]
Channel 9: [350.  50.]
Channel 10: [325. 375.]
Channel 11: [350. 300.]
Channel 12: [350.  75.]
Channel 13: [325. 350.]
Channel 14: [350.  25.]
Channel 15: [350. 325.]
Channel 16: [  0. 250.]
Channel 17: [350.   0.]
Channel 18: [0. 0.]
Channel 19: [ 0. 25.]
Channel 20: [350. 275.]
Channel 21: [ 25. 225.]
Channel 22: [350. 200.]
Channel 23: [325. 125.]
Channel 24: [325. 200.]
Channel 25: [325.   0.]
Channel 26: [325.  25.]
Channel 27: [  0. 200.]
Channel 28: [325.  50.]
Channel 29: [  0. 175.]
Channel 30: [ 25. 375.]

In [8]:
current_time = time.strftime("%Y%m%d_%H%M", time.localtime())
folder_name = 'sorting_results_' + current_time
sort_out_folder = folder / folder_name
if not os.path.exists(sort_out_folder):
    os.makedirs(sort_out_folder)

# write a into json file: sorting_params.json
with open(sort_out_folder / 'sorting_params.json', 'w') as f:
    json.dump(sorting_params.__dict__, f)

In [9]:
print(f'unit number:{len(sorting.get_unit_ids())}')

unit number:1


In [10]:
sorting.register_recording(rec_ref)
sorting.save(folder = os.path.join(sort_out_folder, 'sorting'))



## export sorting result to phy

In [None]:
from spikeinterface import create_sorting_analyzer
from spikeinterface.exporters import export_to_phy


sorting_analyzer_folder = sort_out_folder / 'sorting_analyzer' 

if not os.path.exists(sorting_analyzer_folder):
    sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=rec_ref, format='memory',)
    
print(sorting_analyzer)
sorting_analyzer.compute("random_spikes")
# sorting_analyzer.compute("waveforms", ms_before=2.0, ms_after=2.0)
sorting_analyzer.compute(["templates"])

phy_folder = sort_out_folder / 'phy'
if not phy_folder.exists():
    phy_folder.mkdir()
export_to_phy(
    sorting_analyzer,
    phy_folder,
    verbose=True,
    remove_if_exists=True,
)

estimate_sparsity:   0%|          | 0/1016 [00:00<?, ?it/s]

estimate_sparsity: 100%|##########| 1016/1016 [00:02<00:00, 359.14it/s]


SortingAnalyzer: 98 channels - 1 units - 1 segments - memory - sparse - has recording
Loaded 0 extensions: 


estimate_templates_with_accumulator: 100%|##########| 1016/1016 [00:36<00:00, 27.84it/s]
write_binary_recording:  74%|#######3  | 750/1016 [12:54<04:28,  1.01s/it]

write_binary_recording: 100%|##########| 1016/1016 [17:02<00:00,  1.01s/it]
extract PCs: 100%|##########| 1016/1016 [43:08<00:00,  2.55s/it]

Run:
phy template-gui  D:\cl\rf_reconstruction\head_fixed\CnL14_20241004_153555.rec\sorting_results_20241021_0816\phy\params.py





In [17]:
# convert channel position to 2d
file_path = r"D:\cl\rf_reconstruction\head_fixed\CnL14_20241004_153555.rec\sorting_results_20241021_0816\phy\channel_positions.npy"
channel_positions = np.load(file_path)
channel_positions = channel_positions[:, :2]
np.save(file_path, channel_positions)