In [None]:
import mne
import sys

import numpy as np
from scipy import stats
import sails

sys.path.append('..')

from _parameters import *
sys.path.remove('..')

from mne.preprocessing import compute_current_source_density
from IPython.display import clear_output

In [None]:
def bad_epochs(epochs, metric):

    gesd_arg = {
        'alpha': 0.05,
        'p_out': 0.1,
        'outlier_side': 1
    }

    ch_index = mne.pick_types(epochs.info, eeg = True, exclude='bads')

    X = epochs.get_data(picks = ch_index)

    if metric == "std":
        metric_func = np.std
    elif metric == "var":
        metric_func = np.var
    else:
        metric_func = stats.kurtosis

    # Calculate the metric used to evaluate whether an epoch is bad
    X = metric_func(X, axis=-1)

    # Average over channels so we have a metric for each trial
    X = np.mean(X, axis=1)

    # Use gesd to find outliers
    bad_epochs, _ = sails.utils.gesd(X, **gesd_arg)
    print(f"From EEG - {np.sum(bad_epochs)}/{X.shape[0]} epochs rejected")

    # Drop bad epochs
    epochs.drop(bad_epochs)
    
    return epochs


In [None]:
def get_epochs(s, moment, event_id, tmin, tmax):

    # Load raw and ica
    raw_fname = dirs['raw'] + '/raw_s' + str(s) + '.fif'
    ica_fname = dirs['ica'] + '/ica_s' + str(s) + '.fif'

    raw = mne.io.read_raw_fif(raw_fname, preload = True)

    # Get events
    events = mne.find_events(raw, stim_channel = 'Status')
    events = mne.pick_events(events, include = triggers[moment])

    # Get epochs
    epochs = mne.Epochs(raw, events, event_id,
                        tmin = tmin, tmax = tmax, 
                        detrend = 1, preload = True,
                        baseline = None)
    
    # Drop trials
    epochs = bad_epochs(epochs, metric='var')
    
    # Surface laplacian
    epochs = compute_current_source_density(epochs)

    # Save
    epoch_fname = dirs['epoch'] + '/epoch_' + moment + '_s' + str(s) + '.fif'
    epochs.save(epoch_fname, overwrite = True)

    return epochs

In [None]:
# Run for all subs

for s in subjects:
    print('Running subject ' + str(s))
    epochs = get_epochs(s, 'enc1', event_id, -1, 4)
    clear_output(wait=False)

In [None]:
# Run individual

s = 1

epochs = get_epochs(s, 'enc1', event_id, -1, 4)