In [None]:
%matplotlib inline

# Preprocessing and Spike Sorting Tutorial

- In this introductory example, you will see how to use the :code:`spikeinterface` to perform a full electrophysiology analysis.
- We will first create some simulated data, and we will then perform some pre-processing, run a couple of spike sorting algorithms, inspect and validate the results, export to Phy, and compare spike sorters.


In [None]:
import os
import pickle
import _pickle as cPickle
import glob
import warnings
import imp
import shutil

In [None]:
from collections import defaultdict
import time
import json
from datetime import datetime

In [None]:
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import numpy as np
import pandas as pd
import scipy.signal

In [None]:
# Changing the figure size
from matplotlib.pyplot import figure
figure(figsize=(8, 6), dpi=80)

The spikeinterface module by itself import only the spikeinterface.core submodule
which is not useful for end user



In [None]:
import spikeinterface

We need to import one by one different submodules separately (preferred).
There are 5 modules:

- :code:`extractors` : file IO
- :code:`toolkit` : processing toolkit for pre-, post-processing, validation, and automatic curation
- :code:`sorters` : Python wrappers of spike sorters
- :code:`comparison` : comparison of spike sorting output
- :code:`widgets` : visualization



In [None]:
import spikeinterface as si  # import core only
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.preprocessing as sp

import spikeinterface.comparison as sc
import spikeinterface.widgets as sw
from spikeinterface.exporters import export_to_phy

In [None]:
import spikeinterface.core

In [None]:
from probeinterface import get_probe
from probeinterface.plotting import plot_probe, plot_probe_group
from probeinterface import write_prb, read_prb

In [None]:
import mountainsort5 as ms5

In [None]:
from tempfile import TemporaryDirectory
from mountainsort5.util import create_cached_recording

We can also import all submodules at once with this
  this internally import core+extractors+toolkit+sorters+comparison+widgets+exporters

This is useful for notebooks but this is a more heavy import because internally many more dependency
are imported (scipy/sklearn/networkx/matplotlib/h5py...)



In [None]:
import spikeinterface.full as si

In [None]:
# Increase size of plot in jupyter

plt.rcParams["figure.figsize"] = (10,6)

# Part 0: Loading in the Probe

- Reading in the probe information into Spike interface and plotting the probe

In [None]:
probe_object = read_prb("./linear_probe_with_large_spaces.prb")

In [None]:
probe_object.to_dataframe()

In [None]:
probe_object.get_global_contact_ids()

In [None]:
probe_object.get_global_device_channel_indices()

- Creating a dictionary of all the variables in the probe file

In [None]:
if 'probe_parameters' in locals():
    probe_dict = defaultdict(dict)
    for attribute in dir(probe_parameters):
        # Removing built in attributes
        if not attribute.startswith("__"): 
            probe_dict[attribute] = getattr(probe_parameters, attribute)

In [None]:
if "probe_dict" in locals():
    for key, value in probe_dict.items():
        print("{}: {}".format(key, value))

# Part 1: Importing Data

## Loading in the Electrophysiology Recording

- We are inputting the electrophsiology recording data with probe information. This should have been created in the prevous notebook in a directory created by Spike Interface. If you had already read in your own electrophsiology recording data with probe information with a different way, then follow these instructions.
    - If you want to use a different directory, then you must either:
        - Change `glob.glob({./path/to/with/*/recording_raw})` to the directory that you have the directories created from Spikeinterface. You can use a wildcard if you have multiple folders. You would replace `{./path/to/with/*/recording_raw}` with the path to either the parent directory or the actual directory containing the electrophsiology recording data read into Spikeinterface.
        - Or change `(file_or_folder_or_dict={./path/to/recording_raw})`. You would replace `{./path/to/recording_raw}` with the path to either the parent directory or the actual directory containing the electrophsiology recording data read into Spikeinterface.

In [None]:
recording_filepath_glob = "/scratch/back_up/reward_competition_extention/data/rce_cohort_3/*/*.rec/*merged.rec"

In [None]:
all_recording_files = glob.glob(recording_filepath_glob, recursive=True)

In [None]:
all_recording_files = [file_path for file_path in all_recording_files if "copies" not in file_path]

In [None]:
all_recording_files

# Part 2: Sorting

In [None]:
successful_files = [] 
failed_files = []

for recording_file in all_recording_files:
    print(recording_file)
    
    recording_basename = os.path.basename(recording_file)
    recording_output_directory = "/scratch/back_up/reward_competition_extention/proc/spike_sorting/{}".format(recording_basename)
    os.makedirs(recording_output_directory, exist_ok=True)

    print("Output directory: {}".format(recording_output_directory))
    
    try:
        with open('successful_files.txt', "r") as myfile:
            if recording_basename in myfile.read():
                warnings.warn("""Directory already exists for: {}.
                              Either continue on if you are satisfied with the previous run 
                              or delete the directory and run this cell again""".format(recording_basename))
                continue
    except:
        pass
    
    try:
        trodes_recording = se.read_spikegadgets(recording_file, stream_id="trodes")       
        trodes_recording = trodes_recording.set_probes(probe_object)

        child_spikesorting_output_directory = os.path.join(recording_output_directory,"ss_output")
               
        start = time.time()
        # Make sure the recording is preprocessed appropriately
        # lazy preprocessing
        print("Running bandpass filter")
        recording_filtered = sp.bandpass_filter(trodes_recording, freq_min=300, freq_max=6000, dtype=np.float32)

        print("Running whitening")
        recording_preprocessed: si.BaseRecording = sp.whiten(recording_filtered, dtype=np.float32)
        
        with TemporaryDirectory() as tmpdir:
            # cache the recording to a temporary directory for efficient reading
            print("Caching the recording")
            recording_cached = create_cached_recording(recording_preprocessed, folder=tmpdir)
            recording_cached = recording_cached.set_probes(probe_object)
            recording_cached.annotate(is_filtered=True)

            print("Spike sorting")
            spike_sorted_object = ms5.sorting_scheme2(
                recording=recording_cached,
                sorting_parameters=ms5.Scheme2SortingParameters(
                    detect_sign=0,
                    phase1_detect_channel_radius=700,
                    detect_channel_radius=700,
                # other parameters...
                )
            )
            
            assert isinstance(spike_sorted_object, si.BaseSorting)
            
            shutil.rmtree(child_spikesorting_output_directory, ignore_errors=True)
            spike_sorted_object.save(folder=child_spikesorting_output_directory)

            sw.plot_rasters(spike_sorted_object)
            plt.title(recording_basename)
            plt.ylabel("Unit IDs")
            
            plt.savefig(os.path.join(recording_output_directory, "{}_raster_plot.png".format(recording_basename)))
            plt.close()
            
            print("Exporting waveforms")
            waveform_output_directory = os.path.join(recording_output_directory, "waveforms")
            we_spike_sorted = si.extract_waveforms(
                recording=recording_cached, 
                sorting=spike_sorted_object, 
                folder=waveform_output_directory,
                ms_before=1, 
                ms_after=1, 
                progress_bar=True,
                n_jobs=-1, 
                total_memory="8G", 
                overwrite=True,
                max_spikes_per_unit=2000,
                sparse=False)
            
            print("we_spike_sorted is sparse: {}".format(we_spike_sorted.is_sparse()))
            
            print("Saving to phy")
            phy_output_directory = os.path.join(recording_output_directory, "phy")
            export_to_phy(we_spike_sorted, 
                            phy_output_directory,
                            compute_pc_features=True, 
                            compute_amplitudes=True, 
                            remove_if_exists=True)
            
            # edit the params.py file os that it contains the correct realtive path
            params_path = os.path.join(phy_output_directory, "params.py")
            with open(params_path, 'r') as file:
                lines = file.readlines()
            lines[0] = "dat_path = r'./recording.dat'\n"
            with open(params_path, 'w') as file:
                file.writelines(lines)
            
            successful_files.append(recording_file)
            with open('successful_files.txt', "a+") as fd:
                fd.write(f'\n{recording_basename}')                     

    except Exception as e: 
        print(e)
        failed_files.append(recording_file)


In [None]:
raise ValueError()