In [1]:
import mne
import numpy as np
import pandas as pd
import autoreject
import os


In [2]:
def clean_montage(montage_path):
    '''
    This function takes a path to a montage file and returns a montage object
    rescales the montage coordinates and takes care of any NaNs or missing channels
    '''
    montage_data = pd.read_csv(montage_path, sep = '\s+|\t+', header=0, engine='python')
    montage_data = montage_data.set_index('name', drop=True)
    montage_data.loc['Resp'] = [0,0,0]
    montage_data.replace(np.nan, 0, inplace=True)
    #montage_data.dropna(inplace=True)
    scale = 0.095  #scale based on head radius (value of 1 means equal to head radius)
    montage_data.loc[:, ['x','y','z']] *= scale #rescale coordinates
    mapping = montage_data.T.to_dict('list') #create a mapping from channel name to coordinates
    montage = mne.channels.make_dig_montage(ch_pos=mapping,coord_frame='head') #prepare montage from mapping
    return montage


In [7]:
def preprocess(raw_data_path, montage_path):
    '''
    utility function
    raw_data_path: path to subject's folder containing raw data
    montage_path: path to subject's folder containing montage data
    '''
    rd= mne.io.read_raw_eeglab(raw_data_path, preload=True,verbose=False) #load data
    acti_cap_mon=clean_montage(montage_path) #load montage
    rd.set_montage(acti_cap_mon, on_missing='ignore') #set montage
    #raw_plot=rd.plot()
    chans_to_remove = ['FT9','FT10','TP9','TP10']
    rd.drop_channels([chan for chan in chans_to_remove if chan in rd.ch_names]) #remove unreliable channels
    rd.drop_channels([chan for chan in rd.ch_names if chan not in acti_cap_mon.ch_names]) #remove channels not in montage
    rd.set_eeg_reference(ref_channels='average') #rereference to average
    new_sampling_freq = 256 #new sampling frequency
    rd.resample(new_sampling_freq)  #resample

    rd= rd.copy().filter(l_freq=0.1, h_freq=None) #highpass filter

    events,event_dict = mne.events_from_annotations(rd)  #{'S  1': 1, 'S  2': 2, 'S  3': 3, 'boundary': 4}
    epochs= mne.Epochs(rd, events, tmin=-1, tmax=2.5, event_id=event_dict, preload=True) #epoching

    epochs.info['bads'] = ['Cz'] #exlude Cz channel from ICA
    ica = mne.preprocessing.ICA(n_components=15, random_state=50, max_iter=800) #perform ICA
    ica.fit(epochs)  #fit ICA
    ica.apply(epochs) #apply ICA
    #after_ica=rd.plot()
    
    del rd #delete raw data to save memory; we only need epochs now
   
    ar= autoreject.AutoReject(n_interpolate=[1,2,3,4],random_state=11,n_jobs=1,verbose=True) #perform autoreject to remove bad epochs
    ar.fit(epochs[:10])
    epochs_arr, reject_log = ar.transform(epochs, return_log=True)
    epochs_arr.interpolate_bads() #interpolate bad channels exlcuding Cz
    try:
        epochs_arr.info['bads'].remove('Cz') #remove Cz channel from bad channels
    except:
        pass

    return epochs_arr

In [4]:
def get_subject_folders(eeg_path):
    '''
    utility function
    Returns a list of all subject folders in the given path
    '''
    folders = []
    for root, dirs, _ in os.walk(eeg_path):
        folders.extend([os.path.join(root, d) for d in dirs]) 
        break # only top level subfolders required 
    return folders

In [5]:
def batch_preprocess(input_data_path,output_data_path):
    '''
    input_data_path: path to the folder containing the subject folders.
    output_data_path: path to the folder where the preprocessed data will be saved.
    '''
    if not os.path.exists(output_data_path):
        os.makedirs(output_data_path)
    
    subjects = get_subject_folders(input_data_path)

    for subject in subjects:
        print("\npreproccessing subject: " + subject + "...")
        montage_path = os.path.join(subject,'eeg',subject[-7:]+'_task-Oddball_electrodes.tsv')
        out = preprocess(os.path.join(subject,'eeg',subject[-7:]+'_task-Oddball_eeg.set'),montage_path)
        out_path = os.path.join(output_data_path, subject[-7:])
        os.makedirs(out_path)
        out.save(os.path.join(out_path, subject[-7:] + "-epo.fif"))
        del out
        print("preproccessed subject: " + subject + "!\n")
