In [43]:
import os, sys, time
import numpy as np
import matplotlib.pyplot as plt

from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import ShuffleSplit, cross_val_score, train_test_split
from sklearn.metrics import accuracy_score
from sklearn.utils import shuffle

from mne import Epochs, pick_types, annotations_from_events, events_from_annotations, set_log_level, read_epochs
from mne.channels import make_standard_montage
from mne.io import concatenate_raws, read_raw_edf
from mne.datasets import eegbci
from mne.decoding import CSP 
from mne.viz import plot_events, plot_montage

import joblib

In [44]:
path = os.getenv('HOME') + '/goinfre'

experiments = [
    {
        "description": "open and close left or right fist",
        "runs": [3, 7, 11],
        "mapping": {0: "rest", 1: "left fist", 2: "right fist"},
    },
    {
        "description": "imagine opening and closing left or right fist",
        "runs": [4, 8, 12],
        "mapping": {0: "rest", 1: "imagine left fist", 2: "imagine right fist"},
    },
    {
        "description": "open and close both fists or both feet",
        "runs": [5, 9, 13],
        "mapping": {0: "rest", 1: "both fists", 2: "both feets"},
    },
    {
        "description": "imagine opening and closing both fists or both feet",
        "runs": [6, 10, 14],
        "mapping": {0: "rest", 1: "imagine both fists", 2: "imagine both feets"},
    },
    {
        "description": "movement (real or imagine) of fists",
        "runs": [3, 7, 11, 4, 8, 12],
        "mapping": {0: "rest", 1: "left fist", 2: "right fist"},
    },
    {
        "description": "movement (real or imagine) of both fists or both feet",
        "runs": [5, 9, 13, 6, 10, 14],
        "mapping": {0: "rest", 1: "both fists", 2: "both feets"},
    },
]

In [45]:
exp_set = 4
subject_nb = 100
experiment = experiments[exp_set]
tmin, tmax = -1.0, 4.0

subject_raws = []
raw_fnames = eegbci.load_data(subject_nb, experiment["runs"])
raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames])
events, _ = events_from_annotations(raw, event_id=dict(T1=1, T2=2))
annot_from_events = annotations_from_events(
    events=events, event_desc=experiment["mapping"], sfreq=raw.info["sfreq"]
)
raw.set_annotations(annot_from_events)

eegbci.standardize(raw)  # set channel names
montage = make_standard_montage("standard_1005")
raw.set_montage(montage)
channels = raw.info["ch_names"]
good_channels = ["FC5", "FC3", "FC1", "FCz", "FC2", "FC4", "FC6",
                          "C5",  "C3",  "C1",  "Cz",  "C2",  "C4",  "C6",
                         "CP5", "CP3", "CP1", "CPz", "CP2", "CP4", "CP6"]
bad_channels = [x for x in channels if x not in good_channels]
raw.drop_channels(bad_channels)

# Apply band-pass filter
raw.notch_filter(60, method="iir")
raw.filter(7.0, 32.0, fir_design="firwin", skip_by_annotation="edge")

# Read epochs
events, event_id = events_from_annotations(raw)
picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads")
epochs = Epochs(raw, events, event_id, tmin, tmax, proj=True, picks=picks, baseline=None, preload=True)
labels = epochs.events[:, -1]
epochs_train = epochs.copy().crop(tmin=1.0, tmax=4.0).get_data()
cv = ShuffleSplit(10, test_size=0.2)

Downloading file 'S100/S100R03.edf' from 'https://physionet.org/files/eegmmidb/1.0.0/S100/S100R03.edf' to '/mnt/nfs/homes/clorin/goinfre/MNE-eegbci-data/files/eegmmidb/1.0.0'.
Downloading file 'S100/S100R07.edf' from 'https://physionet.org/files/eegmmidb/1.0.0/S100/S100R07.edf' to '/mnt/nfs/homes/clorin/goinfre/MNE-eegbci-data/files/eegmmidb/1.0.0'.
Downloading file 'S100/S100R11.edf' from 'https://physionet.org/files/eegmmidb/1.0.0/S100/S100R11.edf' to '/mnt/nfs/homes/clorin/goinfre/MNE-eegbci-data/files/eegmmidb/1.0.0'.
Downloading file 'S100/S100R04.edf' from 'https://physionet.org/files/eegmmidb/1.0.0/S100/S100R04.edf' to '/mnt/nfs/homes/clorin/goinfre/MNE-eegbci-data/files/eegmmidb/1.0.0'.
Downloading file 'S100/S100R08.edf' from 'https://physionet.org/files/eegmmidb/1.0.0/S100/S100R08.edf' to '/mnt/nfs/homes/clorin/goinfre/MNE-eegbci-data/files/eegmmidb/1.0.0'.
Downloading file 'S100/S100R12.edf' from 'https://physionet.org/files/eegmmidb/1.0.0/S100/S100R12.edf' to '/mnt/nfs/home

Extracting EDF parameters from /mnt/nfs/homes/clorin/goinfre/MNE-eegbci-data/files/eegmmidb/1.0.0/S100/S100R03.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 15743  =      0.000 ...   122.992 secs...
Extracting EDF parameters from /mnt/nfs/homes/clorin/goinfre/MNE-eegbci-data/files/eegmmidb/1.0.0/S100/S100R07.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 15743  =      0.000 ...   122.992 secs...
Extracting EDF parameters from /mnt/nfs/homes/clorin/goinfre/MNE-eegbci-data/files/eegmmidb/1.0.0/S100/S100R11.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 15743  =      0.000 ...   122.992 secs...
Extracting EDF parameters from /mnt/nfs/homes/clorin/goinfre/MNE-eegbci-data/files/eegmmidb/1.0.0/S100/S100R04.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 15743  =      

  raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames])
  raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames])
  raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames])
  raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames])
  raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames])
  raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames])
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  21 out of  21 | elapsed:    0.0s finished


In [48]:
# Assemble a classifier
csp = CSP(6)
lda = LinearDiscriminantAnalysis()
clf = Pipeline([("CSP", csp), ("LDA", lda)])

# fit our pipeline to the experiment
X_train, X_test, y_train, y_test = train_test_split(epochs_train, labels, random_state=0)
clf.fit(X_train, y_train)



Computing rank from data with rank=None
    Using tolerance 3.3e-05 (2.2e-16 eps * 21 dim * 7.1e+09  max singular value)
    Estimated rank (mag): 21
    MAG: rank 21 computed from 21 data channels with 0 projectors
Reducing data rank from 21 -> 21
Estimating covariance using EMPIRICAL
Done.
Computing rank from data with rank=None
    Using tolerance 3.2e-05 (2.2e-16 eps * 21 dim * 6.8e+09  max singular value)
    Estimated rank (mag): 21
    MAG: rank 21 computed from 21 data channels with 0 projectors
Reducing data rank from 21 -> 21
Estimating covariance using EMPIRICAL
Done.


In [47]:
predictions = clf.predict(X_test)
print(f'epoch nb: [prediction] [truth] equal?')
for i, prediction in enumerate(predictions):
    print(f'epoch {i:02d}: [{prediction}] [{y_test[i]}] {prediction == y_test[i]}')
    time.sleep(0.05)

score_subject = accuracy_score(predictions, y_test)
print(f'mean accuracy for all experiments:{score_subject}')

epoch nb: [prediction] [truth] equal?
epoch 00: [1] [1] True
epoch 01: [2] [2] True
epoch 02: [1] [2] False
epoch 03: [1] [2] False
epoch 04: [2] [1] False
epoch 05: [2] [2] True
epoch 06: [2] [1] False
epoch 07: [1] [2] False
epoch 08: [1] [1] True
epoch 09: [2] [2] True
epoch 10: [1] [2] False
epoch 11: [1] [2] False
epoch 12: [1] [2] False
epoch 13: [1] [1] True
epoch 14: [1] [1] True
epoch 15: [2] [2] True
epoch 16: [1] [1] True
epoch 17: [1] [1] True
mean accuracy for all experiments:0.5555555555555556
