### Load modules

In [None]:
import spikeinterface as si
import spikeinterface.extractors as se 
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import spikeinterface.postprocessing as spost
import spikeinterface.qualitymetrics as sqm
import spikeinterface.comparison as sc
import spikeinterface.exporters as sexp
import spikeinterface.widgets as sw

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

import warnings
warnings.simplefilter("ignore")

%matplotlib widget

### Define paths

In [None]:
base_folder = Path(r"/scratch2/weka/wanglab/prevosto/data/sc012/sc012_0123/sc012_0123_001")
file_path = base_folder.joinpath("Record Node 101")

In [None]:
# ls "$file_path"
!ls {base_folder}

In [None]:
ss.get_default_params('tridesclous')

In [None]:
# se.TridesclousSortingExtractor?
se.SpykingCircusSortingExtractor?

In [None]:
recording_saved = si.load_extractor(base_folder / "preprocessed")
print(recording_saved)
# sorting = se.KiloSortSortingExtractor(base_folder / 'results_KS2_5')
sorting = se.SpykingCircusSortingExtractor(base_folder / 'results_SC')
print(sorting)

A `RecordingExtractor` object extracts information about channel ids, channel locations (if present), the sampling frequency of the recording, and the extracellular traces (when prompted).
Here we retrieve information from the recording using the built-in functions from the RecordingExtractor:

### Extract waveforms

In [None]:
# si.extract_waveforms?

In [None]:
job_kwargs = dict(n_jobs=10, chunk_duration="1s", progress_bar=True)

Extracts waveforms from a random subset of 500 spikes (`WaveformExtractor` default).

In [None]:
we = si.extract_waveforms(recording_saved, sorting, folder=base_folder / "waveforms", 
                          load_if_exists=True, overwrite=False, **job_kwargs)
print(we)

In [None]:
# w = sw.plot_unit_templates(we, radius_um=30, backend="ipywidgets")

In [None]:
# for unit in sorting.get_unit_ids():
#     waveforms = we.get_waveforms(unit_id=unit)
#     spiketrain = sorting.get_unit_spike_train(unit)
#     print(f"Unit {unit} - num waveforms: {waveforms.shape[0]} - num spikes: {len(spiketrain)}")

Extract waveforms for all spikes (use the `max_spikes_per_unit` argument).

In [None]:
we_all = si.extract_waveforms(recording_saved, sorting, folder=base_folder / "waveforms_all", 
                              max_spikes_per_unit=None,
                              overwrite=True,
                              **job_kwargs)

In [None]:
for unit in sorting.get_unit_ids():
    waveforms = we_all.get_waveforms(unit_id=unit)
    spiketrain = sorting.get_unit_spike_train(unit)
    print(f"Unit {unit} - num waveforms: {waveforms.shape[0]} - num spikes: {len(spiketrain)}")

In [None]:
we_KS = si.WaveformExtractor.create(recording_saved, sorting, 'waveforms', remove_if_exists=True)
we_KS.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500)
we_KS.run_extract_waveforms(n_jobs=-1, chunk_size=30000)
print(we_KS)

### Postprocessing

#### Sparsity

Especially when working with silicon high-density probes, or when our probe has multiple groups (e.g. multi-shank, tetrodes), we don't care about waveform/templates on all channels. In order to find a subset of channels for each unit, we can use the `get_template_channel_sparsity()`

In [None]:
# spost.get_template_channel_sparsity?

In [None]:
from spikeinterface import compute_sparsity

In [None]:
# example: radius
# sparsity_radius = spost.get_template_channel_sparsity(we, method="radius", radius_um=50)
sparsity_radius = compute_sparsity(we, method="radius", radius_um=50)
print(sparsity_radius)

In [None]:
# example: best
# sparsity_best = spost.get_template_channel_sparsity(we, method="best_channels", num_channels=4)
sparsity_best = compute_sparsity(we, method="best_channels", num_channels=4)
print(sparsity_best)

Most of the plotting and exporting functions accept `sparsity` as an argument. 

In [None]:
# sw.plot_unit_templates(we, sparsity=sparsity_radius, backend="ipywidgets")

In [None]:
# sw.plot_unit_templates(we, sparsity=sparsity_best, backend="ipywidgets")

### PCA scores

PCA scores can be easily computed with the `compute_principal_components()` function. Similarly to the `extract_waveforms`, the function returns an object of type `WaveformPrincipalComponent` that allows to retrieve all pc scores on demand.

In [None]:
# spost.compute_principal_components?

In [None]:
print(sparsity_radius)

In [None]:
pc = spost.compute_principal_components(we, n_components=3,
                                        sparsity=sparsity_radius, 
                                        load_if_exists=False,
                                        n_jobs=job_kwargs["n_jobs"], 
                                        progress_bar=job_kwargs["progress_bar"])

In [None]:
pc0 = pc.get_projections(unit_id=0)
print(f"PC scores shape: {pc0.shape}")
all_labels, all_pcs = pc.get_all_projections()
print(f"All PC scores shape: {all_pcs.shape}")

For pc scores of a single unit, the dimension is (num_spikes, num_components, num_channels). 

## WaveformExtensions

When we compute PCA (or use other postprocessing functions), the computed information is added to the waveform folder. The functions act as `WaveformExtensions`:

In [None]:
we.get_available_extension_names()

Each `WaveformExtension` is an object that allows us to retrieve the data:

In [None]:
pc = we.load_extension("principal_components")
print(pc)

In [None]:
all_labels, all_pcs = pc.get_data()
print(all_pcs.shape)

### Spike amplitudes

Spike amplitudes can be computed with the `get_spike_amplitudes` function.

In [None]:
amplitudes = spost.compute_spike_amplitudes(we, outputs="by_unit", load_if_exists=True, 
                                            **job_kwargs)

By default, all amplitudes are concatenated in one array with all amplitudes form all spikes. With the `output="by_unit"` argument, instead, a dictionary is returned:

In [None]:
amplitudes[0]

In [None]:
# sw.plot_amplitudes(we, backend="ipywidgets")

### Compute unit and spike locations

When using silicon probes, we can estimate the unit (or spike) location with triangulation. This can be done either with a simple center of mass or by assuming a monopolar model:

$$V_{ext}(\boldsymbol{r_{ext}}) = \frac{I_n}{4 \pi \sigma |\boldsymbol{r_{ext}} - \boldsymbol{r_{n}}|}$$

where $\boldsymbol{r_{n}}$ is the position of the neuron, and $\boldsymbol{r_{n}}$ of the electrode(s).

In [None]:
unit_locations = spost.compute_unit_locations(we, method="monopolar_triangulation", load_if_exists=True)
spike_locations = spost.compute_spike_locations(we, method="center_of_mass", load_if_exists=True,
                                                **job_kwargs)

In [None]:
# sw.plot_unit_locations(we, backend="ipywidgets")

In [None]:
# sw.plot_spike_locations(we, max_spikes_per_unit=300, backend="ipywidgets")

### Compute correlograms

In [None]:
ccgs, bins = spost.compute_correlograms(we)

In [None]:
# sw.plot_autocorrelograms(we, unit_ids=sorting.unit_ids[:3])
# sw.plot_crosscorrelograms(we, unit_ids=sorting.unit_ids[:3])

### Compute template similarity

In [None]:
similarity = spost.compute_template_similarity(we)

### Compute template metrics

Template metrics, or extracellular features, such as peak to valley duration or full-width half maximum, are important to classify neurons into putative classes (excitatory - inhibitory). The `postprocessing` allows one to compute several of these metrics:

In [None]:
print(spost.get_template_metric_names())

In [None]:
template_metrics = spost.calculate_template_metrics(we)
display(template_metrics)

In [None]:
# sw.plot_template_metrics(we, include_metrics=["peak_to_valley", "half_width"], 
#                          backend="ipywidgets")

For more information about these template metrics, we refer to this [documentation](https://github.com/AllenInstitute/ecephys_spike_sorting/tree/master/ecephys_spike_sorting/modules/mean_waveforms) from the Allen Institute.

# 7. Quality metrics and curation <a class="anchor" id="curation"></a>

The `qualitymetrics` module also provides several functions to compute qualitity metrics to validate the spike sorting results.

Let's see what metrics are available:

In [None]:
print(sqm.get_quality_metric_list())
print(sqm.get_quality_pca_metric_list())

In [None]:
from spikeinterface.qualitymetrics import compute_quality_metrics

In [None]:
# qm = sqm.compute_quality_metrics(we, sparsity=sparsity_radius, verbose=True, 
#                                  n_jobs=job_kwargs["n_jobs"])
qm = compute_quality_metrics(we, sparsity=sparsity_radius, verbose=True, 
                                 n_jobs=job_kwargs["n_jobs"])
# qm = si.compute_quality_metrics(we, metric_names=['snr', 'isi_violation', 'amplitude_cutoff'])
# metrics = si.compute_quality_metrics(we_KS, metric_names=['snr', 'isi_violation', 'amplitude_cutoff'])

In [None]:
display(qm)

In [None]:
# sw.plot_quality_metrics(we, include_metrics=["amplitude_cutoff", "presence_ratio", "isi_violations_ratio", "snr"], 
#                         backend="ipywidgets")

For more information about these waveform features, we refer to the [SpikeInterface documentation](https://spikeinterface.readthedocs.io/en/latest/module_qualitymetrics.html) and to this excellent [documentation](https://allensdk.readthedocs.io/en/latest/_static/examples/nb/ecephys_quality_metrics.html) from the Allen Institute.

## Automatic curation based on quality metrics

A viable option to curate (or at least pre-curate) a spike sorting output is to filter units based on quality metrics. As we have already computed quality metrics a few lines above, we can simply filter the `qm` dataframe based on some thresholds.

Here, we'll only keep units with an ISI violation threshold < 0.2 and amplitude cutoff < 0.9:

In [None]:
isi_viol_thresh = 0.2
amp_cutoff_thresh = 0.1

A straightforward way to filter a pandas dataframe is via the `query`.
We first define our query (make sure the names match the column names of the dataframe):

In [None]:
our_query = f"amplitude_cutoff < {amp_cutoff_thresh} & isi_violations_ratio < {isi_viol_thresh}"
print(our_query)

and then we can use the query to select units:

In [None]:
keep_units = qm.query(our_query)
keep_unit_ids = keep_units.index.values

In [None]:
sorting_auto_KS25 = sorting.select_units(keep_unit_ids)
print(f"Number of units before curation: {len(sorting.get_unit_ids())}")
print(f"Number of units after curation: {len(sorting_auto_KS25.get_unit_ids())}")

We can also save all the waveforms and post-processed data for curated units in a separate folder:

In [None]:
we_curated = we.select_units(keep_unit_ids, new_folder=base_folder / "waveforms_curated")

In [None]:
print(we_curated)

In [None]:
we_curated.get_available_extension_names()

### Viewers


#### SpikeInterface GUI

A QT-based GUI built on top of SpikeInterface objects.

Developed by Samuel Garcia, CRNL, Lyon.

In [None]:
#!sigui waveforms/

#### Sorting Summary - SortingView

The `sortingview` backend requires an additional step to configure the transfer of the data to be plotted to the cloud. 

See documentation [here](https://spikeinterface.readthedocs.io/en/latest/module_widgets.html).

Developed by Jeremy Magland and Jeff Soules, Flatiron Institute, NYC

In [None]:
w = sw.plot_sorting_summary(we_curated, sparsity=sparsity_radius, backend="sortingview")

In [None]:
w.view

### Exporters 

#### Export to Phy for manual curation

To perform manual curation we can export the data to [Phy](https://github.com/cortex-lab/phy). 

In [None]:
sexp.export_to_phy?

In [None]:
# sexp.export_to_phy(we, output_folder=base_folder / 'phy_KS25', 
#                    compute_amplitudes=False, compute_pc_features=False, copy_binary=True,
#                    **job_kwargs) 
sparsity_radius = compute_sparsity(we_curated, method="radius", radius_um=50)
sexp.export_to_phy(we_curated, output_folder=base_folder / 'phy_KS25', 
                   compute_amplitudes=False, compute_pc_features=False, copy_binary=True, sparsity=sparsity_radius,
                   **job_kwargs) 
# sparsity_best = compute_sparsity(we_curated, method="best_channels", num_channels=4)
# sexp.export_to_phy(we_curated, output_folder=base_folder / 'phy_KS25', 
#                    compute_amplitudes=False, compute_pc_features=False, copy_binary=True, sparsity=sparsity_best,
#                    **job_kwargs) 

There is a problem with the latest version of Phy so we need to set an environment variable to make it work properly:

- Python:
```
import os
os.environ["QTWEBENGINE_CHROMIUM_FLAGS"] = "--single-process"
```

- OR terminal:

  - Linux/MacOS:
`export QTWEBENGINE_CHROMIUM_FLAGS="--single-process"`

  - Windows:
`set QTWEBENGINE_CHROMIUM_FLAGS="--single-process"`

Then we can run the Phy GUI:

In [None]:
import os
os.environ["QTWEBENGINE_CHROMIUM_FLAGS"] = "--single-process"

In [None]:
%%capture --no-display
!phy template-gui phy_KS25/params.py

After curating the results we can reload it using the `PhySortingExtractor` and exclude the units that we labeled as `noise`:

In [None]:
sorting_phy_curated = se.PhySortingExtractor(base_folder / 'phy_KS25/', exclude_cluster_groups=['noise'])

In [None]:
print(f"Number of units before curation: {len(sorting.get_unit_ids())}")
print(f"Number of units after curation: {len(sorting_phy_curated.get_unit_ids())}")