In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import scipy

%matplotlib qt5

import torch
import torchaudio
import torchaudio.functional as F

from spikeinterface.preprocessing import common_reference, bandpass_filter, phase_shift
from spikeinterface.extractors import read_binary

import spikesorting_fullpursuit as fbp
from spikesorting_fullpursuit.parallel.spikesorting_parallel import spike_sort_parallel
from spikesorting_fullpursuit.postprocessing import WorkItemSummary

import probes
import upsample
import filter

In [2]:
channel_num = 32
sampling_rate = 30000
neighbor_distance = 100 #Distance in microns to add channels to look for templates

#data_dir = 'C:/Users/wanglab/Data/Licking/dirt/sorting_test_zca/'
data_dir = 'C:/Users/wanglab/Data/PrV/112923_2/'
filename = data_dir + 'amplifier.dat'

save_fname = data_dir + 'sorted_data.pickle'

neurons_fname = data_dir + 'sorted_neurons.pickle'

log_dir = data_dir + 'logs'

broken_channels = [16]

# Setup the sorting parameters dictionary.
spike_sort_args = {
        'sigma': 4.0, # Threshold based on noise level
        'clip_width': [-15e-4, 15e-4], # Width of clip in seconds
        'p_value_cut_thresh': 0.01,
        'segment_duration': np.Inf, #Seconds of duration; None Uses entire recording
        'segment_overlap': 150, #Seconds of overlap
        'do_branch_PCA': True,
        'do_branch_PCA_by_chan': True,
        'do_overlap_recheck': True,
        'filter_band': (300, 5000), #Intan filter bandwidth
        'do_ZCA_transform': True,
        'check_components': 20,
        'max_components': 5,
        'min_firing_rate': 0.1,
        'use_rand_init': True,
        'add_peak_valley': False,
        'max_gpu_memory': .1 * (1024 * 1024 * 1024),
        'save_1_cpu': True,
        'sort_peak_clips_only': True,
        'n_cov_samples': 100000, #Used to produce noise covariance matrix. Seems to plateau at 100,000
        'sigma_bp_noise': 2.326, # Number of noise standard deviations an expected template match must exceed the decision boundary by. Otherwise it is a candidate for deletion or increased threshold. Higher values = lower false positives and higher false negatives
  #      'sigma_bp_CI': 12.0,
        #'bp_chan_snr': 2.0, # SNR required for a template on a given channel to be used for binary pursuit. Channels lower than this are set to zero template signal.
        'absolute_refractory_period': 10e-4,
        'get_adjusted_clips': False,
        'max_binary_pursuit_clip_width_factor': 1.0,
        'wiener_filter': True,
        'verbose': True,
        'test_flag': False,
        'log_dir': log_dir,
        }

#Probe Type
probe = probes.create_poly3()
xy_layout = probes.create_poly3_layout()

In [3]:
rec = read_binary(filename,sampling_frequency=30000,
                  dtype='int16',
                  num_channels=32,
                  time_axis=0,
                  gain_to_uV=0.195,
                  offset_to_uV=0,
                  channel_ids=range(32),
                  )

#rec1 = phase_shift(recording=rec,inter_sample_shift=[i/35 for i in range(0,32)])

#rec1 = rec1.set_probe(probe)
#rec1 = common_reference(rec1,operator='median')
#rec1 = bandpass_filter(recording=rec1,freq_min=700,freq_max=8000)

#voltage = rec1.get_traces(return_scaled=True)
#voltage = voltage.transpose()

voltage_raw = rec.get_traces(return_scaled=True)
voltage_raw = voltage_raw.transpose()

sorted_height = np.flip(xy_layout[:,1].argsort())

#xy_layout_sorted = xy_layout[sorted_height,:]
#voltage_sorted = voltage[sorted_height,:]

#Probe = fbp.electrode.DistanceBasedProbe(sampling_rate,channel_num,xy_layout_sorted,100,voltage_array=voltage_sorted)

In [4]:
upsample_factor = 4

up_voltage = upsample.upsample_median_subtraction(voltage_raw, 
                                    sampling_rate, 
                                    upsample_factor,
                                    dead_channels=[2],
                                    window_size=300000)
for i in range(0,channel_num):
    up_voltage[i,:] = filter.butter_bandpass_filter(up_voltage[i,:],
                                                    700,
                                                    8000,
                                                    upsample_factor*30000,
                                                    order=4)
    print("Filtered channel " + str(i))

import gc
torch.cuda.empty_cache()
gc.collect()

Using device: cuda
Total windows: 54
Processed window 0 of 54
Processed window 1 of 54
Processed window 2 of 54
Processed window 3 of 54
Processed window 4 of 54
Processed window 5 of 54
Processed window 6 of 54
Processed window 7 of 54
Processed window 8 of 54
Processed window 9 of 54
Processed window 10 of 54
Processed window 11 of 54
Processed window 12 of 54
Processed window 13 of 54
Processed window 14 of 54
Processed window 15 of 54
Processed window 16 of 54
Processed window 17 of 54
Processed window 18 of 54
Processed window 19 of 54
Processed window 20 of 54
Processed window 21 of 54
Processed window 22 of 54
Processed window 23 of 54
Processed window 24 of 54
Processed window 25 of 54
Processed window 26 of 54
Processed window 27 of 54
Processed window 28 of 54
Processed window 29 of 54
Processed window 30 of 54
Processed window 31 of 54
Processed window 32 of 54
Processed window 33 of 54
Processed window 34 of 54
Processed window 35 of 54
Processed window 36 of 54
Processed w

In [39]:
sorted_height = np.flip(xy_layout[:,1].argsort())

xy_layout_sorted = xy_layout[sorted_height,:]
voltage_sorted = up_voltage[sorted_height,:]

voltage_sorted = voltage_sorted.astype(np.float32)

Probe = fbp.electrode.DistanceBasedProbe(sampling_rate*upsample_factor,channel_num,xy_layout_sorted,100,voltage_array=voltage_sorted)

In [28]:
channel = sorted_height[30]
#channel = 2

#other_channels = np.arange(0,32)
#other_channels = np.delete(other_channels,sorted_height[21])
#other_channels = np.delete(other_channels,[2])

t_start = 600000
t_end = 1200000

fig, ax = plt.subplots(1,1)

n_samples = voltage.shape[1]

t = np.linspace(0,n_samples,n_samples) / sampling_rate
u = np.linspace(0,n_samples,upsample_factor*n_samples) / sampling_rate

plt.plot(u[(t_start*upsample_factor):(t_end*upsample_factor)],
         up_voltage[channel,(t_start*upsample_factor):(t_end*upsample_factor)]
         )

plt.plot(t[(t_start):(t_end)],
         voltage[channel,(t_start):(t_end)])

[<matplotlib.lines.Line2D at 0x16e208bccd0>]

In [None]:
voltage_raw

In [40]:
print("Start sorting")
sort_data, work_items, sorter_info = spike_sort_parallel(Probe, **spike_sort_args)

Start sorting
Using  1 segments per channel for sorting.
Doing parallel ZCA transform and thresholding for 1 segments


In [6]:
print("Saving neurons file as", save_fname)
with open(save_fname, 'wb') as fp:
    pickle.dump((sort_data, work_items, sorter_info), fp, protocol=-1)

Saving neurons file as C:/Users/wanglab/Data/PrV/120523_1/sorted_data.pickle


### Post Processing

In [7]:
# First step in automated post-processing
# Set a few variables that can allow easy detection of units that are poor
absolute_refractory_period = 10e-4 # Refractory period (in ms) will be used to determine potential violations in sorting accuracy
# Max allowable ratio between refractory period violations and maximal bin of ACG. Units that violate will be deleted. Setting to >= 1. allows all units
max_mua_ratio = 1.
min_snr = 0 # Minimal SNR a unit must have to be included in post-processing
min_overlapping_spikes = .75 # Percentage of spikes required with nearly identical spike times in adjacent segments for them to combine in stitching

# Create the work_summary postprocessing object
work_summary = WorkItemSummary(sort_data, 
                            work_items,
                            sorter_info, 
                            absolute_refractory_period=absolute_refractory_period,
                            max_mua_ratio=max_mua_ratio, 
                            min_snr=min_snr,
                            min_overlapping_spikes=min_overlapping_spikes, 
                            verbose=False)

# No segments in the demo (segment_duration > duration of synthetic data) but done as example
work_summary.stitch_segments()

# Summarize the sorted output data into dictionaries by time segment.
work_summary.summarize_neurons_by_seg()

# Finally summarize neurons across channels (combining and removing duplicate
# neurons across space) to get a list of sorted "neurons"
neurons = work_summary.summarize_neurons_across_channels(
                    overlap_ratio_threshold=np.inf, 
                    min_segs_per_unit=1,
                    remove_clips=False)

print("Saving neurons file as", neurons_fname)
with open(neurons_fname, 'wb') as fp:
    pickle.dump(neurons, fp, protocol=-1)



No overlap between segments found. Switching stitch_overlap_only to False.
Least MUA removed was inf on channel None segment None
Maximum SNR removed was -inf on channel None segment None
Start stitching channel 0
Start stitching channel 1
Start stitching channel 2
Start stitching channel 3
Start stitching channel 4
Start stitching channel 5
Start stitching channel 6
Start stitching channel 7
Start stitching channel 8
Start stitching channel 9
Start stitching channel 10
Start stitching channel 11
Start stitching channel 12
Start stitching channel 13
Start stitching channel 14
Start stitching channel 15
Start stitching channel 16
Start stitching channel 17
Start stitching channel 18
Start stitching channel 19
Start stitching channel 20
Start stitching channel 21
Start stitching channel 22
Start stitching channel 23
Start stitching channel 24
Start stitching channel 25
Start stitching channel 26
Start stitching channel 27
Start stitching channel 28
Start stitching channel 29
Start stitch

In [8]:
print("Found", len(neurons), "total units with properties:")
fmtL = "Unit: {:.0f} on chans {}; n spikes = {:.0f}; FR = {:.0f}; Dur = {:.0f}; SNR = {:.2f}; MUA = {:.2f}; TolInds = {:.0f}"
for ind, n in enumerate(neurons):
    print_vals = [ind, n['channel'], n['spike_indices'].size, n['firing_rate'], n['duration_s'], n['snr']['average'], n['fraction_mua'], n['duplicate_tol_inds']]
    print(fmtL.format(*print_vals))


Found 47 total units with properties:
Unit: 0 on chans [0]; n spikes = 5096; FR = 15; Dur = 332; SNR = 5.32; MUA = 0.49; TolInds = 7
Unit: 1 on chans [0]; n spikes = 13994; FR = 42; Dur = 334; SNR = 2.56; MUA = 0.94; TolInds = 6
Unit: 2 on chans [1]; n spikes = 5623; FR = 17; Dur = 334; SNR = 3.17; MUA = 0.62; TolInds = 6
Unit: 3 on chans [1]; n spikes = 7370; FR = 22; Dur = 334; SNR = 2.05; MUA = 0.74; TolInds = 7
Unit: 4 on chans [2]; n spikes = 11813; FR = 35; Dur = 334; SNR = 5.27; MUA = 1.00; TolInds = 8
Unit: 5 on chans [2]; n spikes = 1818; FR = 5; Dur = 333; SNR = 3.00; MUA = 0.50; TolInds = 9
Unit: 6 on chans [3]; n spikes = 5779; FR = 17; Dur = 334; SNR = 4.74; MUA = 0.75; TolInds = 5
Unit: 7 on chans [4]; n spikes = 5066; FR = 15; Dur = 334; SNR = 3.00; MUA = 1.00; TolInds = 5
Unit: 8 on chans [5]; n spikes = 3452; FR = 10; Dur = 334; SNR = 2.18; MUA = 0.96; TolInds = 7
Unit: 9 on chans [5]; n spikes = 3440; FR = 10; Dur = 334; SNR = 2.07; MUA = 1.00; TolInds = 7
Unit: 10 on

In [9]:
from spikesorting_fullpursuit.utils import convert_to_viz
convert_to_viz.f_neurons_to_viz(neurons_fname,neuroviz_only=True)

Loading neurons file C:/Users/wanglab/Data/PrV/120523_1/sorted_neurons.pickle
No filename specified, using 'default_fname'
Saved NeuroViz file: C:/Users/wanglab/Data/PrV/120523_1/sorted_neurons_viz.pkl


### Save voltage arranged by channel as dat file

In [10]:
voltage_rearrange_filename = data_dir + 'voltage_sorted.dat'
voltage_sorted.astype('int16').tofile(voltage_rearrange_filename)

### Load Data

In [None]:
with open(save_fname, 'rb') as fp:
    sorted_data = pickle.load(fp)

sort_data, work_items, sorter_info = sorted_data[0], sorted_data[1], sorted_data[2]