### Load modules

In [75]:
import spikeinterface as si
import spikeinterface.extractors as se 
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import spikeinterface.postprocessing as spost
import spikeinterface.qualitymetrics as sqm
import spikeinterface.comparison as sc
import spikeinterface.exporters as sexp
import spikeinterface.widgets as sw

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

import warnings
warnings.simplefilter("ignore")

%matplotlib widget

### Define paths

In [None]:
base_folder = Path(r"D:\Vincent\Data\sc011\sc011_0106\sc011_0106_001")
file_path = r"D:\Vincent\Data\sc011\sc011_0106\sc011_0106_001\Record Node 101"

In [None]:
recording_saved = si.load_extractor(base_folder / "preprocessed")
print(recording_saved)
sorting = se.KiloSortSortingExtractor(base_folder / 'results_KS2_5')
print(sorting)

A `RecordingExtractor` object extracts information about channel ids, channel locations (if present), the sampling frequency of the recording, and the extracellular traces (when prompted).
Here we retrieve information from the recording using the built-in functions from the RecordingExtractor:

### Extract waveforms

In [None]:
# si.extract_waveforms?

In [None]:
job_kwargs = dict(n_jobs=10, chunk_duration="1s", progress_bar=True)

Extracts waveforms from a random subset of 500 spikes (`WaveformExtractor` default).

In [33]:
we = si.extract_waveforms(recording_saved, sorting, folder=base_folder / "waveforms", 
                          load_if_exists=True, overwrite=False, **job_kwargs)
print(we)

WaveformExtractor: 384 channels - 252 units - 1 segments
  before:90 after:120 n_per_units:500


In [None]:
# w = sw.plot_unit_templates(we, radius_um=30, backend="ipywidgets")

In [None]:
# for unit in sorting.get_unit_ids():
#     waveforms = we.get_waveforms(unit_id=unit)
#     spiketrain = sorting.get_unit_spike_train(unit)
#     print(f"Unit {unit} - num waveforms: {waveforms.shape[0]} - num spikes: {len(spiketrain)}")

Extract waveforms for all spikes (use the `max_spikes_per_unit` argument).

In [11]:
we_all = si.extract_waveforms(recording_saved, sorting, folder=base_folder / "waveforms_all", 
                              max_spikes_per_unit=None,
                              overwrite=True,
                              **job_kwargs)

extract waveforms memmap:   0%|          | 0/713 [00:00<?, ?it/s]

In [12]:
for unit in sorting.get_unit_ids():
    waveforms = we_all.get_waveforms(unit_id=unit)
    spiketrain = sorting.get_unit_spike_train(unit)
    print(f"Unit {unit} - num waveforms: {waveforms.shape[0]} - num spikes: {len(spiketrain)}")

Unit 0 - num waveforms: 100 - num spikes: 100
Unit 1 - num waveforms: 4120 - num spikes: 4120
Unit 2 - num waveforms: 2762 - num spikes: 2762
Unit 3 - num waveforms: 586 - num spikes: 586
Unit 4 - num waveforms: 2205 - num spikes: 2205
Unit 5 - num waveforms: 978 - num spikes: 978
Unit 6 - num waveforms: 51 - num spikes: 51
Unit 7 - num waveforms: 57731 - num spikes: 57731
Unit 8 - num waveforms: 1028 - num spikes: 1028
Unit 9 - num waveforms: 15530 - num spikes: 15530
Unit 10 - num waveforms: 87 - num spikes: 87
Unit 11 - num waveforms: 31 - num spikes: 31
Unit 12 - num waveforms: 1492 - num spikes: 1492
Unit 13 - num waveforms: 13648 - num spikes: 13648
Unit 14 - num waveforms: 618 - num spikes: 618
Unit 15 - num waveforms: 12156 - num spikes: 12156
Unit 16 - num waveforms: 1437 - num spikes: 1437
Unit 17 - num waveforms: 63 - num spikes: 63
Unit 18 - num waveforms: 6146 - num spikes: 6146
Unit 19 - num waveforms: 808 - num spikes: 808
Unit 20 - num waveforms: 1132 - num spikes: 1132

In [78]:
we_KS = si.WaveformExtractor.create(recording_saved, sorting, 'waveforms', remove_if_exists=True)
we_KS.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500)
we_KS.run_extract_waveforms(n_jobs=-1, chunk_size=30000)
print(we_KS)

extract waveforms memmap:   0%|          | 0/713 [00:00<?, ?it/s]

WaveformExtractor: 384 channels - 255 units - 1 segments
  before:90 after:120 n_per_units:500


### Postprocessing

#### Sparsity

Especially when working with silicon high-density probes, or when our probe has multiple groups (e.g. multi-shank, tetrodes), we don't care about waveform/templates on all channels. In order to find a subset of channels for each unit, we can use the `get_template_channel_sparsity()`

In [None]:
# spost.get_template_channel_sparsity?

In [39]:
from spikeinterface import compute_sparsity

In [42]:
# example: radius
# sparsity_radius = spost.get_template_channel_sparsity(we, method="radius", radius_um=50)
sparsity_radius = compute_sparsity(we, method="radius", radius_um=50)
print(sparsity_radius)

ChannelSparsity - units: 252 - channels: 384 - ratio: 0.02


In [43]:
# example: best
# sparsity_best = spost.get_template_channel_sparsity(we, method="best_channels", num_channels=4)
sparsity_best = compute_sparsity(we, method="best_channels", num_channels=4)
print(sparsity_best)

ChannelSparsity - units: 252 - channels: 384 - ratio: 0.01


Most of the plotting and exporting functions accept `sparsity` as an argument. 

In [None]:
# sw.plot_unit_templates(we, sparsity=sparsity_radius, backend="ipywidgets")

In [None]:
# sw.plot_unit_templates(we, sparsity=sparsity_best, backend="ipywidgets")

### PCA scores

PCA scores can be easily computed with the `compute_principal_components()` function. Similarly to the `extract_waveforms`, the function returns an object of type `WaveformPrincipalComponent` that allows to retrieve all pc scores on demand.

In [None]:
# spost.compute_principal_components?

In [44]:
print(sparsity_radius)

ChannelSparsity - units: 252 - channels: 384 - ratio: 0.02


In [45]:
pc = spost.compute_principal_components(we, n_components=3,
                                        sparsity=sparsity_radius, 
                                        load_if_exists=False,
                                        n_jobs=job_kwargs["n_jobs"], 
                                        progress_bar=job_kwargs["progress_bar"])

Fitting PCA:   0%|          | 0/252 [00:00<?, ?it/s]

Projecting waveforms:   0%|          | 0/252 [00:00<?, ?it/s]

In [46]:
pc0 = pc.get_projections(unit_id=0)
print(f"PC scores shape: {pc0.shape}")
all_labels, all_pcs = pc.get_all_projections()
print(f"All PC scores shape: {all_pcs.shape}")

PC scores shape: (100, 3, 384)
All PC scores shape: (84946, 3, 384)


For pc scores of a single unit, the dimension is (num_spikes, num_components, num_channels). 

## WaveformExtensions

When we compute PCA (or use other postprocessing functions), the computed information is added to the waveform folder. The functions act as `WaveformExtensions`:

In [47]:
we.get_available_extension_names()

['template_metrics',
 'similarity',
 'principal_components',
 'spike_amplitudes',
 'correlograms',
 'spike_locations',
 'unit_locations',
 'quality_metrics']

Each `WaveformExtension` is an object that allows us to retrieve the data:

In [48]:
pc = we.load_extension("principal_components")
print(pc)

WaveformPrincipalComponent: 384 channels - 1 segments
  mode: by_channel_local n_components: 3 - sparse


In [49]:
all_labels, all_pcs = pc.get_data()
print(all_pcs.shape)

(84946, 3, 384)


### Spike amplitudes

Spike amplitudes can be computed with the `get_spike_amplitudes` function.

In [50]:
amplitudes = spost.compute_spike_amplitudes(we, outputs="by_unit", load_if_exists=True, 
                                            **job_kwargs)

By default, all amplitudes are concatenated in one array with all amplitudes form all spikes. With the `output="by_unit"` argument, instead, a dictionary is returned:

In [51]:
amplitudes[0]

{0: array([-59.329224, -86.951965, -69.41298 , -66.68459 , -69.58998 ,
        -69.062325, -44.64085 , -78.12652 , -75.40594 , -66.01581 ,
        -66.59062 , -77.278046, -45.41994 , -87.838875, -81.34638 ,
        -49.149414, -90.72847 , -57.381588, -61.831345, -70.35241 ,
        -37.907978, -59.72202 , -83.671295, -70.822235, -69.369865,
        -75.37647 , -64.431435, -42.243298, -73.74475 , -56.38531 ,
        -50.877014, -25.92337 , -71.94904 , -71.49883 , -62.23607 ,
        -77.28405 , -70.38199 , -68.14468 , -45.26394 , -43.810593,
        -29.263493, -49.511276, -38.211933, -47.412796, -66.76882 ,
        -61.59304 , -52.80741 , -79.04591 , -49.89466 , -64.09306 ,
        -40.185425, -52.4434  , -77.65495 , -54.893814, -70.382774,
        -75.63662 , -53.603996, -71.49326 , -47.18581 , -66.32992 ,
        -71.66963 , -57.32662 , -60.982105, -50.154976, -71.33039 ,
        -81.73359 , -39.969185, -29.938917, -51.807705, -22.383068,
        -61.981506, -46.41142 , -28.987514, -

In [52]:
# sw.plot_amplitudes(we, backend="ipywidgets")

### Compute unit and spike locations

When using silicon probes, we can estimate the unit (or spike) location with triangulation. This can be done either with a simple center of mass or by assuming a monopolar model:

$$V_{ext}(\boldsymbol{r_{ext}}) = \frac{I_n}{4 \pi \sigma |\boldsymbol{r_{ext}} - \boldsymbol{r_{n}}|}$$

where $\boldsymbol{r_{n}}$ is the position of the neuron, and $\boldsymbol{r_{n}}$ of the electrode(s).

In [53]:
unit_locations = spost.compute_unit_locations(we, method="monopolar_triangulation", load_if_exists=True)
spike_locations = spost.compute_spike_locations(we, method="center_of_mass", load_if_exists=True,
                                                **job_kwargs)

In [54]:
# sw.plot_unit_locations(we, backend="ipywidgets")

In [55]:
# sw.plot_spike_locations(we, max_spikes_per_unit=300, backend="ipywidgets")

### Compute correlograms

In [56]:
ccgs, bins = spost.compute_correlograms(we)

In [57]:
# sw.plot_autocorrelograms(we, unit_ids=sorting.unit_ids[:3])
# sw.plot_crosscorrelograms(we, unit_ids=sorting.unit_ids[:3])

### Compute template similarity

In [58]:
similarity = spost.compute_template_similarity(we)

### Compute template metrics

Template metrics, or extracellular features, such as peak to valley duration or full-width half maximum, are important to classify neurons into putative classes (excitatory - inhibitory). The `postprocessing` allows one to compute several of these metrics:

In [59]:
print(spost.get_template_metric_names())

['peak_to_valley', 'peak_trough_ratio', 'half_width', 'repolarization_slope', 'recovery_slope']


In [60]:
template_metrics = spost.calculate_template_metrics(we)
display(template_metrics)

Unnamed: 0,peak_to_valley,peak_trough_ratio,half_width,repolarization_slope,recovery_slope
0,0.000293,-0.554656,0.000197,469169.283462,-52823.757632
1,0.00036,-0.211997,0.000207,377229.949645,-26599.570886
2,0.000323,-0.484906,0.00016,593648.0388,-64672.64749
3,0.000133,-0.679876,0.0001,7257964.007012,-392783.036152
4,0.00042,-0.55181,0.000237,532783.294182,-103368.374543
...,...,...,...,...,...
250,0.000197,-0.371662,0.00012,1119072.083715,-71912.758016
251,0.000443,-0.444956,0.000173,409619.027693,-65846.267262
252,0.000153,-0.701695,0.000103,2919791.807888,-199895.883616
253,0.000633,-0.187739,0.000163,223063.857447,-15062.692726


In [27]:
# sw.plot_template_metrics(we, include_metrics=["peak_to_valley", "half_width"], 
#                          backend="ipywidgets")

AppLayout(children=(VBox(children=(Label(value='units:'), SelectMultiple(layout=Layout(height='10cm', width='3…

<spikeinterface.widgets.ipywidgets.template_metrics.TemplateMetricsPlotter at 0x2fb01c010d0>

For more information about these template metrics, we refer to this [documentation](https://github.com/AllenInstitute/ecephys_spike_sorting/tree/master/ecephys_spike_sorting/modules/mean_waveforms) from the Allen Institute.

# 7. Quality metrics and curation <a class="anchor" id="curation"></a>

The `qualitymetrics` module also provides several functions to compute qualitity metrics to validate the spike sorting results.

Let's see what metrics are available:

In [61]:
print(sqm.get_quality_metric_list())
print(sqm.get_quality_pca_metric_list())

['num_spikes', 'firing_rate', 'presence_ratio', 'snr', 'isi_violations', 'rp_violations', 'amplitude_cutoff']
['isolation_distance', 'l_ratio', 'd_prime', 'nearest_neighbor', 'nn_isolation', 'nn_noise_overlap']


In [None]:
from spikeinterface.qualitymetrics import compute_quality_metrics

In [80]:
# qm = sqm.compute_quality_metrics(we, sparsity=sparsity_radius, verbose=True, 
#                                  n_jobs=job_kwargs["n_jobs"])
qm = compute_quality_metrics(we, sparsity=sparsity_radius, verbose=True, 
                                 n_jobs=job_kwargs["n_jobs"])
# qm = si.compute_quality_metrics(we, metric_names=['snr', 'isi_violation', 'amplitude_cutoff'])
# metrics = si.compute_quality_metrics(we_KS, metric_names=['snr', 'isi_violation', 'amplitude_cutoff'])

Computing num_spikes
Computing firing_rate
Computing presence_ratio
Computing snr
Computing isi_violations
Computing rp_violations
Computing amplitude_cutoff


Computing PCA metrics:   0%|          | 0/252 [00:00<?, ?it/s]

In [81]:
display(qm)

Unnamed: 0,num_spikes,firing_rate,presence_ratio,snr,isi_violations_ratio,isi_violations_count,rp_contamination,rp_violations,amplitude_cutoff,isolation_distance,l_ratio,d_prime,nn_hit_rate,nn_miss_rate
0,100,0.140363,1.000000,4.579777,0.000000,0,0.000000,0,,98.494005,0.000110,-6.948646,0.952500,0.000000
1,4120,5.782940,1.000000,5.266185,0.000000,0,0.000000,0,0.000375,84.630560,0.000719,-5.969864,0.982500,0.001278
2,2762,3.876816,1.000000,5.323824,0.000000,0,0.000000,0,0.001206,773.881318,0.298357,-2.826627,0.772500,0.161000
3,586,0.822525,0.818182,33.317558,8.990325,13,1.000000,2,0.001141,475.574833,0.082893,-5.334323,0.917000,0.021388
4,2205,3.094996,1.000000,9.597290,0.000000,0,0.000000,0,0.000771,236.035645,0.000000,-10.938807,0.995000,0.001637
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
250,469,0.658301,1.000000,6.067339,2.159293,2,1.000000,1,,29.596832,0.460845,-3.219254,0.703625,0.043036
251,3835,5.382906,1.000000,7.311440,0.016147,1,0.024521,1,0.002072,37.198577,0.287017,3.779772,0.763000,0.008949
252,3899,5.472738,0.545455,17.996609,0.000000,0,0.000000,0,0.002074,100.953435,0.068867,-3.913120,0.830500,0.071262
253,916,1.285722,0.909091,6.600979,0.566065,2,1.000000,2,0.000984,40.831974,0.079904,4.591321,0.780000,0.027922


In [None]:
# sw.plot_quality_metrics(we, include_metrics=["amplitude_cutoff", "presence_ratio", "isi_violations_ratio", "snr"], 
#                         backend="ipywidgets")

For more information about these waveform features, we refer to the [SpikeInterface documentation](https://spikeinterface.readthedocs.io/en/latest/module_qualitymetrics.html) and to this excellent [documentation](https://allensdk.readthedocs.io/en/latest/_static/examples/nb/ecephys_quality_metrics.html) from the Allen Institute.

## Automatic curation based on quality metrics

A viable option to curate (or at least pre-curate) a spike sorting output is to filter units based on quality metrics. As we have already computed quality metrics a few lines above, we can simply filter the `qm` dataframe based on some thresholds.

Here, we'll only keep units with an ISI violation threshold < 0.2 and amplitude cutoff < 0.9:

In [82]:
isi_viol_thresh = 0.2
amp_cutoff_thresh = 0.1

A straightforward way to filter a pandas dataframe is via the `query`.
We first define our query (make sure the names match the column names of the dataframe):

In [85]:
our_query = f"amplitude_cutoff < {amp_cutoff_thresh} & isi_violations_ratio < {isi_viol_thresh}"
print(our_query)

amplitude_cutoff < 0.1 & isi_violations_ratio < 0.2


and then we can use the query to select units:

In [86]:
keep_units = qm.query(our_query)
keep_unit_ids = keep_units.index.values

In [87]:
sorting_auto_KS25 = sorting.select_units(keep_unit_ids)
print(f"Number of units before curation: {len(sorting.get_unit_ids())}")
print(f"Number of units after curation: {len(sorting_auto_KS25.get_unit_ids())}")

Number of units before curation: 255
Number of units after curation: 107


We can also save all the waveforms and post-processed data for curated units in a separate folder:

In [89]:
we_curated = we.select_units(keep_unit_ids, new_folder=base_folder / "waveforms_curated")

In [90]:
print(we_curated)

WaveformExtractor: 384 channels - 107 units - 1 segments
  before:90 after:120 n_per_units:500


In [91]:
we_curated.get_available_extension_names()

['template_metrics',
 'similarity',
 'principal_components',
 'spike_amplitudes',
 'correlograms',
 'spike_locations',
 'unit_locations',
 'quality_metrics']

### Viewers


#### SpikeInterface GUI

A QT-based GUI built on top of SpikeInterface objects.

Developed by Samuel Garcia, CRNL, Lyon.

In [None]:
#!sigui waveforms/

#### Sorting Summary - SortingView

The `sortingview` backend requires an additional step to configure the transfer of the data to be plotted to the cloud. 

See documentation [here](https://spikeinterface.readthedocs.io/en/latest/module_widgets.html).

Developed by Jeremy Magland and Jeff Soules, Flatiron Institute, NYC

In [None]:
w = sw.plot_sorting_summary(we_curated, sparsity=sparsity_radius, backend="sortingview")

In [None]:
w.view

### Exporters 

#### Export to Phy for manual curation

To perform manual curation we can export the data to [Phy](https://github.com/cortex-lab/phy). 

In [None]:
sexp.export_to_phy?

In [100]:
# sexp.export_to_phy(we, output_folder=base_folder / 'phy_KS25', 
#                    compute_amplitudes=False, compute_pc_features=False, copy_binary=True,
#                    **job_kwargs) 
sparsity_radius = compute_sparsity(we_curated, method="radius", radius_um=50)
sexp.export_to_phy(we_curated, output_folder=base_folder / 'phy_KS25', 
                   compute_amplitudes=False, compute_pc_features=False, copy_binary=True, sparsity=sparsity_radius,
                   **job_kwargs) 
# sparsity_best = compute_sparsity(we_curated, method="best_channels", num_channels=4)
# sexp.export_to_phy(we_curated, output_folder=base_folder / 'phy_KS25', 
#                    compute_amplitudes=False, compute_pc_features=False, copy_binary=True, sparsity=sparsity_best,
#                    **job_kwargs) 

write_binary_recording with n_jobs = 10 and chunk_size = 30000


write_binary_recording:   0%|          | 0/713 [00:00<?, ?it/s]

Run:
phy template-gui  D:\Vincent\Data\sc011\sc011_0106\sc011_0106_001\phy_KS25\params.py


There is a problem with the latest version of Phy so we need to set an environment variable to make it work properly:

- Python:
```
import os
os.environ["QTWEBENGINE_CHROMIUM_FLAGS"] = "--single-process"
```

- OR terminal:

  - Linux/MacOS:
`export QTWEBENGINE_CHROMIUM_FLAGS="--single-process"`

  - Windows:
`set QTWEBENGINE_CHROMIUM_FLAGS="--single-process"`

Then we can run the Phy GUI:

In [94]:
import os
os.environ["QTWEBENGINE_CHROMIUM_FLAGS"] = "--single-process"

In [95]:
%%capture --no-display
!phy template-gui phy_KS25/params.py

After curating the results we can reload it using the `PhySortingExtractor` and exclude the units that we labeled as `noise`:

In [None]:
sorting_phy_curated = se.PhySortingExtractor(base_folder / 'phy_KS25/', exclude_cluster_groups=['noise'])

In [None]:
print(f"Number of units before curation: {len(sorting.get_unit_ids())}")
print(f"Number of units after curation: {len(sorting_phy_curated.get_unit_ids())}")