In [1]:
import datetime
import matplotlib.pyplot as plt
import os
import pandas as pd 
import spikeinterface.core as sc 
import spikeinterface.curation as scu
import spikeinterface.extractors as se  
import spikeinterface.postprocessing as spost
import spikeinterface.widgets as sw
import sys 

from sklearn.metrics.pairwise import cosine_similarity

sys.path.append('src')
from src.facts import *
from src.multiregion_80pin_channels import * 
from src.process import * 
from longterm_multiregion_curation import excluded_segments

subjects = ['M16_6', 'M16_7', 'M17_2', 'M17_5', 'M10_1', 'M10_6', 'M15_2', 'M15_3', 'M15_5', 'M15_7', 'M16_1', 'M16_2', 'M9_7', 'M9_8']
regions = ['region1', 'region2']

probe = create_multi_shank_probe()
sorted_duration = 15
threshold = 4.5

In [2]:
def get_shank(channel_id, channel_indices):
    for shank, shank_channels in enumerate(channel_indices):
        if channel_id in shank_channels:
            return shank 
    raise Exception

def sample_objects(objects, max_n=None):
    if max_n is None:
        return objects
    if len(objects) < max_n:
        return objects
    else:
        return objects[np.random.choice(np.arange(len(objects)), max_n, replace=False)]

def compute_isi_violation_rate(spike_train_ms, window_ms, bin_ms, isi_threshold_ms):
    bins = np.arange(0, window_ms, bin_ms)
    isi = np.diff(spike_train_ms)
    if (len(isi) == 0) or (isi.min() > window_ms):
        return [], [], 0
    else:
        ys, bin_edges = np.histogram(isi, bins=bins)
        xs = bin_edges[:-1]
        rate = (isi < isi_threshold_ms).sum() / len(isi)
        return xs, ys, rate
    
min_firing_rate = 0.1
max_symmetry = 0.95 
plt.rcParams.update({'font.size': 15})

for subject in subjects:
    subject_folder = f'data/processed/{subject}/all'
    session_info = pd.read_csv(f'{subject_folder}/session_info.csv')
    n_segment = len(session_info['segment_path'].unique())

    init_date = datetime.datetime.strptime(surgery_dates[subject], '%Y%m%d')
    for region in regions: 
        output_folder = f'data/processed/curated{threshold}-th{max_symmetry}/{subject}/{region}'
        os.makedirs(output_folder, exist_ok=True)

        region_folder = f'{subject_folder}/{region}'
        recordings = [
            sc.load_extractor(f'{region_folder}/recording/segment{segment_index}').set_probe(probe) 
            for segment_index in range(n_segment)
        ]
        recordings = [
            preprocess(segment_recording) for segment_recording in recordings
        ]
        recordings = [
            segment_recording.frame_slice(
                start_frame=0, 
                end_frame=min(segment_recording.get_num_frames(), int(sorted_duration*n_s_per_min*segment_recording.sampling_frequency))
            ) for segment_recording in recordings
        ]
        recording = sc.concatenate_recordings(recordings).set_probe(probe, in_place=True)

        n_frames_per_ms = int(recording.sampling_frequency // n_ms_per_s)

        sorting = se.NpzSortingExtractor(f'{region_folder}/sorting{threshold}-{sorted_duration}min/sorter_output/firings.npz')
        # spikeinterface https://github.com/SpikeInterface/spikeinterface/pull/1378
        sorting = scu.remove_excess_spikes(sorting, recording)  
        sorting = sc.split_sorting(sorting, recordings)
        sortings = [sc.select_segment_sorting(sorting, segment_indices=segment_index) for segment_index in range(n_segment)]

        waveform_extractors = [sc.load_waveforms(folder=f'{region_folder}/waveform{threshold}-{sorted_duration}min/segment{segment_index}', with_recording=False, sorting=sortings[segment_index]) for segment_index in range(n_segment)]

        for segment, segment_recording in enumerate(recordings):
            segment_recording.set_probe(probe, in_place=True)
            waveform_extractors[segment].set_recording(segment_recording)
            spost.compute_unit_locations(waveform_extractors[segment], load_if_exists=True)

        nrows = n_segment
        ncols = 8
        for unit_id in sorting.unit_ids:
            unit_plot_file = f'{output_folder}/{unit_id}.png'
            if not os.path.isfile(unit_plot_file):
                fig = plt.figure(figsize=(ncols*3, nrows*3))
                plot_realized = False
                for segment_index, segment_path in enumerate(session_info['segment_path'].unique()):
                    segment_date = datetime.datetime.strptime(segment_path.split('/')[-1].split('_')[-2], '%y%m%d')
                    lapse = ((segment_date - init_date).days) / n_day_per_week

                    extremum_channel = sc.get_template_extremum_channel(waveform_extractors[segment_index], peak_sign='neg')[unit_id]
                    extremum_shank = get_shank(extremum_channel, channel_indices)
                    extremum_channel_indices = channel_indices[extremum_shank]

                    segment_waveforms = waveform_extractors[segment_index].get_waveforms(unit_id)[:, :, extremum_channel_indices]
                    segment_firing_rate = len(segment_waveforms) / waveform_extractors[segment_index].get_total_duration()

                    if (segment_firing_rate < min_firing_rate):
                        continue
                    
                    segment_templates = waveform_extractors[segment_index].get_template(unit_id)
                    extremum_template = segment_templates[:, extremum_channel]
                    template_symmetry = cosine_similarity([extremum_template[:ms_before*n_frames_per_ms]], [extremum_template[ms_before*n_frames_per_ms:][::-1]]).item()

                    if (template_symmetry > max_symmetry):
                        continue
                    
                    if excluded_segments[subject][region].get(unit_id, []) == 'all':
                        continue
                    elif unit_id in excluded_segments[subject][region] and segment_index in excluded_segments[subject][region][unit_id]:
                        continue

                    spike_train_ms = sortings[segment_index].get_unit_spike_train(unit_id=unit_id) / n_frames_per_ms
                    xs, ys, isi_violation_rate = compute_isi_violation_rate(spike_train_ms, window_ms, bin_ms, isi_threshold_ms)

                    if isi_violation_rate > 0.1: 
                        continue

                    segment_templates = segment_templates[:, extremum_channel_indices].T.flatten()
                    segment_waveforms = segment_waveforms.transpose(0, 2, 1).reshape(segment_waveforms.shape[0], segment_waveforms.shape[1]*segment_waveforms.shape[2])

                    ax = plt.subplot(nrows, ncols, segment_index*ncols+1)
                    ax.plot(extremum_template, color='black')
                    ax.set_title(f'[{segment_index}]')

                    ax = plt.subplot(nrows, ncols, segment_index*ncols+2)
                    sw.plot_unit_templates(waveform_extractors[segment_index], unit_ids=[unit_id], axes=[ax], unit_colors={unit_id:plt.cm.turbo(lapse / total_week)})

                    ax = plt.subplot(nrows, ncols, segment_index*ncols+3)
                    ax.bar(x=xs, height=ys, width=bin_ms, color=plt.cm.turbo(lapse / total_week), align="edge")
                    ax.set_title(f'ISI rate ({isi_threshold_ms}ms): {isi_violation_rate*100:0.1f}%')
                    ax.set_xlabel('time (ms)')

                    ax = plt.subplot(nrows, ncols, segment_index*ncols+4)
                    sw.plot_autocorrelograms(sortings[segment_index], window_ms=window_ms, bin_ms=bin_ms, unit_ids=[unit_id], axes=[ax], unit_colors={unit_id:plt.cm.turbo(lapse / total_week)})

                    ax = plt.subplot(nrows, 4, segment_index*4+3)
                    ax.plot(segment_templates, label=unit_id, color=plt.cm.turbo(lapse / total_week))

                    ax = plt.subplot(nrows, 4, segment_index*4+4)
                    ax.plot(segment_waveforms.T, label=unit_id, lw=0.5, color=plt.cm.turbo(lapse / total_week))
                    ax.set_ylim(-np.abs(segment_templates).max()-10, np.abs(segment_templates).max()+10)
                    ax.set_title(f'{lapse:0.2f} weeks')

                    plot_realized = True
                plt.tight_layout()
                if plot_realized:
                    plt.savefig(unit_plot_file)
                    # plt.show()
                plt.close()