In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tsmoothie.smoother import GaussianSmoother
import spikeinterface
import spikeinterface.full as si
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.comparison as sc
import spikeinterface.widgets as sw
import spikeinterface.postprocessing as sp
import spikeinterface.preprocessing as spre
import spikeinterface.qualitymetrics as qm
import helper_functions as helper

In [2]:
#Reading the file, BP filtering
local_path= '/home/mmpatil/Documents/spikesorting/Data/Trace_20230317_12_45_44_1000elec.raw.h5' #network data from chip 16848

recording = se.read_maxwell(local_path)

channel_ids = recording.get_channel_ids()
fs = recording.get_sampling_frequency()
num_chan = recording.get_num_channels()
num_seg = recording.get_num_segments()
total_recording = recording.get_total_duration()

#print('Channel ids:', channel_ids)
print('Sampling frequency:', fs)
print('Number of channels:', num_chan)
print('Number of segments:', num_seg)
print(f"total_recording: {total_recording} s")
recording_chunk = recording.frame_slice(start_frame= 30*fs,end_frame=270*fs)

recording_chunk = spre.bandpass_filter(recording_chunk, freq_min=300, freq_max=6000)

recodring_chunk = spre.common_reference(recording_chunk, reference='global', operator='median')




inside get_reader
Inside NeoBaseExtractor
Sampling frequency: 20000.0
Number of channels: 724
Number of segments: 1
total_recording: 616.87 s


In [3]:
default_KS3_params = ss.get_default_sorter_params('kilosort3')
default_KS3_params['keep_good_only'] = True
default_KS3_params['detect_threshold'] = 24
default_KS3_params['projection_threshold']=[30, 30]
default_KS3_params['preclust_threshold'] = 26
print(default_KS3_params)
run_sorter = ss.run_kilosort3(recording_chunk, output_folder="kilosort3_2min_17march_block", docker_image= "kilosort3-maxwellcomplib:latest",verbose=True, **default_KS3_params)

{'detect_threshold': 24, 'projection_threshold': [30, 30], 'preclust_threshold': 26, 'car': True, 'minFR': 0.2, 'minfr_goodchannels': 0.2, 'nblocks': 5, 'sig': 20, 'freq_min': 300, 'sigmaMask': 30, 'nPCs': 3, 'ntbuff': 64, 'nfilt_factor': 4, 'do_correction': True, 'NT': None, 'wave_length': 61, 'keep_good_only': True, 'n_jobs': 48, 'chunk_duration': '1s', 'progress_bar': True}
Starting container
Installing spikeinterface==0.97.0 in kilosort3-maxwellcomplib:latest


KeyboardInterrupt: 

In [4]:
sorting_KS3 = ss.Kilosort3Sorter._get_result_from_folder('/home/mmpatil/Documents/spikesorting/MEA_Analysis/Python/kilosort3_4min_17march_block/sorter_output/')
total_units = sorting_KS3.get_unit_ids()
print(total_units)
print(len(total_units))
channel_ids = recording_chunk.get_channel_ids()
print(channel_ids)
channel_association_dict = {int(y):x for x,y in enumerate(channel_ids) }
print(channel_association_dict)

[  0   3   4   5   6   8  12  13  15  16  18  22  26  27  30  31  33  34
  35  39  40  42  43  46  47  49  50  51  53  57  58  62  66  67  68  69
  73  74  75  77  78  87  89  92  94  95  98 101 102 103 104 106 108]
53
['0' '1' '2' '3' '4' '5' '6' '7' '8' '9' '10' '11' '12' '13' '14' '15'
 '16' '17' '18' '19' '20' '21' '22' '23' '24' '25' '26' '27' '28' '29'
 '30' '31' '32' '33' '34' '35' '36' '37' '38' '39' '40' '41' '42' '43'
 '44' '45' '46' '47' '48' '49' '50' '51' '52' '53' '54' '55' '56' '57'
 '58' '59' '60' '61' '62' '63' '64' '65' '66' '67' '68' '69' '70' '71'
 '72' '73' '74' '75' '76' '77' '78' '79' '80' '81' '82' '83' '84' '85'
 '86' '88' '89' '90' '91' '92' '93' '94' '95' '97' '98' '99' '100' '101'
 '102' '104' '105' '106' '107' '108' '109' '110' '111' '112' '113' '114'
 '116' '117' '118' '119' '120' '121' '122' '124' '125' '126' '129' '133'
 '137' '141' '145' '149' '153' '157' '161' '165' '169' '173' '176' '177'
 '180' '181' '185' '189' '193' '196' '197' '198' '201' '203' '2

In [5]:
job_kwargs = dict(n_jobs=64, chunk_duration="1s", progress_bar=True)
#waveforms = si.extract_waveforms(recording_chunk,sorting_KS3,folder="./waveforms4min",overwrite=True, ms_before=1., ms_after=2.,**job_kwargs)
waveforms = si.extract_waveforms(recording_chunk,sorting_KS3,folder='./waveforms4min',load_if_exists=True)
print(waveforms)

inside get_reader
Inside NeoBaseExtractor
WaveformExtractor: 724 channels - 53 units - 1 segments
  before:20 after:40 n_per_units:500


  waveforms = si.extract_waveforms(recording_chunk,sorting_KS3,folder='./waveforms4min',load_if_exists=True)


In [None]:
pc = sp.compute_principal_components(waveforms, n_components = 3,**job_kwargs)

In [None]:
#projections = pc.get_projections(unit_id=1)

In [6]:
import spikeinterface.qualitymetrics as qm

metrics = qm.compute_quality_metrics(waveforms,**job_kwargs)

[[ 0.         0.         0.        ...  0.         6.2942505  0.       ]
 [ 0.         0.         0.        ...  0.         0.         0.       ]
 [ 6.2942505 -6.2942505  0.        ...  0.         0.         0.       ]
 ...
 [ 0.         0.        -6.2942505 ... -6.2942505  0.         0.       ]
 [-6.2942505  0.         0.        ... -6.2942505  0.        -6.2942505]
 [ 0.         0.         0.        ...  0.         0.        -6.2942505]]




Computing PCA metrics:   0%|          | 0/53 [00:00<?, ?it/s]

: 

: 

In [None]:
display(metrics)

In [None]:
extremum_channels_ids =spikeinterface.full.get_template_extremum_channel(waveforms, peak_sign='neg')
print(extremum_channels_ids)

In [None]:
unique_channels = list(set([x for x in extremum_channels_ids.values()])) 
print(unique_channels)

In [None]:


print(helper.get_key_by_value(extremum_channels_ids,'914'))

In [None]:

isi_violations_ratio, isi_violations_count = qm.compute_isi_violations(waveforms, isi_threshold_ms=1.0)
print(isi_violations_ratio)
print(isi_violations_count)

rp_contamination,rp_violation = qm.compute_refrac_period_violations(waveforms)
print(rp_contamination)
print(rp_violation)

snr_ratio = qm.compute_snrs(waveforms,peak_sign="both", peak_mode='at_index')
print(snr_ratio)

firing_rate = qm.compute_firing_rates(waveforms)
print(firing_rate)

In [None]:
import helper_functions as helper


filename = 'Extremechannels_4min.json'
helper.dumpdicttofile(extremum_channels_ids,filename)


In [None]:
violated_units = [unit for unit, ratio in isi_violations_ratio.items() if ratio > 0.0]
print(violated_units)
print(f"isi violated units:{len(violated_units)}")

refrct_violated_units = [unit for unit,ratio in rp_contamination.items() if ratio >0.0]
print(refrct_violated_units)
print(f"refract vio units:{len(refrct_violated_units)}")

In [None]:
print(sorting_KS3)

clean_sorting = sorting_KS3.remove_units(refrct_violated_units)
print(clean_sorting)
good_units = [units for units in total_units if units not in refrct_violated_units ]
print(good_units)

#now getting the wavefrom extractor

waveform_good = waveforms.select_units(good_units,new_folder='waveforms_good_100elec')
print(waveform_good)

In [None]:
%matplotlib widget

In [None]:
import spikeinterface.postprocessing as sp

locations = sp.compute_unit_locations(waveforms)
print(locations)
import numpy as np
#np.savetxt("unitloc_10mins.txt",locations)
ax = plt.subplot(111)
sw.plot_probe_map(recording,ax=ax,with_channel_ids=True)
for x,y in locations:
    ax.scatter(x,y)


In [None]:
channel_locations = recording_chunk.get_channel_locations()
channel_ids = recording_chunk.get_channel_ids()
channel_locations_mappings= {channel_id: location for location, channel_id in zip(channel_locations, channel_ids)}
print(channel_locations_mappings)



In [None]:
#thinking the unique locations have all the required electrodes.

required_electrodes= {channels :[int(locations[0]/17.5),int(locations[1]/17.5)] for channels, locations in channel_locations_mappings.items() if channels in unique_channels }
print(required_electrodes)

In [None]:
import json

with open('electrode_locations_4min.json','w') as fileptr:
    json.dump(required_electrodes,fileptr)



In [None]:
with open('electrode_locations_4min.json',) as fileptr:
    data = json.load(fileptr)
print(data)

In [None]:
selected_electrodes = [220 * locations[1] + locations[0] for locations in data.values()]
print(selected_electrodes)

In [None]:
val = recording_chunk.get_electrode_info
print(val)

In [None]:
def electrode_rectangle_indices(xmin, ymin, xmax, ymax):
    return [220 * y + x for y in range(max(ymin, 0), min(ymax + 1, 120)) for x in range(max(xmin, 0), min(xmax + 1, 220))]


def electrode_rectangle_um(xmin, ymin, xmax, ymax):
    return electrode_rectangle_indices(int(xmin / 17.5), int(ymin / 17.5), int(xmax / 17.5), int(ymax / 17.5))

XMULTIPLIER = 44
YMULTIPLIER = 20

for y in range(6):
    for x in range(5):
        #print("(",XMULTIPLIER*x,YMULTIPLIER*y,XMULTIPLIER*x+XMULTIPLIER-1,YMULTIPLIER*y+YMULTIPLIER-1,")")
        print(electrode_rectangle_indices(XMULTIPLIER*x,YMULTIPLIER*y,XMULTIPLIER*x+XMULTIPLIER-1,YMULTIPLIER*y+YMULTIPLIER-1))

# print(electrode_rectangle_indices(218,0,219,1))

In [None]:
import datetime as datetime

now = datetime.datetime.now()
strfmt = now.strftime("%Y%m%d_%H_%M_%S")
print(strfmt)

In [None]:
import maxlab.saving

obj = maxlab.saving.Saving()

obj.start()

In [None]:
fig, ax1 = plt.subplots(figsize=(15,5))
spike_times = {}
for idx, unit_id in enumerate(clean_sorting.get_unit_ids()):
    spike_train = clean_sorting.get_unit_spike_train(unit_id,start_frame=1*fs,end_frame=100*fs)
    print(spike_train)
    if len(spike_train) > 0:
        spike_times[idx] = spike_train / float(fs)
        #print(spike_times[unit_id])
       # print(unit_id*np.ones_like(spike_times[unit_id]))
        ax1.plot(spike_times[idx],idx*np.ones_like(spike_times[idx]),
                             marker='|', mew=1, markersize=3,
                             ls='',color='black')
                       

In [None]:
t_start = 0 
t_end = int(600*fs)
dt = 1
#initialising the spike train.
units= clean_sorting.get_num_units()
frame_numbers = t_end
spike_array = np.zeros((units,frame_numbers), dtype= int)
for idx, unit_id in enumerate(clean_sorting.get_unit_ids()):
    spike_train = clean_sorting.get_unit_spike_train(unit_id,start_frame=t_start,end_frame=t_end)
    for spike_time in spike_train:
        spike_array[idx,spike_time] = 1

print(spike_array)

print(spike_array[0,63782])

In [None]:
np.savez_compressed('spike_array_compressed_blockactivity.npz',spike_array)

In [None]:
with np.load('spike_array_compressed_blockactivity.npz') as data:
    decompressed_data = data['arr_0']

print(np.array_equal(spike_array, decompressed_data))


In [None]:
extremum_channels_ids =spikeinterface.full.get_template_extremum_channel(waveforms, peak_sign='neg')
print(extremum_channels_ids)



In [None]:
colors = [ 'Lime','Gold', 'Orange','Orangered']
fig, ax = plt.subplots()
wf=[]
for i, unit_id in enumerate([34, 49]):
    wf = waveforms.get_waveforms(unit_id)
    color = colors[i]
    ax.plot(wf[:, :,channel_association_dict[594]].T, color=color, lw=0.3)
print(wf.shape)

In [None]:
colors = [ 'Fuchsia','Olive', 'Teal']
fig, ax = plt.subplots()
wf=[]
for i, unit_id in enumerate([218]):
    wf = waveform_good.get_waveforms(unit_id)
    ax.plot(wf[:, :,channel_association_dict[902]].T, color=colors[0], lw=0.3)
    ax.plot(wf[:, :,channel_association_dict[613]].T, color=colors[1], lw=0.3)
    ax.plot(wf[:, :,channel_association_dict[663]].T, color=colors[2], lw=0.3)
print(wf.shape)

In [None]:
peak_shift=si.get_template_extremum_channel_peak_shift(waveform_good)

print(peak_shift)

In [None]:
colors = [ 'Fuchsia','Olive', 'Teal']
fig, ax = plt.subplots()
for i, unit_id in enumerate([26, 40 , 46]):
    template = waveforms.get_template(unit_id)
    color = colors[i]
    ax.plot(template[:, channel_association_dict[780]].T, color=color, lw=3)
print(template.shape)

In [None]:
w = sw.plot_unit_templates(waveform_good, unit_ids=[183],plot_channels=False )

In [None]:
w = sw.plot_unit_waveforms(waveform_good, unit_ids=[2,4,7])
w = sw.plot_unit_templates(waveform_good, unit_ids=[2,4,7])
w = sw.plot_unit_probe_map(waveform_good, unit_ids=[2,4,7])