In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pickle

from spikeinterface.preprocessing import common_reference, bandpass_filter, phase_shift
from spikeinterface.extractors import read_binary

import spikesorting_fullpursuit as fbp
from spikesorting_fullpursuit.parallel.spikesorting_parallel import spike_sort_parallel
from spikesorting_fullpursuit.postprocessing import WorkItemSummary

import probes

In [None]:
channel_num = 32
sampling_rate = 30000
neighbor_distance = 100 #Distance in microns to add channels to look for templates

#data_dir = 'C:/Users/wanglab/Data/Licking/dirt/sorting_test_zca/'
data_dir = 'C:/Users/wanglab/Data/PrV/112423_2/'
filename = data_dir + 'amplifier.dat'

save_fname = data_dir + 'sorted_data.pickle'

neurons_fname = data_dir + 'sorted_neurons.pickle'

log_dir = data_dir + 'logs'

broken_channels = [16]

# Setup the sorting parameters dictionary.
spike_sort_args = {
        'sigma': 4.0, # Threshold based on noise level
        'clip_width': [-15e-4, 15e-4], # Width of clip in seconds
        'p_value_cut_thresh': 0.01,
        'segment_duration': np.Inf, #Seconds of duration; None Uses entire recording
        'segment_overlap': 150, #Seconds of overlap
        'do_branch_PCA': True,
        'do_branch_PCA_by_chan': True,
        'do_overlap_recheck': True,
        'filter_band': (300, 5000), #Intan filter bandwidth
        'do_ZCA_transform': True,
        'check_components': 20,
        'max_components': 5,
        'min_firing_rate': 0.1,
        'use_rand_init': True,
        'add_peak_valley': False,
        'max_gpu_memory': .1 * (1024 * 1024 * 1024),
        'save_1_cpu': True,
        'sort_peak_clips_only': True,
        'n_cov_samples': 100000, #Used to produce noise covariance matrix. Seems to plateau at 100,000
        'sigma_bp_noise': 2.326, # Number of noise standard deviations an expected template match must exceed the decision boundary by. Otherwise it is a candidate for deletion or increased threshold. Higher values = lower false positives and higher false negatives
  #      'sigma_bp_CI': 12.0,
        #'bp_chan_snr': 2.0, # SNR required for a template on a given channel to be used for binary pursuit. Channels lower than this are set to zero template signal.
        'absolute_refractory_period': 10e-4,
        'get_adjusted_clips': False,
        'max_binary_pursuit_clip_width_factor': 1.0,
        'wiener_filter': True,
        'verbose': True,
        'test_flag': False,
        'log_dir': log_dir,
        }

#Probe Type
probe = probes.create_poly3()
xy_layout = probes.create_poly3_layout()

In [None]:


rec = read_binary(filename,sampling_frequency=30000,
                  dtype='int16',
                  num_channels=32,
                  time_axis=0,
                  gain_to_uV=0.195,
                  offset_to_uV=0,
                  channel_ids=range(32),
                  )

rec = phase_shift(recording=rec,inter_sample_shift=[i/35 for i in range(0,32)])

rec = rec.set_probe(probe)
rec = common_reference(rec,operator='median')
rec = bandpass_filter(recording=rec,freq_min=700,freq_max=8000)

voltage = rec.get_traces(return_scaled=True)
voltage = voltage.transpose()

sorted_height = np.flip(xy_layout[:,1].argsort())

xy_layout_sorted = xy_layout[sorted_height,:]
voltage_sorted = voltage[sorted_height,:]

Probe = fbp.electrode.DistanceBasedProbe(sampling_rate,channel_num,xy_layout_sorted,100,voltage_array=voltage_sorted)

In [None]:
print("Start sorting")
sort_data, work_items, sorter_info = spike_sort_parallel(Probe, **spike_sort_args)

In [None]:
print("Saving neurons file as", save_fname)
with open(save_fname, 'wb') as fp:
    pickle.dump((sort_data, work_items, sorter_info), fp, protocol=-1)

### Post Processing

In [None]:
# First step in automated post-processing
# Set a few variables that can allow easy detection of units that are poor
absolute_refractory_period = 10e-4 # Refractory period (in ms) will be used to determine potential violations in sorting accuracy
# Max allowable ratio between refractory period violations and maximal bin of ACG. Units that violate will be deleted. Setting to >= 1. allows all units
max_mua_ratio = 1.
min_snr = 0 # Minimal SNR a unit must have to be included in post-processing
min_overlapping_spikes = .75 # Percentage of spikes required with nearly identical spike times in adjacent segments for them to combine in stitching

# Create the work_summary postprocessing object
work_summary = WorkItemSummary(sort_data, 
                            work_items,
                            sorter_info, 
                            absolute_refractory_period=absolute_refractory_period,
                            max_mua_ratio=max_mua_ratio, 
                            min_snr=min_snr,
                            min_overlapping_spikes=min_overlapping_spikes, 
                            verbose=False)

# No segments in the demo (segment_duration > duration of synthetic data) but done as example
work_summary.stitch_segments()

# Summarize the sorted output data into dictionaries by time segment.
work_summary.summarize_neurons_by_seg()

# Finally summarize neurons across channels (combining and removing duplicate
# neurons across space) to get a list of sorted "neurons"
neurons = work_summary.summarize_neurons_across_channels(
                    overlap_ratio_threshold=np.inf, 
                    min_segs_per_unit=1,
                    remove_clips=False)

print("Saving neurons file as", neurons_fname)
with open(neurons_fname, 'wb') as fp:
    pickle.dump(neurons, fp, protocol=-1)



In [None]:
print("Found", len(neurons), "total units with properties:")
fmtL = "Unit: {:.0f} on chans {}; n spikes = {:.0f}; FR = {:.0f}; Dur = {:.0f}; SNR = {:.2f}; MUA = {:.2f}; TolInds = {:.0f}"
for ind, n in enumerate(neurons):
    print_vals = [ind, n['channel'], n['spike_indices'].size, n['firing_rate'], n['duration_s'], n['snr']['average'], n['fraction_mua'], n['duplicate_tol_inds']]
    print(fmtL.format(*print_vals))


In [None]:
from spikesorting_fullpursuit.utils import convert_to_viz
convert_to_viz.f_neurons_to_viz(neurons_fname,neuroviz_only=True)

### Save voltage arranged by channel as dat file

In [None]:
voltage_rearrange_filename = data_dir + 'voltage_sorted.dat'
voltage_sorted.astype('int16').tofile(voltage_rearrange_filename)

### Load Data

In [None]:
with open(save_fname, 'rb') as fp:
    sorted_data = pickle.load(fp)

sort_data, work_items, sorter_info = sorted_data[0], sorted_data[1], sorted_data[2]