# Exploring the AIND-ephys outputs using SpikeInterface

### Notebook usage:
- This notebook assumes some comfort with spike sorting and electrophysiology data. In addition, a basic understanding of [SpikeInterface](https://spikeinterface.readthedocs.io/en/latest/index.html) is helpful.

- This notebook will compare units based on QC metrics and template similarity and provide a list of units to further explore. The units are the most similar in euclidean space after UMAP dimensionality reduction then are thresholded based on probe location and template similarity.


#### Requirements:
- processed AINDS neuropixels data
- installation of spikeinterface - if not installed, please install SpikeInterace using the following command:
```bash
pip install "spikeinterface[full, widgets]"
```

**Note**: This notebook is based on the latest version of SpikeInterface (`spikeinterface==0.101.0`) which is under development. The API may change in the future and *is* different from the version used in the AINDS pipeline (`spikeinterface==0.100.8`). We have adapted the notebook to work with the latest version of SpikeInterface since there is a significant improvement in the API and functionality.

In [1]:
#import packages
import os
import matplotlib.pyplot as plt
import pandas as pd
import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.postprocessing as spost
import spikeinterface.widgets as sw

In [7]:
#Fetch data directories

raw_rec = 'path/to/raw/recording'
baseFolder = r"C:\Users\janet\Documents\Ally_AINDS_output\2022_10_27_AEG26_g0_imec0" #edit this to the location of your data"
experiment = 'block0_imec0.ap_recording1' #edit this to the name of your experiment folder

preProcessed = baseFolder + '/preprocessed'
postProcessed = baseFolder + '/postprocessed'
spikes = baseFolder + '/spikesorted'
curated = baseFolder + '/curated'
preJSON = os.path.join(preProcessed, experiment + '.json')

In [None]:
#Select the data to explore

data_load = curated
print(f'Set path: {data_load}')

## First, let's load the waveform extractor - we'll explore the postprocessed units which are stored in the `postprocessed` folder. These units have been processed to include the following: 
* removal of duplicate units
* computed amplitudes
* spike/unit locations 
* PCA
* correlograms
* template similarity
* templeate metrics
* QC metrics

## The `curated` folder includes units that *have been* automatically curated by:
* ISI violation ratio
* presence ratio
* amplitude cutoff

### First, load the wave forms and the sorting extractor
*Note: we will use the back compatible version of the waveform extractor which is the `MockWaveformExtractor` that is used in the latest version of SpikeInterface

In [None]:
we =  si.load_waveforms(folder=(os.path.join(postProcessed, experiment)))
sorting_curated = si.load_extractor(os.path.join(data_load, experiment))
we, sorting_curated

### The available extensions within the waveform extractor will be printed below:

In [None]:
avail_extensions = we.get_available_extension_names()
print(f"Extensions available in MockWaveformExtractor {avail_extensions}")

### Create sorting analyzer and fetch quality metrics and unit information

In [11]:
sorting_analyzer = we.sorting_analyzer

#quality metrics
qm=sorting_analyzer.get_extension(extension_name='quality_metrics').get_data()
#fetch decoder labels (e.g. SUA, MUA, noise)
labels = sorting_curated.get_property('decoder_label')
#fetch unit ids and locations
unit_ids = sorting_curated.get_unit_ids()
unit_locations = sorting_analyzer.get_extension("unit_locations").get_data()
unit_locations = unit_locations[:,1]
#isi_histograms = sorting_analyzer.get_extension("isi_histograms").get_data()
#fetch template similarity for each unit
template_similarity = sorting_analyzer.get_extension(extension_name='template_similarity').get_data()


In [12]:
#change to dataframe for easier manipulation
template_sim = pd.DataFrame(template_similarity, columns=unit_ids, index=unit_ids)

In [13]:
#create dataframe of all the quality metrics
df = pd.DataFrame(qm)
df['unit_ids'] = unit_ids
df['labels'] = labels
df['unit_locations'] = unit_locations

### Drop features that are not needed for the analysis, these can be edited based on your preference/needs

In [14]:
features_to_drop = ['sync_spike_2', 'sync_spike_4', 'sync_spike_8', 'amplitude_cv_range', 'drift_mad']
df = df.drop(features_to_drop, axis=1)

### Sum of units within the selected recording

In [None]:
#sum up sua, mua, noise
print('Sum of all labels/units:', len(df))
print('Total SUAs:', len(df[df['labels']=='sua']))
print('Total MUAs:', len(df[df['labels']=='mua']))
print('Total noise:', len(df[df['labels']=='noise']))

## Now, we'll begin to explore the units based on QC metrics and reduce the dimensionality of the units using UMAP. We'll use this to compare the units based on similarity and location.

In [None]:
#umap of all quality metrics
import umap
from sklearn.preprocessing import StandardScaler
import numpy as np

reducer = umap.UMAP()
scaled_df = df.drop(columns=['unit_ids', 'labels'])
#take median of each column in scaled_df
fill_value = scaled_df.median()
scaled_df = scaled_df.fillna(fill_value)
embedding = reducer.fit_transform(scaled_df)
print(f"UMAP shape: {embedding.shape}")

### We will now create an interactive UMAP plot using bokeh to explore the units based on similarity and location. You can zoom in/out on each cluster to explore the units in more detail.

In [None]:
#make interactive plot with bokeh
import bokeh
from bokeh.plotting import figure, show, output_notebook
from bokeh.models import HoverTool, ColumnDataSource, CategoricalColorMapper
from bokeh.palettes import Spectral10

df_filtered = df[(df['labels'] != 'noise') & (df['presence_ratio']>0.8)] 
df_filtered_copy = df_filtered.copy() 
unique_labels = df_filtered['labels'].unique()
color_mapper = CategoricalColorMapper(factors=unique_labels, palette=Spectral10)

output_notebook()

# Create a ColumnDataSource from df. These will be displayed in the plot when we hover over the data points.
source = ColumnDataSource(data=dict(
    x=embedding[:, 0],
    y=embedding[:, 1],
    unit_ids=df_filtered['unit_ids'],
    unit_locations=df_filtered['unit_locations'],
    amplitude_median=df_filtered['amplitude_median'],
    firing_range=df_filtered['firing_range'],
    snr=df_filtered['snr'],
    d_prime=df_filtered['d_prime'],
    labels=df_filtered['labels']
))

# Create a HoverTool
hover = HoverTool(tooltips=[
    ("unit_ids", "@unit_ids"),
    ("unit_locations", "@unit_locations"),
    ("amplitude_median", "@amplitude_median"),
    ("firing_range", "@firing_range"),
    ("snr", "@snr"),
    ("d_prime", "@d_prime"),
    ("labels", "@labels")
])

# Create a figure
p = figure(width=600, height=600, tools=[hover, 'pan', 'reset', 'box_zoom'], title='UMAP of select quality metrics')

# Add circle glyphs to the figure p
p.scatter('x', 'y', source=source, color=dict(field='labels', transform=color_mapper), legend_field='labels', size=12)

# Set the legend.location attribute of the plot to 'top_right'
p.legend.location = 'top_right'

# Show the plot
show(p)


### After exploring the units, we will threshold the units based on probe location and template similarity. We will then provide a list of units to further explore for potential merging.

In [18]:
#find pairs of units that are close together
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.preprocessing import StandardScaler

#standardize the data
scaler = StandardScaler()
scaled_df = scaler.fit_transform(df_filtered_copy.drop(columns=['unit_ids', 'labels']).fillna(0))

#calculate euclidean distances
distances = euclidean_distances(scaled_df, scaled_df)
distances_df = pd.DataFrame(distances, columns=df_filtered_copy['unit_ids'], index=df_filtered_copy['unit_ids'])

In [None]:
#find pairs of units that are close together
close_units = []
for i in range(len(distances_df)):
    for j in range(i+1, len(distances_df)):
        if distances_df.iloc[i, j] < 2:
            close_units.append((distances_df.index[i], distances_df.columns[j]))

#close units to dataframe, map labels
df_close_units = pd.DataFrame(close_units, columns=['unit1', 'unit2'])
df_close_units['label1'] = df_close_units['unit1'].map(df_filtered.set_index('unit_ids')['labels'])
df_close_units['label2'] = df_close_units['unit2'].map(df_filtered.set_index('unit_ids')['labels'])
#remove units that are more than 5 units apart
df_close_units = df_close_units[df_close_units['unit2'] - df_close_units['unit1'] < 5]
df_close_units

In [20]:
sim_df = pd.DataFrame()
sim_df['unit_ids']=df_close_units[['unit1', 'unit2']].values.flatten()
#remove duplicates
sim_df = sim_df.drop_duplicates()

if not set(sim_df['unit_ids']).issubset(set(distances_df.columns) & set(template_sim.columns)):
    raise ValueError("Unit IDs in sim_df must be present in both distances_df and template_sim.")

# Create a new DataFrame with the unit IDs from sim_df
result_df = pd.DataFrame(index=sim_df['unit_ids'])

# Iterate over the unit IDs in sim_df
for unit_id in sim_df['unit_ids']:
    # Extract the corresponding rows from distances_df and template_sim
    distance_row = distances_df[unit_id]
    similarity_row = template_sim[unit_id]

    # Calculate the ratio of similarity to distance for each element
    ratio = similarity_row / distance_row

    # Add the ratio to the result DataFrame
    result_df[unit_id] = ratio

In [21]:
unit_id_to_label = {}
for index, row in df_close_units.iterrows():
    unit_id1 = row['unit1']
    unit_id2 = row['unit2']
    label1 = row['label1']
    label2 = row['label2']
    unit_id_to_label[unit_id1] = {'unit1': label1, 'unit2': label2}
    unit_id_to_label[unit_id2] = {'unit1': label1, 'unit2': label2}


# Filter result_df to include only the rows and columns with unit IDs that have labels
filtered_df = result_df.loc[unit_id_to_label.keys(), unit_id_to_label.keys()]

#create a filtered distances dataframe
filtered_distances_df = distances_df.loc[unit_id_to_label.keys(), unit_id_to_label.keys()]
# Get template similarity values for the unit IDs in df_close_units
filtered_template_sim = template_sim.loc[unit_id_to_label.keys(), unit_id_to_label.keys()]

# Create a new dataframe to store the extracted distances
extracted_distances_df = pd.DataFrame(index=filtered_distances_df.index, columns=filtered_distances_df.columns)
extracted_template_sim_df = pd.DataFrame(index=filtered_template_sim.index, columns=filtered_template_sim.columns)

# Iterate through the filtered dataframe to extract distances based on labels
for i, row in filtered_distances_df.iterrows():
    for j, col in filtered_distances_df.iterrows():
        # Check if the pair of unit IDs is in df_close_units
        if i != j and any((i, j) == (row.unit1, row.unit2) for row in df_close_units.itertuples()):
            label1 = unit_id_to_label[i]
            label2 = unit_id_to_label[j]
            extracted_distances_df.loc[i, j] = filtered_distances_df.loc[i, j]
extracted_distances_df = extracted_distances_df.where(np.triu(extracted_distances_df.to_numpy() != np.nan, k=1))

for i, row in filtered_template_sim.iterrows():
    for j, col in filtered_template_sim.iterrows():
        # Check if the pair of unit IDs is in df_close_units
        if i != j and any((i, j) == (row.unit1, row.unit2) for row in df_close_units.itertuples()):
            label1 = unit_id_to_label[i]
            label2 = unit_id_to_label[j]
            extracted_template_sim_df.loc[i, j] = filtered_template_sim.loc[i, j]
extracted_template_sim_df = extracted_template_sim_df.where(np.triu(extracted_template_sim_df.to_numpy() !=np.nan, k=1))

# Stack the dataframe to get a single value for each pair
stacked_distances_df = extracted_distances_df.stack()
stacked_distances_df.columns = ['unit1', 'unit2', 'distance']
stacked_distances_df=pd.DataFrame(stacked_distances_df)

stacked_template_sim_df = extracted_template_sim_df.stack()
stacked_template_sim_df.columns = ['unit1', 'unit2', 'similarity']
stacked_template_sim_df=pd.DataFrame(stacked_template_sim_df)

In [22]:
#renaming for clarity
stacked_template_sim_df.rename(columns={0: 'similarity'}, inplace=True)
stacked_template_sim_df['unit1'] = stacked_template_sim_df.index.get_level_values(0)
stacked_template_sim_df['unit2'] = stacked_template_sim_df.index.get_level_values(1)

## Candidates for merging based on distance. We will then filter based on template similarity and assess their cross-correlograms.

In [None]:
stacked_template_sim_df

## Let's look at the crosscorrelograms for the candidates. A template similarity threshold is set here to be > 0.6 but this can be adjusted based on preference.

In [None]:
#plot crosscorrelograms for pairs of units
from spikeinterface.widgets import plot_crosscorrelograms

# Filter unit pairs based on similarity threshold
sim_threshold = 0.6
filtered_units = stacked_template_sim_df[stacked_template_sim_df['similarity'] > sim_threshold]
unit_ids = filtered_units[['unit1', 'unit2']].values
print(f"Number of unit pairs with similarity > {sim_threshold}: {len(unit_ids)}")

for i, (unit1, unit2) in enumerate(unit_ids):
    plot_crosscorrelograms(sorting_curated, unit_ids=[unit1, unit2])
    plt.show()

### Because the counts can vary widely between units, we can plot the correlograms on top of each other to compare the shape of the correlograms along with spike counts.

In [None]:
flat_unit_ids = unit_ids.flatten()  

from spikeinterface.postprocessing import compute_correlograms
spike_extractor = si.load_extractor(os.path.join(spikes, experiment)) #to look like plot_correlograms def, cannot use stored correlograms
correlograms,bins = compute_correlograms(spike_extractor, window_ms=50.0, bin_ms=1.0) #same params as workflow

rp_t = 1.0 #in ms (refractory period threshold)
bar_centers = bins[:-1] + (bins[1] - bins[0]) / 2

for i, (unit1, unit2) in enumerate(unit_ids):
    isi_violations = df[(df['unit_ids'] == unit1) | (df['unit_ids'] == unit2)]['isi_violations_count']

    plt.figure(figsize=(8, 6))
    plt.bar(bar_centers, correlograms[unit1, unit1,:], width=bins[1] - bins[0], color="CornflowerBlue", alpha=0.7, label=f"unit {unit1}")
    plt.bar(bar_centers, correlograms[unit2, unit2,:], width=bins[1] - bins[0], color="Thistle", alpha=0.7, label=f"unit {unit2}")

    # Add vertical lines
    plt.axvline(x=-rp_t, color="Crimson", linestyle="dashed", linewidth=1, label="RP Threshold (-)")
    plt.axvline(x=rp_t, color="Crimson", linestyle="dashed", linewidth=1, label="RP Threshold (+)")

    # Customize plot
    plt.text(0.05, 0.95, f'ISI violations: {isi_violations.values}', fontsize=12, transform=plt.gca().transAxes, verticalalignment='top')
    plt.xlabel("Bins")
    plt.ylabel("Counts (spike matches/bin)")
    plt.title("auto-correlogram")
    plt.legend()

    plt.tight_layout()
    plt.show()

## We can now compare the potential units for merging based on QC similarity to the `get_potential_auto_merge` function in spikeinterface. This function will return a list of units that are potential candidates for merging based on the user's input. These are likely to not agree since we have thresholded the units based on probe location and template similarity.

In [None]:
## test get_potential_merge from spikeinterface 
from spikeinterface.curation import get_potential_auto_merge
merge_unit_pairs = get_potential_auto_merge(
    sorting_analyzer,
    preset="similarity_correlograms", #others include: x_contaiminations, temporal_splits, feature_neighbors
    resolve_graph=True,
    corr_diff_thresh=0.5,
)
print(f'Potential total unit pairs to merge from SI: {len(merge_unit_pairs)}')
print(f'Potential unit ids to merge from SI: {merge_unit_pairs}')

## Easy merging of units can be done using the webapp. All of the available figurl links are within the `visualization_output.json` file and contains a timeseries summary and a sorting summary.

In [None]:
#pull webadress from json
import json

json_path = baseFolder + '/visualization_output.json'
with open(json_path, 'r') as f:
    data = json.load(f)
    current_data_url = data[experiment]['sorting_summary']

print(f'Open the sorting summary dashboard at: {current_data_url}')

## `plot_unit_summary` provides a unit-by-unit summary throughout the recording. The following will plot:
- the unit's location
- the unit's waveform across channels
- the unit's waveform
- the unit's autocorrelogram
- the unit's amplitude across the recording


In [None]:
#plot unit summary for pairs of units
from spikeinterface.widgets import plot_unit_summary

for i, (unit1, unit2) in enumerate(unit_ids):
    plot_unit_summary(sorting_analyzer, unit_id=unit1)
    plot_unit_summary(sorting_analyzer, unit_id=unit2)
    plt.show()
