In [1]:
import glob
import matplotlib.pyplot as plt
import mne
import numpy as np
import pyxdf
import pickle
from joblib import dump, load

In [2]:
baseline_list =glob.glob('data/*/ses-S001/*/*.xdf')
experiment_list = glob.glob('data/*/ses-S002/*/*.xdf')
sfreq = 250
info = mne.create_info(8, sfreq, ["eeg"] * 8)

In [3]:
def read_trails(current_experiment_list: list , baseline_list:list) -> dict:
    baseline_mne_list = []
    for baseline in baseline_list:
        streams, _ = pyxdf.load_xdf(baseline,verbose=False)
        try:
            data = streams[1]["time_series"].T[:8]
        except AttributeError:
            data = streams[0]["time_series"].T[:8]
        raw = mne.io.RawArray(data, info, verbose=False)
        raw = raw.crop(tmin=1, tmax=41)
        baseline_mne_list.append(raw)
    trials = { 
            'ball': list(),
            'ball+number': list(),
            'ball+number+wheel': list(),
            'baseline': baseline_mne_list,
            }

    for experiment in current_experiment_list:
        print("processing: ", experiment)
        streams, _ = pyxdf.load_xdf(experiment)
        marker = -1 
        eeg = -1
        for idx, stream in enumerate(streams):
            if streams[idx]['time_stamps'].shape[0] == 0:
                # ignore empty streams
                continue
            if stream["info"]["name"] == ["markers"]:
                marker = idx
            elif stream["info"]["name"] == ["eeg"]:
                eeg = idx
            else:
                raise Exception(f"stream info name unknown {stream['info']['name']}")

        if marker == -1 or eeg == -1:
            # eeg or marker stream not found
            raise Exception('channels not found')    
        
        for i in range(0, streams[marker]["time_stamps"].shape[0], 2):
            event = streams[marker]["time_series"][i][0]
            if event in ['Starting Complex Eye Tracking Dashboard', 'Starting Simple Eye Tracking Dashboard', '']:
                continue
            start = streams[marker]["time_stamps"][i]
            stop = streams[marker]["time_stamps"][i+1]

            data = []
            for j, stamp in enumerate(streams[eeg]["time_stamps"]):
                if start <= stamp and stamp <= stop: 
                    data.append(streams[eeg]["time_series"][j])
            data = np.array(data).T[:8]

            trial = mne.io.RawArray(data, info, verbose=False)

            if event == 'Starting Cognitive Load Tasks: Balls True, Numbers False, Wheel False':
                key = 'ball'
            elif event == 'Starting Cognitive Load Tasks: Balls True, Numbers True, Wheel False':
                key = 'ball+number'
            elif event == 'Starting Cognitive Load Tasks: Balls True, Numbers True, Wheel True':
                key = 'ball+number+wheel'
            else:
                print(f"key {event} not defined!")

            trials[key].append(trial)
        
    return trials

trails = read_trails(current_experiment_list=experiment_list, baseline_list=baseline_list);

In [4]:
from mne.time_frequency import psd_welch


def eeg_power_band(raw: mne.io.RawArray):
    """EEG relative power band feature extraction.

    This function takes an ``mne.io.RawArray`` object and creates EEG features based
    on relative power in specific frequency bands that are compatible with
    scikit-learn.

    Parameters
    ----------
    raw : RawArray
        The data.

    Returns
    -------
    X : numpy array of shape [n_samples, 5]
        Transformed data.
    """

    # bands taken from slides
    fmin=8
    fmax=60
    FREQ_BANDS = {
        "alpha": [8, 13],
        "beta": [13.5, 30],
        "gamma": [30.5, 60],
    }
    # freqs  = (50, 100)
    # raw_notch = raw.copy().notch_filter(freqs=freqs, verbose=False, notch_widths=0.5)
    psds, freqs = psd_welch(raw, picks="eeg", fmin=fmin, fmax=fmax, verbose=False)

    # Normalize the PSDs
    # Baseline in time dimension not neccessary (0th PSD dimension)
    psds /= np.sum(psds, axis=-1, keepdims=True)

    X = []
    for fmin, fmax in FREQ_BANDS.values():
        psds_band = psds[:, (freqs >= fmin) & (freqs < fmax)].mean(axis=-1)
        # TODO: use statistical features (min, max, var)
        # use bins (e.g. 2 Hz range)
        X.append(psds_band.reshape(len(psds), -1))

    return np.concatenate(X, axis=-2).squeeze()  # concatenate all 8 channels * 3 frequences


In [5]:
# create feature and class arrays

def create_ml_data(trails, step=2):
    X = []
    y = []
    for data_class, data_list in trails.items():
        for raw in data_list:

            # apply notch filter
            freqs = (50, 100)
            raw_notch = raw.copy().notch_filter(freqs=freqs, verbose=False, notch_widths=0.5)

            # slice and create psd
            for start in range(0, int(raw.tmax), step):
                raw_crop = raw.copy().crop(tmin=start, tmax=start+step)
                X.append(eeg_power_band(raw_crop))
                y.append(data_class)

    return X, y

X, y = create_ml_data(trails)

In [6]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report , ConfusionMatrixDisplay

pipe = make_pipeline(RandomForestClassifier(n_estimators=100, random_state=42))
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

pipe.fit(X_train, y_train)

y_pred = pipe.predict(X_test)
acc = accuracy_score(y_test, y_pred)

ValueError: With n_samples=0, test_size=0.25 and train_size=None, the resulting train set will be empty. Adjust any of the aforementioned parameters.

In [None]:
cm = confusion_matrix(y_test, y_pred, labels = pipe.classes_)
disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=pipe.classes_)
disp.plot()
plt.show()

In [None]:
def save_model(model, filename="model"):
    dump(model, f'{filename}.joblib')


In [None]:
save_model(model=pipe, filename="model01")