# SpikeInterface v0.101.1 - Adapted by Rodrigo Noseda - October 2024

SpikeInterface to analyze a multichannel dataset from Cambridge Neurotech Probes. 
The dataset is extracted using open-ephys DAQ and Bonsai-rx (in .bin).
Event_timestamps need some work.

# 0. Preparation <a class="anchor" id="preparation"></a>

In [11]:
import spikeinterface.full as si
print(f"SpikeInterface Version: {si.__version__}")

SpikeInterface Version: 0.101.1


In [2]:
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import os
import csv
from datetime import datetime

import warnings
warnings.simplefilter("ignore")
%matplotlib widget
#%matplotlib inline

# 1. Loading recording and probe information <a class="anchor" id="loading"></a>

In [3]:
# file paths
base_folder = Path('D:/Ephys_C2DRG/')
data_folder = Path("D:/Ephys_C2DRG/2023_9_19/")
#Pasted directly from explorer "C:\Users\rodri\Documents\Bonsai-RN\Bonsai_DataRN\2023_3_21\"

recording_paths_list = []
for filename in os.listdir(data_folder):
    if filename.startswith('RawEphysData') and filename.endswith('.bin'):
        recording_paths_list.append(data_folder / filename)

print('Recording Files List:')
print(recording_paths_list)
n_files = len(recording_paths_list)  
#Dinamically create 'recording_n' variables
for i in range(n_files):
    globals()[f'recording{i}'] = f"{recording_paths_list[i]}" 

# parameters associated to the bin format
num_channels = 64 #must know apriori; modify in probe below accordingly.
fs = 30000
gain_to_uV = 0.195
offset_to_uV = 0
time_format = "%H:%M:%S.%f"
dtype = "float32"
time_axis = 0     

Recording Files List:
[WindowsPath('D:/Ephys_C2DRG/2023_9_19/RawEphysData_32Ch_ProbeF_Broken_0.bin'), WindowsPath('D:/Ephys_C2DRG/2023_9_19/RawEphysData_32Ch_ProbeF_Broken_1.bin')]


In [None]:
recordings_list = []
rec = si.read_binary(recording_paths_list, num_chan=num_channels,sampling_frequency=fs,
                           dtype=dtype, gain_to_uV=gain_to_uV, offset_to_uV=offset_to_uV, 
                           time_axis=time_axis, is_filtered=False)
recordings_list.append(rec)
#Appending segments of recordings (Better because concatenate breaks timeline!)
recording = si.append_recordings(recordings_list)
print(recording)
for i in range(recording.get_num_segments()):
    s = recording.get_num_samples(segment_index=i)
    print(f"Segment {i}: num_samples {s}")

In [34]:
basetime_paths_list = []
for filename_t in os.listdir(data_folder):
    if filename_t.startswith('Timestamps') and filename_t.endswith('.csv'):
        basetime_paths_list.append(data_folder / filename_t)

print('BaseTime Files List:')
print(basetime_paths_list)  

n_files = len(basetime_paths_list)  
#Dinamically create 'basetime_n' variables
for i in range(n_files):
    csv_file = Path(f"basetime{i}")
print(csv_file)

BaseTime Files List:
[WindowsPath('D:/Ephys_C2DRG/2023_9_19/TimestampsEphys_0.csv'), WindowsPath('D:/Ephys_C2DRG/2023_9_19/TimestampsEphys_1.csv')]
basetime1


In [None]:
import probeinterface as pi
from probeinterface.plotting import plot_probe
print(f"ProbeInterface version: {pi.__version__}")
manufacturer = 'cambridgeneurotech'
probe_name = 'ASSY-158-H10' #probe_name = 'ASSY-158-F' #probe_name = 'ASSY-158-H6'

In [None]:
#probe object from library comes with contact_ids and shank_ids info.
probeH10 = pi.get_probe(manufacturer, probe_name)
#Intan mapping 64 channels
device_channel_indices = [24,23,25,22,26,21,27,20,28,19,29,18,30,17,31,16,0,15,1,14,2,13,3,12,4,11,5,10,6,9,7,8,
                56,55,57,54,58,53,59,52,60,51,61,50,62,49,63,48,32,47,33,46,34,45,35,44,36,43,37,42,38,41,39,40] #Modify accordingly.
#                88,87,89,86,90,85,91,84,92,83,93,82,94,81,95,80,64,79,65,78,66,77,67,76,68,75,69,74,70,73,71,72,
#                120,119,121,118,122,117,123,116,124,115,125,114,126,113,127,112,96,111,97,110,98,109,99,108,100,107,101,106,102,105,103,104]
#setting intan channels to probe
probeH10.set_device_channel_indices(device_channel_indices) #print(probeH10.device_channel_indices)
fig, ax = plt.subplots(figsize=(5, 10))
plot_probe(probeH10, ax=ax, with_contact_id=True, with_device_index=True,)
ax.set_xlim(-20, 200)
ax.set_ylim(-60, 300)

probeH10.to_dataframe(complete=True).loc[:, ["contact_ids", "shank_ids", "device_channel_indices"]]
#probeF64.to_dataframe(complete=True).loc[:, ["contact_ids", "shank_ids", "device_channel_indices"]]

The probe now is loaded with contact_ids, device_ids and shank_id.
A probe (prb) or `probeinterface` object can be loaded directly to a SI recording object. A group can also be formed from each shank - 'by_shank'

In [None]:
recording_prb = recording.set_probe(probeH10, group_mode="by_probe")
print(recording_prb)

In [None]:
A = recording_prb.get_channel_ids()
recording_slice = recording_prb.channel_slice(A[0:7]) #channel_ids = list(range(0, 4))
si.plot_traces(recording_slice, segment_index=0, channel_ids=None,
                          time_range=(0, 1.25), mode='line', backend='ipywidgets', 
                          show_channel_ids= True, clim=None)

# 2. Preprocessing <a class="anchor" id="preprocessing"></a>

All preprocessing modules return new `RecordingExtractor` objects that apply the underlying preprocessing function. This allows users to access the preprocessed data in the same way as the raw data. We will focus only on the first shank (group `0`) for now.

In [None]:
recordings_by_group = recording_prb.split_by("group")
recording_to_process = recordings_by_group[0]
recording_f = si.bandpass_filter(recording_to_process, freq_min=300, freq_max=6000)
recording_cmr = si.common_reference(recording_f, reference='global', operator='median')
#recording_removeart = si.RemoveArtifactsRecording(recording_prb, list_triggers=triggers, ms_before=0.5, ms_after=3)
recording_to_process

In [None]:
#events0 = np.array([10.11, 10.17], [10.12, 10.18])
#events = list(np.ndarray([[10, 17], [2, 20]]))#, dtype=float)]
events = np.array([(10, 11), (12, 13)], dtype='f4, f4')
#[[10.11, 10.17], [10.12, 10.18]])
#events = event_list.tolist()
channel_ids = list(range(0, 4))
w = si.plot_traces({"filtered": recording_f, "common": recording_cmr}, mode='line',
                   segment_index=0, channel_ids=channel_ids, show_channel_ids=True, events=events,
                   time_range=[10, 11], backend='ipywidgets')

In [151]:
times = recording_slice.get_times(segment_index=0)
last_time = times[-1]

In [None]:
# Step 1: Load start times from timestamps in CSV files
for filename_st in os.listdir(data_folder):
    if filename_st.startswith('Timestamps') and filename_st.endswith('.csv'):
        csv_file = data_folder / filename_st
time_format = "%H:%M:%S.%f"
with open(csv_file, 'r') as ft:
        reader = csv.reader(ft)
        time_str = next(reader)
        start_time = "".join(time_str)
        start_time = start_time[:15]
        start_time = datetime.strptime(start_time, time_format)
        
#good for structured ttl, but start time get only from second file        
# Step 2: Load events times from the CSV file, convert into seconds and structured np.array
def load_event_times(csv_file):
    event_times = []
    with open(csv_file, 'r') as f:
        reader = csv.reader(f)
        for row in reader:
            time_str = row[0].strip()  # Ensure no leading/trailing spaces
            event_times.append(time_str)
    return event_times

def convert_to_seconds(event_times):
    #base_time = datetime.strptime(start_time, time_format)  # First event
    events_seconds_array = []
    for time_str in event_times:
        # Limit to microseconds since datetime only supports up to 6 decimal places
        current_time = datetime.strptime(time_str[:15], time_format)
        delta_seconds = (current_time - start_time).total_seconds()
        events_seconds_array.append(delta_seconds)
    return np.array(events_seconds_array)


def create_structured_array(events_seconds_array):
    dtype = np.dtype([('time', np.float64)])  # Structured dtype
    events_structured_array = np.zeros(len(events_seconds_array), dtype=dtype)
    events_structured_array['time'] = events_seconds_array
    return events_structured_array

# Step 3: Load TTL times from the CSV file, convert into seconds and structured np.array
def load_ttl_times(csv_file):
    ttl_times = []
    with open(csv_file, 'r') as f:
        reader = csv.reader(f)
        for row in reader:
            time_str = row[0].strip()  # Ensure no leading/trailing spaces
            ttl_times.append(time_str)
    return ttl_times

def convert_to_seconds(ttl_times):
    #base_time = datetime.strptime(ttl_times[0][:15], time_format)  # First event
    ttl_seconds_array = []
    for time_str in ttl_times:
        # Limit to microseconds since datetime only supports up to 6 decimal places
        current_time = datetime.strptime(time_str[:15], time_format)
        delta_seconds = (current_time - start_time).total_seconds()
        ttl_seconds_array.append(delta_seconds)
    return np.array(ttl_seconds_array)

def create_structured_array(ttl_seconds_array):
    dtype = np.dtype([('time', np.float64)])  # Structured dtype
    ttl_structured_array = np.zeros(len(ttl_seconds_array), dtype=dtype)
    ttl_structured_array['time'] = ttl_seconds_array
    return ttl_structured_array

# Example usage
for filename_ev in os.listdir(data_folder):
    if filename_ev.startswith('Events') and filename_ev.endswith('.csv'):
        csv_file = data_folder / filename_ev
event_times = load_event_times(csv_file)  # Load time events

for filename_ttl in os.listdir(data_folder):
    if filename_ttl.startswith('TTL') and filename_ttl.endswith('.csv'):
        csv_file = data_folder / filename_ttl
ttl_times = load_ttl_times(csv_file)  # Load time TTL events

#start_seconds_array = convert_to_seconds(start_times)  # Convert to seconds
events_seconds_array = convert_to_seconds(event_times)  # Convert to seconds
ttl_seconds_array = convert_to_seconds(ttl_times)  # Convert to seconds

#start_structured_array = create_structured_array(start_seconds_array)  # Create structured array
events_structured_array = create_structured_array(events_seconds_array)  # Create structured array
ttl_structured_array = create_structured_array(ttl_seconds_array)  # Create structured array

# Display the result
#print(start_structured_array)
print(events_structured_array)
print(ttl_structured_array)


In [87]:
import os
import glob
import numpy as np
import pandas as pd
from datetime import datetime
#good for start time.
# Define the directory where your CSV files are located
csv_directory = data_folder  # Change this to the correct path
time_format = "%H:%M:%S.%f"
# Helper function to extract the first row from a CSV and convert to datetime
def extract_start_time(file_path):
    first_row = pd.read_csv(file_path, header=None).iloc[0, 0].strip()
    return datetime.strptime(first_row[:15], time_format).time()

# Find all timestamp files and extract their start times
timestamp_files = sorted(glob.glob(os.path.join(csv_directory, "Timestamps*.csv")))
start_times = [extract_start_time(file) for file in timestamp_files]

In [89]:
def time_to_datetime(time_str):
        return datetime.strptime(time_str[:15], '%H:%M:%S.%f')

start_times = []
ttl_times = []
event_times = []

# Process Timestamps files
for file in os.listdir(csv_directory):
    if file.startswith('Timestamps') and file.endswith('.csv'):
        df = pd.read_csv(os.path.join(csv_directory, file), header=None)
        first_row = df.iloc[0, 0]
        start_time = time_to_datetime(first_row)
        start_times.append(start_time)

# Process TTL files
for file in os.listdir(csv_directory):
    if file.startswith('TTL') and file.endswith('.csv'):
        df = pd.read_csv(os.path.join(csv_directory, file), header=None)
        ttl_time_list = df[0].apply(time_to_datetime).tolist()
        ttl_times.append(ttl_time_list)

In [95]:
# Helper function to extract ttl time events from a CSV and convert to datetime
def extract_ttl_time(file_path):
    ttl_times = pd.read_csv(file_path, header=None).iloc[0, 0].strip()
    return datetime.strptime(ttl_times[:15], time_format).time()
#Good but extract only first row
# Find all TTL files and extract their times
ttl_files = sorted(glob.glob(os.path.join(csv_directory, "TTL*.csv")))
ttl_times = [extract_ttl_time(file) for file in ttl_files]

In [None]:
# Function to calculate time differences and store in structured numpy array
def process_event_file(file_path, start_time):
    # Load the time events from the file
    time_events = pd.read_csv(file_path, header=None)[0].str.strip()
    
    # Convert the time events to datetime objects with error handling
    def parse_time(t):
        try:
            return datetime.strptime(t, "%H:%M:%S.%f")
        except ValueError:
            # Handle case where fractional seconds are missing
            return datetime.strptime(t, "%H:%M:%S")

    event_times = time_events.apply(parse_time)
    
    # Calculate the time differences in seconds
    time_diffs = (event_times - start_time).dt.total_seconds()
    
    # Create a structured array to store the time differences
    structured_array = np.zeros(len(time_diffs), dtype=[('time', 'f8')])
    structured_array['time'] = time_diffs.values
    return structured_array

# Process Event and TTL files
event_files = sorted(glob.glob(os.path.join(csv_directory, "Events*.csv")))
ttl_files = sorted(glob.glob(os.path.join(csv_directory, "TTL*.csv")))

# Ensure the number of timestamp files matches event and ttl files
assert len(timestamp_files) == len(event_files) == len(ttl_files), \
    "Mismatch in the number of timestamp, event, and TTL files."

# Store structured arrays for events and TTLs
events_structured_array = []
ttl_structured_array = []

for i in range(len(timestamp_files)):
    # Process events and TTL files with corresponding start times
    events_structured_array.append(process_event_file(event_files[i], start_times[i]))
    ttl_structured_array.append(process_event_file(ttl_files[i], start_times[i]))

# Example output (you can inspect the arrays or use them further)
print("Events Structured Array:")
for arr in events_structured_array:
    print(arr)

print("\nTTL Structured Array:")
for arr in ttl_structured_array:
    print(arr)


In [96]:
import os
import pandas as pd
from datetime import datetime

def extract_and_convert_csv_data(directory):
    # Initialize lists to store datetime objects
    start_times = []
    ttl_times = []
    event_times = []

    # Helper function to clean and convert time strings
    def convert_to_datetime(time_str):
        time_str = time_str.strip()  # Remove leading/trailing whitespace
        try:
            return datetime.strptime(time_str, "%H:%M:%S.%f")
        except ValueError:
            # If nanoseconds are truncated, try without %f
            return datetime.strptime(time_str, "%H:%M:%S")

    # Loop through all files in the directory
    for filename in os.listdir(directory):
        if filename.startswith("Timestamps"):
            # Extract the first row as a datetime object
            df = pd.read_csv(os.path.join(directory, filename), header=None)
            first_time = convert_to_datetime(df.iloc[0, 0])
            start_times.append(first_time)

        elif filename.startswith("TTL"):
            # Extract all rows as datetime objects
            df = pd.read_csv(os.path.join(directory, filename), header=None)
            times = [convert_to_datetime(time) for time in df[0]]
            ttl_times.append(times)

        elif filename.startswith("Events"):
            # Extract both columns as datetime objects
            df = pd.read_csv(os.path.join(directory, filename), header=None)
            event_time_pairs = [
                (convert_to_datetime(row[0]), convert_to_datetime(row[1]))
                for row in df.values
            ]
            event_times.append(event_time_pairs)

    return start_times, ttl_times, event_times

# Usage example
directory_path = data_folder  # Update with your directory path
start_times, ttl_times, event_times = extract_and_convert_csv_data(directory_path)

# Output the extracted data
print("Start Times:", start_times)
print("TTL Times:", ttl_times)
print("Event Times:", event_times)


ValueError: unconverted data remains: .5943552

In [None]:
import os
import pandas as pd
import numpy as np
from datetime import datetime

# Define the directory containing your CSV files
directory = data_folder

# Function to convert time strings to datetime objects, handling extra digits
def time_to_datetime(time_str):
    try:
        # Try parsing with standard microsecond precision
        return datetime.strptime(time_str, '%H:%M:%S.%f')
    except ValueError:
        # Handle extra digits by truncating to 6 decimal places
        truncated_time_str = time_str[:time_str.index('.') + 7]
        return datetime.strptime(truncated_time_str, '%H:%M:%S.%f')

# Initialize lists for storing data
start_times = []
ttl_times = []
event_times = []

# Process Timestamps files
for file in os.listdir(directory):
    if file.startswith('Timestamps') and file.endswith('.csv'):
        df = pd.read_csv(os.path.join(directory, file), header=None)
        first_row = df.iloc[0, 0]
        start_time = time_to_datetime(first_row)
        start_times.append(start_time)

# Process TTL files
for file in os.listdir(directory):
    if file.startswith('TTL') and file.endswith('.csv'):
        df = pd.read_csv(os.path.join(directory, file), header=None)
        ttl_time_list = df[0].apply(time_to_datetime).tolist()
        ttl_times.extend(ttl_time_list)

# Process Events files
for file in os.listdir(directory):
    if file.startswith('Events') and file.endswith('.csv'):
        df = pd.read_csv(os.path.join(directory, file), header=None)
        event_time_list = df.applymap(time_to_datetime).values.flatten().tolist()
        event_times.extend(event_time_list)

# Calculate time differences for TTL and Events relative to start times
ttl_diff_in_seconds = [
    (ttl_time - start_times[i % len(start_times)]).total_seconds()
    for i, ttl_time in enumerate(ttl_times)
]

event_diff_in_seconds = [
    (event_time - start_times[i % len(start_times)]).total_seconds()
    for i, event_time in enumerate(event_times)
]

# Create structured numpy arrays
ttl_structured_array = np.array(
    [(time,) for time in ttl_diff_in_seconds], dtype=[('time', 'f8')]
)

events_structured_array = np.array(
    [(time,) for time in event_diff_in_seconds], dtype=[('time', 'f8')]
)

# Optional: Display results for verification
print("Start Times:", start_times)
print("TTL Structured Array:")
print(ttl_structured_array)
print("Events Structured Array:")
print(events_structured_array)


In [None]:
channel_ids = si.event.channel_ids
num_channels = event.get_num_channels()
# get structured dtype for the first channel
event_dtype = event.get_dtype(channel_ids[0])
print(event_dtype)
# >>> dtype([('time', '<f8'), ('duration', '<f8'), ('label', '<U100')])

# retrieve events (with structured dtype)
events = event.get_events(channel_id=channel_ids[0], segment_index=0)
# retrieve event times
event_times = event.get_event_times(channel_id=channel_ids[0], segment_index=0)

## Take only 5 min. for demo

Since we are going to spike sort the data, let's first cut out a 5-minute recording, to speed up computations.

We can easily do so with the `frame_slice()` function:

In [None]:
recording_sub = recording_cmr.frame_slice(start_frame=0*fs, end_frame=300*fs)
print(recording_sub)

# 3. Saving and loading SpikeInterface objects <a class="anchor" id="save-load"></a>

All operations in SpikeInterface are *lazy*, meaning that they are not performed if not needed. This is why the creation of our filter recording was almost instantaneous. However, to speed up further processing, we might want to **save** it to a file and perform those operations (eg. filters, CMR, etc.) at once. 

Note: you can use the si.set_global_job_kwargs() to set job_kwargs globally for the entire session!

In [16]:
n_cpus = os.cpu_count()
n_jobs = n_cpus - 2 #n_jobs = -1 :equal to the number of cores.
job_kwargs = dict(n_jobs=n_jobs, chunk_duration="1s", progress_bar=True)
#global_job_kwargs = dict(n_jobs=n_jobs, chunk_duration="1s", progress_bar=True)
#si.set_global_job_kwargs(global_job_kwargs)

In [None]:
if (data_folder / "preprocessed").is_dir():
    recording_saved = si.load_extractor(data_folder / "preprocessed")
else:
    recording_saved = recording_cmr.save(folder=data_folder / "preprocessed", **job_kwargs)
    
print(recording_saved)
print(f'Cached channels ids:\n {recording_saved.get_channel_ids()}')
print(f'Channel groups after caching:\n {recording_saved.get_channel_groups()}')

The `traces_cached_seg0.raw` contains the processed raw data, while the `.json` files include information on how to reload the binary file. The `provenance.json` includes the information of the recording before saving it to a binary file, and the `probe.json` represents the probe object. The `save` returns a new *cached* recording that has all the previously loaded information: 

After saving the SI object, we can easily load it back in a new session:

In [None]:
recording_loaded = si.load_extractor(data_folder/"preprocessed")
print(f'Loaded channels ids: {recording_loaded.get_channel_ids()}')
print(f'Channel groups after loading: {recording_loaded.get_channel_groups()}')

We can double check that the traces are exactly the same as the `recording_saved` that we saved:

In [None]:
fig, axs = plt.subplots(ncols=2)
w_saved = si.plot_timeseries(recording_saved, ax=axs[0])
w_loaded = si.plot_timeseries(recording_loaded, ax=axs[1])
axs[0].set_title("Saved")
axs[1].set_title("Loaded")

**IMPORTANT**: the same saving mechanisms are available also for all SortingExtractor

# 5. Spike sorting <a class="anchor" id="spike-sorting"></a>

We can now run spike sorting on the above recording. We will use different spike sorters for this demonstration, to show how easy SpikeInterface makes it easy to interchengably run different sorters :)

Let's first check the installed sorters in `SpikeInterface` to see if `tridesclous` is available. Then we can then check the `tridesclous` default parameters.
We will sort the bandpass cached filtered recording the `recording_saved` object.

In [None]:
si.installed_sorters()

In [None]:
from pprint import pprint
default_KS4_params = si.get_default_sorter_params('kilosort4')
# Parameters can be changed by single arguments: 
#default_KS4_params['Th_universal'] = 9
#sorter_params = {'do_correction': False} #??
pprint(default_KS4_params)

In [None]:
si.run_sorter?

In [None]:
# run spike sorting on recording
#sorter_params = {'do_correction': False}
sorting_KS4 = si.run_sorter('kilosort4', recording_sub, 
                            output_folder=data_folder / 'results_KS4',
                            docker_image=True, verbose=True)#, **sorter_params, **job_kwargs)

In [None]:
sorting_KS4

In [None]:
sorting_saved_KS4 = sorting_KS4.save(folder=data_folder / "sorting_KS4")

In [None]:
sorting_loaded_KS4 = si.load_extractor(data_folder / "sorting_KS4")
sorting_loaded_KS4
#sorting_KS4 = si.read_sorter_folder(data_folder/"results_KS4")

We can use `spikewidgets` functions for some quick visualizations:

In [None]:
w_rs = si.plot_rasters(sorting_KS4, time_range=(0, 60), backend='matplotlib')

# 6. SortingAnalyzer <a class="anchor" id="sortinganalyzer"></a>

The core module uses `SortingAnalyzer` for postprocessing computation from paired recording-sorting objects. It retrieves waveforms, templates, spike amplitudes, etc.

In [None]:
si.create_sorting_analyzer?

In [None]:
sa = si.create_sorting_analyzer(sorting_KS4, recording_sub, folder=data_folder / "sorting_analyzer_3m", 
                              format="binary_folder", sparse=True, overwrite=True, **job_kwargs)

In [None]:
#Saving Analyzer in specific format and loading it from saved
#sa.save_as(format="zarr",folder=data_folder / "sorting_analyzer_3m")
#sa_bin = si.load_sorting_analyzer(folder=data_folder / "sorting_analyzer_3m")
#sa_zarr = si.load_sorting_analyzer(folder=data_folder / "sorting_analyzer_3m.zarr")

# 7. Postprocessing <a class="anchor" id="postprocessing"></a>

### Computing Extensions: PCA, waveforms, templates, spike amplitude, correlograms, etc.

Let's move on to explore the postprocessing capabilities of the `postprocessing` module. Similarly to the `SortingAnalizer` object, the method 'compute` retrieve info on demand.

In [None]:
all_computable_extensions = sa.get_computable_extensions()
print(all_computable_extensions)

In [None]:
#each call will recompute and overwrite previous computations
sa.compute("random_spikes")#subsample to create a template
wf = sa.compute("waveforms", ms_before=1.5, ms_after=2.5)
sa.compute("templates")#from raw waveforms or random_spikes
sa.compute("spike_amplitudes", peak_sign="neg")#based on templates
sa.compute("noise_levels")#per channel
sa.compute("principal_components", n_components=3, mode="by_channel_local")
sa.compute("correlograms", window_ms=50.0, bin_ms=1.0, method="auto")
sa.compute("isi_histograms", window_ms=50.0, bin_ms=1.0, method="auto")
sa.compute("spike_locations")#need for drift metrics (drift_ptp, drift_std, drift_mad)
sa.compute("unit_locations", "template_metrics", "quality_metrics")

Extensions are generally saved in two ways: 

sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, format="memory")

sorting_analyzer.save_as(folder="my_sorting_analyzer")
sorting_analyzer.compute("random_spikes", save=True)

Here the random_spikes extension is not saved. The sorting_analyzer is still saved in memory. The save_as method only made a snapshot of the sorting analyzer which is saved in a folder. This is useful when trying out different parameters and initially setting up your pipeline. If we wanted to save the extension we should have started with a non-memory sorting analyzer:

sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, format="binary_folder", folder="my_sorting_analyzer")
sorting_analyzer.compute("random_spikes", save=True)

NOTE: We recommend choosing a workflow and sticking with it. Either keep everything on disk or keep everything in memory until you’d like to save. A mixture can lead to unexpected behavior.

# 8. Quality Metrics <a class="anchor" id="qualitymetrics"></a>

#### Metrics for Spikes

In [None]:
#Amplitud cutoff (calculate the approximate fraction of missing spikes)
#Need "spike_amplitudes"
fraction_missing = si.compute_amplitude_cutoffs(sa, peak_sign="neg")

#Amplitud CV (coefficient of variation)
#Need "spike_amplitudes" or "amplitude_scalings" pre-computed.
amplitude_cv_median, amplitude_cv_range = si.compute_amplitude_cv_metrics(sa)
#dicts: unit ids as keys, and amplitude_cv metrics as values.

#Drift metrics
#Need "spike_locations"
drift_ptps, drift_stds, drift_mads = si.compute_drift_metrics(sa)
#dicts: unit ids as keys, and drifts metrics as values.

#Firing Range (outside of physiological range, might indicate noise contamination)
firing_range = si.compute_firing_ranges(sa)
#dict: unit IDs as keys, firing_range as values (in Hz).

#Firing Rate (average number of spikes/sec within the recording)
firing_rate = si.compute_firing_rates(sa)
#dict or floats: unit IDs as keys, firing rates across segments as values (in Hz).

#Inter-spike-interval (ISI) Violations (rate of refractory period violations)
isi_violations_ratio, isi_violations_count = si.compute_isi_violations(sa, isi_threshold_ms=1.0) 
#dicts: unit ids as keys, and isi ratio viol and number of viol as values.

#Presence Ratio (proportion of discrete time bins in which at least one spike occurred)
presence_ratio = si.compute_presence_ratios(sa)
#dict: unit IDs as keys, presence ratio (between 0 and 1) as values.
#Close or > 0.9 = complete units.
#Close to 0 = incompleteness (type II error) or highly selective firing pattern.

#Standard Deviation (SD) ratio
sd_ratio = si.compute_sd_ratio(sa, censored_period_ms=4.0)
#Close to 1 = unit from single neuron.

#Signal-to-noise ratio (SNR)
SNRs = si.compute_snrs(sa)
#dict: unit IDs as keys and their SNRs as values.
#High SNR = likely to correspond to a neuron. Low SNR = unit contaminated.

#Synchrony Metrics (characterize synchronous events within the same spike train and across different spike trains)
synchrony = si.compute_synchrony_metrics(sa, synchrony_sizes=(2, 4, 8))
#tuple of dicts with the synchrony metrics for each unit.

#### Metrics for Clusters

In [None]:
#Isolation Distance (distance from a cluster to the nearest other cluster)
iso_distance = si.pca_metrics.mahalanobis_metrics(all_pcs=all_pcs, all_labels=all_labels, this_unit_id=0)
#returns floats: iso_distance, l_ratio.

#Nearest Neighbor Metrics (evaluate unit quality)
si.pca_metrics.nearest_neighbors_metrics(all_pcs, all_labels, this_unit_id, max_spikes, n_neighbors)
#Calculate unit contamination based on NearestNeighbors search in PCA space.
si.pca_metrics.nearest_neighbors_isolation(sa)
#Calculate unit isolation based on NearestNeighbors search in PCA space.
si.pca_metrics.nearest_neighbors_noise_overlap(sa)
#Calculate unit noise overlap based on NearestNeighbors search in PCA space.

#D-prime (estimate the classification accuracy between two units)
d_prime = si.lda_metrics(all_pcs=all_pcs, all_labels=all_labels, this_unit_id=0)
#returns a float (larger in well separated clusters)

#Silhouette score (ratio between the cohesiveness of a cluster and its separation from other clusters)
simple_sil_score = si.simplified_silhouette_score(all_pcs=all_pcs, all_labels=all_labels, this_unit_id=0)
#Close to 1 = good clustering. Close to -1 = poorly isolated cluster.



A straightforward way to filter a pandas dataframe is via the `query`.
We first define our query (make sure the names match the column names of the dataframe):

In [None]:
our_query = f"amplitude_cutoff < {amp_cutoff_thresh} & isi_violations_ratio < {isi_viol_thresh}"
print(our_query)

and then we can use the query to select units:

In [None]:
keep_units = qm.query(our_query)
keep_unit_ids = keep_units.index.values

In [None]:
sorting_auto = sorting_SC2.select_units(keep_unit_ids)
print(f"Number of units before curation: {len(sorting_SC2.get_unit_ids())}")
print(f"Number of units after curation: {len(sorting_auto.get_unit_ids())}")

# 9. Viewers <a class="anchor" id="viewers"></a>

Let's check put the `spikeinterface-gui` to explore our spike sorting results:

### SpikeInterface GUI

In [None]:
!sigui waveforms

In [None]:
from ipywidgets import widgets
sw.plot_unit_locations(we_all, backend="ipywidgets")
sw.plot_spike_locations(we_all, backend="ipywidgets")
sw.plot_amplitudes(we_all, backend="ipywidgets")
sw.plot_autocorrelograms(we_all, unit_ids=sorting_SC2.unit_ids[:4])
sw.plot_crosscorrelograms(we_all, unit_ids=sorting_SC2.unit_ids[:4])
sw.plot_unit_templates(we_all, backend="matplotlib")

### Sorting Summary - SortingView

The `sortingview` backend requires an additional step to configure the transfer of the data to be plotted to the cloud. 

See documentation [here](https://spikeinterface.readthedocs.io/en/latest/module_widgets.html): 

# 11. Exporters <a class="anchor" id="exporters"></a>

## Export to Phy for manual curation

To perform manual curation we can export the data to [Phy](https://github.com/cortex-lab/phy). 

In [None]:
sexp.export_to_phy(we_all, output_folder=base_folder / 'phy_SC2_RN', compute_pc_features=True,
                   copy_binary=True, dtype='float32', compute_amplitudes=True, template_mode='median', verbose=True,**job_kwargs)

In [None]:
#sexp.export_to_phy(we_all, output_folder=base_folder / 'phy_SC2c', 
#                   **job_kwargs)

In [None]:
#%%capture --no-display
!phy template-gui phy_SC2_RN/params.py

After curating the results we can reload it using the `PhySortingExtractor` and exclude the units that we labeled as `noise`:

In [None]:
sorting_phy_curated = se.PhySortingExtractor(base_folder / 'phy_SC2_RN/', exclude_cluster_groups=['noise'])

In [None]:
print(f"Number of units before curation: {len(sorting_SC2.get_unit_ids())}")
print(f"Number of units after curation: {len(sorting_phy_curated.get_unit_ids())}")