In [1]:
import os
import numpy as np
import mne
import matplotlib.pyplot as plt
import matplotlib.backends.backend_pdf
import glob
from autoreject import AutoReject, get_rejection_threshold

In [5]:
subjs_all = glob.glob('./EEG/Revcor*.vhdr')

In [6]:
subjs_all

['./EEG\\Revcor0006.vhdr',
 './EEG\\Revcor0007.vhdr',
 './EEG\\Revcor0011.vhdr',
 './EEG\\Revcor0012.vhdr',
 './EEG\\Revcor0015.vhdr',
 './EEG\\Revcor0016.vhdr',
 './EEG\\Revcor0017.vhdr',
 './EEG\\Revcor0018.vhdr',
 './EEG\\Revcor0019.vhdr',
 './EEG\\Revcor0020.vhdr']

In [None]:
def get_raw(subj):
    raw = mne.io.read_raw_brainvision(subj, verbose=False)
    return raw

In [None]:
def rename_channels(raw, channel_names):
    channel_names_old = raw.ch_names
    channel_dict = dict(zip(channel_names_old, channel_names))
    mne.rename_channels(raw.info, mapping=channel_dict)

In [None]:
def make_montage(raw, montage):
    # Use the preloaded montage
    raw.set_montage(montage)
    raw.plot_sensors(show_names=True)

In [None]:
def get_events(raw):
    events_from_annot, event_dict = mne.events_from_annotations(raw)

    # Get indices of R11 events with event code 1011 or New Segment indices with event code 99999 and delete them
    useless_events = list(filter(lambda i: events_from_annot[:, 2][i] == 99999 or events_from_annot[:, 2][i] == 1011 or events_from_annot[:, 2][i] == 2 or events_from_annot[:, 2][i] == 6, range(len(events_from_annot[:, 2]))))
    events = np.delete(events_from_annot, useless_events, 0)

    for i in range(len(events[:, 2])):
        if events[:, 2][i] == 1004 or events[:, 2][i] == 1008:
            events[:, 2][i] = 1001
        elif  events[:, 2][i] == 1012:
            events[:, 2][i] = 1002
        elif events[:, 2][i] == 1024:
            events[:, 2][i] = 1003
        elif events[:, 2][i] == 1028:
            events[:, 2][i] = 1004

    return events

In [None]:
def create_stim_channel(raw, events):
    raw.load_data()
    stim_data = np.zeros((1, len(raw.times)))

    # Add stimulus channel in 'raw' object's info class
    info = mne.create_info(['STI'], raw.info['sfreq'], ['stim'])
    stim_raw = mne.io.RawArray(stim_data, info)
    raw.add_channels([stim_raw], force_update_info=True)

    # Add events extracted from annotations to the stimulus channel
    raw.add_events(events, stim_channel='STI')

In [None]:
def crop_data(raw, t_from, t_to):
    first_event_time = mne.find_events(raw)[0][0]
    last_event_time = mne.find_events(raw)[-1][0]
    print(first_event_time, ' ', last_event_time)

    part_to_remove_from_beginning = (first_event_time - abs(t_from*500))/1000
    part_to_remove_from_end = (last_event_time + abs(t_to*5000))/1000
    raw.crop(part_to_remove_from_beginning, part_to_remove_from_end)

In [None]:
def apply_filter(raw):
    # Soft bandpass Butterworth filter 
    iir_params = dict(  order=2, 
                        ftype='butter', 
                        output='sos'
                    )
    iir_params = mne.filter.construct_iir_filter(   iir_params, 
                                                    f_pass=[0.1, 30], 
                                                    f_stop=None, 
                                                    sfreq=1000, 
                                                    btype='bandpass', 
                                                    return_copy=False
                                                )
    raw.filter(0.1, 30, method='iir', iir_params=iir_params)

    # Notch filter
    raw.notch_filter(   freqs=np.arange(50, 251, 50), 
                        method='fir', 
                        fir_design='firwin2'
                    )

In [None]:
def add_ref_ch(raw, ref_channel):
    mne.add_reference_channels(raw, ref_channel, copy=False)

In [None]:
def re_reference(raw, ref_method):
    mne.set_eeg_reference(raw, ref_channels=ref_method, projection=True)
    raw.apply_proj()

In [None]:
def create_epochs(raw, events, epoch_limits):
    # Don't baseline correct before ICA
    epochs = mne.Epochs(raw, events, tmin=epoch_limits[0], tmax=epoch_limits[1], preload=True, baseline=None)
    
    return epochs

In [None]:
def run_autoreject(epochs):
    ar = AutoReject(n_interpolate=[1, 2, 3, 4], random_state=11, n_jobs=1, verbose=True)
    ar.fit(epochs) 
    epochs_ar, reject_log = ar.transform(epochs, return_log=True)

    return epochs_ar, reject_log

In [None]:
def run_ica(epochs, reject_log, eog_proxy):
    ica = mne.preprocessing.ICA(random_state=99)
    ica.fit(epochs[~reject_log.bad_epochs])

    # Find which ICs match the EOG pattern
    eog_indices, eog_scores = ica.find_bads_eog(epochs[~reject_log.bad_epochs], ch_name=eog_proxy)
    print(f'**************** Automatically found EOG artifact ICA components: {eog_indices} ****************')

    # # Find which ICs match the EMG pattern
    # muscle_idx_auto, scores = ica.find_bads_muscle(epochs[~reject_log.bad_epochs])
    # print(f'**************** Automatically found muscle artifact ICA components: {muscle_idx_auto} ****************')

    ica.exclude = eog_indices

    ica.plot_overlay(epochs.average(), exclude=ica.exclude)
    ica.apply(epochs, exclude=ica.exclude)

In [None]:
def baseline_correction(epochs, baseline):
    epochs.apply_baseline(baseline=baseline)

In [None]:
def save_epochs(epochs, subj):
    epochs.save('./analysis/'+subj[6:16]+'-epo.fif', overwrite=True)

In [None]:
def create_evokeds(epochs, subj):
    # Create Evoked object from epochs (an Evoked object contains the average data over all epochs)
    evoked_standard = epochs['1001'].average()
    evoked_neutral = epochs['1002'].average()
    evoked_rise = epochs['1003'].average()
    evoked_fall = epochs['1004'].average()

    mne.write_evokeds('./analysis/'+subj[17:19]+'-ave.fif', [evoked_standard, evoked_neutral, evoked_rise, evoked_fall], overwrite=True)

    evokeds = dict(standard=evoked_standard, neutral=evoked_neutral, rise=evoked_rise, fall=evoked_fall)

    return evokeds

In [None]:
def plot_erps(evokeds, subj, channels):
    # Create PDF file in which to save all plots
    with matplotlib.backends.backend_pdf.PdfPages('./analysis/'+subj[6:16]+'-plots.pdf') as pdf:
    
        for channel in channels:
            fig = mne.viz.plot_compare_evokeds(evokeds, picks=channel, combine=None, time_unit='ms', ylim=dict(eeg=[-10, 10]), invert_y=True,
                                            colors=dict(standard='black', neutral='red', rise='blue', fall='green'), 
                                            styles={'standard': {'linewidth': 1}, 'neutral': {'linewidth': 1}, 'rise': {'linewidth': 1}, 'fall': {'linewidth': 1}})
            # Save plot to PDF
            pdf.savefig(fig[0])
            plt.close()

In [None]:
data_dir = './eeg_data/rise'
channel_names = [
                    'Fp1','Fz','F3','F7','FT9','FC5','FC1','C3','T7','TP9','CP5','CP1','Pz','P3','P7','O1','Oz','O2','P4','P8','TP10','CP6',
                    'CP2','C4','T8','FT10','FC6','FC2','F4','F8','Fp2', 'AF7','AF3','AFz','F1','F5','FT7','FC3','C1','C5','TP7','CP3','P1','P5',
                    'PO7','PO3','POz','PO4','PO8','P6','P2','CPz','CP4','TP8','C6','C2','FC4','FT8','F6','AF8','AF4','F2','FCz', 'Cz'
                ]
montage = mne.channels.make_standard_montage('easycap-M1')
epoch_limits = [-0.1, 0.6]
baseline = (-0.1, 0)
ref_channel = 'Cz'
reref = 'average'
channels_to_vis = ['Fz', 'Pz', 'Oz', 'AFz', 'POz', 'CPz', 'FCz', 'Cz']
output_dir = './analysis/'

In [None]:
for subj in subjs_all:
    raw = get_raw(subj)

    rename_channels(raw, channel_names)
    
    make_montage(raw, montage)

    events = get_events(raw)
    create_stim_channel(raw, events)

    crop_data(raw, epoch_limits[0], epoch_limits[1])

    apply_filter(raw)

    add_ref_ch(raw, ref_channel)
    re_reference(raw, reref)

    make_montage(raw, montage)

    epochs = create_epochs(raw, events, epoch_limits)

    epochs_ar, reject_log = run_autoreject(epochs)
    run_ica(epochs, reject_log, 'Fp1')
    baseline_correction(epochs, baseline)
    epochs_ar = run_autoreject(epochs)

    save_epochs(epochs_ar, subj)

    evokeds = create_evokeds(epochs_ar, subj)
    
    plot_erps(evokeds, subj, )