In [None]:
# Import other useful Python packages
import numpy as np
import matplotlib.pyplot as plt
import spikesorting_fullpursuit as fbp
from spikesorting_fullpursuit.parallel.spikesorting_parallel import spike_sort_parallel
import pickle

In [None]:
channel_num = 32

data_dir = 'C:/Users/wanglab/Data/PrV/061523_1/'

filename = data_dir + 'amplifier.dat'

save_fname = data_dir + 'sorted_data.pickle'

neurons_fname = data_dir + 'sorted_neurons.pickle'

log_dir = data_dir + 'logs'

# Setup the sorting parameters dictionary.
spike_sort_args = {
        'sigma': 4.0, # Threshold based on noise level
        'clip_width': [-10e-4, 10e-4], # Width of clip in seconds
        'p_value_cut_thresh': 0.01,
        'segment_duration': np.inf, #None Uses entire recording
        'segment_overlap': 150,
        'do_branch_PCA': True,
        'do_branch_PCA_by_chan': True,
        'do_overlap_recheck': True,
        'filter_band': (300, 5000), #Intan filter bandwidth
        'do_ZCA_transform': True,
        'check_components': 20,
        'max_components': 5,
        'min_firing_rate': 0.1,
        'use_rand_init': True,
        'add_peak_valley': False,
        'max_gpu_memory': .1 * (1024 * 1024 * 1024),
        'save_1_cpu': True,
        'sort_peak_clips_only': True,
        'n_cov_samples': 100000, #Used to produce noise covariance matrix. Seems to plateau at 100,000
        'sigma_bp_noise': 2.326,
  #      'sigma_bp_CI': 12.0,
        'absolute_refractory_period': 10e-4,
        'get_adjusted_clips': False,
        'max_binary_pursuit_clip_width_factor': 1.0,
        'verbose': True,
        'test_flag': False,
        'log_dir': log_dir,
        }

In [None]:
voltage = np.fromfile(filename, dtype=np.int16)
voltage = voltage.reshape(int(voltage.shape[0] / channel_num),channel_num)
voltage = voltage.transpose();
voltage = np.float32(voltage)

In [None]:
""" xy_layout is 2D numpy array where each row represents its
corresonding channel number and each column gives the x, y coordinates
of that channel in micrometers. """

xy_layout = np.array([
    [21.65, 25], #0
    [21.65, 725], #1
    [21.65, 125], #2
    [21.65, 325],
    [21.65, 375], #4
    [21.65, 475], #5
    [21.65, 675], #6
    [21.65, 775], #7
    [21.65, 225], #8
    [21.65, 275], #9
    [21.65, 575], #10
    [21.65, 525], #11
    [21.65, 425], #12
    [21.65, 175], #13
    [21.65, 625], #14
    [21.65, 75], #15
    [-21.65, 0], #16
    [-21.65, 450], #17
    [-21.65, 100], #18
    [-21.65, 250], #19
    [-21.65, 600], #20
    [-21.65, 400], #21
    [-21.65, 550], #22
    [-21.65, 750], #23
    [-21.65, 200], #24
    [-21.65, 650], #25
    [-21.65, 350], #26
    [-21.65, 500], #27
    [-21.65, 700], #28
    [-21.65, 150], #29
    [-21.65, 300], #30
    [-21.65, 50] #31   
    ])

sampling_rate = 30000
neighbor_distance = 100 #Distance in microns to add channels to look for templates

Probe = fbp.electrode.DistanceBasedProbe(sampling_rate,channel_num,xy_layout,100,voltage_array=voltage)

In [None]:
#Probe.bandpass_filter_parallel(spike_sort_args['filter_band'][0], spike_sort_args['filter_band'][1])

In [None]:
print("Start sorting")
sort_data, work_items, sorter_info = spike_sort_parallel(Probe, **spike_sort_args)

In [None]:
print("Saving neurons file as", save_fname)
with open(save_fname, 'wb') as fp:
    pickle.dump((sort_data, work_items, sorter_info), fp, protocol=-1)

### Post Processing

In [None]:
from spikesorting_fullpursuit.postprocessing import WorkItemSummary

In [None]:
# First step in automated post-processing
# Set a few variables that can allow easy detection of units that are poor
absolute_refractory_period = 10e-4 # Refractory period (in ms) will be used to determine potential violations in sorting accuracy
# Max allowable ratio between refractory period violations and maximal bin of ACG. Units that violate will be deleted. Setting to >= 1. allows all units
max_mua_ratio = 1.
min_snr = 0 # Minimal SNR a unit must have to be included in post-processing
min_overlapping_spikes = .75 # Percentage of spikes required with nearly identical spike times in adjacent segments for them to combine in stitching

# Create the work_summary postprocessing object
work_summary = WorkItemSummary(sort_data, work_items,
                sorter_info, absolute_refractory_period=absolute_refractory_period,
                max_mua_ratio=max_mua_ratio, min_snr=min_snr,
                min_overlapping_spikes=min_overlapping_spikes, verbose=False)

# No segments in the demo (segment_duration > duration of synthetic data) but done as example
work_summary.stitch_segments()

# Summarize the sorted output data into dictionaries by time segment.
work_summary.summarize_neurons_by_seg()

# Finally summarize neurons across channels (combining and removing duplicate
# neurons across space) to get a list of sorted "neurons"
neurons = work_summary.summarize_neurons_across_channels(
                    overlap_ratio_threshold=np.inf, 
                    min_segs_per_unit=1,
                    remove_clips=False)

print("Saving neurons file as", neurons_fname)
with open(neurons_fname, 'wb') as fp:
    pickle.dump(neurons, fp, protocol=-1)

In [None]:
print("Found", len(neurons), "total units with properties:")
fmtL = "Unit: {:.0f} on chans {}; n spikes = {:.0f}; FR = {:.0f}; Dur = {:.0f}; SNR = {:.2f}; MUA = {:.2f}; TolInds = {:.0f}"
for ind, n in enumerate(neurons):
    print_vals = [ind, n['channel'], n['spike_indices'].size, n['firing_rate'], n['duration_s'], n['snr']['average'], n['fraction_mua'], n['duplicate_tol_inds']]
    print(fmtL.format(*print_vals))


### Plot Data

In [None]:
start_win = round(101.5 * 30000)
end_win = round(101.5 * 30000) + 2000
clip_win = 30

total_neurons = 32;

for n in range(0,total_neurons):
    plt.plot(voltage[n,start_win:end_win] + n * 3000,linewidth = 0.5, color="blue")

for n in range(0,len(neurons)):
    
    color = tuple(np.random.random(size=3))
    
    for spike in neurons[n]['spike_indices']:
        if ((spike > start_win) & (spike < end_win)):
            
            channels = Probe.get_neighbors(neurons[n]['channel'][0])
            
            for c in channels:
                voltage_range = range((spike-clip_win),(spike+clip_win))
                plot_range = range((spike-clip_win) - start_win,(spike+clip_win) - start_win)
                plt.plot(plot_range,voltage[c,voltage_range] + c * 3000,color=color, linewidth = 1)

In [None]:
start_win = round(101.55 * 30000)
end_win = round(101.55 * 30000) + 1000
clip_win = 30

noi = 25;

plt.plot(voltage[noi,start_win:end_win],linewidth = 0.5, color="blue")

#for n in range(0,len(neurons)):
for n in range(32,34):
    
    channels = Probe.get_neighbors(neurons[n]['channel'][0])
    
    if any(x == noi for x in channels):
    
        color = tuple(np.random.random(size=3))
    
        for spike in neurons[n]['spike_indices']:
            if ((spike > start_win) & (spike < end_win)):
                voltage_range = range((spike-clip_win),(spike+clip_win))
                plot_range = range((spike-clip_win) - start_win,(spike+clip_win) - start_win)
                plt.plot(plot_range,voltage[noi,voltage_range] + n * 1000,color=color, linewidth = 1)

In [None]:
# ACF from one unit
plt.hist(np.diff(neurons[32]['spike_indices']),bins=np.arange(0,3000,30));

In [None]:
# ACF from another unit
plt.hist(np.diff(neurons[33]['spike_indices']),bins=np.arange(0,3000,30))

### Load Data

In [None]:
with open(save_fname, 'rb') as fp:
    sorted_data = pickle.load(fp)

sort_data, work_items, sorter_info = sorted_data[0], sorted_data[1], sorted_data[2]