In [None]:
import matplotlib
matplotlib.use('Agg') # disable interactive matplotlib to save RAM

mice = ['1_5']
 #min
exclude_sessions = {
    # # Example
    # # '6_2': [
    # #     'data/spikeinterface-0_98_2/Behavior/6_2/20231208/session_1',
    # # ],
}
nsx = 'ns4'
experiments = [f'LongTerm-{nsx}']

import anndata as ad
import datetime
import glob
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import scanpy
import scipy.signal
import spikeinterface.core as sc
import spikeinterface.curation as scu
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import sys 

from tqdm.auto import tqdm

sys.path.append('src')

from utils import *

today = datetime.datetime.today().strftime('%Y%m%d')
sampling_frequency = 10 * n_ms_per_s
data_root = 'data'
si_folder = f'{data_root}{os.sep}spikeinterface-0_98_2'
sorted_folder = f'{data_root}{os.sep}sorted{os.sep}{today}'

print(f'Saving to {sorted_folder}')

In [None]:
recording_sessions = {}
for mouse in mice:
    recording_sessions[mouse] = []
    for experiment in experiments:
        mouse_experiment_essions = glob.glob(f'{si_folder}{os.sep}{experiment}{os.sep}{mouse}{os.sep}**{os.sep}**')
        mouse_experiment_essions = [(experiment, session) for session in mouse_experiment_essions if session not in exclude_sessions.get(mouse, [])]
        recording_sessions[mouse].extend(mouse_experiment_essions)
    recording_sessions[mouse] = sorted(recording_sessions[mouse])
recording_sessions

In [None]:
for mouse in (pbar := tqdm(mice)):
    mouse_sorted_folder = f'{sorted_folder}{os.sep}{mouse}'
    os.makedirs(mouse_sorted_folder, exist_ok=True)
    pbar.set_description(mouse)

    mouse_recording_si_path = f'{mouse_sorted_folder}{os.sep}processed'
    mouse_sorting_si_path = f'{mouse_sorted_folder}{os.sep}sorting-by-group'
    mouse_waveforms_si_path = f'{mouse_sorted_folder}{os.sep}waveforms-by-group'
    mouse_units_si_path = f'{mouse_sorted_folder}{os.sep}units-by-group'
    mouse_sessions_file = f'{mouse_sorted_folder}{os.sep}sessions.csv'

    if not os.path.isfile(mouse_sessions_file):
        mouse_traces, mouse_sessions = [], []
        cumulative_samples = 0
        for (experiment, session) in recording_sessions[mouse]:
            pbar.set_description(f'{mouse} -> {experiment} -> {session}')

            recording = sc.load_extractor(f'{session}{os.sep}raw')
            traces = recording.get_traces()
            recording_samping_frequency = recording.get_sampling_frequency()
            recording_samples = traces.shape[0]
            recording_duration = recording_samples / recording_samping_frequency / n_s_per_min
            if recording_duration < min_recording_duration:
                print(f'[duration {recording_duration:0.0f}min] discarding {session}')
                continue
            
            if recording_samping_frequency != sampling_frequency:
                new_recording_samples = int(traces.shape[0] / (recording_samping_frequency / sampling_frequency))
                traces = scipy.signal.resample(traces, new_recording_samples, axis=0)

            if experiment == 'Behavior':
                traces = traces[:, intan_channel_indices.flatten().argsort()][:, blackrock_channel_indices.flatten()]
                print(f'    [Converted to blackrock layout] {session}')
            elif experiment != f'LongTerm-{nsx}':
                raise Exception(f'Unrecognized experiment {experiment}')
            
            mouse_traces.append(traces)
            mouse_sessions.append({
                'mouse': mouse,
                'date': session.split(os.sep)[-2],
                'session': session,
                'session_start': cumulative_samples,
                'sampling_frequency': sampling_frequency,
                'session_length': traces.shape[0],
            })
            cumulative_samples += traces.shape[0] 
        mouse_traces = np.vstack(mouse_traces)

        recording = se.NumpyRecording(traces_list=mouse_traces, sampling_frequency=sampling_frequency)
        recording_processed = preprocess_recording(recording)

        probegroup = create_probegroup(blackrock_channel_indices, f'{mouse_sorted_folder}{os.sep}probe.png')
        recording_processed.set_probegroup(probegroup, in_place=True)

        shutil.rmtree(mouse_recording_si_path, ignore_errors=True)
        recording_processed.save(folder=mouse_recording_si_path, memory=memory_limit)

        mouse_sessions = pd.json_normalize(mouse_sessions)
        mouse_sessions.to_csv(mouse_sessions_file, index=False)
    pbar.set_description(f'{mouse} preprocessed')

    if not os.path.isfile(f'{mouse_sorting_si_path}{os.sep}sorter_output{os.sep}firings.npz'):
        recording_processed = sc.load_extractor(mouse_recording_si_path)
        sorting = ss.run_sorter_by_property(
            sorter_name='mountainsort4',
            recording=recording_processed,
            grouping_property='group',
            working_folder=mouse_sorting_si_path,
            mode_if_folder_exists='overwrite',
            **sorter_parameters,
        )
        os.makedirs(f'{mouse_sorting_si_path}{os.sep}sorter_output', exist_ok=True)
        se.NpzSortingExtractor.write_sorting(sorting, f'{mouse_sorting_si_path}{os.sep}sorter_output{os.sep}firings.npz')
    pbar.set_description(f'{mouse} sorted')

    if not os.path.isfile(f'{mouse_waveforms_si_path}{os.sep}templates_average.npy'):
        recording_processed = sc.load_extractor(mouse_recording_si_path)
        sorting = se.NpzSortingExtractor(f'{mouse_sorting_si_path}{os.sep}sorter_output{os.sep}firings.npz')
        sorting = scu.remove_excess_spikes(sorting, recording_processed) # spikeinterface https://github.com/SpikeInterface/spikeinterface/pull/1378
        sc.extract_waveforms(
            recording_processed, sorting, 
            folder=mouse_waveforms_si_path,
            ms_before=ms_before, ms_after=ms_after, max_spikes_per_unit=None,
            return_scaled=False,
            overwrite=True,
            use_relative_path=True,
        )
    pbar.set_description(f'{mouse} extracted')

    recording_processed, sorting, waveform_extractor, extremum_channels, mouse_sessions = read_sorted_results(mouse_sorted_folder, read_sessions=True)

    os.makedirs(mouse_units_si_path, exist_ok=True)
    for unit_id in sorting.unit_ids:
        pbar.set_description(f'{mouse} Plotting [unit {unit_id} / {len(sorting.unit_ids)}]')
        unit_plot_file = f'{mouse_units_si_path}{os.sep}{unit_id}.png'
        if not os.path.isfile(unit_plot_file):
            plot_unit(waveform_extractor, extremum_channels, sorting, unit_id, blackrock_channel_indices, initdate=surgery_dates[mouse], savepath=unit_plot_file, sessions=mouse_sessions)