In [None]:
import sys
sys.path.insert(0, '..')
from mne_bids import (BIDSPath,read_raw_bids)
import mne_bids
import importlib
import mne
import ccs_eeg_utils

#ccs_eeg_utils.download_erpcore(task="P3",subject=30,localpath="../local/bids/")

bids_root = "../local/bids"
bids_root = "/bigpool/export/users/ehinger/erp-core/bids" # Bene's Server location
bids_root = "/store/data/erp-core/"
subject_id = '030'


bids_path = BIDSPath(subject=subject_id,task="P3",
                     datatype='eeg', suffix='eeg',session="P3",
                     root=bids_root)

raw = read_raw_bids(bids_path)
ccs_eeg_utils.read_annotations_core(bids_path,raw)
raw.load_data()
raw.filter(0.5,50, fir_design='firwin')

**T:** Go through the dataset using the MNE explorer and clean it. You can use `raw.plot()` for this. If you are working from a jupyter notebook, try to use `%matplotlib qt` for better support of the cleaning window. To get an understanding how the tool works, press `help` or type `?` in the window. (Hint: You first have to add a new annotation by pressing `a`)

In [None]:
%matplotlib qt
raw.plot(n_channels=len(raw.ch_names))#,scalings =40e-6)
# See below

In [None]:
bad_ix = [i for i,a in enumerate(raw.annotations) if a['description']=="BAD_"]

#raw.annotations[bad_ix].save("sub-{}_task-P3_badannotations.csv".format(subject_id))


annotations = mne.read_annotations("sub-{}_task-P3_badannotations.csv".format(subject_id))
raw.annotations.append(annotations.onset,annotations.duration,annotations.description)



**T:** While going through the dataset, mark what you observe as bad electrodes. Those are saved in `raw.info['bads']`. The channels can be interpolated with `raw.interpolate_bads()` or `epoch.interpolate_bads()`. Compare the channel + neighbours before and after. Did the interpolation succeed? (If you are interested in the mathematical details of spline interpolation, checkout this https://mne.tools/dev/overview/implementation.html#id26)
Hint: You need channel locations to run the interpolation which you can get by using the default-standardized channel locations `raw.set_montage('standard_1020',match_case=False)`

In [None]:
# I thought there was not really a bad channel in this dataset, so I remove one at random.
raw.info['bads'] = ['FP2']
# for interpolation
raw.set_montage('standard_1020',match_case=False)
raw.interpolate_bads()

**T:** In the epoching step, we can also specify rejection criterion for a peak-to-peak rejection method

In [None]:
%matplotlib inline
import mne
evts,evts_dict = mne.events_from_annotations(raw)
wanted_keys = [e for e in evts_dict.keys() if "stimulus" in e]
evts_dict_stim=dict((k, evts_dict[k]) for k in wanted_keys if k in evts_dict)

# get epochs with and without rejection
epochs        = mne.Epochs(raw,evts,evts_dict_stim,tmin=-0.1,tmax=1,reject_by_annotation=False)
epochs_manual = mne.Epochs(raw,evts,evts_dict_stim,tmin=-0.1,tmax=1,reject_by_annotation=True)
reject_criteria = dict(eeg=200e-6,       # 100 µV # HAD TO INCREASE IT HERE, 100 was too harsh
                       eog=200e-6)       # 200 µV
epochs_thresh = mne.Epochs(raw,evts,evts_dict_stim,tmin=-0.1,tmax=1,reject=reject_criteria,reject_by_annotation=False)

#from matplotlib import pyplot as plt
# compare
#plt.plot([0,:])
mne.viz.plot_compare_evokeds({'raw':epochs.average(),'clean':epochs_manual.average(),'thresh':epochs_thresh.average()},picks="Cz")


## Bonus Tasks!

In [None]:
from autoreject import AutoReject
ar = AutoReject(verbose='tqdm')
epochs.load_data()
epochs_ar = ar.fit_transform(epochs)  

In [None]:
r = ar.get_reject_log(epochs_ar)

In [None]:
r.plot(orientation="horizontal");

In [None]:

mne.viz.plot_compare_evokeds({
    'raw':epochs.average(),
    'clean':epochs_manual.average(),
    'ar':epochs_ar.average()
    },picks="Cz")


In [None]:
from scipy.stats.mstats import winsorize
import numpy as np
def winsor(d):
    return np.mean(winsorize(d,axis=0,limits=(0.2,0.2)),axis=0)
def median(d):
    return np.median(d,axis=0)

mne.viz.plot_compare_evokeds({
    'clean':epochs_manual.average(),
    'robust':epochs.load_data().average(method=winsor),
    'median':epochs.load_data().average(method=median),
    },picks="Cz")

In [None]:
mne.viz.plot_compare_evokeds({
    'clean':epochs_manual.average(),
    'robust':epochs.average(method=winsor),
    'ar':epochs_ar.average()
    },picks="Cz")

In [None]:
%matplotlib inline
ylim = dict(eeg=(-20, 20))
epochs.average().plot(ylim=ylim, spatial_colors=True);
epochs_ar.average().plot(ylim=ylim, spatial_colors=True);
epochs_manual.average().plot(ylim=ylim, spatial_colors=True);