In [4]:
def create_epochs(raw_filt, detections, tmin=-1, tmax=1, picks_raw=True, no_first_sample=False):
    """ Here we create epochs for events"""
    import mne
    import numpy as np
    eve_id = 1
    eve_name = 'ICA_det'
    

    raw_filt.load_data()
    
    new_events, eve = [], []
    if no_first_sample:
        first_samp = 0
    else:
        first_samp = raw_filt.first_samp
    
    for spike_time in detections:
        eve = [int(round(spike_time + first_samp)), 0, eve_id]
        new_events.append(eve)
    
    ch_name = 'ICA'
    if ch_name not in raw_filt.info['ch_names']:
        stim_data = np.zeros((1, len(raw_filt.times)))
        info_sp = mne.create_info([ch_name], raw_filt.info['sfreq'], ['stim'])
        stim_sp = mne.io.RawArray(stim_data, info_sp, verbose=False)
        raw_filt.add_channels([stim_sp], force_update_info=True)

    raw_filt.add_events(new_events, stim_channel=ch_name, replace=True)
    events= mne.find_events(raw_filt, stim_channel=ch_name, verbose=False)
    event_id = {eve_name: eve_id}
    picks = mne.pick_types(raw_filt.info, meg=picks_raw, eeg=False, eog=False)
    epochs = mne.Epochs(raw_filt, events, event_id,  tmin, tmax, baseline=None, picks=picks, preload=True, verbose=False)
    return epochs
    del raw_filt, picks, event_id
    
def find_nearest(array, value):
    import numpy as np
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]

def plot_topomap(temp_evocked, evocked, svg_path, temp_n, time, amplitude, one_spike=False, diff=0):
    import matplotlib.pyplot as plt
    from mne.viz import plot_evoked_topo
    
    colors = 'blue','red'
    evocked.data = evocked.data*10**13
    temp_evocked.data = temp_evocked.data*10**4.5
    temp_evocked.times = np.linspace(-0.050,0.050,101) + diff
    fig = plot_evoked_topo([evocked,temp_evocked], color=colors, background_color='w',legend=False)
    fig.set_figheight(fig.get_figwidth()*1.5)
    fig.set_figwidth(fig.get_figheight()*1.5)
    if not one_spike:
        fig.suptitle('Template %s\nSpike time: avarage of %s event(s)\nGoodness of fit (average): %s'%(temp_n, evocked.nave, amplitude), x=0.42)
    else:
        fig.suptitle('Template %s\nSpike time: %ss\nGoodness of fit: %s'%(temp_n, time/1000, amplitude), x=0.42)
        temp.one_spike_name = "spike_%s_templates_topo_%s"%(time,temp_n)
    
    fname = "%s/spike_%s_templates_topo_%s.svg"%(svg_path,time,temp_n)
    fig.savefig(fname, format='svg', papertype='legal')
    
    temp.templates_topo_temp_n_path = fname
    
    
    fig.clf()
    plt.close(fig)

    
def plot_joint(evocked):
    import matplotlib.pyplot as plt
    import mne
    fig = mne.viz.plot_evoked_joint(evocked, times = np.array([-0.2, -0.1, -0.05, 0.0, 0.05, 0.1, 0.2]), show=False)
    #fig.set_figheight(fig.get_figwidth()*1.05)
    #fig.set_figwidth(fig.get_figheight()*1.05)
    fname = "%s%s_evoked_joint_grad.svg"%(temp.svg_path, temp.temp_n)
    plt.savefig(fname, format='svg', papertype='legal')

    temp.plot_evoked_joint_temp_n_grad_path = fname
    
    fig.clf()
    plt.close(fig)

def plot_epoch_image(epochs):
    import mne
    import matplotlib.pyplot as plt
    fig = epochs.plot_image()
    fname = "%s%s_epochs_image.svg"%(temp.svg_path, temp.temp_n)

    fig[0].savefig(fname, format='svg', papertype='legal')
    temp.plot_epochs_image_temp_n_path = fname
    
def plot_final_temp_n(temp, one_spike=False, system_type='Mac'):
    """ Plot all plots in one"""
    import svgutils.transform as sg
    from subprocess import Popen
    import os
    
    fig = sg.SVGFigure("22cm", "24cm")
    plot1 = sg.fromfile(temp.templates_topo_temp_n_path).getroot()
    #plot1 = sg.from_mpl(temp.templates_topo_temp_n_fig).getroot()
    plot1.moveto(0, 0)
    plot2 = sg.fromfile(temp.plot_evoked_joint_temp_n_grad_path).getroot()
    #plot2 = sg.from_mpl(temp.plot_evoked_joint_temp_n_grad_fig).getroot()
    plot2.moveto(30, 650)
    plot2.scale_xy(0.7, 0.7)
    
    plot3 = sg.fromfile(temp.plot_epochs_image_temp_n_path).getroot()
    plot3.moveto(450, 650)
    plot3.scale_xy(0.7, 0.7)
    
    l = []
    l.append(sg.TextElement(50,50, "A", size=12, weight="bold"))
    l.append(sg.TextElement(50,650, "B", size=12, weight="bold"))
    l.append(sg.TextElement(450,650, "C", size=12, weight="bold"))
    
    fig.append([plot1, plot2, plot3])
    fig.append(l)
    fig.save("%s%s_fig_final.svg"%(temp.svg_path, temp.temp_n))
    
    your_svg_input = "%s%s_fig_final.svg"%(temp.svg_path, temp.temp_n)
    
    if not one_spike:
        your_png_output = "%s%s_temp.png"%(temp.png_path, temp.temp_n)
    else:
        os.makedirs("{}{}_temp".format(temp.png_path, temp.temp_n),exist_ok=True)
        your_png_output = "{}{}_temp/{}.png".format(temp.png_path, temp.temp_n,temp.one_spike_name)

    if system_type == 'Mac':
        x = Popen(['/Applications/Inkscape.app/Contents/Resources/bin/inkscape', your_svg_input, \
                   '--export-png=%s' % your_png_output, '-w3000 -h4500', '-b white'])
    elif system_type == 'Windows':
        x = Popen(['C:/Program Files/Inkscape/inkscape', your_svg_input, \
                   '--export-png=%s' % your_png_output, '-w2000 -h3000', '-b white'])
    #try:
    #    Templates._waitForResponse(x)
    #except OSError:
    #    return False

## Alignment spikes

In [180]:
%%capture
import scipy.io
import mne
import numpy as np
import pandas as pd
block1 = scipy.io.loadmat('/Users/valery/MEG/Cases/B1C2/ASPIRE/results/\
sources_B1C2_ii_run1_raw_tsss_mc_art_corr_data_\
block001_ICA_based_grad.mat')

block2 = scipy.io.loadmat('/Users/valery/MEG/Cases/B1C2/ASPIRE/results/\
sources_B1C2_ii_run1_raw_tsss_mc_art_corr_data_\
block002_ICA_based_grad.mat')
block2['spikeind'][0,:] += 600_000
spikes = block1['spikeind'][0,:].tolist() + block2['spikeind'][0,:].tolist()

tsss_file = "/Users/valery/MEG/Cases/B1C2/art_corr/B1C2_ii_run1_raw_tsss_mc_art_corr.fif"
data = mne.io.read_raw_fif(tsss_file, preload=False, verbose=False)
#epochs = create_epochs(data, spikes, tmin=-0.2, tmax=0.2, picks_raw=True)

#aligned_spikes = [epochs[i].events.tolist()[0][0]+\
#                  epochs[i].load_data().pick_types(meg='grad').filter(9.,100.).average().get_peak()[1]\
#                  for i in range(len(epochs))]
aligned_spikes = [i+data.first_samp for i in spikes]
results = pd.DataFrame(data=zip(spikes,aligned_spikes),columns=["Spikes","Aligned_spikes"])
results.to_excel("/Users/valery/MEG/Cases/B1C2/results.xlsx",index=False)

## Run SPC

In [None]:
import run_circus_ASPIRE as run_circus
import scipy.io
import circus_templates_ASPIRE as circus_templates
import mne
import traceback
import shutil

dir_case = "/Users/valery/MEG/Cases/B1C2/"
dir_SPC = dir_case + "Spyking_circus/B1C2_B1C2_ii_run1_raw_tsss_mc_art_corr/"
SPC_fname = "B1C2_ii_run1_raw_tsss_mc_art_corr.fif"
nmpy_file = 'B1C2_ii_run1_raw_tsss_mc_art_corr_0.npy'
epochs_fname = dir_case +  "art_corr/B1C2_ii_run1_raw_tsss_mc_art_corr_epochs.fif"
n_t = 100

for cc_merge in [0.3]:
    for mad in [7.0]:
        try:
            #shutil.copy(dir_SPC +'full_case/' + nmpy_file, dir_SPC)
            sc = run_circus.Circus(dir_case, "B1C2", dir_SPC, SPC_fname, N_t=n_t, cc_merge=cc_merge, MAD=mad, cut_off=3)
            sc.params_iterations(run_spc=True, only_fitting = False, sensors=['grad'])
            
            #shutil.copy(dir_SPC +'only_epochs/' + nmpy_file, dir_SPC)
            #sc = run_circus.Circus(dir_case, "B1C2", dir_SPC, SPC_fname, N_t=n_t, cc_merge=0.97, MAD=mad, cut_off=9)
            #sc.params_iterations(run_spc=True, only_fitting = True, sensors=['grad'])

            for sens in ['grad']:
                temp = circus_templates.Templates(dir_case, "B1C2", SPC_fname, sc.sensors_params[sens], sensors=sens, n_sp=3, N_t=n_t, cc_merge=0.3, MAD=mad)
                #temp.plot_all_templates(epochs_fname, 'Mac')
        except Exception: traceback.print_exc()

        try:    
            ## Select fitted spikes from SPC that was fitted on the same peak
            import pandas as pd
            results = pd.read_excel("/Users/valery/MEG/Cases/B1C2/results.xlsx")

            for i in range(len(results)):
                around_spike = temp.templates[abs(temp.templates.Spiketimes-(results.Aligned_spikes[i]-39000))<100]
                if not around_spike.empty:
                    Amplitude =  find_nearest(around_spike.Amplitudes,1)
                    Spiketime = temp.templates[temp.templates.Amplitudes==Amplitude].Spiketimes.tolist()[0]
                    Template = temp.templates[temp.templates.Amplitudes==Amplitude].Template.tolist()[0]

                    results.loc[i,'Difference'] = Spiketime - (results.Aligned_spikes[i]-39000)
                    results.loc[i,'Template'] = temp.templates[temp.templates.Spiketimes==results.loc[i,'Difference']+(results.Aligned_spikes[i]-39000)].Template.tolist()[0]
                    results.loc[i,'Template_n'] = int(results.loc[i,'Template'].split('_')[1])
                    results.loc[i,'Amplitude'] = Amplitude

        except Exception: traceback.print_exc()
        try:
            ## Plot average & individual spikes
            import mne
            import numpy as np

            tsss_file = "/Users/valery/MEG/Cases/B1C2/art_corr/B1C2_ii_run1_raw_tsss_mc_art_corr.fif"
            data = mne.io.read_raw_fif(tsss_file, preload=False, verbose=False)

            godness_percentile = np.percentile(abs(results.dropna().Amplitude.to_numpy()-1),10) #10%
            best_results = results[abs(results.Amplitude-1)<godness_percentile]

            filter_value = (1.0,45.0)

            for temp_n in best_results.Template_n.unique().tolist():
                results_temp_n = best_results[best_results.Template_n == temp_n]
                aligned_spikes = (results_temp_n.Aligned_spikes + results_temp_n.Difference).tolist()

                epochs = create_epochs(data, aligned_spikes, tmin=-0.5, tmax=0.5, picks_raw=True,no_first_sample=True)

                temp.data_info = epochs.info
                temp.data_info_sens = epochs[0].load_data().pick_types(meg=temp.sensors, eeg=False,stim=False, eog=False).info
                temp.temp_n = int(temp_n)

                epochs.load_data().pick_types(meg=temp.sensors, eeg=False,stim=False, eog=False).filter(filter_value[0], filter_value[1], fir_design='firwin')
                temp_evocked = temp.templates_topo_temp_n(return_temp=True)

                ## Plot each spike
                for i in range(len(epochs)):
                    ep = epochs[i]
                    time = ep.events[0][0]
                    ev = ep.average()
                    amplitude = results_temp_n.loc[(results_temp_n.Aligned_spikes+results_temp_n.Difference).round()==time,'Amplitude'].values[0]
                    plot_epoch_image(epochs[i])
                    temp_evocked = temp.templates_topo_temp_n(return_temp=True)
                    plot_topomap(temp_evocked, ev, temp.svg_path, temp.temp_n, int(time), amplitude, one_spike=True)
                    plot_joint(ev)
                    plot_final_temp_n(temp, one_spike=True)

                ## Plot average
                temp_evocked = temp.templates_topo_temp_n(return_temp=True)
                file_name = 'average_{}_spikes_filt_{}'.format(len(results_temp_n),filter_value)
                plot_epoch_image(epochs)
                evocked = epochs.average()

                amplitude = results_temp_n.Amplitude.mean()   
                plot_topomap(temp_evocked, evocked, temp.svg_path, temp.temp_n, file_name, amplitude)
                plot_joint(evocked)

                plot_final_temp_n(temp)

            
            results.to_excel(temp.waveforms_path + '/results.xlsx')
        except Exception: traceback.print_exc()
            
            

## Select fitted spikes from SPC that was fitted on the same peak

In [None]:

import pandas as pd
results = pd.read_excel("/Users/valery/MEG/Cases/B1C2/results.xlsx")

for i in range(len(results)):
    around_spike = temp.templates[abs(temp.templates.Spiketimes-(results.Aligned_spikes[i]-39000))<100]
    if not around_spike.empty:
        Amplitude =  find_nearest(around_spike.Amplitudes,1)
        Spiketime = temp.templates[temp.templates.Amplitudes==Amplitude].Spiketimes.tolist()[0]
        Template = temp.templates[temp.templates.Amplitudes==Amplitude].Template.tolist()[0]

        results.loc[i,'Difference'] = Spiketime - (results.Aligned_spikes[i]-39000)
        results.loc[i,'Template'] = temp.templates[temp.templates.Spiketimes==results.loc[i,'Difference']+(results.Aligned_spikes[i]-39000)].Template.tolist()[0]
        results.loc[i,'Template_n'] = int(results.loc[i,'Template'].split('_')[1])
        results.loc[i,'Amplitude'] = Amplitude
        

## Plot average & individual spikes

In [38]:
%%capture
import mne
import numpy as np

godness_percentile = np.percentile(abs(results.dropna().Amplitude.to_numpy()-1),10) #10%
tsss_file = "/Users/valery/MEG/Cases/B1C2/art_corr/B1C2_ii_run1_raw_tsss_mc_art_corr.fif"
data = mne.io.read_raw_fif(tsss_file, preload=False, verbose=False)
best_results = results[abs(results.Amplitude-1)<godness_percentile]
filter_value = (1.0,45.0)

for temp_n in best_results.Template_n.unique().tolist():
    results_temp_n = best_results[best_results.Template_n == temp_n]
    aligned_spikes = (results_temp_n.Aligned_spikes + results_temp_n.Difference).tolist()
    
    epochs = create_epochs(data, aligned_spikes, tmin=-0.5, tmax=0.5, picks_raw=True,no_first_sample=True)
    
    temp.data_info = epochs.info
    temp.data_info_sens = epochs[0].load_data().pick_types(meg=temp.sensors, eeg=False,stim=False, eog=False).info
    temp.temp_n = int(temp_n)
    
    epochs.load_data().pick_types(meg=temp.sensors, eeg=False,stim=False, eog=False).filter(filter_value[0], filter_value[1], fir_design='firwin')
    temp_evocked = temp.templates_topo_temp_n(return_temp=True)
    
    ## Plot each spike
    for i in range(len(epochs)):
        ep = epochs[i]
        time = ep.events[0][0]
        ev = ep.average()
        amplitude = results_temp_n.loc[(results_temp_n.Aligned_spikes+results_temp_n.Difference).round()==time,'Amplitude'].values[0]
        plot_epoch_image(epochs[i])
        temp_evocked = temp.templates_topo_temp_n(return_temp=True)
        plot_topomap(temp_evocked, ev, temp.svg_path, temp.temp_n, int(time), amplitude, one_spike=True)
        plot_joint(ev)
        plot_final_temp_n(temp, one_spike=True)
        
    ## Plot average
    temp_evocked = temp.templates_topo_temp_n(return_temp=True)
    file_name = 'average_{}_spikes_filt_{}'.format(len(results_temp_n),filter_value)
    plot_epoch_image(epochs)
    evocked = epochs.average()
    
    amplitude = results_temp_n.Amplitude.mean()   
    plot_topomap(temp_evocked, evocked, temp.svg_path, temp.temp_n, file_name, amplitude)
    plot_joint(evocked)
    
    plot_final_temp_n(temp)

## Save results

In [15]:
%%capture
import scipy.io
import mne
import numpy as np
import pandas as pd
#T_spikes = scipy.io.loadmat('/Users/valery/MEG/Cases/B1C2/tommaso_visual_marking_B1C2_events_block001.mat')
results = pd.read_excel('/Users/valery/MEG/Cases/B1C2/Spyking_circus/B1C2_B1C2_ii_run1_raw_tsss_mc_art_corr/cut_off_7_spike_thresh_7.0_N_t_100_grad_(B1C2_ii_run1_raw_tsss_mc_art_corr_0)_cc_merge_0.3_sensitivity_3/Templates_waveforms/results.xlsx')
tsss_file = "/Users/valery/MEG/Cases/B1C2/art_corr/B1C2_ii_run1_raw_tsss_mc_art_corr.fif"
data = mne.io.read_raw_fif(tsss_file, preload=False, verbose=False)

In [16]:
T_spikes = [ 39.264,  40.731,  41.631,  72.827,  73.303,  74.078,  74.564,
         75.142,  75.99 ,  76.476,  76.827,  77.396,  87.175,  88.157,
         88.704,  89.139,  90.131, 101.719, 102.949, 103.881, 104.439,
        104.811, 105.266, 105.295, 106.754, 111.861, 111.902, 128.441,
        129.051, 129.505, 130.053, 130.446, 130.963, 137.589, 137.972,
        149.239, 156.04 , 157.973, 158.791, 169.406, 173.241, 173.614,
        217.463, 218.424, 220.336, 224.936, 232.244, 232.73 , 234.498,
        251.067, 251.646, 252.132, 252.535, 253.135, 253.527, 265.095,
        265.456, 266.077, 277.385, 278.243, 279.256, 290.038, 291.04 ,
        291.536, 295.806, 300.654, 301.512, 301.977, 324.822, 326.114,
        327.292, 333.566, 334.838, 335.676, 336.223, 341.133, 341.536,
        341.877, 393.707, 403.517, 422.66 , 460.111, 460.928, 463.481,
        463.791, 464.101, 470.314, 470.831, 480.185, 480.557, 483.959,
        486.315, 495.111, 495.536, 495.556, 503.143, 503.671, 504.146,
        504.663, 505.149, 505.614, 506.038, 542.096, 542.613, 543.894,
        616.532, 616.542, 619.282, 621.669]

In [21]:
godness_percentile = np.percentile(abs(results.dropna().Amplitude.to_numpy()-1),100) #10%
best_results = results[abs(results.Amplitude-1)<godness_percentile]
ICA_spikes = results.Aligned_spikes.tolist() 
V_spikes = (best_results.Aligned_spikes+best_results.Difference).tolist()
#scipy.io.savemat('/Users/valery/MEG/Cases/B1C2/SPC_B1C2_full.mat',{'SPC':V_spikes})

In [22]:
#print(V_spikes)

In [23]:
%%capture
%matplotlib qt5
%matplotlib qt5


plot_sensors = 'grad' # 'mag', 'grad' or True (all sensors)
window = 10 #in seconds
start = 988.0 #start point in seconds

tsss_file = "/Users/valery/MEG/Cases/B1C2/art_corr/B1C2_ii_run1_raw_tsss_mc_art_corr.fif"
data = mne.io.read_raw_fif(tsss_file, preload=False, verbose=False)

onset = [i for i in T_spikes]
data.annotations.append(onset, np.repeat(0.00001, len(onset)), ['T']*len(onset))

onset = [i/1000 for i in ICA_spikes]
data.annotations.append(onset, np.repeat(0.00001, len(onset)), ['I']*len(onset))

onset = [i/1000 for i in V_spikes]
labels = [str(i) for i in best_results.Template_n]
data.annotations.append(onset, np.repeat(0.00001, len(onset)), labels)

data.plot(duration=window, group_by='position', start=start, lowpass=70, highpass=1)



## Make ROC

In [1]:
import scipy.io
import mne
import numpy as np
import pandas as pd

path_res_ASPIRE = '/Users/valery/MEG/Cases/B1C2/ASPIRE/results/sources__befor_thershold_B1C2_ii_run1_raw_tsss_mc_art_corr_data_block00'
detections = {}
block1 = scipy.io.loadmat(path_res_ASPIRE + '1_ICA_based_grad.mat')
#block2 = scipy.io.loadmat(path_res_ASPIRE + '2_ICA_based_grad.mat')
#block2['spike_ind'][0,:] += 600_000
#detections['ICA'] = pd.DataFrame(data=np.array([block1['spike_ind'][0,:].tolist()+block2['spike_ind'][0,:].tolist(),
#                            block1['ValMax'][0,:].tolist()+block2['ValMax'][0,:].tolist()]).T, columns= ['spike_ind','ValMax'])
detections['ICA'] = pd.DataFrame(data=np.array([block1['spike_ind'][0,:].tolist(),block1['ValMax'][0,:].tolist()]).T, 
                                 columns= ['spike_ind','ValMax'])
detections['ICA'].spike_ind += 39000

block1 = scipy.io.loadmat(path_res_ASPIRE + '1_SpyCir_based_grad.mat')
#block2 = scipy.io.loadmat(path_res_ASPIRE + '2_SpyCir_based_grad.mat')
#block2['spike_ind'][:,0] += 600_000
#detections['SPC'] = pd.DataFrame(data=np.array([block1['spike_ind'][:,0].tolist()+block2['spike_ind'][:,0].tolist(),
#                            block1['ValMax'][0,:].tolist()+block2['ValMax'][0,:].tolist()]).T, columns= ['spike_ind','ValMax'])

detections['SPC'] = pd.DataFrame(data=np.array([block1['spike_ind'][:,0].tolist(),block1['ValMax'][0,:].tolist()]).T, 
                                 columns= ['spike_ind','ValMax'])
detections['SPC'].spike_ind += 39000
res_csv_spc = pd.read_csv('/Users/valery/MEG/Cases/B1C2/ASPIRE/detections/Templates_B1C2_ii_run1_raw_tsss_mc_grad.csv')
res_csv_spc.sort_values('Spiketimes', inplace=True)
res_csv_spc.Spiketimes += 39000
#detections['SPC']['Amplitudes'] = 0
for i in range(len(detections['SPC'].spike_ind)):
    detections['SPC'].loc[i,'Amplitudes'] = res_csv_spc[res_csv_spc.Spiketimes==detections['SPC'].spike_ind[i]].Amplitudes.values[0]

In [2]:
T_spikes = [ 39.264,  40.731,  41.631,  72.827,  73.303,  74.078,  74.564,
         75.142,  75.99 ,  76.476,  76.827,  77.396,  87.175,  88.157,
         88.704,  89.139,  90.131, 101.719, 102.949, 103.881, 104.439,
        104.811, 105.266, 105.295, 106.754, 111.861, 111.902, 128.441,
        129.051, 129.505, 130.053, 130.446, 130.963, 137.589, 137.972,
        149.239, 156.04 , 157.973, 158.791, 169.406, 173.241, 173.614,
        217.463, 218.424, 220.336, 224.936, 232.244, 232.73 , 234.498,
        251.067, 251.646, 252.132, 252.535, 253.135, 253.527, 265.095,
        265.456, 266.077, 277.385, 278.243, 279.256, 290.038, 291.04 ,
        291.536, 295.806, 300.654, 301.512, 301.977, 324.822, 326.114,
        327.292, 333.566, 334.838, 335.676, 336.223, 341.133, 341.536,
        341.877, 393.707, 403.517, 422.66 , 460.111, 460.928, 463.481,
        463.791, 464.101, 470.314, 470.831, 480.185, 480.557, 483.959,
        486.315, 495.111, 495.536, 495.556, 503.143, 503.671, 504.146,
        504.663, 505.149, 505.614, 506.038, 542.096, 542.613, 543.894,
        616.532, 616.542, 619.282, 621.669]

visual_track = pd.DataFrame(data=[i*1000 for i in T_spikes],columns=['Visual'])
#ROC_results

## Sensitivity&Precision

In [3]:
def find_nearest(array, value):
    import numpy as np
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]


def sensitivity_precision(visual_track, det, detections): 
    win = 100 #ms around detections
    
    for i in range(len(visual_track.Visual)):
        if abs(find_nearest(detections[det].spike_ind, visual_track.loc[i,'Visual']) - visual_track.loc[i,'Visual']) <= win:
            visual_track.loc[i,det] = 1
        else:
            visual_track.loc[i,det] = 0    
    
    TP = len(visual_track[det]) - len(visual_track[visual_track[det]==0.0][det])
    FN = len(visual_track[det]) - len(visual_track[visual_track[det]==1.0][det])
    
    FP = 0
    for i in range(len(detections[det].spike_ind)):
        if abs(find_nearest(visual_track['Visual'],detections[det].loc[i,'spike_ind']) - detections[det].loc[i,'spike_ind']) > win:
            FP += 1       
    if TP!=0:
        sensitivity = TP/(TP + FN)
        precision = TP/(TP + FP)
    else:
        sensitivity = 0
        precision = 0
    return sensitivity, precision

def apply_threshold(detections, det, thr_name):
    import numpy as np
    for percentile in range(0,100,5):
        if thr_name != 'Amplitudes':
            threshold = np.percentile(detections[det][thr_name].to_numpy(),percentile) #10%
            detections['{}_{}_{}'.format(det, thr_name, percentile)] = detections[det][detections[det][thr_name]>=threshold].reset_index(drop=True)
        else:
            threshold = np.percentile(-abs(1-detections[det][thr_name].to_numpy()), percentile)
            detections['{}_{}_{}'.format(det, thr_name, percentile)] = detections[det][-abs(1-detections[det][thr_name])>=threshold].reset_index(drop=True)

def plot_roc(visual_track, det_list, detections, thr_name):
    res_roc = pd.DataFrame()
    n = 0
    for det in det_list:
        for perc in range(0,100,5):
            sens, prec = sensitivity_precision(visual_track, '{}_{}_{}'.format(det, thr_name, perc), detections)
            res_roc.loc[n, 'Percentile'] = perc
            res_roc.loc[n, 'ROC type'] = 'Sensitivity '+det
            res_roc.loc[n, 'Sensitivity and Precision'] = sens
            res_roc.loc[n, 'N_spikes'] = len(detections['{}_{}_{}'.format(det, thr_name, perc)])
            n += 1
            res_roc.loc[n, 'Percentile'] = perc
            res_roc.loc[n, 'ROC type'] = 'Precision '+det
            res_roc.loc[n, 'Sensitivity and Precision'] = prec
            res_roc.loc[n, 'N_spikes'] = len(detections['{}_{}_{}'.format(det, thr_name, perc)])
            n += 1
    return res_roc


## Raw detections

In [83]:
sensitivity_precision(visual_track, 'SPC', detections)
#visual_track[(visual_track.ICA == 0)|(visual_track.SPC == 0)]

(1.0, 0.017151848937844216)

In [4]:
apply_threshold(detections, 'ICA', 'ValMax')
apply_threshold(detections, 'SPC', 'ValMax')
res_fitting_thesholds = plot_roc(visual_track, ['ICA','SPC'], detections, 'ValMax')

In [5]:
apply_threshold(detections, 'SPC', 'Amplitudes')
res_SPC_amplitudes = plot_roc(visual_track, ['SPC'], detections, 'Amplitudes')

In [6]:
n =0
sel_thr = '_Amplitudes_20'
detections['Overlap'] = pd.DataFrame()
for i in detections['ICA'].spike_ind:
    if not detections['SPC'+sel_thr][abs(detections['SPC'+sel_thr].spike_ind - i)<20].empty:
        ind_max = detections['SPC'+sel_thr][abs(detections['SPC'+sel_thr].spike_ind - i)<20].sort_values('Amplitudes').index[0]
        detections['Overlap'].loc[n,'spike_ind'] = detections['SPC'+sel_thr].loc[ind_max, 'spike_ind']
        detections['Overlap'].loc[n,'ValMax'] = detections['SPC'+sel_thr].loc[ind_max, 'ValMax']
        detections['Overlap'].loc[n,'Amplitudes'] = detections['SPC'+sel_thr].loc[ind_max, 'Amplitudes']
        n += 1

In [7]:
apply_threshold(detections, 'Overlap', 'ValMax')
res_Overlap_fitting_thesholds_best_spc = plot_roc(visual_track, ['Overlap'], detections, 'ValMax')
apply_threshold(detections, 'Overlap', 'Amplitudes')
res_Overlap_amplitudes_thesholds_best_spc = plot_roc(visual_track, ['Overlap'], detections, 'Amplitudes')

In [8]:
%matplotlib inline
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="darkgrid")

fig, axs = plt.subplots(2, 2,figsize=(13,13))
sns.lineplot(x="Percentile", y='Sensitivity and Precision', hue='ROC type', data=res_fitting_thesholds, ax=axs[0, 0])
axs[0, 0].set_title('Fitting threshold (raw detections)')


sns.lineplot(x="Percentile", y='Sensitivity and Precision', hue='ROC type', data=res_SPC_amplitudes, ax=axs[0, 1])
axs[0, 1].set_title('Amplitudes threshold (SPC detections)')

sns.lineplot(x="Percentile", y='Sensitivity and Precision', hue='ROC type', data=res_Overlap_fitting_thesholds_best_spc, ax=axs[1, 0])
axs[1, 0].set_title('Fitting threshold (overlap between ICA and SPC)\nonly the best 20% SPC detections were used')

sns.lineplot(x="Percentile", y='Sensitivity and Precision', hue='ROC type', data=res_Overlap_amplitudes_thesholds_best_spc, ax=axs[1, 1])
axs[1, 1].set_title('Amplitudes threshold (overlap between ICA and SPC)\nonly the best 20% SPC detections were used')

plt.suptitle('Sensitivity = TP/(TP + FN)\nPrecision = TP/(TP + FP)\nParameters for SPC: filter 9-100Hz, spike threshold 5MAD, template width 80ms, cc_merge = 0.9\nThe number of the manual spikes 109')
plt.savefig('/Users/valery/MEG/Cases/B1C2/ROC_thresholds.png', dpi=400)
plt.close()


In [9]:
%%capture
%matplotlib qt5
%matplotlib qt5
sns.set(style="white")

plot_sensors = 'grad' # 'mag', 'grad' or True (all sensors)
window = 10 #in seconds
start = 988.0 #start point in seconds

tsss_file = "/Users/valery/MEG/Cases/B1C2/art_corr/B1C2_ii_run1_raw_tsss_mc_art_corr.fif"
data = mne.io.read_raw_fif(tsss_file, preload=False, verbose=False)

onset = [i for i in T_spikes]
data.annotations.append(onset, np.repeat(0.00001, len(onset)), ['T']*len(onset))

onset = [i/1000 for i in detections['Overlap_ValMax_50'].spike_ind]
data.annotations.append(onset, np.repeat(0.00001, len(onset)), ['O']*len(onset))

#onset = [i/1000 for i in V_spikes]
#labels = [str(i) for i in best_results.Template_n]
#data.annotations.append(onset, np.repeat(0.00001, len(onset)), labels)

data.plot(duration=window, group_by='position', start=start, lowpass=70, highpass=1)

## 

In [424]:
np.percentile(-abs(1-detections['SPC']['Amplitudes'].to_numpy()),5) 

-0.6299416550000001

In [388]:
detections['ICA'][detections['ICA']['ValMax']<=0.9240047691429929]['ValMax']

50      0.918396
58      0.923617
61      0.923371
63      0.917635
64      0.919504
          ...   
2139    0.914062
2141    0.922433
2146    0.910942
2253    0.921052
2258    0.912924
Name: ValMax, Length: 114, dtype: float64

In [401]:
detections['ICA_ValMax_95']['ValMax'].min()

0.9731173930080393