# SpikeInterface Tutorial - NWB User Days Workshop  - September 2020


In this tutorial, we will cover the basics of using SpikeInterface for extracellular analysis and spike sorting comparison. We will be using the `spikeinterface` from the SpikeInterface github organization. 

`spikeinterface` wraps 5 subpackages: `spikeextractors`, `spikesorters`, `spiketoolkit`, `spikecomparison`, and `spikewidgets`.

For this analysis, we will be using a real dataset recorded from CA1 region in the hippocampus (recording from [CINPLA](https://www.mn.uio.no/ibv/english/research/sections/fyscell/cinpla/)). We will show how to:

- load the data with spikeextractors package
- load a probe file
- preprocess the signals
- run a popular spike sorting algorithm with different parameters
- curate the spike sorting output using 1) quality metrics (automatic) - 2) [Phy](https://github.com/cortex-lab/phy) 
(manual) - 3) consensus-based
- save the results to NWB!


We recommend creating a new `spiketutorial` conda environment using:

`conda env create -f environment.yml`

In addition, for the conda environment, you need to install [Phy](https://github.com/cortex-lab/phy) for the manual curation step.

`pip install phy --pre --upgrade`


Alternatively, you can install the requirements you can use the `requirements.txt` in this directory by running the command:

`pip install -r requirements.txt`

(in this case Phy should be automatically installed)


### Downloading the recording

First, we need to download the recording. Feel free to use your own recordings as well later on. 
From this Zenodo [link](https://doi.org/10.5281/zenodo.3825284), you can download the dataset mentioned above (`open-ephys-dataset.zip`). Move the dataset in the current folder and unzip it.
The recording was performed with the mircodrives with 4 tetrodes each (in total 32 channels).


### Importing the modules

Let's now import the `spikeinterface` modules that we need.

In [5]:
import spikeinterface
import spikeinterface.extractors as se 
import spikeinterface.toolkit as st
import spikeinterface.sorters as ss
import spikeinterface.comparison as sc
import spikeinterface.widgets as sw
import matplotlib.pyplot as plt
import numpy as np
%matplotlib notebook

## Loading recording and probe information

In [10]:
recording_folder = '../../../spike_data/open-ephys-dataset/open-ephys-dataset/'
recording = se.OpenEphysRecordingExtractor(recording_folder)

Loading Open-Ephys: reading settings...
Decoding data from  binary  format
Reading oebin file


In [11]:
se.recording_extractor_full_list

[spikeextractors.extractors.mdaextractors.mdaextractors.MdaRecordingExtractor,
 spikeextractors.extractors.mearecextractors.mearecextractors.MEArecRecordingExtractor,
 spikeextractors.extractors.biocamrecordingextractor.biocamrecordingextractor.BiocamRecordingExtractor,
 spikeextractors.extractors.exdirextractors.exdirextractors.ExdirRecordingExtractor,
 spikeextractors.extractors.openephysextractors.openephysextractors.OpenEphysRecordingExtractor,
 spikeextractors.extractors.intanrecordingextractor.intanrecordingextractor.IntanRecordingExtractor,
 spikeextractors.extractors.bindatrecordingextractor.bindatrecordingextractor.BinDatRecordingExtractor,
 spikeextractors.extractors.klustaextractors.klustaextractors.KlustaRecordingExtractor,
 spikeextractors.extractors.kilosortextractors.kilosortextractors.KiloSortRecordingExtractor,
 spikeextractors.extractors.spykingcircusextractors.spykingcircusextractors.SpykingCircusRecordingExtractor,
 spikeextractors.extractors.spikeglxrecordingextrac

In [12]:
se.installed_recording_extractor_list

[spikeextractors.extractors.mdaextractors.mdaextractors.MdaRecordingExtractor,
 spikeextractors.extractors.biocamrecordingextractor.biocamrecordingextractor.BiocamRecordingExtractor,
 spikeextractors.extractors.openephysextractors.openephysextractors.OpenEphysRecordingExtractor,
 spikeextractors.extractors.bindatrecordingextractor.bindatrecordingextractor.BinDatRecordingExtractor,
 spikeextractors.extractors.klustaextractors.klustaextractors.KlustaRecordingExtractor,
 spikeextractors.extractors.kilosortextractors.kilosortextractors.KiloSortRecordingExtractor,
 spikeextractors.extractors.spykingcircusextractors.spykingcircusextractors.SpykingCircusRecordingExtractor,
 spikeextractors.extractors.spikeglxrecordingextractor.spikeglxrecordingextractor.SpikeGLXRecordingExtractor,
 spikeextractors.extractors.phyextractors.phyextractors.PhyRecordingExtractor,
 spikeextractors.extractors.maxoneextractors.maxoneextractors.MaxOneRecordingExtractor,
 spikeextractors.extractors.mea1kextractors.mea1

A `RecordingExtractor` object extracts information about channel ids, channel locations (if present), the sampling frequency of the recording, and the extracellular traces (when prompted). The `OpenEphysRecording` is designed specifically for open-ephys datasets.

Here we load information from the recording using the built-in functions from the RecordingExtractor

In [16]:
type(recording)

spikeextractors.extractors.openephysextractors.openephysextractors.OpenEphysRecordingExtractor

In [13]:
channel_ids = recording.get_channel_ids()
fs = recording.get_sampling_frequency()
num_chan = recording.get_num_channels()

print(f'Channel ids: {channel_ids}')
print(f'Sampling frequency: {fs}')
print(f'Number of channels: {num_chan}')

Channel ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
Sampling frequency: 30000.0
Number of channels: 32


The `get_traces()` function returns a NxT numpy array where N is the number of channel ids passed in (all channel ids are passed in by default) and T is the number of frames (determined by start_frame and end_frame).

In [20]:
[print(l) for l in dir(recording)]

__abstractmethods__
__class__
__del__
__delattr__
__dict__
__dir__
__doc__
__eq__
__format__
__ge__
__getattribute__
__gt__
__hash__
__init__
__init_subclass__
__le__
__lt__
__module__
__ne__
__new__
__reduce__
__reduce_ex__
__repr__
__setattr__
__sizeof__
__slots__
__str__
__subclasshook__
__weakref__
_abc_impl
_cast_start_end_frame
_default_filename
_dtype
_epochs
_features
_get_file_path
_key_properties
_kwargs
_memmap_files
_properties
_recording
_recording_file
_tmp_folder
add_epoch
allocate_array
check_if_dumpable
clear_channel_property
clear_channels_property
copy_channel_properties
copy_epochs
del_memmap_file
dump_to_dict
dump_to_json
dump_to_pickle
extractor_name
frame_to_time
get_channel_gains
get_channel_groups
get_channel_ids
get_channel_locations
get_channel_property
get_channel_property_names
get_dtype
get_epoch
get_epoch_info
get_epoch_names
get_num_channels
get_num_frames
get_sampling_frequency
get_shared_channel_property_names
get_snippets
get_sub_extractors_by_propert

[None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None]

In [14]:
trace_snippet = recording.get_traces(start_frame=int(fs*0), end_frame=int(fs*2))

AttributeError: 'AnalogSignal' object has no attribute 'gain'

In [None]:
print('Traces shape:', trace_snippet.shape)

The `spikewidgets` module includes several convenient plotting functions that can be used to explore the data:

In [None]:
w_ts = sw.plot_timeseries(recording)

In [None]:
w_sp = sw.plot_spectrum(recording, channels=[0, 1, 2])

We can see that the spikes mainly appear separately on different tetrodes. Each tetrode belongs to a different `group`. We can load the `group` information in two ways:

- using the `set_channel_groups` in your RecordingExtractor (manually loading group information)
- loading a probe file using the `load_probe_file` from RecordingExtractor (automatically loading group information)

Let's use the second option. Probe files (`.prb`) also enable users to change the channel map (reorder the channels) and add channel grouping properties and locations. In this case, our probe file will order the channels in reverse and split them in 4 groups, representing the 4 tetrodes. We'll also add locations to separate the different tetrodes.

In [None]:
# only works on linux and mac. For windows, open the file using a text editor
!cat tetrode_32.prb

In [None]:
recording_prb = recording.load_probe_file('tetrode_32.prb')

In [None]:
print(f'Original channels: {recording.get_channel_ids()}')
print(f'Channels after loading the probe file: {recording_prb.get_channel_ids()}')
print(f'Channel groups after loading the probe file: {recording_prb.get_channel_groups()}')

In [None]:
w_elec = sw.plot_electrode_geometry(recording_prb)

### Properties (and features)

For now we have seen that the `RecordingEctractor` can have `group` and `location` *properies*. These are very special properties that can be very important for spike sorting. Anything related to a channel can be saved as a property.

Similarly, for `SortingExtractor` objects, anything related to a unit can be stored as a property. In addition, for `SortingExtractor` objects we can also store anything related to spikes as *features* (e.g. waveforms, as we'll see later).

We can check which properties are in the estractor as follows:

In [None]:
print(recording.get_shared_channel_property_names())
print(recording_prb.get_shared_channel_property_names())

Let's add a new property! The first 16 channels are in the left hemisphere, the second 16 are in the right one:

In [None]:
for ch in recording_prb.get_channel_ids():
    if ch < 16:
        recording_prb.set_channel_property(ch, property_name='hemisphere', value='left')
    else:
        recording_prb.set_channel_property(ch, property_name='hemisphere', value='right')

In [None]:
print(recording_prb.get_shared_channel_property_names())

## Preprocessing recordings


Now that the probe information is loaded we can do some preprocessing using `spiketoolkit`.

We can filter the recordings, rereference the signals to remove noise, discard noisy channels, whiten the data, remove stimulation artifacts, etc. (more info [here](https://spiketoolkit.readthedocs.io/en/latest/preprocessing_example.html)).

For this notebook, let's filter the recordings, remove a noisy channel, and apply common median reference (CMR). All preprocessing modules return new `RecordingExtractor` objects that apply the underlying preprocessing function. This allows users to access the preprocessed data in the same way as the raw data.

Below, we bandpass filter the recording, remove channel 5, and apply common median reference to the original recording.

In [None]:
recording_f = st.preprocessing.bandpass_filter(recording_prb, freq_min=300, freq_max=6000)

w = sw.plot_timeseries(recording_f, color_groups=True)

We can see that the first drive is quite active, while the second one is not. For sake of time, we can just focus on the first drive (channels 0-15). We can easily select these channels and get a new extractor using the `SubRecordingExtractor`:

In [None]:
recording_1 = se.SubRecordingExtractor(recording_f, channel_ids=range(16))

In [None]:
print(f'Sub channels: {recording_1.get_channel_ids()}')
print(f'Channel groups after SubRecordingExtractor: {recording_1.get_channel_groups()}')
w = sw.plot_timeseries(recording_1, color_groups=True)

As we can notice from the first and second plot, channel 2 seems to be a bit noisy. We can remove it using the `remove_bad_channels` function:

In [None]:
recording_rm_noise = st.preprocessing.remove_bad_channels(recording_1, bad_channel_ids=[2])
print(f'Channel ids after removing bad channel: {recording_rm_noise.get_channel_ids()}')
print(f'Channel groups after removing bad channel: {recording_rm_noise.get_channel_groups()}')

In [None]:
recording_cmr = st.preprocessing.common_reference(recording_rm_noise, reference='median')

We can plot the traces after removing the bad channel and applying CMR:

In [None]:
w = sw.plot_timeseries(recording_cmr, color_groups=True)

In [None]:
print(f'Channel ids for CMR recordings: {recording_cmr.get_channel_ids()}')
print(f'Channel groups for CMR recordings: {recording_cmr.get_channel_groups()}')

Since we are going to spike sort the data, let's first cut out a 2-minute recording, to speed up computations.

In [None]:
fs = recording_cmr.get_sampling_frequency()
recording_sub = se.SubRecordingExtractor(recording_cmr, start_frame=200*fs, end_frame=320*fs)

## Caching 

All operations in SpikeInterface are *lazy*, meaning that they are not performed if not needed. This is why the creation of our filter recording was almost instantaneous. However, to speed up further processing, we might want to **cache** it to a file and perform those operations (eg. filters, CMR, etc.) at once. This is particularly important if we are going to extract waveforms, templates, pca scores, or in general *post-process* the results.

In [None]:
recording_cache = se.CacheRecordingExtractor(recording_sub) 

The cached recording has all the previously loaded information:

In [None]:
print(f'Cached channels: {recording_cache.get_channel_ids()}')
print(f'Cached channels ids: {recording_cache.get_channel_ids()}')
print(f'Channel groups after caching: {recording_cache.get_channel_groups()}')

Under the hood, this convenient function is retrieving all the traces (in chunks, to save up some memory), applying the preprocessing steps, and dumping them to a binary temporary file:

In [None]:
recording_cache.filename

Note that all extractors (including sorting extractors), have a temporary folder associated with it, that enables SpikeInterface to cache several data (including waveforms) and be gentle on RAM usage:

In [None]:
recording_cache.get_tmp_folder()

The temporary files in the tmp folder, are *temporary*, and they will be deleted when the Python session is closed (or the object destroyed). To prevent this, we can simply move the binary file to a custom location:

In [None]:
recording_cache.move_to('filtered_data.dat') 
print(recording_cache.filename)

Alternatively, we could have passed the `save_path` argument to the `se.CacheRecordingExtractor` directly.

## Dumping

If we now closed the Python session, we would have a nice `.dat` file, but no information on how to open it! 
In order to save the state of an extractor, we can use the **dumping** mechanism.
Each extractor can be converted to a dictionary, which holds the path to the data file and all relevant information:

In [None]:
recording_cache.dump_to_dict()

We can now dump our extractor object, so it can be loaded in a future session. We can dump either to `.json` or to `.pkl`. Dumping to pickle also allow us to store properties (other than group and locations) and features (for `SortingExtractor` objects).

In [None]:
recording_cache.dump_to_pickle('recording.pkl')

In another session, we can pick up from where we left by loading the extractor from the pickle file:

In [None]:
recording_loaded = se.load_extractor_from_pickle('recording.pkl')

In [None]:
w = sw.plot_timeseries(recording_loaded, color_groups=True)

We can double check that the traces are exactly the same as the `recording_sub` that we dumped:

In [None]:
w = sw.plot_timeseries(recording_sub, color_groups=True)

**IMPORTANT**: the same caching/dumping mechanisms are available also for all SortingExtractor

# Spike sorting

We can now run spike sorting on the above recording. We will use `klusta` and `ironclust` for this demonstration, to show how easy SpikeInterface makes it to interchengably run different sorters :)

Let's first check the installed sorters in spiketoolkit to see if klusta is available. Then we can check the `klusta` default parameters.
We will sort the bandpass cached filtered recording the `recording_cache` object.

In [None]:
ss.installed_sorters()

We can retrieve the parameters associated to any sorter with the `get_default_params()` function from the `spikesorters` module:

In [None]:
ss.get_default_params('klusta')

In [None]:
ss.get_params_description('klusta')

In [None]:
ss.run_sorter?

In [None]:
ss.run_klusta?

We will set the `adjacency_radius` to 50 microns as electrodes belonging to the same tetrode are within this distance.

In [None]:
# run spike sorting on entire recording
sorting_KL_all = ss.run_klusta(recording_cache, output_folder='results_all_klusta', adjacency_radius=50, verbose=True)
print('Found', len(sorting_KL_all.get_unit_ids()), 'units')

SpikeInterface ensures full provenance of the spike sorting pipeline. Upon running a spike sorter, a `spikeinterface_params.json` file is saved in the `output_folder`. This contains a `.json` version of the recording and all the input parameters. 

### Spike sorting by group

Since we have 4 tetrodes and we know that they are physically apart, we would like to sort them separately.

Here is how it's done in SpikeInterface:

![](sort_by_group.png)

In [None]:
# run spike sorting by group 
sorting_KL = ss.run_klusta(recording_cache, adjacency_radius=50, 
                                      output_folder='results_split_klusta', 
                                      grouping_property='group', parallel=True)
print(f'Klusta found {len(sorting_KL.get_unit_ids())} units')

In [None]:
print(type(sorting_KL))

In [None]:
sorting_KL.sortings

### Installing IronClust (requires MATLAB)

For MATLAB-based sorters, all you need to do is cloning the sorter repo and point it to SpikeInterface:

Let's clone ironclust in the current directory:

In [None]:
!git clone https://github.com/flatironinstitute/ironclust

Now all we have to tell the IronClustSorter class where is the ironclust repo:

In [None]:
ss.IronClustSorter.set_ironclust_path('./ironclust')

Note that we can also set a global environment variable called `IRONCLUST_PATH`. In that case we don't need to set the path in each session because the sorter class looks for this environment variable.

Now ironclust should be installed and we can run it:

In [None]:
ss.IronClustSorter.ironclust_path

In [None]:
!echo $IRONCLUST_PATH

In [None]:
ss.installed_sorters()

In [None]:
# run spike sorting by group
sorting_IC = ss.run_ironclust(recording_cache, 
                              output_folder='results_split_ic', 
                              grouping_property='group', parallel=True, verbose=True)
print(f'IronClust found {len(sorting_IC.get_unit_ids())} units')

The spike sorting returns a `SortingExtractor` object. Let's see some of its functions:

In [None]:
print(f'Klusta unit ids: {sorting_KL.get_unit_ids()}')

In [None]:
print(f'Spike train of a unit: {sorting_KL.get_unit_spike_train(13)}')

We can use `spikewidgets` functions to quickly visualize some unit features:

In [None]:
w_rs = sw.plot_rasters(sorting_IC, trange=[0,10])

We can now perform some automatic curation by thresholding low snr units on the split sorting result

### Loading a spike sorting output from a spike sorting folder

If a spike sorter has been run, you can reload the output as a `SortingExtractor` using the corresponding `spikeextractors` class. Note that if sorting by group/property, single groups must be loaded separately:

In [None]:
sorting_KL_0 = se.KlustaSortingExtractor("results_split_klusta/0")

In [None]:
print(f'Klusta unit ids group 0: {sorting_KL_0.get_unit_ids()}')

## Postptocessing

The `postprocessing` submodule of `spiketoolkit` allow us to extract information from the combination of the recording and sorting extractors. For example, we can extract waveforms, templates, maximum channels and pca scores. In addition, we can also compute waveform features that could be used for further processing, e.g. classyfing excitatory-inhibitory neurons.

To extract the waveforms, we can run:

In [None]:
waveforms = st.postprocessing.get_unit_waveforms(recording_cache, sorting_IC, verbose=True)

In [None]:
waveforms[0].shape

Similarly, we can get templates, maximum channels, and pca scores. 
Whem these are computed, they are automatically stored in the `SortingExtractor` object, so that they don't need to be recomputed. 

Each waveform is associated with a specific spike, so they are saved as spike *features*:


You may have noticed that 300 waveforms were extracted from the spike train of the first unit. However, it has more spikes:

In [None]:
len(sorting_IC.get_unit_spike_train(0))

It can be convenient to only compute a subset of waveforms to speed up the calculation. The `waveform_idxs` property contains the spike indexes associated with the waveforms.

In [None]:
sorting_IC.get_shared_unit_spike_feature_names()

In [None]:
print(sorting_IC.get_unit_spike_features(0, 'waveforms_idxs'))

Since waveforms are already computed, the next time we (or another function - e.g. `get_unit_templates()`) call it it will just return the stored waveforms.

In [None]:
waveforms = st.postprocessing.get_unit_waveforms(recording_cache, sorting_IC, verbose=True)

Lighthing fast! 

If we want to recompute the waveforms, for example because we want to extract the waveforms divided by group, we can use the `recompute_info` argument (available for all `postprocessing`, `validation`, and `curation` functions):

Where are waveforms stored? We have seen above that each `Extractor` object has a tmp folder associated. Waveforms (and other features, e.g. pca scores) are stored in this folder as binary raw files:

In [None]:
tmp_folder = sorting_IC.get_tmp_folder()
print(tmp_folder)
print([(p.name) for p in tmp_folder.iterdir()])

In [None]:
waveforms_group = st.postprocessing.get_unit_waveforms(recording_cache, sorting_IC, max_spikes_per_unit=None, 
                                                       grouping_property='group', recompute_info=True,
                                                       verbose=True)

In [None]:
sorting_IC.get_shared_unit_property_names()

In [None]:
for wf in waveforms_group:
    print(wf.shape)

We can use `spikewidgets` to quickly inspect the spike sorting output:

In [None]:
w_wf = sw.plot_unit_templates(sorting=sorting_IC, recording=recording_cache)

In [None]:
w_acc = sw.plot_autocorrelograms(sorting_IC, unit_ids=[0,1,2,3])

### Compute extracellular features

Extracellular features, such as peak to valley duration or full-width half maximum, are important to classify neurons into putative classes (excitatory - inhibitory). The `postprocessing` module of `spiketoolkit` allows one to compute several of these features:

In [None]:
st.postprocessing.get_template_features_list()

In [None]:
features = st.postprocessing.compute_unit_template_features(recording_cache, sorting_IC, as_dataframe=True, 
                                                            upsampling_factor=10)
display(features)

For more information about these waveform features, we refer to this [documentation](https://github.com/AllenInstitute/ecephys_spike_sorting/tree/master/ecephys_spike_sorting/modules/mean_waveforms) from the Allen Institute.

## Validation

The `spiketoolkit` package also provides several functions to compute qualitity metrics about the spike sorting results through the `validation` module.

Let's see what metrics are available:



In [None]:
st.validation.get_quality_metrics_list()

We can either compute one metric at a time, or compute a subset of metrics using the `compute_quality_metrics` function:

In [None]:
duration = recording.get_num_frames()
isi_violations = st.validation.compute_isi_violations(sorting_IC, duration_in_frames=duration)
print('ISI violations:', isi_violations)

snrs = st.validation.compute_snrs(sorting_IC, recording_cache)
print('SNRs:', snrs)

In [None]:
quality_metrics = st.validation.compute_quality_metrics(sorting_IC, recording_cache, 
                                                        metric_names=['firing_rate', 'isi_violation', 'snr'], 
                                                        as_dataframe=True)
display(quality_metrics)

For more information about these waveform features, we refer to this [documentation](https://allensdk.readthedocs.io/en/latest/_static/examples/nb/ecephys_quality_metrics.html) from the Allen Institute.

## Curation

### 1) Manual curation using Phy

To perform manual curation we will export the data to [Phy](https://github.com/cortex-lab/phy). 

In [None]:
st.postprocessing.export_to_phy(recording_cache, 
                                sorting_IC, output_folder='phy_IC',
                                grouping_property='group', verbose=True, recompute_info=True)

In [None]:
%%capture --no-display
!phy template-gui phy_IC/params.py

After curating the results we can reload it using the `PhySortingExtractor`:

In [None]:
sorting_IC_phy_curated = se.PhySortingExtractor('phy_IC/', exclude_cluster_groups=['noise'])

In [None]:
print(len(sorting_IC_phy_curated.get_unit_ids()))
print(f"Unit ids after manual curation: {sorting_IC_phy_curated.get_unit_ids()}")

In [None]:
# We can do the same for the klusta output.
# st.postprocessing.export_to_phy(recording_cache, 
#                                 sorting_KL, output_folder='phy_KL',
#                                 grouping_property='group', verbose=True, recompute_info=True)

In [None]:
# %%capture --no-display
# !phy template-gui phy_KL/params.py

### 2) Automatic curation based on quality metrics

In [None]:
snr_thresh = 5
isi_viol_thresh = 0.5

In [None]:
sorting_auto = st.curation.threshold_isi_violations(sorting_KL, isi_viol_thresh, 'greater', duration)

In [None]:
len(sorting_auto.get_unit_ids())

In [None]:
sorting_auto = st.curation.threshold_snrs(sorting_auto, recording_cache, snr_thresh, 'less')

In [None]:
len(sorting_auto.get_unit_ids())

### 3) Consensus-based curation 

Can we combine the output of multiple sorters to curate the spike sorting output?

To answer this question we can use the `comparison` module.
We first compare and match the output spike trains of the different sorters, and we can then extract a new `SortingExtractor` with only the units in agreement.

In [None]:
mcmp = sc.compare_multiple_sorters([sorting_KL, sorting_IC], ['KL', 'IC'], spiketrain_mode='union',
                                   verbose=True)

In [None]:
w = sw.plot_multicomp_agreement(mcmp)
w = sw.plot_multicomp_agreement_by_sorter(mcmp)

In [None]:
agreement_sorting = mcmp.get_agreement_sorting(minimum_agreement_count=2)

In [None]:
agreement_sorting.get_unit_ids()

In [None]:
cmp_manual_agr = sc.compare_sorter_to_ground_truth(sorting_IC_phy_curated, agreement_sorting)

In [None]:
w_agr = sw.plot_agreement_matrix(cmp_manual_agr)

In [None]:
st.postprocessing.export_to_phy(recording_cache, 
                                agreement_sorting, output_folder='phy_AGR',
                                grouping_property='group', verbose=True, recompute_info=True)

In [None]:
%%capture --no-display
!phy template-gui phy_AGR/params.py

## Save to / load from NWB

In [None]:
metadata = {'Ecephys': {'Device': [{'name': 'open-ephys',
                                    'description': 'Open Ephys acquisition board'}]}}

In [None]:
se.NwbRecordingExtractor.write_recording(recording_cache, 'si_tutorial.nwb', metadata=metadata)

In [None]:
se.NwbSortingExtractor.write_sorting(sorting_IC, 'si_tutorial.nwb')

In [None]:
recording_nwb = se.NwbRecordingExtractor('si_tutorial.nwb')
sorting_nwb = se.NwbSortingExtractor('si_tutorial.nwb')