## Spikesorting Pipeline Populator

This Notebook is to demonstrate how to use the `spikesorting_pipeline_populator` to efficiently populate the spikesorting pipeline tables once you have determined the sorting parameters appropriate for your dataset

First, we'll import the relevant spyglass tools

In [1]:
import os
os.environ["DASK_DISTRIBUTED__WORKER__DAEMON"] = "False" #set variable before importing dask for parallel processing (see below)

from spyglass.spikesorting import spikesorting_pipeline_populator, SpikeSortingPipelineParameters

[2023-09-13 16:02:42,354][INFO]: Connecting sambray@lmf-db.cin.ucsf.edu:3306
[2023-09-13 16:02:42,403][INFO]: Connected sambray@lmf-db.cin.ucsf.edu:3306
  @numba.jit(parallel=False)
  @numba.jit(parallel=False)


The `SpikeSortingPipelineParameters` table where you can manually define a set of pipeline parameters composed of the names of parameter sets you will use at each stage of the pipeline.  This can be useful in tracking and reusing the same parameters settings within a project

In [2]:
SpikeSortingPipelineParameters()

pipeline_parameters_name,preproc_params_name,artifact_parameters,sorter,sorter_params_name,waveform_params_name,metric_params_name,auto_curation_params_name
ms_stim_project,franklab_tetrode_hippocampus,ampl_2000_prop_75,mountainsort4,franklab_tetrode_hippocampus_30KHz_tmp,default_whitened,peak_offest_num_spikes_2,mike_noise_03_offset_2_isi_0025_mua


These parameters are used to populate the spike sorting pipeline within the function `spikesorting_pipeline_populator`.  
In the simplest use case, you can call this function with just a pipeline_parameters_name referencing an entry in the SpikeSortingPipelineParameters table, and the session and interval you want to sort. Doing so will create sort groups by shank (if not already present for the session), a sort interval matching the interval provided and run through the pipeline.  Other parameter options are:

- team_name : str
        Which team to assign the spike sorting to

- __fig_url_repo :__ str, optional
    Where to store the curation figurl json files (e.g.,
    'gh://LorenFrankLab/sorting-curations/main/user/'). Default None to
    skip figurl

- __interval_list_name :__ str,
    if sort_interval_name not provided, will create a sort interval for the
    given interval with the same name

- __sort_interval_name :__ str, default None
    if provided, will use the given sort interval, requires making this
    interval yourself

- __pipeline_parameters_name :__ str, optional
    If provided, will lookup pipeline parameters from the
    SpikeSortingPipelineParameters table, supersedes other values provided,
    by default None

- __restrict_probe_type :__ dict, optional
    Restricts analysis to sort groups with matching keys. Can use keys from
    the SortGroup and ElectrodeGroup Tables (e.g. electrode_group_name,
    probe_id, target_hemisphere), by default {}

- __\{\}\_params_name :__ str, optional
    Optionally, you can pass each parameter element individually. If __pipeline_parameters_name__ is provided, entries from the table will overide these passed options

In [3]:
nwb_file_name = "SB2spikegadgets20220224_.nwb"
interval_list_name = "02_wtrackBan77mWlockout80mstheta90"

spikesorting_pipeline_populator(nwb_file_name=nwb_file_name,
                                interval_list_name=interval_list_name,
                                pipeline_parameters_name='ms_stim_project',
                                team_name='ms_stim',
                                )

Using pipeline parameters ms_stim_project
Generating sort interval from 02_wtrackBan77mWlockout80mstheta90
Generating spike sorting recording
Running artifact detection
Running spike sorting
Beginning curation
Extracting waveforms
Extracting waveforms...


extract waveforms memmap:   0%|          | 0/5 [00:00<?, ?it/s]

Writing new NWB file SB2spikegadgets20220224_0EFVDBMYXL.nwb




Calculating quality metrics
Computed all metrics: {'snr': {1: 3.5332086, 2: 3.4332783, 3: 3.7708488, 4: 3.4630616, 5: 3.4783084, 6: 3.548221, 7: 3.6467302, 8: 3.4736087, 9: 3.4898095, 10: 3.4654353, 11: 3.5187457, 12: 3.438066, 13: 3.6057005, 14: 3.472549, 15: 3.571222, 16: 3.4618917, 17: 3.5582566, 18: 3.452622, 19: 3.5286677, 20: 3.7574408, 21: 3.5070932, 22: 3.581863, 23: 3.4403317, 24: 3.4768283, 25: 3.4827776, 26: 3.5205245, 27: 3.4651222, 28: 3.599984, 29: 3.4288933, 30: 3.5165195, 31: 3.4402614}, 'isi_violation': {'1': 0.04639226878974513, '2': 0.0400116680441441, '3': 0.05676583118607012, '4': 0.040147833130822194, '5': 0.03874026580195052, '6': 0.037633715685010956, '7': 0.04128487069843209, '8': 0.04050858907546805, '9': 0.03918876604764313, '10': 0.03795885025782407, '11': 0.047917880675960374, '12': 0.03859004343468092, '13': 0.03943620828160374, '14': 0.039824439824439825, '15': 0.042233682440875116, '16': 0.03970011927074459, '17': 0.041910729691685966, '18': 0.0431040515

## Parallelization

This pipeline population is embarasingly parallel between different intervals. To speed up spikesorting on many datasets, we can use parallel calls to the pipeline populator to take advantage of computational resources.  

Here we will use `dask` to spawn workers which will each run the pipeline on a different interval. First, we need to create a list containing the set of arguments for each call to the pipeline.  These can be from any number of sessions, intervals, or parameters sets

In [6]:
from spyglass.common import PositionIntervalMap, IntervalList
from spyglass.spikesorting import spikesorting_pipeline_populator
nwb_file_name = "SB2spikegadgets20220224_.nwb"

intervals = [x for x in (IntervalList() & {"nwb_file_name":nwb_file_name}).fetch("interval_list_name") if x[0] =='0']

arguments_list = []
for interval in intervals:
    arguments_list.append(dict(nwb_file_name=nwb_file_name,
                               interval_list_name=interval,
                               pipeline_parameters_name='ms_stim_project',
                               team_name='ms_stim',
                               ))


[2023-09-13 16:07:07,792][INFO]: Connecting sambray@lmf-db.cin.ucsf.edu:3306
[2023-09-13 16:07:07,849][INFO]: Connected sambray@lmf-db.cin.ucsf.edu:3306
  @numba.jit(parallel=False)
  @numba.jit(parallel=False)


We will also need a function to be called by dask which can take the passed group arguments and feed them to the pipeline poopulator. We can also use this to insulate from exceptions that may arise from the pipeline and instead write them to a list

In [7]:
def pass_function(arg_dict):
    try:
        spikesorting_pipeline_populator(
            nwb_file_name=arg_dict["nwb_file_name"],
            interval_list_name=arg_dict["interval_list_name"],
            team_name=arg_dict["team_name"],
            pipeline_parameters_name=arg_dict["pipeline_parameters_name"],
        )
        return
    except Exception as e:
        print(e)
        return str(e)

We Can now create our dask client to create and manage workers. 

`TODO:` Optomize the number of threads and workers for virga machines

In [4]:
import dask
dask.config.get("distributed.worker.daemon")
from dask.distributed import Client, progress
client = Client(threads_per_worker=4, n_workers=10)
client

False

Now we use the dask to map our list of arguments to our helper function and enjoy using our cores

In [8]:
results = client.map(pass_function, arguments_list)

[2023-09-13 16:07:21,347][INFO]: Connecting sambray@lmf-db.cin.ucsf.edu:3306
[2023-09-13 16:07:21,378][INFO]: Connected sambray@lmf-db.cin.ucsf.edu:3306
[2023-09-13 16:07:21,583][INFO]: Connecting sambray@lmf-db.cin.ucsf.edu:3306
[2023-09-13 16:07:21,617][INFO]: Connected sambray@lmf-db.cin.ucsf.edu:3306
  @numba.jit(parallel=False)
  @numba.jit(parallel=False)
  @numba.jit(parallel=False)
  @numba.jit(parallel=False)


Using pipeline parameters ms_stim_project
Generating sort interval from 01_sleepBan77mWnostim
Generating spike sorting recording
Using pipeline parameters ms_stim_project
Generating sort interval from 02_wtrackBan77mWlockout80mstheta90
Generating spike sorting recording
write_binary_recording with n_jobs = 8 and chunk_size = 299593


write_binary_recording: 100%|██████████| 84/84 [00:46<00:00,  1.79it/s]


Running artifact detection
(1217, 'Cannot delete or update a parent row: a foreign key constraint fails')
using 4 jobs...


detect_artifact_frames:   0%|          | 0/84 [00:00<?, ?it/s][2023-09-13 16:08:26,733][INFO]: Connecting sambray@lmf-db.cin.ucsf.edu:3306
[2023-09-13 16:08:26,733][INFO]: Connecting sambray@lmf-db.cin.ucsf.edu:3306
[2023-09-13 16:08:26,733][INFO]: Connecting sambray@lmf-db.cin.ucsf.edu:3306
[2023-09-13 16:08:26,733][INFO]: Connecting sambray@lmf-db.cin.ucsf.edu:3306
[2023-09-13 16:08:26,785][INFO]: Connected sambray@lmf-db.cin.ucsf.edu:3306
[2023-09-13 16:08:26,786][INFO]: Connected sambray@lmf-db.cin.ucsf.edu:3306
[2023-09-13 16:08:26,787][INFO]: Connected sambray@lmf-db.cin.ucsf.edu:3306
[2023-09-13 16:08:26,787][INFO]: Connected sambray@lmf-db.cin.ucsf.edu:3306
  @numba.jit(parallel=False)
  @numba.jit(parallel=False)
  @numba.jit(parallel=False)
  @numba.jit(parallel=False)
  @numba.jit(parallel=False)
  @numba.jit(parallel=False)
  @numba.jit(parallel=False)
  @numba.jit(parallel=False)
detect_artifact_frames: 100%|██████████| 84/84 [00:24<00:00,  3.49it/s]


No artifacts detected.
Running spike sorting
Running spike sorting on {'nwb_file_name': 'SB2spikegadgets20220224_.nwb', 'sort_group_id': 0, 'sort_interval_name': '01_sleepBan77mWnostim', 'preproc_params_name': 'franklab_tetrode_hippocampus', 'team_name': 'ms_stim', 'sorter': 'mountainsort4', 'sorter_params_name': 'franklab_tetrode_hippocampus_30KHz_tmp', 'artifact_removed_interval_list_name': 'SB2spikegadgets20220224_.nwb_01_sleepBan77mWnostim_0_franklab_tetrode_hippocampus_ampl_2000_prop_75_artifact_removed_valid_times'}...
Mountainsort4 use the OLD spikeextractors mapped with NewToOldRecording
Using temporary directory /stelmo/nwb/tmp/spyglass/tmplwbttflq
Using 4 workers.
Using tempdir: /stelmo/nwb/tmp/spyglass/tmplwbttflq/tmph94bfez_
Num. workers = 4
Preparing /stelmo/nwb/tmp/spyglass/tmplwbttflq/tmph94bfez_/timeseries.hdf5...


Finally, We can check for error messages for each argument set

In [10]:
for args,result in zip(arguments_list,client.gather(results)):
    print(args["nwb_file_name"],args["interval_list_name"])
    if result is not None:
        print(result)

[<Future: pending, key: pass_function-2a28d6b93261542cef0922db6341fe88>,
 <Future: finished, type: str, key: pass_function-d4f1bc3ecf62ea231745c3f30d6e5575>]