# SpikeInterface v0.101.2 - Adapted by Rodrigo Noseda - October 2024

SpikeInterface to analyze a multichannel dataset from Cambridge Neurotech Probes. 
The dataset is extracted using open-ephys DAQ and Bonsai-rx (in .bin).
Event_timestamps need some work.

# 0. Preparation <a class="anchor" id="preparation"></a>

In [None]:
import spikeinterface.full as si
print(f"SpikeInterface Version: {si.__version__}")

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pathlib import Path
import os
import csv
import glob
from datetime import datetime

import warnings
warnings.simplefilter("ignore")
%matplotlib widget
#%matplotlib inline

# 1. Loading recording and probe information <a class="anchor" id="loading"></a>

In [None]:
# Setting file paths and basic parameters
base_folder = Path('D:/Ephys_C2DRG/')
data_folder = Path("D:/Ephys_C2DRG/2023_9_19/")
#Pasted directly from explorer "C:\Users\rodri\Documents\Bonsai-RN\Bonsai_DataRN\2023_3_21\"

recording_paths_list = []
for filename in os.listdir(data_folder):
    if filename.startswith('RawEphysData') and filename.endswith('.bin'):
        recording_paths_list.append(data_folder / filename)
print('Recording Files List:')
print(recording_paths_list)

# parameters associated to the recording in bin format
num_channels = 64 #must know apriori; modify in probe below accordingly.
fs = 30000
gain_to_uV = 0.195
offset_to_uV = 0
rec_dtype = "float32"
time_axis = 0     
time_format = "%H:%M:%S.%f"
n_jobs = -1#-1 :equal to the number of cores.
job_kwargs = dict(n_jobs=n_jobs, chunk_duration="1s", progress_bar=True)

In [4]:
#Extract and append recording segments to Baserecording object
recordings_list = []
rec = si.read_binary(recording_paths_list, num_chan=num_channels,sampling_frequency=fs,
                           dtype=rec_dtype, gain_to_uV=gain_to_uV, offset_to_uV=offset_to_uV, 
                           time_axis=time_axis, is_filtered=False)
recordings_list.append(rec)#Appends all extracted rec to a list. Kilosort does not support segments. Use concatenation.
recording = si.concatenate_recordings(recordings_list)#Creates Object ConcatenateSegmentRecording

# 2. Preprocessing <a class="anchor" id="preprocessing"></a>

All preprocessing modules return new `RecordingExtractor` objects that apply the underlying preprocessing function. This allows users to access the preprocessed data in the same way as the raw data. We will focus only on the first shank (group `0`) for now.

In [6]:
recording_f = si.bandpass_filter(recording, freq_min=300, freq_max=6000)
recording_cmr = si.common_reference(recording_f, reference='global', operator='median')
recording_layers = [recording, recording_f, recording_cmr]

In [None]:
si.plot_traces(recording_layers, time_range=[390, 420], channel_ids=[8, 39, 45, 63],
                    return_scaled=True, show_channel_ids=True, backend="ipywidgets")

In [10]:
list_labels = ['art1', 'art2', 'art3', 'art4']
artifacts = {'art1': [390.0, 420.7], 'art2': [5 , 45], 'art3': [569.2, 599.6], 'art4': [90, 91]}
list_periods = [(17073622, 17976622)]
#One list per segment of tuples (start_frame, end_frame) to silence

In [82]:

recording_clean = si.silence_periods(recording_f, list_periods=list_periods, seed=0, mode='zeros')
#recording_clean = si.remove_artifacts(recording_f, list_triggers=triggers_in_frames, ms_before=14, ms_after=1, mode='linear')

In [None]:
recording_clean.save(format='binary', folder=data_folder / "recording_f_clean", **job_kwargs)

In [None]:
#Get probe from library and set channel mapping
import probeinterface as pi
from probeinterface.plotting import plot_probe
print(f"ProbeInterface version: {pi.__version__}")
manufacturer = 'cambridgeneurotech'
probe_name = 'ASSY-158-H10' #probe_name = 'ASSY-158-F' #probe_name = 'ASSY-158-H6'
probeH10 = pi.get_probe(manufacturer, probe_name)#library: comes with contact_ids and shank_ids info.

#Mapping Intan (device) channels
device_channel_indices = [24,23,25,22,26,21,27,20,28,19,29,18,30,17,31,16,0,15,1,14,2,13,3,12,4,11,5,10,6,9,7,8,
    56,55,57,54,58,53,59,52,60,51,61,50,62,49,63,48,32,47,33,46,34,45,35,44,36,43,37,42,38,41,39,40] #Modify accordingly.
#   88,87,89,86,90,85,91,84,92,83,93,82,94,81,95,80,64,79,65,78,66,77,67,76,68,75,69,74,70,73,71,72,
#   120,119,121,118,122,117,123,116,124,115,125,114,126,113,127,112,96,111,97,110,98,109,99,108,100,107,101,106,102,105,103,104]
#Setting Intan channels to probe(RHD-2132/2164)
probeH10.set_device_channel_indices(device_channel_indices) #print(probeH10.device_channel_indices)

In [None]:
#filename = 'traces_cached_seg0.raw'
#recording_f_clean = data_folder / 'recording_f_clean2' / filename
#recording_loaded = si.read_binary(recording_f_clean, num_chan=num_channels,sampling_frequency=fs,
#                           dtype=rec_dtype, gain_to_uV=gain_to_uV, offset_to_uV=offset_to_uV, 
#                           time_axis=time_axis, is_filtered=True)

recording_clean_prb = recording_clean.set_probe(probeH10, group_mode="by_shank")
recordings_by_group = recording_clean_prb.split_by("group")
recording_to_process = recordings_by_group[0]
recording_to_process_short = recording_to_process.time_slice(start_time=0, end_time=600)
print(recording_to_process_short)

In [None]:
si.plot_traces(recording_to_process_short, time_range=[0, 600], channel_ids=[8, 45, 63],
                    return_scaled=True, show_channel_ids=True, backend="ipywidgets")

# 3. Spike sorting <a class="anchor" id="spike-sorting"></a>

In [None]:
from pprint import pprint
default_KS4_params = si.get_default_sorter_params('kilosort4')
# Parameters can be changed by single arguments: 
default_KS4_params['batch_size'] = 60000 #2 sec
default_KS4_params['nblocks'] = 0 
#default_KS4_params['Th_universal'] = 8
#default_KS4_params['Th_learned'] = 6
#default_KS4_params['nearest_chans'] = 8 
default_KS4_params['nearest_templates'] = 32
#default_KS4_params['artifact_threshold'] = 20
#default_KS4_params['dmin'] = 30
#default_KS4_params['dminx'] = 30
#default_KS4_params['min_template_size'] = 15
#default_KS4_params['scale'] = 5
#default_KS4_params['duplicate_spike_ms'] = 0.25
default_KS4_params['skip_kilosort_preprocessing'] = True
default_KS4_params['do_correction'] = False
pprint(default_KS4_params)

In [None]:
#Run spike sorting on recording using docker container
sorting_KS4_s0 = si.run_sorter('kilosort4', recording_to_process_short, folder=data_folder / 'sorting_KS4_shank0d',
                            docker_image=True, verbose=True)#, **sorter_params, **job_kwargs)

In [None]:
print(sorting_KS4_s0)

In [None]:
w_rs = si.plot_rasters(sorting_KS4_s0, time_range=(0, 600), backend='matplotlib')

# 6. Postprocessing: SortingAnalyzer <a class="anchor" id="sortinganalyzer"></a>

The core module uses `SortingAnalyzer` for postprocessing computation from paired recording-sorting objects. It retrieves waveforms, templates, spike amplitudes, etc.

In [None]:
#sparsity = si.estimate_sparsity?
sparsity = si.estimate_sparsity(sorting_KS4_s0,recording_f_clean, num_spikes_for_sparsity=200, method="radius",
                                radius_um=100, peak_sign="neg", amplitude_mode="extremum")
print(sparsity)
#for unit_id in sparsity.unit_ids[::30]:
#    print(unit_id, list(sparsity.unit_id_to_channel_ids[unit_id]))
#most of the plotting, computation and export functions are using 'sparsity' in the background.

In [44]:
#si.create_sorting_analyzer?
sa = si.create_sorting_analyzer(sorting_KS4_s0, recording_clean5, folder=data_folder / "sorting_analyzer_KS4_s0", 
                              format="binary_folder", sparsity=sparsity, overwrite=True, **job_kwargs)
#Saving Analyzer in specific format and loading it from saved
#sa.save_as(format="zarr",folder=data_folder / "sorting_analyzer")

#### Computing Extensions: PCA, waveforms, templates, spike amplitude, correlograms, etc.

Let's move on to explore the postprocessing capabilities of the `postprocessing` module. Similarly to the `SortingAnalizer` object, the method 'compute` retrieve info on demand.

In [None]:
all_computable_extensions = sa.get_computable_extensions()
print(all_computable_extensions)

In [None]:
#SortingAnalizer computations: each call will recompute and overwrite previous computations
rand = sa.compute("random_spikes", method="uniform", max_spikes_per_unit=500)#subsample to create a template
wf = sa.compute("waveforms", ms_before=1, ms_after=2, **job_kwargs)
templ =sa.compute("templates", operators=["average", "median", "std"])#from raw waveforms or random_spikes
spk_amp = sa.compute("spike_amplitudes", peak_sign="neg")#based on templates
noise = sa.compute("noise_levels")
amp_scal = sa.compute("amplitude_scalings")#per channel
pca = sa.compute("principal_components", n_components=3, mode="by_channel_local")
corr = sa.compute("correlograms", window_ms=50.0, bin_ms=1.0, method="auto")
isi = sa.compute("isi_histograms", window_ms=50.0, bin_ms=1.0, method="auto")
spk_loc = sa.compute("spike_locations", method="center_of_mass")#need for drift metrics (drift_ptp, drift_std, drift_mad)
templ_sim = sa.compute("template_similarity")#need for spikeinterface_gui
u_loc = sa.compute("unit_locations", method="monopolar_triangulation")
templ_metric = sa.compute("template_metrics")
qm = sa.compute("quality_metrics")

In [None]:
sparsity2 = si.compute_sparsity(sa,recording_f_clean)
print(sparsity2)

Extensions are generally saved in two ways: 

sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, format="memory")

sorting_analyzer.save_as(folder="my_sorting_analyzer")
sorting_analyzer.compute("random_spikes", save=True)

Here the random_spikes extension is not saved. The sorting_analyzer is still saved in memory. The save_as method only made a snapshot of the sorting analyzer which is saved in a folder. This is useful when trying out different parameters and initially setting up your pipeline. If we wanted to save the extension we should have started with a non-memory sorting analyzer:

sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, format="binary_folder", folder="my_sorting_analyzer")
sorting_analyzer.compute("random_spikes", save=True)

NOTE: We recommend choosing a workflow and sticking with it. Either keep everything on disk or keep everything in memory until you’d like to save. A mixture can lead to unexpected behavior.

# 7. Quality Metrics <a class="anchor" id="qualitymetrics"></a>

#### Metrics for Spikes

In [None]:
si.get_default_qm_params()
si.get_quality_metric_list()

In [21]:
#Amplitud cutoff (calculate the approximate fraction of missing spikes)
#Need "spike_amplitudes"
fraction_missing = si.compute_amplitude_cutoffs(sa, peak_sign="neg")

#Amplitud CV (coefficient of variation)
#Need "spike_amplitudes" or "amplitude_scalings" pre-computed.
amplitude_cv_median, amplitude_cv_range = si.compute_amplitude_cv_metrics(sa)
#dicts: unit ids as keys, and amplitude_cv metrics as values.

#Drift metrics
#Need "spike_locations"
drift_ptps, drift_stds, drift_mads = si.compute_drift_metrics(sa)
#dicts: unit ids as keys, and drifts metrics as values.

#Firing Range (outside of physiological range, might indicate noise contamination)
firing_range = si.compute_firing_ranges(sa)
#dict: unit IDs as keys, firing_range as values (in Hz).

#Firing Rate (average number of spikes/sec within the recording)
firing_rate = si.compute_firing_rates(sa)
#dict or floats: unit IDs as keys, firing rates across segments as values (in Hz).

#Inter-spike-interval (ISI) Violations (rate of refractory period violations)
isi_violations_ratio, isi_violations_count = si.compute_isi_violations(sa, isi_threshold_ms=1.5) 
#dicts: unit ids as keys, and isi ratio viol and number of viol as values.

#Presence Ratio (proportion of discrete time bins in which at least one spike occurred)
presence_ratio = si.compute_presence_ratios(sa)
#dict: unit IDs as keys, presence ratio (between 0 and 1) as values.
#Close or > 0.9 = complete units.
#Close to 0 = incompleteness (type II error) or highly selective firing pattern.

#Standard Deviation (SD) ratio
sd_ratio = si.compute_sd_ratio(sa, censored_period_ms=4.0)
#Close to 1 = unit from single neuron.

#Signal-to-noise ratio (SNR)
SNRs = si.compute_snrs(sa)
#dict: unit IDs as keys and their SNRs as values.
#High SNR = likely to correspond to a neuron. Low SNR = unit contaminated.

#Synchrony Metrics (characterize synchronous events within the same spike train and across different spike trains)
synchrony = si.compute_synchrony_metrics(sa, synchrony_sizes=(2, 4, 8))
#tuple of dicts with the synchrony metrics for each unit.

#### Metrics for Clusters

In [None]:
si.get_quality_pca_metric_list()

In [None]:
#Isolation Distance (distance from a cluster to the nearest other cluster)
iso_distance = si.pca_metrics.mahalanobis_metrics(all_pcs=all_pcs, all_labels=all_labels, this_unit_id=0)
#returns floats: iso_distance, l_ratio.

#Nearest Neighbor Metrics (evaluate unit quality)
si.pca_metrics.nearest_neighbors_metrics(all_pcs, all_labels, this_unit_id, max_spikes, n_neighbors)
#Calculate unit contamination based on NearestNeighbors search in PCA space.
si.pca_metrics.nearest_neighbors_isolation(sa)
#Calculate unit isolation based on NearestNeighbors search in PCA space.
si.pca_metrics.nearest_neighbors_noise_overlap(sa)
#Calculate unit noise overlap based on NearestNeighbors search in PCA space.

#D-prime (estimate the classification accuracy between two units)
d_prime = si.lda_metrics(all_pcs=all_pcs, all_labels=all_labels, this_unit_id=0)
#returns a float (larger in well separated clusters)

#Silhouette score (ratio between the cohesiveness of a cluster and its separation from other clusters)
simple_sil_score = si.simplified_silhouette_score(all_pcs=all_pcs, all_labels=all_labels, this_unit_id=0)
#Close to 1 = good clustering. Close to -1 = poorly isolated cluster.

A straightforward way to filter a pandas dataframe is via the `query`.
We first define our query (make sure the names match the column names of the dataframe):
and then we can use the query to select units:

In [None]:
#Automatic curation based on parameters
isi_viol_thresh = 0.5
amp_cutoff_thresh = 0.1

#our_query = f"amplitude_cutoff < {amp_cutoff_thresh} & isi_violations_ratio < {isi_viol_thresh}"
our_query = f"isi_violations_ratio < {isi_viol_thresh}"
print(our_query)

keep_units = df.query(our_query)
keep_unit_ids = keep_units.index.values
keep_unit_ids

In [None]:
sorting_auto_KS4 = sorting_KS4_s0.select_units(keep_unit_ids)
print(f"Number of units before curation: {len(sorting_KS4_s0.get_unit_ids())}")
print(f"Number of units after curation: {len(sorting_auto_KS4.get_unit_ids())}")

In [None]:
sa_curated = sa.select_units(keep_unit_ids)

#Saving Analyzer in specific format and loading it from saved
sa_curated_saved = sa_curated.save_as(format="zarr", folder=data_folder / "sorting_analyzer_curated.zarr")
print(sa_curated_saved)

# 8. Viewers <a class="anchor" id="viewers"></a>
### SpikeInterface GUI
Can be run directly in a terminal with: 
sigui /path/to/analyzer

In [None]:
%gui qt
si.plot_sorting_summary(sorting_analyzer=sa, curation=True, backend='spikeinterface_gui')

In [None]:
templates = si.plot_unit_templates(sa, backend="ipywidgets")#templ
waveforms = si.plot_unit_waveforms(sa, backend="ipywidgets")#wf
unit_locations = si.plot_unit_locations(sa, backend="ipywidgets")#u_loc
spk_locations = si.plot_spike_locations(sa, backend="ipywidgets")#spk_loc
spk_amplitude = si.plot_amplitudes(sa, backend="ipywidgets")#spk_amp
template_similarity = si.plot_template_similarity(sa)#, backend="ipywidgets")#templ_sim
autocorr = si.plot_autocorrelograms(sa, unit_ids=[0, 1 ,2])#, backend="ipywidgets")#corr
crosscorr = si.plot_crosscorrelograms(sa, unit_ids=[0, 1, 2])#, backend="ipywidgets")#corr

### SortingView
Web-based, shareable (with link), sorter visualizer.

In [None]:
# One-time initialization (alternate method)
import kachery_cloud as kcl
kcl.init()

# Follow the instructions to associate the client with your GitHub user on the kachery-cloud network

In [None]:
sv = si.plot_sorting_summary(sa, curation=True, backend='sortingview')

# 9. Exporters <a class="anchor" id="exporters"></a>
#### Export to Phy for manual curation [Phy](https://github.com/cortex-lab/phy). 

In [None]:
si.export_to_phy(sa, output_folder=data_folder / 'phy_KS4_RN_s0_clean5_full', compute_pc_features=True,
                   copy_binary=True, dtype='float32', compute_amplitudes=True,
                   sparsity=sparsity, add_quality_metrics=True, add_template_metrics=True, 
                   template_mode='median', verbose=True,**job_kwargs)

After curating the results we can reload it using the `PhySortingExtractor` and exclude the units that we labeled as `noise`:

In [None]:
sorting_phy_curated = si.read_phy(data_folder / 'phy_KS4_RN/', exclude_cluster_groups=['noise'])
print(f"Number of units before curation: {len(sorting_KS4.get_unit_ids())}")
print(f"Number of units after curation: {len(sorting_phy_curated.get_unit_ids())}")
#Save the loaded curated phy into Spikeinterface object!!
#si.export_report(sa)

In [None]:
import spikeinterface_gui
app = spikeinterface_gui.mkQApp() 
win = spikeinterface_gui.MainWindow(sa, curation=True)
win.show()
app.exec_()