In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
import spikeinterface.full as si
import matplotlib.pyplot as plt
# import IO_tools as io
from pathlib import Path
from pprint import pprint
from tools.settings import settings

In [2]:
from spikeinterface import __version__ as sivers
print(f'spikeinterface version:  {sivers}')

spikeinterface version:  0.102.3


## setup

#### set paths

In [3]:
# load config settings
paths = settings.paths
experiment = settings.experiment

# define project paths
raw_drive = paths.drive
expt_folder = experiment.dir
batch_folder = raw_drive / expt_folder / paths.data_dir
print(f'Looking for recordings in:\n\t{batch_folder}')

# define session paths
raw_dir = paths.raw_dir
processed_dir = paths.processed_dir

Looking for recordings in:
	/mnt/array/3_TRAP_ISO/1_Recordings


set parallel processing parameters

In [4]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

global_job_kwargs=dict(
    n_jobs=6,
    chunk_duration='1s',
    progress_bar=True,
)
si.set_global_job_kwargs(**global_job_kwargs)
print("\nParallel Job parameters:")
pprint(global_job_kwargs, indent=2, width=2)


Parallel Job parameters:
{ 'chunk_duration': '1s',
  'n_jobs': 6,
  'progress_bar': True}


set default kilosort parameters

In [5]:
# check list of default values with `si.get_default_sorter_params('kilosort4')`
ks_params = si.get_default_sorter_params('kilosort4')
ks_params.update(dict(
            batch_size=45000, nblocks=5, Th_universal=9.0, Th_learned=8.0,
            save_preprocessed_copy=True, save_extra_vars=False,
            # cluster_downsampling = 25, max_cluster_subset=None  # address potential memory issues in KS v4.1.
        ))

# session-specific sorting params
sessions_ksparams = {
    # 'NP01_R1': dict(
    #     Th_universal=5.0, Th_learned=5.0, ...
    # )
}

## sort recordings

set subject / recording pairs to process

In [6]:
recording_pairs = experiment.recordings
print("Processing the following recordings:")
pprint(recording_pairs, indent=4)

Processing the following recordings:
{'TRP804_R2': {'concatenate': False, 'multiple_shanks': True}}


In [7]:
# change to other than None to compress specific raw recording (including multiple segments)
# e.g. "NP02_R1", or keep as None to batch compress all recordings in recording_pairs
recording_name = None

In [8]:
if recording_name is not None:
    properties = recording_pairs[recording_name]
    print(f'---processing single recording')
    recording_pairs = {recording_name: properties}
# else:
#     batch

In [None]:
from probeinterface.plotting import plot_probe, plot_probegroup
from torch.cuda import empty_cache
from tools.spikesorting import load_recording, process_recording

overwrite = True
for session, properties in recording_pairs.items():
    animal = session.split('_')[0]
    recording_name = session
    concatenate = properties['concatenate']
    print(f'---processing  {recording_name}{", multiple recordings..." if concatenate else ""}')

    # find session folder
    rec_folder = batch_folder / animal / recording_name
    if rec_folder.exists():
        print(f'recording session folder:  {rec_folder}')  # top-level
    else:
        print(f'(!) No recording session folder found for:  {rec_folder}\nSkipping...\n\n')
        continue

    # find raw recording folders
    raw_folder = rec_folder / raw_dir
    raw_folders = None
    assert raw_folder.exists(), f'Raw folder does not exist:\n\t{raw_folder}'
    if not concatenate: # single recording
        if raw_file := next(raw_folder.glob('*.cbin'), None):  # no subfolder, raw files only
            raw_folder = raw_file.parent
            print(f'Raw files found in:\n\t{raw_folder}')
        elif (raw_file := list(raw_folder.rglob('*.cbin'))) != []:  # subfolder with raw files
            if len(raw_file) > 1:
                raw_folders = [f.parent for f in raw_file]
                print(f'Multiple raw folders found:', end='')
                print(*raw_folders, sep='\n\t')
            else:
                raw_folder = raw_file[0].parent
                print(f'Raw folder found:\n\t{raw_folder}')
        else:
            print(f'No recordings found for {recording_name}!\nSkipping...\n\n')
            continue
    else:  # multiple segments for recording
        if raw_files := raw_folder.rglob('*.cbin') != []:
            if len(raw_files) == 1:
                raw_folder = raw_files[0].parent
                print(f'Only one raw file found:\n\t{raw_folder}')
            else:
                raw_folders = [f.parent for f in raw_files]
                print(f'Raw folders found for {recording_name}:\n\t', end='')
                print(*raw_folders, sep='\n\t')

    if not raw_folders:  # single recording segment
        raw_folders = [raw_folder]
        rec_name = recording_name

    processed_folder = rec_folder / processed_dir
    assert processed_folder.exists(), f'Processed folder does not exist:\n\t{processed_folder}'
    print(f'---saving processed outputs to  "{processed_folder}"')

    probe_shanks = properties.get('multiple_shanks', False)
    if not isinstance(probe_shanks, list):  # single probe
        probe_shanks = [probe_shanks]

    raw_folder = rec_folder / raw_dir
    assert raw_folder.exists(), f"(!) No raw data folder found for recording: {rec_folder}\nExpected in: {raw_folder}\nSkipping...\n\n"
    match raw_files := list(raw_folder.rglob(f'{recording_name}*imec*.cbin')):
        case x if len(x) > 1:
            print(f'Found multiple raw files for {recording_name}: {x}')
        case _:
            print(f'Found single raw file for {recording_name}: {raw_files}')

    for probe_num, raw_file in enumerate(raw_files):
        print(f'---processing probe {probe_num} from file: {raw_file.name}')
        empty_cache()  # clear GPU memory between probes
        # TODO: condense to spikesorting functions
        # similar compress_recording(rec_name, rec_folder, target_folder, job_kwargs)
        # spike_sorting(rec_name, rec_folder, job_kwargs)
            # includes load_recording, 

        # load recording
        rec = load_recording(raw_file, concatenate=concatenate) 
        if not rec:
            print(f'(!) No valid recording found for {recording_name}!\nSkipping...\n\n')
            continue
        
        rec_name = f'{recording_name}_probe{probe_num}'  # with probe number
        print(f'\nFinal recording: {rec}\n\t', rec, '\n')

        # save probe channel map
        fig, ax = plt.subplots(1, 1, figsize=(10, 5), dpi=300)
        plot_probe(
            rec.get_probe(), 
            ax=ax
        )
        probemap_filepath = processed_folder / f"{rec_name}_channel_selection.png"
        print(f'...saving channel selection to:\n\t{probemap_filepath}')
        plt.savefig(probemap_filepath, dpi=400, bbox_inches='tight')

        # set session-specific sorting parameters
        session_params = sessions_ksparams.get(recording_name, {})
        session_params.update({'fs': rec.get_sampling_frequency()})  # add sampling frequency
        ks_params.update(session_params)

        # run sorter
        sorter_output_folder = processed_folder / f'kilosort4_probe{probe_num}'
        multiple_shanks = probe_shanks[probe_num] if probe_num < len(probe_shanks) else False
        if multiple_shanks:
            print(f'\nRunning sorter for multiple shanks...')
            
            if (not overwrite) and sorter_output_folder.exists():
                print(f'...skipping sorting for existing output:\n\t{sorter_output_folder}')
                continue
            print('\n---------Starting sorting---------\n')
            print(f'...saving sorting output to:\n\t{sorter_output_folder}')        
            aggregated_sorting = si.run_sorter_by_property(
                sorter_name='kilosort4',
                recording=rec,
                grouping_property='group',  # defines shank number in recording object
                folder=sorter_output_folder,
                remove_existing_folder=True,
                docker_image=False,
                **ks_params
            )
            # concatenate shank sortings into one folder for postprocessing
            try:
                aggregated_sorting.save_to_zarr(sorter_output_folder / 'aggregated_sorting', overwrite=True, **global_job_kwargs)
            except:
                aggregated_sorting.save_to_zarr(sorter_output_folder / 'aggregated_sorting', overwrite=True)
            print('\n---------Sorting finished---------\n\n')

        else:
            print(f'\nRunning sorter for single shank...')
            print('\n---------Starting sorting---------\n')
            print(f'...saving sorting output to:\n\t{sorter_output_folder}')
            sorting = si.run_sorter(
                sorter_name='kilosort4',
                recording=rec,
                folder=sorter_output_folder,
                remove_existing_folder=True,
                docker_image=False,
                verbose=True,
                **ks_params
            )
            # save separate sorting object for postprocessing
            sorting.save_to_zarr(sorter_output_folder / 'sorting', overwrite=True)
            print('\n---------Sorting finished---------\n\n')

---processing  TRP804_R2
recording session folder:  /mnt/array/3_TRAP_ISO/1_Recordings/TRP804/TRP804_R2
Raw folder found:
	/mnt/array/3_TRAP_ISO/1_Recordings/TRP804/TRP804_R2/0_raw_compressed/TRP804_R2_g0_t0
---saving processed outputs to  "/mnt/array/3_TRAP_ISO/1_Recordings/TRP804/TRP804_R2/2_processed"
Found single raw file for TRP804_R2: [PosixPath('/mnt/array/3_TRAP_ISO/1_Recordings/TRP804/TRP804_R2/0_raw_compressed/TRP804_R2_g0_t0/TRP804_R2_g0_t0.imec0.ap.cbin')]
---processing probe 0 from file: TRP804_R2_g0_t0.imec0.ap.cbin
