In [None]:
import matplotlib
matplotlib.use('Agg') # disable interactive matplotlib to save RAM

import glob
import os 
import shutil
import spikeinterface.core as sc
import spikeinterface.curation as scu
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import sys 

from tqdm.auto import tqdm

sys.path.append('src')

from utils import memory_limit, sorter_parameters, ms_before, ms_after, preprocess_recording, plot_unit, blackrock_channel_indices

data_root = 'data'
nsx = 'ns4'
si_folder = f'{data_root}{os.sep}spikeinterface-0_98_2{os.sep}LongTerm-{nsx}'

In [None]:
recording_paths = sorted(glob.glob(f'{si_folder}{os.sep}**{os.sep}raw', recursive=True))

for recording_path in (pbar := tqdm(recording_paths)):
    recording_stem = os.sep.join(recording_path.split(os.sep)[:-1])
    pbar.set_description(recording_path)

    recording_si_path = f'{recording_stem}{os.sep}processed'
    sorting_si_path = f'{recording_stem}{os.sep}sorting'
    waveforms_si_path = f'{recording_stem}{os.sep}waveforms'
    units_si_path = f'{recording_stem}{os.sep}units'

    if not os.path.isfile(f'{recording_si_path}{os.sep}binary.json'):
        shutil.rmtree(recording_si_path, ignore_errors=True)
        recording = sc.load_extractor(recording_path)    
        recording = preprocess_recording(recording)
        recording.save(folder=recording_si_path, memory=memory_limit)
    pbar.set_description(f'{recording_stem} preprocessed')

    if not os.path.isfile(f'{sorting_si_path}{os.sep}sorter_output{os.sep}firings.npz'):
        recording_processed = sc.load_extractor(recording_si_path)
        ss.run_sorter(
            sorter_name='mountainsort4',
            recording=recording_processed,
            output_folder = sorting_si_path,
            remove_existing_folder=True,
            with_output=True,
            **sorter_parameters,
        )
    pbar.set_description(f'{recording_stem} sorted')

    if not os.path.isfile(f'{waveforms_si_path}{os.sep}templates_average.npy'):
        recording_processed = sc.load_extractor(recording_si_path)
        sorting = se.NpzSortingExtractor(f'{sorting_si_path}{os.sep}sorter_output{os.sep}firings.npz')
        # spikeinterface https://github.com/SpikeInterface/spikeinterface/pull/1378
        sorting = scu.remove_excess_spikes(sorting, recording_processed)
        sc.extract_waveforms(
            recording_processed, sorting, 
            folder=waveforms_si_path,
            ms_before=ms_before, ms_after=ms_after, max_spikes_per_unit=None,
            return_scaled=False,
            overwrite=True,
            use_relative_path=True,
        )
    pbar.set_description(f'{recording_stem} extracted')


    recording_processed = sc.load_extractor(recording_si_path)

    sorting = se.NpzSortingExtractor(f'{sorting_si_path}{os.sep}sorter_output{os.sep}firings.npz')
    # spikeinterface https://github.com/SpikeInterface/spikeinterface/pull/1378
    sorting = scu.remove_excess_spikes(sorting, recording_processed)

    waveform_extractor = sc.load_waveforms(
        folder=waveforms_si_path, with_recording=True, sorting=sorting
    )
    extremum_channels = sc.get_template_extremum_channel(waveform_extractor, peak_sign='neg')

    os.makedirs(units_si_path, exist_ok=True)
    for unit_id in sorting.unit_ids:
        pbar.set_description(f'{recording_stem} Plotting [unit {unit_id} / {len(sorting.unit_ids)}]')
        unit_plot_file = f'{units_si_path}{os.sep}{unit_id}.png'
        if not os.path.isfile(unit_plot_file):
            plot_unit(waveform_extractor, extremum_channels, sorting, unit_id, blackrock_channel_indices, unit_plot_file)