In [None]:
import os
import numpy as np
import mne
import matplotlib.pyplot as plt
import matplotlib.backends.backend_pdf
import glob
from autoreject import AutoReject
from collections import Counter

#### Read Raw Data

In [None]:
# Read BrainVision EEG data into 'raw' object
subj = 'Revcor0006'
raw = mne.io.read_raw_brainvision(f'./EEG/{subj}.vhdr', verbose=False)
raw

#### Rename Channels

In [None]:
# Replace all channels with proper channel names instead of numbers
channel_names_old = raw.ch_names
channel_names_new = ['Fp1','Fz','F3','F7','FT9','FC5','FC1','C3','T7','TP9','CP5','CP1','Pz','P3','P7','O1','Oz','O2','P4','P8','TP10','CP6',
                        'CP2','C4','T8','FT10','FC6','FC2','F4','F8','Fp2', 'AF7','AF3','AFz','F1','F5','FT7','FC3','C1','C5','TP7','CP3','P1','P5',
                        'PO7','PO3','POz','PO4','PO8','P6','P2','CPz','CP4','TP8','C6','C2','FC4','FT8','F6','AF8','AF4','F2','FCz', 'Cz']
channel_dict = dict(zip(channel_names_old, channel_names_new))
mne.rename_channels(raw.info, mapping=channel_dict)

#### Set Montage 

In [None]:
easycap_montage = mne.channels.make_standard_montage('easycap-M1')

# Use the preloaded montage
raw.set_montage(easycap_montage)
fig = raw.plot_sensors(show_names=True)

raw.info

#### Get Events

In [None]:
events_from_annot, event_dict = mne.events_from_annotations(raw)
Counter(events_from_annot[:, 2])

In [None]:
# Delete New Segment (99999) marker
events_from_annot = np.delete(events_from_annot, [0], 0)

In [None]:
print('Total number of markers: ', len(events_from_annot))

In [None]:
# Remove S2 markers that are repeated & take only the first one
s_event_indices = np.where(events_from_annot[:, 2] == 2)

indices_to_remove = []
for i in range(len(s_event_indices[0])-1):
    if s_event_indices[0][i] - s_event_indices[0][i+1] == -1:
        indices_to_remove.append(s_event_indices[0][i+1])

events = np.delete(events_from_annot, indices_to_remove, 0)

In [None]:
print('Number of S2 markers: ', len(s_event_indices[0]))
print('Number of total markers: ', len(events))

In [None]:
events = np.delete(events, [0], 0)

# Split events array into 3000 arrays of 6 elements each
marker_blocks = np.split(events, np.where(events[:, 2] == 2)[0][0:]+1)
del marker_blocks[-1]
print(marker_blocks)
print('Number of marker blocks: ', len(marker_blocks))

In [None]:
marker_dict = {
                1015: '0', 1001: '1', 1002: '2', 1003: '3', 1004: '4', 1005: '5', 
                1006: '6', 1007: '7', 1008: '8', 1009: '9', 1012: 'std_', 1013: 'dev_'
            }

In [None]:
markers = []

for mb in marker_blocks:
    marker_id_str = ''

    # Remove 1011, 1012 and 1013 markers from each marker block
    useless_markers = list(filter(lambda i: mb[:, 2][i] == 1011 or mb[:, 2][i] == 1012 or mb[:, 2][i] == 1013, range(len(mb))))
    mb = np.delete(mb, useless_markers, 0)

    # Get marker id digits from markers preceding S2
    for i in range(len(mb)-1):
        marker_id_str = marker_id_str + marker_dict.get(mb[:, 2][i])
        
    # Replace 2 with the proper marker id
    mb[:, 2][-1] = marker_id_str
    markers.append(mb[-1])

In [None]:
print('Number of total markers after cleaning: ', len(markers))

#### Create Events Channel in Raw Data

In [None]:
# Create event channel
raw.load_data()
stim_data = np.zeros((1, len(raw.times)))

# Add stimulus channel in 'raw' object's info class
info = mne.create_info(['STI'], raw.info['sfreq'], ['stim'])
stim_raw = mne.io.RawArray(stim_data, info)
raw.add_channels([stim_raw], force_update_info=True)

# Add events extracted from annotations to the stimulus channel
raw.add_events(markers, stim_channel='STI')
print(raw.info)

#### Crop Raw Data to Remove Useless Parts

In [None]:
# Get event times
first_event_time = mne.find_events(raw)[0][0]
last_event_time = mne.find_events(raw)[-1][0]
print(first_event_time, ' ', last_event_time)

# Remove useless parts of raw EEG
epoch_limits = [-0.1, 0.6]
part_to_remove_from_beginning = (first_event_time - abs(epoch_limits[0]*500))/1000
part_to_remove_from_end = (last_event_time + abs(epoch_limits[1]*5000))/1000
raw.crop(part_to_remove_from_beginning, part_to_remove_from_end)

#### Filtering

In [None]:
# Soft bandpass Butterworth filter 
iir_params = dict(  order=2, 
                    ftype='butter', 
                    output='sos'
                )
iir_params = mne.filter.construct_iir_filter(   iir_params, 
                                                f_pass=[0.1, 30], 
                                                f_stop=None, 
                                                sfreq=1000, 
                                                btype='bandpass', 
                                                return_copy=False
                                            )
raw.filter(0.1, 30, method='iir', iir_params=iir_params)

# Notch filter
raw.notch_filter(   freqs=np.arange(50, 251, 50), 
                    method='fir', 
                    fir_design='firwin2'
                )

#### Add Reference Channel (Cz)

In [None]:
mne.add_reference_channels(raw, 'Cz', copy=False)

#### Re-reference electrodes

In [None]:
mne.set_eeg_reference(raw, ref_channels='average', projection=True)
raw.apply_proj()

#### Reset Montage after adding Cz

In [None]:
raw.set_montage(easycap_montage)

#### Epoching

In [None]:
# Don't baseline correct before ICA
epochs = mne.Epochs(raw, markers, tmin=-0.1, tmax=0.6, preload=True, baseline=None)

#### Autoreject on All Epochs

In [None]:
ar = AutoReject(n_interpolate=[1, 2, 3, 4], random_state=11, n_jobs=1, verbose=True)
ar.fit(epochs) 
epochs_ar, reject_log = ar.transform(epochs, return_log=True)

#### ICA

In [None]:
ica = mne.preprocessing.ICA(random_state=99, verbose=False)
ica.fit(epochs[~reject_log.bad_epochs])

# Find which ICs match the EOG pattern
eog_indices, eog_scores = ica.find_bads_eog(epochs[~reject_log.bad_epochs], ch_name='Fp1')
print(f'**************** Automatically found EOG artifact ICA components: {eog_indices} ****************')

# # Find which ICs match the EMG pattern
# muscle_idx_auto, scores = ica.find_bads_muscle(epochs[~reject_log.bad_epochs])
# print(f'**************** Automatically found muscle artifact ICA components: {muscle_idx_auto} ****************')

ica.exclude = eog_indices

ica.plot_overlay(epochs.average(), exclude=ica.exclude)
ica.apply(epochs, exclude=ica.exclude)

#### Baseline Correction

In [None]:
epochs.apply_baseline(baseline=(-0.1, 0), verbose=False)

#### Autoreject on All Epochs

In [None]:
ar = AutoReject(n_interpolate=[1, 2, 3, 4], random_state=11, n_jobs=1, verbose=True)
ar.fit(epochs) 
epochs_ar, reject_log = ar.transform(epochs, return_log=True)

#### Save epochs to file

In [None]:
epochs_ar.save(f'./analysis/{subj}-epo.fif', overwrite=True)

#### Create Evoked objects and save to file

In [None]:
# Create Evoked object from epochs (an Evoked object contains the average data over all epochs)
evoked_standard = epochs_ar['1001'].average()
evoked_neutral = epochs_ar['1002'].average()
evoked_rise = epochs_ar['1003'].average()
evoked_fall = epochs_ar['1004'].average()

mne.write_evokeds('./analysis/'+subj+'-ave.fif', [evoked_standard, evoked_neutral, evoked_rise, evoked_fall], overwrite=True)

evokeds = dict(standard=evoked_standard, neutral=evoked_neutral, rise=evoked_rise, fall=evoked_fall)

#### Plot ERPs

In [None]:
def plot_channel_by_condition(channels=[]):
    # Create PDF file in which to save all plots
    with matplotlib.backends.backend_pdf.PdfPages('./analysis/'+subj+'-plots.pdf') as pdf:
    
        for channel in channels:
            fig = mne.viz.plot_compare_evokeds(evokeds, picks=channel, combine=None, time_unit='ms', ylim=dict(eeg=[-10, 10]), invert_y=True,
                                            colors=dict(standard='black', neutral='red', rise='blue', fall='green'), 
                                            styles={'standard': {'linewidth': 1}, 'neutral': {'linewidth': 1}, 'rise': {'linewidth': 1}, 'fall': {'linewidth': 1}})
            # Save plot to PDF
            pdf.savefig(fig[0])
            plt.close()

# Plot channels Fz, Pz, Oz, AFz, POz, CPz, FCz, Cz
plot_channel_by_condition(channels=['Fz', 'Pz', 'Oz', 'AFz', 'POz', 'CPz', 'FCz', 'Cz'])