# SpikeInterface pipeline for Tank Lab

In [None]:
import spikeextractors as se
import spiketoolkit as st
import spikesorters as ss
import spikecomparison as sc
import spikewidgets as sw

from nwb_conversion_tools.conversion_tools import save_si_object
from tank_lab_to_nwb import TowersProcessedNWBConverter

from isodate import duration_isoformat
from datetime import timedelta
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from pprint import pprint
%matplotlib notebook

## 1) Load AP recordings, LF recordings and TTL signals

In [None]:
#base_data_path = Path("/Users/abuccino/Documents/Data/catalyst/brody/A256_bank1_2020_09_30_g0")
base_data_path = Path("D:/Neuropixels/Neuropixels/A256_bank1_2020_09_30/A256_bank1_2020_09_30_g0")
session_name = "A256_bank1_2020_09_30_g0_t0"
ap_bin_path = base_data_path / f"{session_name}.imec0.ap.bin"
lf_bin_path = base_data_path / f"{session_name}.imec0.lf.bin"
recording_folder = ap_bin_path.parent

### Make spikeinterface folders

In [None]:
spikeinterface_folder = recording_folder / "spikeinterface"
working_folder = recording_folder / "working"

spikeinterface_folder.mkdir(parents=True, exist_ok=True)

In [None]:
# For testing purposes, shorten the recording
stub_percent = 2.5

base_recording = se.SpikeGLXRecordingExtractor(ap_bin_path)
recording_ap = se.SubRecordingExtractor(base_recording, end_frame=round(base_recording.get_num_frames()*stub_percent/100))

In [None]:
base_lf = se.SpikeGLXRecordingExtractor(lf_bin_path)
recording_lf = se.SubRecordingExtractor(base_lf, end_frame=round(base_lf.get_num_frames()*stub_percent/100))

In [None]:
print(f"Sampling frequency AP: {recording_ap.get_sampling_frequency()}")
print(f"Sampling frequency LF: {recording_lf.get_sampling_frequency()}")      

### Load TTL signals

In [None]:
ttl, states = recording_ap.get_ttl_events()
rising_times = ttl[states==1]

In [None]:
start_time = recording_ap.frame_to_time(rising_times[0])

In [None]:
start_frame_ap = int(recording_ap.time_to_frame(start_time))
start_frame_lf = int(recording_lf.time_to_frame(start_time))
print(f"Start frame AP: {start_frame_ap}")
print(f"Start frame LF: {start_frame_lf}")    

#### Synchronize recording

In [None]:
recording_ap_sync = se.SubRecordingExtractor(recording_ap, start_frame=start_frame_ap)
recording_lf_sync = se.SubRecordingExtractor(recording_lf, start_frame=start_frame_lf)

## 2) Pre-processing

In [None]:
apply_cmr = True

In [None]:
if apply_cmr:
    recording_processed = st.preprocessing.common_reference(recording_ap_sync)
else:
    recording_processed = recording_ap_sync

## 3) Run spike sorters

In [None]:
sorter_list = [
    'ironclust',
    # 'hdsort',
    # 'kilosort',
    'waveclus'
]
# Ensuring install location
ss.IronClustSorter.set_ironclust_path("D:/GitHub/ironclust")
# ss.HDSortSorter.set_hdsort_path("D:/GitHub/HDsort")
# ss.KilosortSorter.set_kilosort_path("D:/GitHub/KiloSort")
ss.WaveClusSorter.set_waveclus_path("D:/GitHub/wave_clus")

In [None]:
# Inspect sorter-specific parameters and defaults
for sorter in sorter_list:
    print(f"{sorter} params description:")
    pprint(ss.get_params_description(sorter))
    print("Default params:")
    pprint(ss.get_default_params(sorter))    

In [None]:
# user-specific parameters
sorter_params = dict(
    ironclust=dict(),
    # hdsort=dict(),
    # kilosort=dict(),
    waveclus=dict()
)

In [None]:
sorting_outputs = ss.run_sorters(
    sorter_list=sorter_list, 
    recording_dict_or_list=dict(rec0=recording_ap),
    working_folder=working_folder,
    sorter_params=sorter_params
)

## 4) Post-processing: extract waveforms, templates, quality metrics, extracellular features

### Set postprocessing parameters

In [None]:
postprocessing_params = st.postprocessing.get_common_params()
pprint(postprocessing_params)

### (optional) Manually set postprocessing parameters

In [None]:
postprocessing_params['max_spikes_per_unit'] = 1000  # with None, all waveforms are extracted

### Set quality metrics

In [None]:
qc_list = st.validation.get_quality_metrics_list()
print(f"Available quality metrics: {qc_list}")

### (optional) define a subset of quality metrics

In [None]:
# (optional) define subset of qc
qc_list = ['snr', 'isi_violation', 'firing_rate']

### Set extracellular features

In [None]:
ec_list = st.postprocessing.get_template_features_list()
print(f"Available EC features: {ec_list}")

### (optional) define a subset of extracellular features

In [None]:
# (optional) define subset of ec
ec_list = ['peak_to_valley', 'halfwidth']

### Postprocess all sorting outputs

In [None]:
for result_name, sorting in sorting_outputs.items():
    rec_name, sorter = result_name
    tmp_folder = spikeinterface_folder / 'tmp' / sorter
    tmp_folder.mkdir(parents=True)
    
    # set local tmp folder
    sorting.set_tmp_folder(tmp_folder)
    
    # compute waveforms
    waveforms = st.postprocessing.get_unit_waveforms(recording_processed, sorting, **postprocessing_params)
    
    # compute templates
    templates = st.postprocessing.get_unit_templates(recording_processed, sorting, **postprocessing_params)
    
    # comput EC features
    ec = st.postprocessing.compute_unit_template_features(recording_processed, sorting,
                                                          feature_names=ec_list, as_dataframe=True)
    # compute QCs
    qc = st.validation.compute_quality_metrics(sorting, recording=recording_processed, 
                                               metric_names=qc_list, as_dataframe=True)

## 5) Ensemble spike sorting

In [None]:
# retrieve sortings and sorter names
sorting_list = []
sorter_names_comp = []
for result_name, sorting in sorting_outputs.items():
    rec_name, sorter = result_name
    sorting_list.append(sorting)
    sorter_names_comp.append(sorter)

In [None]:
# run multisorting comparison
mcmp = sc.compare_multiple_sorters(sorting_list=sorting_list, name_list=sorter_names_comp)

In [None]:
# extract ensemble sorting
sorting_ensemble = mcmp.get_agreement_sorting(minimum_agreement_count=1)

In [None]:
# plot agreement results
w_agr = sw.plot_multicomp_agreement(mcmp)

# 6) Automatic curation

#### Define thresholds

In [None]:
isi_violation_threshold = 0.5
snr_threshold = 5
firing_rate_threshold = 0.1

#### Run curation

In [None]:
sorting_auto_curated = []
sorter_names_curation = []
for result_name, sorting in sorting_outputs.items():
    rec_name, sorter = result_name
    sorter_names_curation.append(sorter)
    
    # firing rate threshold
    sorting_curated = st.curation.threshold_firing_rates(sorting,
                                                         duration_in_frames=recording_processed.get_num_frames(),
                                                         threshold=firing_rate_threshold, 
                                                         threshold_sign='less')
    
    # isi violation threshold
    sorting_curated = st.curation.threshold_isi_violations(sorting,
                                                           duration_in_frames=recording_processed.get_num_frames(),
                                                           threshold=isi_violation_threshold, 
                                                           threshold_sign='greater')
    
    # isi violation threshold
    sorting_curated = st.curation.threshold_snrs(sorting,
                                                 recording=recording_processed,
                                                 threshold=snr_threshold, 
                                                 threshold_sign='less')
    sorting_auto_curated.append(sorting_curated)

## 7) Save outputs in spikeinterface folder

In [None]:
save_si_object("raw", recording_processed, spikeinterface_folder,
               cache_raw=False, include_properties=True, include_features=False)
save_si_object("sorting_ensemble", sorting_ensemble, spikeinterface_folder,
               cache_raw=False, include_properties=True, include_features=False)
save_si_object("sorter1", sorting_list[0], spikeinterface_folder,
               cache_raw=False, include_properties=True, include_features=False)

## Export to phy

In [None]:
st.postprocessing.export_to_phy