# Preprocessing and Spike Sorting Tutorial

- In this introductory example, you will see how to use the :code:`spikeinterface` to perform a full electrophysiology analysis.
- We will first create some simulated data, and we will then perform some pre-processing, run a couple of spike sorting algorithms, inspect and validate the results, export to Phy, and compare spike sorters.


In [1]:
import os
import pickle
import glob
import warnings
import git
import imp
import spikeinterface
import time
import json
import spikeinterface.core
import numpy as np
import pandas as pd
import scipy.signal
import _pickle as cPickle
import matplotlib.pyplot as plt
import spikeinterface as si  # import core only
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.preprocessing as sp
import spikeinterface.comparison as sc
import spikeinterface.widgets as sw
import spikeinterface.full as si
import mountainsort5 as ms5
from collections import defaultdict
from datetime import datetime
from matplotlib.pyplot import cm
from spikeinterface.exporters import export_to_phy
from probeinterface import get_probe
from probeinterface.plotting import plot_probe, plot_probe_group
from probeinterface import write_prb, read_prb
from pathlib import Path
import gradio as gr

# Changing the figure size
from matplotlib.pyplot import figure
figure(figsize=(8, 6), dpi=80)
plt.rcParams["figure.figsize"] = (10,6)

  import imp


<Figure size 640x480 with 0 Axes>

In [2]:
def spikesort(data_dir):
    pwd = os.getcwd()
    prb_file_path = Path(pwd +"/nancyprobe_linearprobelargespace.prb")
    probe_object = read_prb(prb_file_path)
    probe_df = probe_object.to_dataframe()
    
    recording_filepath_glob = str(Path(data_dir + "/**/*merged.rec"))
    all_recording_files = glob.glob(recording_filepath_glob, recursive=True)
    
    for recording_file in all_recording_files:
        trodes_recording = se.read_spikegadgets(recording_file, stream_id="trodes")       
        trodes_recording = trodes_recording.set_probes(probe_object)
        recording_basename = os.path.basename(recording_file)
        recording_output_directory = f"./proc1/{recording_basename}"
        os.makedirs(recording_output_directory, exist_ok=True)
        child_spikesorting_output_directory = os.path.join(recording_output_directory,"ss_output")

        # Make sure the recording is preprocessed appropriately
        # lazy preprocessing
        recording_filtered = sp.bandpass_filter(trodes_recording, freq_min=300, freq_max=6000)
        recording_preprocessed: si.BaseRecording = sp.whiten(recording_filtered, dtype='float32')
        spike_sorted_object = ms5.sorting_scheme2(
        recording=recording_preprocessed,
        sorting_parameters=ms5.Scheme2SortingParameters(
            detect_sign=0,
            phase1_detect_channel_radius=700,
            detect_channel_radius=700,
            # other parameters...
            )
                )
        spike_sorted_object.save(folder=child_spikesorting_output_directory)

        sw.plot_rasters(spike_sorted_object)
        plt.title(recording_basename)
        plt.ylabel("Unit IDs")

        plt.savefig(os.path.join(recording_output_directory, f"{recording_basename}_raster_plot.png"))
        plt.close()

        waveform_output_directory = os.path.join(parent_spikesorting_output_directory, "waveforms")

        we_spike_sorted = si.extract_waveforms(recording=recording_preprocessed, 
                                       sorting=spike_sorted_object, folder=waveform_output_directory,
                                      ms_before=1, ms_after=1, progress_bar=True,
                                      n_jobs=8, total_memory="1G", overwrite=True,
                                       max_spikes_per_unit=2000)

        phy_output_directory = os.path.join(parent_spikesorting_output_directory, "phy")

        export_to_phy(we_spike_sorted, phy_output_directory,
              compute_pc_features=True, compute_amplitudes=True, remove_if_exists=False)


    # Make sure the recording is preprocessed appropriately
    # lazy preprocessing
    recording_filtered = sp.bandpass_filter(trodes_recording, freq_min=300, freq_max=6000)
    recording_preprocessed: si.BaseRecording = sp.whiten(recording_filtered, dtype='float32')
    spike_sorted_object = ms5.sorting_scheme2(
    recording=recording_preprocessed,
    sorting_parameters=ms5.Scheme2SortingParameters(
        detect_sign=0,
        phase1_detect_channel_radius=700,
        detect_channel_radius=700,
        # other parameters...
        )
            )
    spike_sorted_object.save(folder=child_spikesorting_output_directory)

    sw.plot_rasters(spike_sorted_object)
    plt.title('plot_title')
    plt.ylabel("Unit IDs")

    plt.savefig(os.path.join(recording_output_directory, f"{recording_basename}_raster_plot.png"))
    plt.close()

    waveform_output_directory = os.path.join(parent_spikesorting_output_directory, "waveforms")

    we_spike_sorted = si.extract_waveforms(recording=recording_preprocessed, 
                                   sorting=spike_sorted_object, folder=waveform_output_directory,
                                  ms_before=1, ms_after=1, progress_bar=True,
                                  n_jobs=8, total_memory="1G", overwrite=True,
                                   max_spikes_per_unit=2000)

    phy_output_directory = os.path.join(parent_spikesorting_output_directory, "phy")

    export_to_phy(we_spike_sorted, phy_output_directory,
          compute_pc_features=True, compute_amplitudes=True, remove_if_exists=False)
    return("SPIKES ARE SORTED! :)")

# input_text = gr.inputs.Textbox(label="Enter folder path")
# output_text = gr.outputs.Textbox(label="Status")
# interface = gr.Interface(fn=spikesort, inputs=input_text, outputs= output_text)
# interface.launch()

  if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
  if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):


Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.




In [4]:
import gradio as gr

def spikesort(data_dir):
    pwd = os.getcwd()
    prb_file_path = Path(pwd +"/nancyprobe_linearprobelargespace.prb")
    probe_object = read_prb(prb_file_path)
    probe_df = probe_object.to_dataframe()
    
    recording_filepath_glob = str(Path(data_dir + "/**/*merged.rec"))
    all_recording_files = glob.glob(recording_filepath_glob, recursive=True)
    
    for recording_file in all_recording_files:
        trodes_recording = se.read_spikegadgets(recording_file, stream_id="trodes")       
        trodes_recording = trodes_recording.set_probes(probe_object)
        recording_basename = os.path.basename(recording_file)
        recording_output_directory = f"./proc1/{recording_basename}"
        os.makedirs(recording_output_directory, exist_ok=True)
        child_spikesorting_output_directory = os.path.join(recording_output_directory,"ss_output")

        # Make sure the recording is preprocessed appropriately
        # lazy preprocessing
        recording_filtered = sp.bandpass_filter(trodes_recording, freq_min=300, freq_max=6000)
        recording_preprocessed: si.BaseRecording = sp.whiten(recording_filtered, dtype='float32')
        spike_sorted_object = ms5.sorting_scheme2(
        recording=recording_preprocessed,
        sorting_parameters=ms5.Scheme2SortingParameters(
            detect_sign=0,
            phase1_detect_channel_radius=700,
            detect_channel_radius=700,
            # other parameters...
            )
                )
        spike_sorted_object.save(folder=child_spikesorting_output_directory)

        sw.plot_rasters(spike_sorted_object)
        plt.title(recording_basename)
        plt.ylabel("Unit IDs")

        plt.savefig(os.path.join(recording_output_directory, f"{recording_basename}_raster_plot.png"))
        plt.close()

        waveform_output_directory = os.path.join(parent_spikesorting_output_directory, "waveforms")

        we_spike_sorted = si.extract_waveforms(recording=recording_preprocessed, 
                                       sorting=spike_sorted_object, folder=waveform_output_directory,
                                      ms_before=1, ms_after=1, progress_bar=True,
                                      n_jobs=8, total_memory="1G", overwrite=True,
                                       max_spikes_per_unit=2000)

        phy_output_directory = os.path.join(parent_spikesorting_output_directory, "phy")

        export_to_phy(we_spike_sorted, phy_output_directory,
              compute_pc_features=True, compute_amplitudes=True, remove_if_exists=False)


    # Make sure the recording is preprocessed appropriately
    # lazy preprocessing
    recording_filtered = sp.bandpass_filter(trodes_recording, freq_min=300, freq_max=6000)
    recording_preprocessed: si.BaseRecording = sp.whiten(recording_filtered, dtype='float32')
    spike_sorted_object = ms5.sorting_scheme2(
    recording=recording_preprocessed,
    sorting_parameters=ms5.Scheme2SortingParameters(
        detect_sign=0,
        phase1_detect_channel_radius=700,
        detect_channel_radius=700,
        # other parameters...
        )
            )
    spike_sorted_object.save(folder=child_spikesorting_output_directory)

    sw.plot_rasters(spike_sorted_object)
    plt.title('plot_title')
    plt.ylabel("Unit IDs")

    plt.savefig(os.path.join(recording_output_directory, f"{recording_basename}_raster_plot.png"))
    plt.close()

    waveform_output_directory = os.path.join(parent_spikesorting_output_directory, "waveforms")

    we_spike_sorted = si.extract_waveforms(recording=recording_preprocessed, 
                                   sorting=spike_sorted_object, folder=waveform_output_directory,
                                  ms_before=1, ms_after=1, progress_bar=True,
                                  n_jobs=8, total_memory="1G", overwrite=True,
                                   max_spikes_per_unit=2000)

    phy_output_directory = os.path.join(parent_spikesorting_output_directory, "phy")

    export_to_phy(we_spike_sorted, phy_output_directory,
          compute_pc_features=True, compute_amplitudes=True, remove_if_exists=False)
    return "SPIKES ARE SORTED! :)"

input_text = gr.inputs.Textbox(label="Enter folder path")
output_text = gr.outputs.Textbox(label="Status")
interface = gr.Interface(fn=spikesort, inputs=input_text, outputs=output_text)
interface.launch()

  if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
  if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):


Running on local URL:  http://127.0.0.1:7861

To create a public link, set `share=True` in `launch()`.




Number of channels: 32
Number of timepoints: 24011226
Sampling frequency: 20000.0 Hz
Channel 0: [0. 0.]
Channel 1: [ 5. 20.]
Channel 2: [-7. 40.]
Channel 3: [ 9. 60.]
Channel 4: [-11.  80.]
Channel 5: [ 13. 100.]
Channel 6: [-15. 120.]
Channel 7: [ 17. 140.]
Channel 8: [-19. 160.]
Channel 9: [ 21. 180.]
Channel 10: [-23. 200.]
Channel 11: [ 25. 220.]
Channel 12: [-27. 240.]
Channel 13: [ 29. 260.]
Channel 14: [-31. 280.]
Channel 15: [ 33. 300.]
Channel 16: [-35. 320.]
Channel 17: [ 37. 340.]
Channel 18: [-39. 360.]
Channel 19: [ 41. 380.]
Channel 20: [-43. 400.]
Channel 21: [ 45. 420.]
Channel 22: [-47. 440.]
Channel 23: [ 49. 460.]
Channel 24: [-51. 480.]
Channel 25: [ 53. 500.]
Channel 26: [-55. 520.]
Channel 27: [ 57. 540.]
Channel 28: [-59. 560.]
Channel 29: [ 61. 580.]
Channel 30: [-63. 600.]
Channel 31: [ 65. 620.]
Loading traces
Detecting spikes

Adjacency for detect spikes with channel radius 700
[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22

----------------------

# Inputs

* PRB file
* recording files path
* output_dir (otherwise use default and send to download folder)

# Processing

In [None]:
def spikesort(data_dir):
    pwd = os.getcwd()
    prb_file_path = Path(pwd +"/nancyprobe_linearprobelargespace.prb")
    probe_object = read_prb(prb_file_path)
    probe_df = probe_object.to_dataframe()
    
    recording_filepath_glob = str(Path(data_dir + "/**/*merged.rec"))
    all_recording_files = glob.glob(recording_filepath_glob, recursive=True)
    
    for recording_file in all_recording_files:
        trodes_recording = se.read_spikegadgets(recording_file, stream_id="trodes")       
        trodes_recording = trodes_recording.set_probes(probe_object)
        recording_basename = os.path.basename(recording_file)
        recording_output_directory = f"./proc1/{recording_basename}"
        os.makedirs(recording_output_directory, exist_ok=True)
        child_spikesorting_output_directory = os.path.join(recording_output_directory,"ss_output")

        # Make sure the recording is preprocessed appropriately
        # lazy preprocessing
        recording_filtered = sp.bandpass_filter(trodes_recording, freq_min=300, freq_max=6000)
        recording_preprocessed: si.BaseRecording = sp.whiten(recording_filtered, dtype='float32')
        spike_sorted_object = ms5.sorting_scheme2(
        recording=recording_preprocessed,
        sorting_parameters=ms5.Scheme2SortingParameters(
            detect_sign=0,
            phase1_detect_channel_radius=700,
            detect_channel_radius=700,
            # other parameters...
            )
                )
        spike_sorted_object.save(folder=child_spikesorting_output_directory)

        sw.plot_rasters(spike_sorted_object)
        plt.title(recording_basename)
        plt.ylabel("Unit IDs")

        plt.savefig(os.path.join(recording_output_directory, f"{recording_basename}_raster_plot.png"))
        plt.close()

        waveform_output_directory = os.path.join(parent_spikesorting_output_directory, "waveforms")

        we_spike_sorted = si.extract_waveforms(recording=recording_preprocessed, 
                                       sorting=spike_sorted_object, folder=waveform_output_directory,
                                      ms_before=1, ms_after=1, progress_bar=True,
                                      n_jobs=8, total_memory="1G", overwrite=True,
                                       max_spikes_per_unit=2000)

        phy_output_directory = os.path.join(parent_spikesorting_output_directory, "phy")

        export_to_phy(we_spike_sorted, phy_output_directory,
              compute_pc_features=True, compute_amplitudes=True, remove_if_exists=False)


    # Make sure the recording is preprocessed appropriately
    # lazy preprocessing
    recording_filtered = sp.bandpass_filter(trodes_recording, freq_min=300, freq_max=6000)
    recording_preprocessed: si.BaseRecording = sp.whiten(recording_filtered, dtype='float32')
    spike_sorted_object = ms5.sorting_scheme2(
    recording=recording_preprocessed,
    sorting_parameters=ms5.Scheme2SortingParameters(
        detect_sign=0,
        phase1_detect_channel_radius=700,
        detect_channel_radius=700,
        # other parameters...
        )
            )
    spike_sorted_object.save(folder=child_spikesorting_output_directory)

    sw.plot_rasters(spike_sorted_object)
    plt.title('plot_title')
    plt.ylabel("Unit IDs")

    plt.savefig(os.path.join(recording_output_directory, f"{recording_basename}_raster_plot.png"))
    plt.close()

    waveform_output_directory = os.path.join(parent_spikesorting_output_directory, "waveforms")

    we_spike_sorted = si.extract_waveforms(recording=recording_preprocessed, 
                                   sorting=spike_sorted_object, folder=waveform_output_directory,
                                  ms_before=1, ms_after=1, progress_bar=True,
                                  n_jobs=8, total_memory="1G", overwrite=True,
                                   max_spikes_per_unit=2000)

    phy_output_directory = os.path.join(parent_spikesorting_output_directory, "phy")

    export_to_phy(we_spike_sorted, phy_output_directory,
          compute_pc_features=True, compute_amplitudes=True, remove_if_exists=False)

In [None]:
# pwd = os.getcwd()
# prb_file_path = Path(pwd +"/nancyprobe_linearprobelargespace.prb")
# probe_object = read_prb(prb_file_path)
# probe_df = probe_object.to_dataframe()
# data_dir = r"C:\Users\Padilla-Coreano\Desktop\GITHUB_REPOS\diff_fam_social_memory_ephys\data"
# recording_filepath_glob = str(Path(data_dir + "/**/*merged.rec"))
# all_recording_files = glob.glob(recording_filepath_glob, recursive=True) # get all *merged.rec files in path 


# ####################
# for recording_file in all_recording_files:
#     trodes_recording = se.read_spikegadgets(recording_file, stream_id="trodes")       
#     trodes_recording = trodes_recording.set_probes(probe_object)
#     recording_basename = os.path.basename(recording_file)
#     recording_output_directory = f"./proc1/{recording_basename}"
#     os.makedirs(recording_output_directory, exist_ok=True)
#     child_spikesorting_output_directory = os.path.join(recording_output_directory,"ss_output")
    
#     # Make sure the recording is preprocessed appropriately
#     # lazy preprocessing
#     recording_filtered = sp.bandpass_filter(trodes_recording, freq_min=300, freq_max=6000)
#     recording_preprocessed: si.BaseRecording = sp.whiten(recording_filtered, dtype='float32')
#     spike_sorted_object = ms5.sorting_scheme2(
#     recording=recording_preprocessed,
#     sorting_parameters=ms5.Scheme2SortingParameters(
#         detect_sign=0,
#         phase1_detect_channel_radius=700,
#         detect_channel_radius=700,
#         # other parameters...
#         )
#             )
#     spike_sorted_object.save(folder=child_spikesorting_output_directory)

#     sw.plot_rasters(spike_sorted_object)
#     plt.title(recording_basename)
#     plt.ylabel("Unit IDs")

#     plt.savefig(os.path.join(recording_output_directory, f"{recording_basename}_raster_plot.png"))
#     plt.close()

#     waveform_output_directory = os.path.join(parent_spikesorting_output_directory, "waveforms")

#     we_spike_sorted = si.extract_waveforms(recording=recording_preprocessed, 
#                                    sorting=spike_sorted_object, folder=waveform_output_directory,
#                                   ms_before=1, ms_after=1, progress_bar=True,
#                                   n_jobs=8, total_memory="1G", overwrite=True,
#                                    max_spikes_per_unit=2000)

#     phy_output_directory = os.path.join(parent_spikesorting_output_directory, "phy")

#     export_to_phy(we_spike_sorted, phy_output_directory,
#           compute_pc_features=True, compute_amplitudes=True, remove_if_exists=False)

    
# # Make sure the recording is preprocessed appropriately
# # lazy preprocessing
# recording_filtered = sp.bandpass_filter(trodes_recording, freq_min=300, freq_max=6000)
# recording_preprocessed: si.BaseRecording = sp.whiten(recording_filtered, dtype='float32')
# spike_sorted_object = ms5.sorting_scheme2(
# recording=recording_preprocessed,
# sorting_parameters=ms5.Scheme2SortingParameters(
#     detect_sign=0,
#     phase1_detect_channel_radius=700,
#     detect_channel_radius=700,
#     # other parameters...
#     )
#         )
# spike_sorted_object.save(folder=child_spikesorting_output_directory)

# sw.plot_rasters(spike_sorted_object)
# plt.title('plot_title')
# plt.ylabel("Unit IDs")

# plt.savefig(os.path.join(recording_output_directory, f"{recording_basename}_raster_plot.png"))
# plt.close()

# waveform_output_directory = os.path.join(parent_spikesorting_output_directory, "waveforms")

# we_spike_sorted = si.extract_waveforms(recording=recording_preprocessed, 
#                                sorting=spike_sorted_object, folder=waveform_output_directory,
#                               ms_before=1, ms_after=1, progress_bar=True,
#                               n_jobs=8, total_memory="1G", overwrite=True,
#                                max_spikes_per_unit=2000)

# phy_output_directory = os.path.join(parent_spikesorting_output_directory, "phy")

# export_to_phy(we_spike_sorted, phy_output_directory,
#       compute_pc_features=True, compute_amplitudes=True, remove_if_exists=False)

# Outputs

* probe_df as display of table?
* phy exports (export_to_phy())

----------------------