In [None]:
import os
import numpy as np
import mne
import matplotlib.pyplot as plt
import matplotlib.backends.backend_pdf
import glob
from autoreject import AutoReject
from collections import Counter
import pandas as pd

#### Read Raw Data

In [None]:
# Read BrainVision EEG data into 'raw' object
subj = 'looming002'
raw = mne.io.read_raw_brainvision(f'./EEG/{subj}.vhdr', verbose=False)
raw

#### Rename Channels

In [None]:
# Replace all channels with proper channel names instead of numbers
channel_names_old = raw.ch_names
channel_names_new = ['Fp1','Fz','F3','F7','FT9','FC5','FC1','C3','T7','TP9','CP5','CP1','Pz','P3','P7','O1','Oz','O2','P4','P8','TP10','CP6',
                        'CP2','C4','T8','FT10','FC6','FC2','F4','F8','Fp2', 'AF7','AF3','AFz','F1','F5','FT7','FC3','C1','C5','TP7','CP3','P1','P5',
                        'PO7','PO3','POz','PO4','PO8','P6','P2','CPz','CP4','TP8','C6','C2','FC4','FT8','F6','AF8','AF4','F2','FCz', 'Cz']
channel_dict = dict(zip(channel_names_old, channel_names_new))
mne.rename_channels(raw.info, mapping=channel_dict)

#### Set Montage 

In [None]:
easycap_montage = mne.channels.make_standard_montage('easycap-M1')

# Use the preloaded montage
raw.set_montage(easycap_montage)
fig = raw.plot_sensors(show_names=True)

raw.info

#### Get Events

In [None]:
# Function to fix events
def fix_events(events):

    # Get indices of R11 events with event code 1011 or New Segment indices with event code 99999 and delete them
    useless_events = list(filter(lambda i: events[:, 2][i] == 99999 or events[:, 2][i] == 1011, range(len(events[:, 2]))))
    events = np.delete(events, useless_events, 0)
    
    # If events start with an S marker, delete it
    if events[0][2] == 2:
        events = np.delete(events, 0, 0)

    # Get indices of consective equal events and keep only the 1st one and delete the others
    consecutive_equal_events = list(filter(lambda i: events[:, 2][i] == 2 and events[:, 2][i+1] == 2, range(len(events[:, 2])-1)))
    consecutive_equal_events = [index+1 for index in consecutive_equal_events]
    events = np.delete(events, consecutive_equal_events, 0)

    # Get indices of R markers
    r_marker_indices = list(filter(lambda i: events[:, 2][i] == 1001 or events[:, 2][i] == 1002 or events[:, 2][i] == 1003, range(len(events[:, 2]))))
    # Get indices of S markers
    s_marker_indices = list(filter(lambda i: events[:, 2][i] == 2, range(len(events[:, 2]))))

    # Delete all S markers to get only R events
    r_events = np.delete(events, s_marker_indices, 0)

    return events, r_events, r_marker_indices, s_marker_indices

In [None]:
# Get events from annotations in raw
events_from_annot, event_dict = mne.events_from_annotations(raw)
events, r_events, r_marker_indices, s_marker_indices = fix_events(events_from_annot)

In [None]:
# Get indices of flat deviant (R4) markers
df = pd.read_csv('./new_csv/new_csv/results_subj2_condition_looming_220404_11.06.csv')
print(df['stim_marker_code'].value_counts())

r4_indices = np.where(df['stim_marker_code'] == 4)

for idx in r4_indices[0]:
    r_events[idx][2] = 1004

print(Counter(r_events[:, 2]))
print('Number of total markers after cleaning: ', len(r_events))

In [None]:
# Change original events array to include R4 markers
for i in range(len(r_events)):
    events[r_marker_indices[i]] = r_events[i]

In [None]:
# Change S marker label to take label from R markers
for i in range(len(events)-1):
    if events[i][2] != 2 and events[i+1][2] == 2:
        events[i+1][2] = events[i][2]

events = np.delete(events, r_marker_indices, 0)
print(len(events))

In [None]:
print(Counter(events[:, 2]))

#### Create Events Channel in Raw Data

In [None]:
# Create event channel
raw.load_data()
stim_data = np.zeros((1, len(raw.times)))

# Add stimulus channel in 'raw' object's info class
info = mne.create_info(['STI'], raw.info['sfreq'], ['stim'])
stim_raw = mne.io.RawArray(stim_data, info)
raw.add_channels([stim_raw], force_update_info=True)

# Add events extracted from annotations to the stimulus channel
raw.add_events(events, stim_channel='STI')
print(raw.info)

#### Filtering

In [None]:
# Soft bandpass Butterworth filter 
iir_params = dict(  order=2, 
                    ftype='butter', 
                    output='sos'
                )
iir_params = mne.filter.construct_iir_filter(   iir_params, 
                                                f_pass=[0.1, 30], 
                                                f_stop=None, 
                                                sfreq=1000, 
                                                btype='bandpass', 
                                                return_copy=False
                                            )
raw.filter(0.1, 30, method='iir', iir_params=iir_params)

# Notch filter
raw.notch_filter(   freqs=np.arange(50, 251, 50), 
                    method='fir', 
                    fir_design='firwin2'
                )

#### Add Reference Channel (Cz)

In [None]:
mne.add_reference_channels(raw, 'Cz', copy=False)

#### Re-reference electrodes

In [None]:
mne.set_eeg_reference(raw, ref_channels='average', projection=True)
raw.apply_proj()

#### Reset Montage after adding Cz

In [None]:
raw.set_montage(easycap_montage)

#### Epoching

In [None]:
# Don't baseline correct before ICA
epochs = mne.Epochs(raw, events, tmin=-0.1, tmax=0.7, preload=True, baseline=None, reject=None)

#### Autoreject on All Epochs

In [None]:
# ar = AutoReject(n_interpolate=[1, 2, 3, 32], random_state=4, n_jobs=3, verbose=True)
# ar.fit(epochs) 
# epochs_ar, reject_log = ar.transform(epochs, return_log=True)

from autoreject import get_rejection_threshold
reject = get_rejection_threshold(epochs)
print(reject)

epochs_clean = mne.Epochs(raw, events, tmin=-0.1, tmax=0.7, preload=True, baseline=None, reject=reject)

print(f'Number of epochs before rejecton: {len(epochs)}\n Number of epochs after rejection: {len(epochs_clean)}')

#### ICA

In [None]:
ica = mne.preprocessing.ICA(random_state=99, verbose=False)
ica.fit(epochs_clean)

# Find which ICs match the EOG pattern
# Since we don't have EOG we use Fp1 as a template (because it has most eyeblink artifacts)
eog_indices, eog_scores = ica.find_bads_eog(epochs_clean, ch_name='Fp1')
print(f'**************** Automatically found EOG artifact ICA components: {eog_indices} ****************')

# # Find which ICs match the EMG pattern (there's a known bug in find_bads_muscle())
# muscle_idx_auto, scores = ica.find_bads_muscle(epochs[~reject_log.bad_epochs])
# print(f'**************** Automatically found muscle artifact ICA components: {muscle_idx_auto} ****************')

ica.exclude = eog_indices

ica.plot_overlay(epochs_clean.average(), exclude=ica.exclude)
ica.apply(epochs_clean, exclude=ica.exclude)

#### Baseline Correction

In [None]:
epochs_clean.apply_baseline(baseline=(-0.1, 0), verbose=False)

#### Autoreject on All Epochs

In [None]:
# ar = AutoReject(n_interpolate=[1, 2, 3, 32], random_state=11, n_jobs=3, verbose=True)
# ar.fit(epochs) 
# epochs_ar, reject_log = ar.transform(epochs, return_log=True)

#### Save epochs to file

In [None]:
epochs_clean.save(f'./analysis_global_autoreject/{subj}-epo.fif', overwrite=True)

#### Create Evoked objects and save to file

In [None]:
# Create Evoked object from epochs (an Evoked object contains the average data over all epochs)
evoked_standard = epochs_clean['1001'].average()
evoked_looming = epochs_clean['1002'].average()
evoked_receding = epochs_clean['1003'].average()
evoked_deviant = epochs_clean['1004'].average()

mne.write_evokeds('./analysis_global_autoreject/'+subj+'-ave.fif', [evoked_standard, evoked_looming, evoked_receding, evoked_deviant], overwrite=True)

evokeds = dict(standard=evoked_standard, looming=evoked_looming, receding=evoked_receding, deviant=evoked_deviant)

#### Plot ERPs

In [None]:
def plot_channel_by_condition(channels=[]):
    # Create PDF file in which to save all plots
    with matplotlib.backends.backend_pdf.PdfPages('./analysis_global_autoreject/'+subj+'-plots.pdf') as pdf:
    
        for channel in channels:
            fig = mne.viz.plot_compare_evokeds(evokeds, picks=channel, combine=None, time_unit='ms', ylim=dict(eeg=[-10, 10]), invert_y=True,
                                            colors=dict(standard='black', looming='blue', receding='green', deviant='red'), 
                                            styles={'standard': {'linewidth': 1}, 'looming': {'linewidth': 1}, 'receding': {'linewidth': 1}, 'deviant': {'linewidth': 1}})
            # Save plot to PDF
            pdf.savefig(fig[0])
            plt.close()

# Plot channels Fz, Pz, Oz, AFz, POz, CPz, FCz, Cz
plot_channel_by_condition(channels=['Fz', 'Pz', 'Oz', 'AFz', 'POz', 'CPz', 'FCz', 'Cz'])