In [1]:
import spikeinterface as si
# import spikeinterface.extractors as se 
# import spikeinterface.preprocessing as spre
# import spikeinterface.sorters as ss
import spikeinterface.postprocessing as spost
# import spikeinterface.qualitymetrics as sqm
# import spikeinterface.comparison as sc
# import spikeinterface.exporters as sexp
import spikeinterface.widgets as sw

In [2]:
# print(f"SpikeInterface version: {si.__version__}")

In [5]:
# import matplotlib.pyplot as plt
# import numpy as np
from pathlib import Path

# import warnings
# warnings.simplefilter("ignore")

# %matplotlib widget

In [6]:
base_folder = Path(r"/scratch2/weka/wanglab/prevosto/data/sc012/sc012_0123/sc012_0123_001")
file_path = base_folder.joinpath("Record Node 101")

### Extract waveforms

In [None]:
recording_saved = si.load_extractor(base_folder / "preprocessed")
sorting = sorting_KS25
print(sorting)

In [None]:
si.extract_waveforms?

In [None]:
we = si.extract_waveforms(recording_saved, sorting, folder=base_folder / "waveforms", 
                          load_if_exists=False, overwrite=True, **job_kwargs)
print(we)

In [None]:
waveforms0 = we.get_waveforms(unit_id=0)
print(f"Waveforms shape: {waveforms0.shape}")
template0 = we.get_template(unit_id=0)
print(f"Template shape: {template0.shape}")
all_templates = we.get_all_templates()
print(f"All templates shape: {all_templates.shape}")

In [None]:
w = sw.plot_unit_templates(we, radius_um=30, backend="ipywidgets")

In [None]:
for unit in sorting.get_unit_ids():
    waveforms = we.get_waveforms(unit_id=unit)
    spiketrain = sorting.get_unit_spike_train(unit)
    print(f"Unit {unit} - num waveforms: {waveforms.shape[0]} - num spikes: {len(spiketrain)}")

In [None]:
we_all = si.extract_waveforms(recording_saved, sorting, folder=base_folder / "waveforms_all", 
                              max_spikes_per_unit=None,
                              overwrite=True,
                              **job_kwargs)

In [None]:
for unit in sorting.get_unit_ids():
    waveforms = we_all.get_waveforms(unit_id=unit)
    spiketrain = sorting.get_unit_spike_train(unit)
    print(f"Unit {unit} - num waveforms: {waveforms.shape[0]} - num spikes: {len(spiketrain)}")

### Postprocessing

#### Sparsity

In [None]:
spost.get_template_channel_sparsity?

In [None]:
# example: radius
sparsity_radius = spost.get_template_channel_sparsity(we, method="radius", radius_um=50)
print(sparsity_radius[sorting.unit_ids[0]])

In [None]:
# example: best
sparsity_best = spost.get_template_channel_sparsity(we, method="best_channels", num_channels=4)
print(sparsity_best[sorting.unit_ids[0]])

In [None]:
sw.plot_unit_templates(we, sparsity=sparsity_radius, backend="ipywidgets")

In [None]:
sw.plot_unit_templates(we, sparsity=sparsity_best, backend="ipywidgets")

#### PCA scores

In [None]:
spost.compute_principal_components?

In [None]:
pc = spost.compute_principal_components(we, n_components=3,
                                        sparsity=sparsity_radius, 
                                        load_if_exists=False,
                                        n_jobs=job_kwargs["n_jobs"], 
                                        progress_bar=job_kwargs["progress_bar"])

In [None]:
pc0 = pc.get_projections(unit_id=0)
print(f"PC scores shape: {pc0.shape}")
all_labels, all_pcs = pc.get_all_projections()
print(f"All PC scores shape: {all_pcs.shape}")

#### WaveformExtensions

In [None]:
we.get_available_extension_names()

In [None]:
pc = we.load_extension("principal_components")
print(pc)

In [None]:
all_labels, all_pcs = pc.get_data()
print(all_pcs.shape)

#### Spike amplitudes

In [None]:
amplitudes = spost.compute_spike_amplitudes(we, outputs="by_unit", load_if_exists=True, 
                                            **job_kwargs)

In [None]:
amplitudes[0]

In [None]:
sw.plot_amplitudes(we, backend="ipywidgets")

#### Compute unit and spike locations

In [None]:
unit_locations = spost.compute_unit_locations(we, method="monopolar_triangulation", load_if_exists=True)
spike_locations = spost.compute_spike_locations(we, method="center_of_mass", load_if_exists=True,
                                                **job_kwargs)

In [None]:
sw.plot_unit_locations(we, backend="ipywidgets")

In [None]:
sw.plot_spike_locations(we, max_spikes_per_unit=300, backend="ipywidgets")

#### Compute correlograms

In [None]:
ccgs, bins = spost.compute_correlograms(we)

In [None]:
sw.plot_autocorrelograms(we, unit_ids=sorting.unit_ids[:3])
sw.plot_crosscorrelograms(we, unit_ids=sorting.unit_ids[:3])

#### Compute template similarity

In [None]:
similarity = spost.compute_template_similarity(we)

#### Compute template metrics
For more information about these template metrics, we refer to this [documentation](https://github.com/AllenInstitute/ecephys_spike_sorting/tree/master/ecephys_spike_sorting/modules/mean_waveforms) from the Allen Institute.

In [None]:
print(spost.get_template_metric_names())

In [None]:
template_metrics = spost.calculate_template_metrics(we)
display(template_metrics)

In [None]:
sw.plot_template_metrics(we, include_metrics=["peak_to_valley", "half_width"], 
                         backend="ipywidgets")

In [None]:
!ls waveforms

#### Quality metrics and curation
For more information about these waveform features, we refer to the [SpikeInterface documentation](https://spikeinterface.readthedocs.io/en/latest/module_qualitymetrics.html) and to this excellent [documentation](https://allensdk.readthedocs.io/en/latest/_static/examples/nb/ecephys_quality_metrics.html) from the Allen Institute.

In [None]:
print(sqm.get_quality_metric_list())
print(sqm.get_quality_pca_metric_list())

In [None]:
qm = sqm.compute_quality_metrics(we, sparsity=sparsity_radius, verbose=True, 
                                 n_jobs=job_kwargs["n_jobs"])

In [None]:
display(qm)

In [None]:
sw.plot_quality_metrics(we, include_metrics=["amplitude_cutoff", "presence_ratio", "isi_violations_ratio", "snr"], 
                        backend="ipywidgets")

#### Automatic curation based on quality metrics

In [None]:
isi_viol_thresh = 0.2
amp_cutoff_thresh = 0.1

In [None]:
our_query = f"amplitude_cutoff < {amp_cutoff_thresh} & isi_violations_rate < {isi_viol_thresh}"
print(our_query)

In [None]:
keep_units = qm.query(our_query)
keep_unit_ids = keep_units.index.values

In [None]:
sorting_auto_KS25 = sorting.select_units(keep_unit_ids)
print(f"Number of units before curation: {len(sorting.get_unit_ids())}")
print(f"Number of units after curation: {len(sorting_auto_KS25.get_unit_ids())}")

In [None]:
we.select_units?

In [None]:
we_curated = we.select_units(keep_unit_ids, new_folder="waveforms_cur")
# we_curated = we.select_units(keep_unit_ids, new_folder=None, use_relative_path="waveforms_curated")

In [None]:
print(we_curated)

In [None]:
we_curated.get_available_extension_names()

# 8. Viewers <a class="anchor" id="viewers"></a>


### Spike sorting comparison

#### Compare two sorters

In [None]:
comp_KS2_KS25 = sc.compare_two_sorters(sorting_KS2, sorting_KS25, 'KS2', 'KS25')

In [None]:
sw.plot_agreement_matrix(comp_KS2_KS25)

In [None]:
comp_KS2_KS25auto = sc.compare_two_sorters(sorting_KS2, sorting_auto_KS25, 'KS2', 'KS25_auto')

In [None]:
sw.plot_agreement_matrix(comp_KS2_KS25auto)

#### Compare multiple sorters

In [None]:
mcmp = sc.compare_multiple_sorters([sorting_KS2, sorting_KS25], ['KS2', 'KS25'], 
                                   spiketrain_mode='union', verbose=True)

In [None]:
w = sw.plot_multicomp_agreement(mcmp)
w = sw.plot_multicomp_agreement_by_sorter(mcmp)

In [None]:
sw.plot_multicomp_graph(mcmp, draw_labels=False)

In [None]:
agreement_sorting = mcmp.get_agreement_sorting(minimum_agreement_count=2)
print(agreement_sorting)

In [None]:
sw.plot_rasters(agreement_sorting)

In [None]:
# compare consensus and auto
comp_agr_auto = sc.compare_two_sorters(agreement_sorting, sorting_auto_KS25, 'AGR', 'AUTO')

In [None]:
sw.plot_agreement_matrix(comp_agr_auto)