In [4]:
def create_epochs(raw_filt, detections, tmin=-1, tmax=1, picks_raw=True, no_first_sample=False):
    """ Here we create epochs for events"""
    import mne
    import numpy as np
    eve_id = 1
    eve_name = 'ICA_det'
    

    raw_filt.load_data()
    
    new_events, eve = [], []
    if no_first_sample:
        first_samp = 0
    else:
        first_samp = raw_filt.first_samp
    
    for spike_time in detections:
        eve = [int(round(spike_time + first_samp)), 0, eve_id]
        new_events.append(eve)
    
    ch_name = 'ICA'
    if ch_name not in raw_filt.info['ch_names']:
        stim_data = np.zeros((1, len(raw_filt.times)))
        info_sp = mne.create_info([ch_name], raw_filt.info['sfreq'], ['stim'])
        stim_sp = mne.io.RawArray(stim_data, info_sp, verbose=False)
        raw_filt.add_channels([stim_sp], force_update_info=True)

    raw_filt.add_events(new_events, stim_channel=ch_name, replace=True)
    events= mne.find_events(raw_filt, stim_channel=ch_name, verbose=False)
    event_id = {eve_name: eve_id}
    picks = mne.pick_types(raw_filt.info, meg=picks_raw, eeg=False, eog=False)
    epochs = mne.Epochs(raw_filt, events, event_id,  tmin, tmax, baseline=None, picks=picks, preload=True, verbose=False)
    return epochs
    del raw_filt, picks, event_id
    
def find_nearest(array, value):
    import numpy as np
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]

def plot_topomap(temp_evocked, evocked, svg_path, temp_n, time, amplitude, one_spike=False, diff=0):
    import matplotlib.pyplot as plt
    from mne.viz import plot_evoked_topo
    
    colors = 'blue','red'
    evocked.data = evocked.data*10**13
    temp_evocked.data = temp_evocked.data*10**4.5
    temp_evocked.times = np.linspace(-0.050,0.050,101) + diff
    fig = plot_evoked_topo([evocked,temp_evocked], color=colors, background_color='w',legend=False)
    fig.set_figheight(fig.get_figwidth()*1.5)
    fig.set_figwidth(fig.get_figheight()*1.5)
    if not one_spike:
        fig.suptitle('Template %s\nSpike time: avarage of %s event(s)\nGoodness of fit (average): %s'%(temp_n, evocked.nave, amplitude), x=0.42)
    else:
        fig.suptitle('Template %s\nSpike time: %ss\nGoodness of fit: %s'%(temp_n, time/1000, amplitude), x=0.42)
        temp.one_spike_name = "spike_%s_templates_topo_%s"%(time,temp_n)
    
    fname = "%s/spike_%s_templates_topo_%s.svg"%(svg_path,time,temp_n)
    fig.savefig(fname, format='svg', papertype='legal')
    
    temp.templates_topo_temp_n_path = fname
    
    
    fig.clf()
    plt.close(fig)

    
def plot_joint(evocked):
    import matplotlib.pyplot as plt
    import mne
    fig = mne.viz.plot_evoked_joint(evocked, times = np.array([-0.2, -0.1, -0.05, 0.0, 0.05, 0.1, 0.2]), show=False)
    #fig.set_figheight(fig.get_figwidth()*1.05)
    #fig.set_figwidth(fig.get_figheight()*1.05)
    fname = "%s%s_evoked_joint_grad.svg"%(temp.svg_path, temp.temp_n)
    plt.savefig(fname, format='svg', papertype='legal')

    temp.plot_evoked_joint_temp_n_grad_path = fname
    
    fig.clf()
    plt.close(fig)

def plot_epoch_image(epochs):
    import mne
    import matplotlib.pyplot as plt
    fig = epochs.plot_image()
    fname = "%s%s_epochs_image.svg"%(temp.svg_path, temp.temp_n)

    fig[0].savefig(fname, format='svg', papertype='legal')
    temp.plot_epochs_image_temp_n_path = fname
    
def plot_final_temp_n(temp, one_spike=False, system_type='Mac'):
    """ Plot all plots in one"""
    import svgutils.transform as sg
    from subprocess import Popen
    import os
    
    fig = sg.SVGFigure("22cm", "24cm")
    plot1 = sg.fromfile(temp.templates_topo_temp_n_path).getroot()
    #plot1 = sg.from_mpl(temp.templates_topo_temp_n_fig).getroot()
    plot1.moveto(0, 0)
    plot2 = sg.fromfile(temp.plot_evoked_joint_temp_n_grad_path).getroot()
    #plot2 = sg.from_mpl(temp.plot_evoked_joint_temp_n_grad_fig).getroot()
    plot2.moveto(30, 650)
    plot2.scale_xy(0.7, 0.7)
    
    plot3 = sg.fromfile(temp.plot_epochs_image_temp_n_path).getroot()
    plot3.moveto(450, 650)
    plot3.scale_xy(0.7, 0.7)
    
    l = []
    l.append(sg.TextElement(50,50, "A", size=12, weight="bold"))
    l.append(sg.TextElement(50,650, "B", size=12, weight="bold"))
    l.append(sg.TextElement(450,650, "C", size=12, weight="bold"))
    
    fig.append([plot1, plot2, plot3])
    fig.append(l)
    fig.save("%s%s_fig_final.svg"%(temp.svg_path, temp.temp_n))
    
    your_svg_input = "%s%s_fig_final.svg"%(temp.svg_path, temp.temp_n)
    
    if not one_spike:
        your_png_output = "%s%s_temp.png"%(temp.png_path, temp.temp_n)
    else:
        os.makedirs("{}{}_temp".format(temp.png_path, temp.temp_n),exist_ok=True)
        your_png_output = "{}{}_temp/{}.png".format(temp.png_path, temp.temp_n,temp.one_spike_name)

    if system_type == 'Mac':
        x = Popen(['/Applications/Inkscape.app/Contents/Resources/bin/inkscape', your_svg_input, \
                   '--export-png=%s' % your_png_output, '-w3000 -h4500', '-b white'])
    elif system_type == 'Windows':
        x = Popen(['C:/Program Files/Inkscape/inkscape', your_svg_input, \
                   '--export-png=%s' % your_png_output, '-w2000 -h3000', '-b white'])
    #try:
    #    Templates._waitForResponse(x)
    #except OSError:
    #    return False

## Alignment spikes

In [180]:
%%capture
import scipy.io
import mne
import numpy as np
import pandas as pd
block1 = scipy.io.loadmat('/Users/valery/MEG/Cases/B1C2/ASPIRE/results/\
sources_B1C2_ii_run1_raw_tsss_mc_art_corr_data_\
block001_ICA_based_grad.mat')

block2 = scipy.io.loadmat('/Users/valery/MEG/Cases/B1C2/ASPIRE/results/\
sources_B1C2_ii_run1_raw_tsss_mc_art_corr_data_\
block002_ICA_based_grad.mat')
block2['spikeind'][0,:] += 600_000
spikes = block1['spikeind'][0,:].tolist() + block2['spikeind'][0,:].tolist()

tsss_file = "/Users/valery/MEG/Cases/B1C2/art_corr/B1C2_ii_run1_raw_tsss_mc_art_corr.fif"
data = mne.io.read_raw_fif(tsss_file, preload=False, verbose=False)
#epochs = create_epochs(data, spikes, tmin=-0.2, tmax=0.2, picks_raw=True)

#aligned_spikes = [epochs[i].events.tolist()[0][0]+\
#                  epochs[i].load_data().pick_types(meg='grad').filter(9.,100.).average().get_peak()[1]\
#                  for i in range(len(epochs))]
aligned_spikes = [i+data.first_samp for i in spikes]
results = pd.DataFrame(data=zip(spikes,aligned_spikes),columns=["Spikes","Aligned_spikes"])
results.to_excel("/Users/valery/MEG/Cases/B1C2/results.xlsx",index=False)

## Run SPC

In [None]:
import run_circus_ASPIRE as run_circus
import scipy.io
import circus_templates_ASPIRE as circus_templates
import mne
import traceback
import shutil

dir_case = "/Users/valery/MEG/Cases/B1C2/"
dir_SPC = dir_case + "Spyking_circus/B1C2_B1C2_ii_run1_raw_tsss_mc_art_corr/"
SPC_fname = "B1C2_ii_run1_raw_tsss_mc_art_corr.fif"
nmpy_file = 'B1C2_ii_run1_raw_tsss_mc_art_corr_0.npy'
epochs_fname = dir_case +  "art_corr/B1C2_ii_run1_raw_tsss_mc_art_corr_epochs.fif"
n_t = 100

for cc_merge in [0.3]:
    for mad in [7.0]:
        try:
            #shutil.copy(dir_SPC +'full_case/' + nmpy_file, dir_SPC)
            sc = run_circus.Circus(dir_case, "B1C2", dir_SPC, SPC_fname, N_t=n_t, cc_merge=cc_merge, MAD=mad, cut_off=3)
            sc.params_iterations(run_spc=True, only_fitting = False, sensors=['grad'])
            
            #shutil.copy(dir_SPC +'only_epochs/' + nmpy_file, dir_SPC)
            #sc = run_circus.Circus(dir_case, "B1C2", dir_SPC, SPC_fname, N_t=n_t, cc_merge=0.97, MAD=mad, cut_off=9)
            #sc.params_iterations(run_spc=True, only_fitting = True, sensors=['grad'])

            for sens in ['grad']:
                temp = circus_templates.Templates(dir_case, "B1C2", SPC_fname, sc.sensors_params[sens], sensors=sens, n_sp=3, N_t=n_t, cc_merge=0.3, MAD=mad)
                #temp.plot_all_templates(epochs_fname, 'Mac')
        except Exception: traceback.print_exc()

        try:    
            ## Select fitted spikes from SPC that was fitted on the same peak
            import pandas as pd
            results = pd.read_excel("/Users/valery/MEG/Cases/B1C2/results.xlsx")

            for i in range(len(results)):
                around_spike = temp.templates[abs(temp.templates.Spiketimes-(results.Aligned_spikes[i]-39000))<100]
                if not around_spike.empty:
                    Amplitude =  find_nearest(around_spike.Amplitudes,1)
                    Spiketime = temp.templates[temp.templates.Amplitudes==Amplitude].Spiketimes.tolist()[0]
                    Template = temp.templates[temp.templates.Amplitudes==Amplitude].Template.tolist()[0]

                    results.loc[i,'Difference'] = Spiketime - (results.Aligned_spikes[i]-39000)
                    results.loc[i,'Template'] = temp.templates[temp.templates.Spiketimes==results.loc[i,'Difference']+(results.Aligned_spikes[i]-39000)].Template.tolist()[0]
                    results.loc[i,'Template_n'] = int(results.loc[i,'Template'].split('_')[1])
                    results.loc[i,'Amplitude'] = Amplitude

        except Exception: traceback.print_exc()
        try:
            ## Plot average & individual spikes
            import mne
            import numpy as np

            tsss_file = "/Users/valery/MEG/Cases/B1C2/art_corr/B1C2_ii_run1_raw_tsss_mc_art_corr.fif"
            data = mne.io.read_raw_fif(tsss_file, preload=False, verbose=False)

            godness_percentile = np.percentile(abs(results.dropna().Amplitude.to_numpy()-1),10) #10%
            best_results = results[abs(results.Amplitude-1)<godness_percentile]

            filter_value = (1.0,45.0)

            for temp_n in best_results.Template_n.unique().tolist():
                results_temp_n = best_results[best_results.Template_n == temp_n]
                aligned_spikes = (results_temp_n.Aligned_spikes + results_temp_n.Difference).tolist()

                epochs = create_epochs(data, aligned_spikes, tmin=-0.5, tmax=0.5, picks_raw=True,no_first_sample=True)

                temp.data_info = epochs.info
                temp.data_info_sens = epochs[0].load_data().pick_types(meg=temp.sensors, eeg=False,stim=False, eog=False).info
                temp.temp_n = int(temp_n)

                epochs.load_data().pick_types(meg=temp.sensors, eeg=False,stim=False, eog=False).filter(filter_value[0], filter_value[1], fir_design='firwin')
                temp_evocked = temp.templates_topo_temp_n(return_temp=True)

                ## Plot each spike
                for i in range(len(epochs)):
                    ep = epochs[i]
                    time = ep.events[0][0]
                    ev = ep.average()
                    amplitude = results_temp_n.loc[(results_temp_n.Aligned_spikes+results_temp_n.Difference).round()==time,'Amplitude'].values[0]
                    plot_epoch_image(epochs[i])
                    temp_evocked = temp.templates_topo_temp_n(return_temp=True)
                    plot_topomap(temp_evocked, ev, temp.svg_path, temp.temp_n, int(time), amplitude, one_spike=True)
                    plot_joint(ev)
                    plot_final_temp_n(temp, one_spike=True)

                ## Plot average
                temp_evocked = temp.templates_topo_temp_n(return_temp=True)
                file_name = 'average_{}_spikes_filt_{}'.format(len(results_temp_n),filter_value)
                plot_epoch_image(epochs)
                evocked = epochs.average()

                amplitude = results_temp_n.Amplitude.mean()   
                plot_topomap(temp_evocked, evocked, temp.svg_path, temp.temp_n, file_name, amplitude)
                plot_joint(evocked)

                plot_final_temp_n(temp)

            
            results.to_excel(temp.waveforms_path + '/results.xlsx')
        except Exception: traceback.print_exc()
            
            

## Select fitted spikes from SPC that was fitted on the same peak

In [None]:

import pandas as pd
results = pd.read_excel("/Users/valery/MEG/Cases/B1C2/results.xlsx")

for i in range(len(results)):
    around_spike = temp.templates[abs(temp.templates.Spiketimes-(results.Aligned_spikes[i]-39000))<100]
    if not around_spike.empty:
        Amplitude =  find_nearest(around_spike.Amplitudes,1)
        Spiketime = temp.templates[temp.templates.Amplitudes==Amplitude].Spiketimes.tolist()[0]
        Template = temp.templates[temp.templates.Amplitudes==Amplitude].Template.tolist()[0]

        results.loc[i,'Difference'] = Spiketime - (results.Aligned_spikes[i]-39000)
        results.loc[i,'Template'] = temp.templates[temp.templates.Spiketimes==results.loc[i,'Difference']+(results.Aligned_spikes[i]-39000)].Template.tolist()[0]
        results.loc[i,'Template_n'] = int(results.loc[i,'Template'].split('_')[1])
        results.loc[i,'Amplitude'] = Amplitude
        

## Plot average & individual spikes

In [38]:
%%capture
import mne
import numpy as np

godness_percentile = np.percentile(abs(results.dropna().Amplitude.to_numpy()-1),10) #10%
tsss_file = "/Users/valery/MEG/Cases/B1C2/art_corr/B1C2_ii_run1_raw_tsss_mc_art_corr.fif"
data = mne.io.read_raw_fif(tsss_file, preload=False, verbose=False)
best_results = results[abs(results.Amplitude-1)<godness_percentile]
filter_value = (1.0,45.0)

for temp_n in best_results.Template_n.unique().tolist():
    results_temp_n = best_results[best_results.Template_n == temp_n]
    aligned_spikes = (results_temp_n.Aligned_spikes + results_temp_n.Difference).tolist()
    
    epochs = create_epochs(data, aligned_spikes, tmin=-0.5, tmax=0.5, picks_raw=True,no_first_sample=True)
    
    temp.data_info = epochs.info
    temp.data_info_sens = epochs[0].load_data().pick_types(meg=temp.sensors, eeg=False,stim=False, eog=False).info
    temp.temp_n = int(temp_n)
    
    epochs.load_data().pick_types(meg=temp.sensors, eeg=False,stim=False, eog=False).filter(filter_value[0], filter_value[1], fir_design='firwin')
    temp_evocked = temp.templates_topo_temp_n(return_temp=True)
    
    ## Plot each spike
    for i in range(len(epochs)):
        ep = epochs[i]
        time = ep.events[0][0]
        ev = ep.average()
        amplitude = results_temp_n.loc[(results_temp_n.Aligned_spikes+results_temp_n.Difference).round()==time,'Amplitude'].values[0]
        plot_epoch_image(epochs[i])
        temp_evocked = temp.templates_topo_temp_n(return_temp=True)
        plot_topomap(temp_evocked, ev, temp.svg_path, temp.temp_n, int(time), amplitude, one_spike=True)
        plot_joint(ev)
        plot_final_temp_n(temp, one_spike=True)
        
    ## Plot average
    temp_evocked = temp.templates_topo_temp_n(return_temp=True)
    file_name = 'average_{}_spikes_filt_{}'.format(len(results_temp_n),filter_value)
    plot_epoch_image(epochs)
    evocked = epochs.average()
    
    amplitude = results_temp_n.Amplitude.mean()   
    plot_topomap(temp_evocked, evocked, temp.svg_path, temp.temp_n, file_name, amplitude)
    plot_joint(evocked)
    
    plot_final_temp_n(temp)