In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tsmoothie.smoother import GaussianSmoother
import spikeinterface
import spikeinterface.full as si
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.comparison as sc
import spikeinterface.widgets as sw
import spikeinterface.postprocessing as sp
import spikeinterface.preprocessing as spre
import spikeinterface.qualitymetrics as qm
import helper_functions as helper

In [3]:
local_path= '/mnt/disk15tb/mmpatil/Spikesorting/Data/mandar_div18/Trace_20230410_16_04_40.raw.h5' 

recording2 = se.read_maxwell(local_path)
#recording = si.ConcatenateSegmentRecording([recording1,recording2])
channel_ids = recording2.get_channel_ids()
fs = recording2.get_sampling_frequency()
num_chan = recording2.get_num_channels()
num_seg = recording2.get_num_segments()
total_recording = recording2.get_total_duration()
res = recording2.get_probegroup()
print(res)

#print('Channel ids:', channel_ids)
print('Sampling frequency:', fs)
print('Number of channels:', num_chan)
print('Number of segments:', num_seg)
print(f"total_recording: {total_recording} s")

recording_bp = spre.bandpass_filter(recording2, freq_min=300, freq_max=6000)

recodring_cmr = spre.common_reference(recording_bp, reference='global', operator='median')
recording_chunk = recodring_cmr.frame_slice(start_frame= 10*fs,end_frame=310*fs)
print(f"chunk duration: {recording_chunk.get_total_duration()} s")

<probeinterface.probegroup.ProbeGroup object at 0x7fd3374b45b0>
Sampling frequency: 20000.0
Number of channels: 724
Number of segments: 1
total_recording: 690.19 s
chunk duration: 300.0 s


In [None]:
default_KS3_params = ss.get_default_sorter_params('kilosort3')
default_KS3_params['keep_good_only'] = True
default_KS3_params['detect_threshold'] = 24
default_KS3_params['projection_threshold']=[30, 30]
default_KS3_params['preclust_threshold'] = 26
print(default_KS3_params)
run_sorter = ss.run_kilosort3(recording_chunk, output_folder="kilosort3_block2_5min", docker_image= "kilosort3-maxwellcomplib:latest",verbose=True, **default_KS3_params)

In [None]:
sorting_KS3 = ss.Kilosort3Sorter._get_result_from_folder('/home/mmpatil/Documents/spikesorting/MEA_Analysis/Python/kilosort3_block2_5min/sorter_output/')
total_units = sorting_KS3.get_unit_ids()
print(total_units)
print(len(total_units))
channel_ids = recording_chunk.get_channel_ids()
print(channel_ids)
channel_association_dict = {int(y):x for x,y in enumerate(channel_ids) }
print(channel_association_dict)

In [None]:
job_kwargs = dict(n_jobs=64, chunk_duration="1s", progress_bar=True)
waveforms = si.extract_waveforms(recording_chunk,sorting_KS3,folder="./waveformsblock2_4min",overwrite=True, ms_before=1., ms_after=2.,**job_kwargs)
#waveforms = si.extract_waveforms(recording_chunk,sorting_KS3,folder='./waveformsblock1',load_if_exists=True)
print(waveforms)

In [None]:
pc = sp.compute_principal_components(waveforms, n_components = 3,**job_kwargs)

In [None]:
import spikeinterface.qualitymetrics as qm

metrics = qm.compute_quality_metrics(waveforms,**job_kwargs)

In [None]:
display(metrics)

In [None]:
extremum_channels_ids =spikeinterface.full.get_template_extremum_channel(waveforms, peak_sign='neg')
print(extremum_channels_ids)



In [None]:


print(helper.get_key_by_value(extremum_channels_ids,'625'))

In [None]:

isi_violations_ratio, isi_violations_count = qm.compute_isi_violations(waveforms, isi_threshold_ms=1.0)
print(isi_violations_ratio)
print(isi_violations_count)

rp_contamination,rp_violation = qm.compute_refrac_period_violations(waveforms)
print(rp_contamination)
print(rp_violation)
914
snr_ratio = qm.compute_snrs(waveforms,peak_sign="both", peak_mode='at_index')
print(snr_ratio)

firing_rate = qm.compute_firing_rates(waveforms)
print(firing_rate)

In [None]:
import helper_functions as helper


filename = 'Extremechannels_4min.json'
helper.dumpdicttofile(extremum_channels_ids,filename)


In [None]:
violated_units = [unit for unit, ratio in isi_violations_ratio.items() if ratio > 0.0]
print(violated_units)
print(f"isi violated units:{len(violated_units)}")

refrct_violated_units = [unit for unit,ratio in rp_contamination.items() if ratio >0.0]
print(refrct_violated_units)
print(f"refract vio units:{len(refrct_violated_units)}")

In [None]:
print(sorting_KS3)

clean_sorting = sorting_KS3.remove_units(refrct_violated_units)
print(clean_sorting)
good_units = [units for units in total_units if units not in refrct_violated_units ]
print(good_units)

#now getting the wavefrom extractor

waveform_good = waveforms.select_units(good_units,new_folder='waveforms_good_100elec')
print(waveform_good)

In [None]:
%matplotlib widget

In [None]:
import spikeinterface.postprocessing as sp

locations = sp.compute_unit_locations(waveforms)
print(type(locations))
import numpy as np
#np.savetxt("unitloc_10mins.txt",locations)
ax = plt.subplot(111)
sw.plot_probe_map(recording2,ax=ax,with_channel_ids=False)
for x,y in locations:
    ax.scatter(x,y)

In [None]:
channel_locations = recording_chunk.get_channel_locations()
channel_ids = recording_chunk.get_channel_ids()
_ = [print(f"{channel_id}: {location}") for location, channel_id in zip(channel_locations, channel_ids)]



In [None]:
fig, ax1 = plt.subplots(figsize=(15,5))
spike_times = {}
for idx, unit_id in enumerate(clean_sorting.get_unit_ids()):
    spike_train = clean_sorting.get_unit_spike_train(unit_id,start_frame=1*fs,end_frame=100*fs)
    print(spike_train)
    if len(spike_train) > 0:
        spike_times[idx] = spike_train / float(fs)
        #print(spike_times[unit_id])
       # print(unit_id*np.ones_like(spike_times[unit_id]))
        ax1.plot(spike_times[idx],idx*np.ones_like(spike_times[idx]),
                             marker='|', mew=1, markersize=3,
                             ls='',color='black')
                       

In [None]:
t_start = 0 
t_end = int(600*fs)
dt = 1
#initialising the spike train.
units= clean_sorting.get_num_units()
frame_numbers = t_end
spike_array = np.zeros((units,frame_numbers), dtype= int)
for idx, unit_id in enumerate(clean_sorting.get_unit_ids()):
    spike_train = clean_sorting.get_unit_spike_train(unit_id,start_frame=t_start,end_frame=t_end)
    for spike_time in spike_train:
        spike_array[idx,spike_time] = 1

print(spike_array)

print(spike_array[0,63782])

In [None]:
np.savez_compressed('spike_array_compressed_blockactivity.npz',spike_array)

In [None]:
with np.load('spike_array_compressed_blockactivity.npz') as data:
    decompressed_data = data['arr_0']

print(np.array_equal(spike_array, decompressed_data))


In [None]:
extremum_channels_ids =spikeinterface.full.get_template_extremum_channel(waveforms, peak_sign='neg')
print(extremum_channels_ids)



In [None]:
colors = [ 'Lime','Gold', 'Orange','Orangered']
fig, ax = plt.subplots()
wf=[]
for i, unit_id in enumerate([34, 49]):
    wf = waveforms.get_waveforms(unit_id)
    color = colors[i]
    ax.plot(wf[:, :,channel_association_dict[594]].T, color=color, lw=0.3)
print(wf.shape)

In [None]:
colors = [ 'Fuchsia','Olive', 'Teal']
fig, ax = plt.subplots()
wf=[]
for i, unit_id in enumerate([218]):
    wf = waveform_good.get_waveforms(unit_id)
    ax.plot(wf[:, :,channel_association_dict[902]].T, color=colors[0], lw=0.3)
    ax.plot(wf[:, :,channel_association_dict[613]].T, color=colors[1], lw=0.3)
    ax.plot(wf[:, :,channel_association_dict[663]].T, color=colors[2], lw=0.3)
print(wf.shape)

In [None]:
peak_shift=si.get_template_extremum_channel_peak_shift(waveform_good)

print(peak_shift)

In [None]:
colors = [ 'Fuchsia','Olive', 'Teal']
fig, ax = plt.subplots()
for i, unit_id in enumerate([26, 40 , 46]):
    template = waveforms.get_template(unit_id)
    color = colors[i]
    ax.plot(template[:, channel_association_dict[780]].T, color=color, lw=3)
print(template.shape)

In [None]:
w = sw.plot_unit_templates(waveform_good, unit_ids=[183],plot_channels=False )

In [None]:
w = sw.plot_unit_waveforms(waveform_good, unit_ids=[2,4,7])
w = sw.plot_unit_templates(waveform_good, unit_ids=[2,4,7])
w = sw.plot_unit_probe_map(waveform_good, unit_ids=[2,4,7])