# Preprocessing and Spike Sorting Tutorial

- In this introductory example, you will see how to use the :code:`spikeinterface` to perform a full electrophysiology analysis.
- We will first create some simulated data, and we will then perform some pre-processing, run a couple of spike sorting algorithms, inspect and validate the results, export to Phy, and compare spike sorters.


In [1]:
import os
import pickle
import glob
import warnings
import git
import imp
import spikeinterface
import time
import json
import spikeinterface.core
import numpy as np
import pandas as pd
import scipy.signal
import _pickle as cPickle
import matplotlib.pyplot as plt
import spikeinterface as si  # import core only
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.preprocessing as sp
import spikeinterface.comparison as sc
import spikeinterface.widgets as sw
import spikeinterface.full as si
import mountainsort5 as ms5
from collections import defaultdict
from datetime import datetime
from matplotlib.pyplot import cm
from spikeinterface.exporters import export_to_phy
from probeinterface import get_probe
from probeinterface.plotting import plot_probe, plot_probe_group
from probeinterface import write_prb, read_prb
from pathlib import Path
import gradio as gr

# Changing the figure size
from matplotlib.pyplot import figure
figure(figsize=(8, 6), dpi=80)
plt.rcParams["figure.figsize"] = (10,6)

  import imp


<Figure size 640x480 with 0 Axes>

In [2]:
def spikesort():
    pwd = r"C:\Users\Padilla-Coreano\Desktop\GITHUB_REPOS\diff_fam_social_memory_ephys" # os.getcwd()
    print(pwd)
    prb_file_path = Path(f"{pwd}/data/nancyprobe_linearprobelargespace.prb")
    probe_object = read_prb(prb_file_path)
    probe_df = probe_object.to_dataframe()
    print(probe_df)
    recording_filepath_glob = str(Path(f"{pwd}/data/**/*merged.rec"))
    all_recording_files = glob.glob(recording_filepath_glob, recursive=True)
    
    for recording_file in all_recording_files:
        trodes_recording = se.read_spikegadgets(recording_file, stream_id="trodes")       
        trodes_recording = trodes_recording.set_probes(probe_object)
        recording_basename = os.path.basename(recording_file)
        recording_output_directory = str(Path(f"{pwd}/proc1/{recording_basename}"))
        os.makedirs(recording_output_directory, exist_ok=True)
        child_spikesorting_output_directory = os.path.join(recording_output_directory,"ss_output")

        # Make sure the recording is preprocessed appropriately
        # lazy preprocessing
        recording_filtered = sp.bandpass_filter(trodes_recording, freq_min=300, freq_max=6000)
        recording_preprocessed: si.BaseRecording = sp.whiten(recording_filtered, dtype='float32')
        spike_sorted_object = ms5.sorting_scheme2(
        recording=recording_preprocessed,
        sorting_parameters=ms5.Scheme2SortingParameters(
            detect_sign=0,
            phase1_detect_channel_radius=700,
            detect_channel_radius=700,
            # other parameters...
            )
                )
        print("STARTING SORTING...")
        spike_sorted_object.save(folder=child_spikesorting_output_directory)

        sw.plot_rasters(spike_sorted_object)
        plt.title(recording_basename)
        plt.ylabel("Unit IDs")

        plt.savefig(os.path.join(recording_output_directory, f"{recording_basename}_raster_plot.png"))
        plt.close()

        waveform_output_directory = os.path.join(recording_output_directory, "waveforms")

        we_spike_sorted = si.extract_waveforms(recording=recording_preprocessed, 
                                       sorting=spike_sorted_object, folder=waveform_output_directory,
                                      ms_before=1, ms_after=1, progress_bar=True,
                                      n_jobs=8, total_memory="1G", overwrite=True,
                                       max_spikes_per_unit=2000)

        phy_output_directory = os.path.join(recording_output_directory, "phy")
        print("Saving PHY2 output...")
        export_to_phy(we_spike_sorted, phy_output_directory,
              compute_pc_features=True, compute_amplitudes=True, remove_if_exists=False)
        print("PHY2 output Saved!")

    return "SPIKES ARE SORTED! :)"

In [3]:
spikesort()

C:\Users\Padilla-Coreano\Desktop\GITHUB_REPOS\diff_fam_social_memory_ephys
    probe_index     x      y contact_shapes  radius shank_ids contact_ids
0             0   0.0    0.0         circle     5.0                      
1             0   5.0   20.0         circle     5.0                      
2             0  -7.0   40.0         circle     5.0                      
3             0   9.0   60.0         circle     5.0                      
4             0 -11.0   80.0         circle     5.0                      
5             0  13.0  100.0         circle     5.0                      
6             0 -15.0  120.0         circle     5.0                      
7             0  17.0  140.0         circle     5.0                      
8             0 -19.0  160.0         circle     5.0                      
9             0  21.0  180.0         circle     5.0                      
10            0 -23.0  200.0         circle     5.0                      
11            0  25.0  220.0         

KeyboardInterrupt: 

In [None]:
# # gradio app example

# def spikesort(data_dir):
#     pwd = os.getcwd()
#     prb_file_path = Path(pwd +"/nancyprobe_linearprobelargespace.prb")
#     probe_object = read_prb(prb_file_path)
#     probe_df = probe_object.to_dataframe()
    
#     recording_filepath_glob = str(Path(data_dir + "/**/*merged.rec"))
#     all_recording_files = glob.glob(recording_filepath_glob, recursive=True)
    
#     for recording_file in all_recording_files:
#         trodes_recording = se.read_spikegadgets(recording_file, stream_id="trodes")       
#         trodes_recording = trodes_recording.set_probes(probe_object)
#         recording_basename = os.path.basename(recording_file)
#         recording_output_directory = f"./proc1/{recording_basename}"
#         os.makedirs(recording_output_directory, exist_ok=True)
#         child_spikesorting_output_directory = os.path.join(recording_output_directory,"ss_output")

#         # Make sure the recording is preprocessed appropriately
#         # lazy preprocessing
#         recording_filtered = sp.bandpass_filter(trodes_recording, freq_min=300, freq_max=6000)
#         recording_preprocessed: si.BaseRecording = sp.whiten(recording_filtered, dtype='float32')
#         spike_sorted_object = ms5.sorting_scheme2(
#         recording=recording_preprocessed,
#         sorting_parameters=ms5.Scheme2SortingParameters(
#             detect_sign=0,
#             phase1_detect_channel_radius=700,
#             detect_channel_radius=700,
#             # other parameters...
#             )
#                 )
#         spike_sorted_object.save(folder=child_spikesorting_output_directory)

#         sw.plot_rasters(spike_sorted_object)
#         plt.title(recording_basename)
#         plt.ylabel("Unit IDs")

#         plt.savefig(os.path.join(recording_output_directory, f"{recording_basename}_raster_plot.png"))
#         plt.close()

#         waveform_output_directory = os.path.join(parent_spikesorting_output_directory, "waveforms")

#         we_spike_sorted = si.extract_waveforms(recording=recording_preprocessed, 
#                                        sorting=spike_sorted_object, folder=waveform_output_directory,
#                                       ms_before=1, ms_after=1, progress_bar=True,
#                                       n_jobs=8, total_memory="1G", overwrite=True,
#                                        max_spikes_per_unit=2000)

#         phy_output_directory = os.path.join(parent_spikesorting_output_directory, "phy")

#         export_to_phy(we_spike_sorted, phy_output_directory,
#               compute_pc_features=True, compute_amplitudes=True, remove_if_exists=False)
        
#     return("SPIKES ARE SORTED! :)")

# input_text = gr.inputs.Textbox(label="Enter folder path")
# output_text = gr.outputs.Textbox(label="Status")
# interface = gr.Interface(fn=spikesort, inputs=input_text, outputs= output_text)
# interface.launch()

----------------------

# Inputs

* PRB file
* recording files path
* output_dir (otherwise use default and send to download folder)

# Processing

* Spikesorting
* LFP

# Outputs

* probe_df as display of table?
* phy exports (export_to_phy())

----------------------