## Interactive demo for spike sorting

Demo uses the extracted waveforms from waveforms.csv within the results directory for a given recording. 


In [2]:
import grnsuite.spike_detection as spikes
import grnsuite.spike_sorting as sorting
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os

import matplotlib
matplotlib.use("TkAgg")  # Set the backend before creating any figures


# Load waveforms from a results directory
results_dir = "results/20231103-M04-sucr-100-Gal-A1-02"
waveforms = pd.read_csv(os.path.join(results_dir, 'waveforms.csv')).values

ModuleNotFoundError: No module named '_tkinter'

1. Reduce dimensions with SVD and cluster

In [2]:
reduced_waveforms = sorting.reduce_dimensions(waveforms, n_components=3)

# First, estimate the optimal number of clusters
n_clusters = sorting.estimate_clusters(reduced_waveforms)

# Then cluster with the optimal number
labels = sorting.cluster_spikes(reduced_waveforms, n_clusters=n_clusters)

Explained variance with 3 components: 98.37%
Clusters: 2, Silhouette Score: 0.245
Clusters: 3, Silhouette Score: 0.594
Clusters: 4, Silhouette Score: 0.263
Clusters: 5, Silhouette Score: 0.388
Clusters: 6, Silhouette Score: 0.266
Clusters: 7, Silhouette Score: 0.255
Clusters: 8, Silhouette Score: 0.257
Clusters: 9, Silhouette Score: 0.256
Clusters: 10, Silhouette Score: 0.266

Optimal number of clusters: 3
Cluster 0 size: 179
Cluster 1 size: 16
Cluster 2 size: 13


2. Plot and save clustering waveforms figure

In [3]:
sorting.plot_clustering_summary(
    waveforms, 
    reduced_waveforms, 
    labels,
    save_path=os.path.join(results_dir, 'clustering_summary.png'),
    display=True
)

3. Add unit labels to detected spikes and save

In [4]:
spikes_df = pd.read_csv(os.path.join(results_dir, 'detected_spikes.csv'))
unit_labels = [f'unit_{label + 1}' for label in labels]  # Convert cluster numbers to unit labels
spikes_df['unit'] = unit_labels
spikes_df.to_csv(os.path.join(results_dir, 'sorted_spikes.csv'), index=False)

6. Plot processed data with spikes coloured by unit

In [5]:
processed_data = pd.read_csv(os.path.join(results_dir, 'processed_data.csv'))

sorting.plot_sorted_spikes(
    processed_data,
    spikes_df,
    save_path=os.path.join(results_dir, 'sorted_spikes.png'),
    display=True
)