In [14]:
import numpy as np
from scipy import stats
from scipy.stats import entropy
from pywt import wavedec
from scipy.signal import welch


from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

from mne_features.univariate import compute_spect_entropy, compute_wavelet_coef_energy
# https://mne.tools/mne-features/api.html

import mne
from mne.io import concatenate_raws, read_raw_edf
import glob


from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV, ShuffleSplit, cross_val_score
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from mne.decoding import CSP, SPoC
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

from mne.decoding import (
    SlidingEstimator,
    GeneralizingEstimator,
    Scaler,
    cross_val_multiscore,
    LinearModel,
    get_coef,
    Vectorizer,
    CSP,
)
import numpy as np
from mne.preprocessing import ICA

from lightgbm import LGBMClassifier
from xgboost.sklearn import XGBClassifier
from sklearn.tree import DecisionTreeClassifier
from utils import preprocess_data


import numpy as np
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Conv1D, MaxPooling1D, Dropout, Flatten, BatchNormalization, Conv2D, DepthwiseConv2D, AveragePooling2D, Activation, SeparableConv2D, SpatialDropout1D
from tensorflow.keras.utils import to_categorical
import mne
from mne.datasets import eegbci
from mne.io import concatenate_raws, read_raw_edf
import glob
import numpy as np
from utils import preprocess_data
from mne.preprocessing import ICA

https://www.researchgate.net/publication/309577320_Motor_Imagery_Classification_Using_Mu_and_Beta_Rhythms_of_EEG_with_Strong_Uncorrelating_Transform_Based_Complex_Common_Spatial_Patterns

https://arxiv.org/ftp/arxiv/papers/1312/1312.2877.pdf

In [15]:
'''
=========  ===================================
run        task
=========  ===================================
1          Baseline, eyes open
2          Baseline, eyes closed
3, 7, 11   Motor execution: left vs right hand
4, 8, 12   Motor imagery: left vs right hand
5, 9, 13   Motor execution: hands vs feet
6, 10, 14  Motor imagery: hands vs feet
=========  ===================================
'''
raws_train = []
raws_test = []
for ii in range(1, 20):
    subject = f'S{ii:03d}'
    files = glob.glob(f'../../files/{subject}/*.edf')
    for i in [5, 9, 13, 3, 7, 11]:
        current_file = files[i]
        r = read_raw_edf(current_file, preload=True, stim_channel='auto')
        events, _ = mne.events_from_annotations(r)
        if i in [5, 9, 13]:
            new_labels_events = {1:'rest', 2:'T1', 3:'T2'} # action
        elif i in [3, 7, 11]:
            new_labels_events = {1:'rest', 2:'T3', 3:'T4'}
        new_annot = mne.annotations_from_events(events=events, event_desc=new_labels_events, sfreq=r.info['sfreq'], orig_time=r.info['meas_date'])
        r.set_annotations(new_annot)
        if ii < 4:
            raws_test.append(r)
        else:
            raws_train.append(r)
    
raws_train_obj = concatenate_raws(raws_train)
raw_test_obj = concatenate_raws(raws_test)

Extracting EDF parameters from /Users/owalid/42/post_intership/total-perspective-vortex/files/S001/S001R03.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Used Annotations descriptions: ['T0', 'T1', 'T2']
Extracting EDF parameters from /Users/owalid/42/post_intership/total-perspective-vortex/files/S001/S001R13.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Used Annotations descriptions: ['T0', 'T1', 'T2']
Extracting EDF parameters from /Users/owalid/42/post_intership/total-perspective-vortex/files/S001/S001R09.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Used Annotations descriptions: ['T0', 'T1', 'T2']
Extracting EDF parameters from /Users/owalid/42/post_intership/total-perspective-vortex/fil

In [16]:
def preprocess_raw(raw):
    # "Fc5.","Fc6.","Fc3.","Fc4.","Fc1.","Fc2.","C5..","C6..","C3..","C4..","C1..","C2..","Cp5.","Cp6.","Cp3.","Cp4.","Cp1.","Cp2."
    # raw = raw.pick_channels(['Fc5.', 'Fc6.', 'Fc3.', 'Fc4.', 'Fc1.', 'Fc2.', 'C5..', 'C6..', 'C3..', 'C4..', 'C1..', 'C2..', 'Cp5.', 'Cp6.', 'Cp3.', 'Cp4.', 'Cp1.', 'Cp2.'])

    # filters
    notch_freq = 60
    raw.notch_filter(notch_freq, fir_design='firwin')

    low_cutoff = 8
    high_cutoff = 40
    raw.filter(low_cutoff, high_cutoff, fir_design='firwin')

    events, event_dict = mne.events_from_annotations(raw)
    print(raw.info)
    print(event_dict)
    picks = mne.pick_types(raw.info, meg=True, eeg=True, stim=False, eog=False, exclude='bads')

    event_id = {'T1': 2, 'T2': 3, 'T3': 4, 'T4': 5}
    events, event_dict = mne.events_from_annotations(raw, event_id=event_id)
    tmin = -0.2  # Time before event in seconds
    tmax = 0.8  # Time after event in seconds
    epochs = mne.Epochs(raw, events, event_dict, tmin, tmax, proj=True, picks=picks, baseline=None, preload=True)

    # raw, events, event_dict, picks, epochs
    return raw, events, event_dict, picks, epochs

In [17]:

raw_train = raws_train_obj.copy()
raw_test = raw_test_obj.copy()

raw_train, events_train, event_dict_train, picks_train, epochs_train = preprocess_raw(raw_train)
raw_test, events_test, event_dict_test, picks_test, epochs_test = preprocess_raw(raw_test)

Setting up band-stop filter from 59 - 61 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 59.35
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 59.10 Hz)
- Upper passband edge: 60.65 Hz
- Upper transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 60.90 Hz)
- Filter length: 1057 samples (6.606 sec)



[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.3s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.5s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.7s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.9s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  64 out of  64 | elapsed:    8.9s finished


Filtering raw data in 96 contiguous segments
Setting up band-pass filter from 8 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 265 samples (1.656 sec)



[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  64 out of  64 | elapsed:    0.1s finished


Used Annotations descriptions: ['T1', 'T2', 'T3', 'T4', 'rest']
<Info | 7 non-empty values
 bads: []
 ch_names: Fc5., Fc3., Fc1., Fcz., Fc2., Fc4., Fc6., C5.., C3.., C1.., ...
 chs: 64 EEG
 custom_ref_applied: False
 highpass: 8.0 Hz
 lowpass: 40.0 Hz
 meas_date: 2009-08-12 16:15:00 UTC
 nchan: 64
 projs: []
 sfreq: 160.0 Hz
>
{'T1': 1, 'T2': 2, 'T3': 3, 'T4': 4, 'rest': 5}
Used Annotations descriptions: ['T1', 'T2', 'T3', 'T4']
Not setting metadata
1275 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 1275 events and 161 original time points ...
0 bad epochs dropped
Setting up band-stop filter from 59 - 61 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 59.35
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequ

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.1s remaining:    0.0s


Filtering raw data in 18 contiguous segments
Setting up band-pass filter from 8 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 265 samples (1.656 sec)



[Parallel(n_jobs=1)]: Done  64 out of  64 | elapsed:    1.2s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  64 out of  64 | elapsed:    0.1s finished


Used Annotations descriptions: ['T1', 'T2', 'T3', 'T4', 'rest']
<Info | 7 non-empty values
 bads: []
 ch_names: Fc5., Fc3., Fc1., Fcz., Fc2., Fc4., Fc6., C5.., C3.., C1.., ...
 chs: 64 EEG
 custom_ref_applied: False
 highpass: 8.0 Hz
 lowpass: 40.0 Hz
 meas_date: 2009-08-12 16:15:00 UTC
 nchan: 64
 projs: []
 sfreq: 160.0 Hz
>
{'T1': 1, 'T2': 2, 'T3': 3, 'T4': 4, 'rest': 5}
Used Annotations descriptions: ['T1', 'T2', 'T3', 'T4']
Not setting metadata
255 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 255 events and 161 original time points ...
0 bad epochs dropped


In [18]:
def get_X_y(epochs):
    import pywt

    # Define the wavelet decomposition parameters
    wavelet = 'db10'  # Choose a wavelet (e.g., 'morl')
    level = 15  # Adjust the decomposition level as needed

    # Perform wavelet decomposition on EEG data
    coeffs = pywt.wavedec(epochs.get_data(), wavelet, level=level)
    mu_rhythm = coeffs[level] # Mu rhythm (8-13 Hz)
    beta_rhythm = coeffs[level]  # Beta rhythm (13-30 Hz)
    X = np.concatenate((mu_rhythm, beta_rhythm), axis=1)
    y = epochs.events[:, -1] - 2
    return X, to_categorical(y)

In [19]:
X_train, y_train = get_X_y(epochs_train)
X_test, y_test = get_X_y(epochs_test)



In [20]:
X_train.shape, y_train.shape

((1275, 128, 90), (1275, 4))

In [21]:
# cnn
n_channels = X_train.shape[1]
input_window_size = X_train.shape[2]
input_shape = (1, n_channels, input_window_size)
X_train = X_train.reshape(X_train.shape[0], 1, n_channels, input_window_size)
X_test = X_test.reshape(X_test.shape[0], 1, n_channels, input_window_size)
print(X_train.shape, X_test.shape)

(1275, 1, 128, 90) (255, 1, 128, 90)


In [22]:
model = Sequential()
# Block 1: Temporal Convolution
model.add(Conv2D(8, (1, n_channels), strides=(1, 1), padding='same', use_bias=False, input_shape=input_shape))
model.add(BatchNormalization())
model.add(Dropout(0.5))

# Block 2: Spacial Convolution
model.add(Conv2D(int(n_channels/2), (1, n_channels), strides=(1, 1), padding='same', use_bias=False))
model.add(BatchNormalization())
model.add(Activation('elu'))
# model.add(AveragePooling2D(pool_size=(1, 4), strides=(1, 4)))
model.add(Dropout(0.25))

# Block 3: Separable Convolution
model.add(Conv2D(8, (1, 1), strides=(1, 1), padding='same', use_bias=False))
model.add(Conv2D(16, (1, 1), strides=(1, 1), padding='same', use_bias=False))
model.add(BatchNormalization())
model.add(Activation('elu'))
# model.add(AveragePooling2D(pool_size=(1, 4), strides=(1, 4)))
model.add(Dropout(0.25))
model.add(Flatten())

# Classifier
model.add(Dense(4, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_4 (Conv2D)           (None, 1, 128, 8)         92160     
                                                                 
 batch_normalization_3 (Batc  (None, 1, 128, 8)        32        
 hNormalization)                                                 
                                                                 
 dropout_3 (Dropout)         (None, 1, 128, 8)         0         
                                                                 
 conv2d_5 (Conv2D)           (None, 1, 128, 64)        65536     
                                                                 
 batch_normalization_4 (Batc  (None, 1, 128, 64)       256       
 hNormalization)                                                 
                                                                 
 activation_2 (Activation)   (None, 1, 128, 64)       

In [23]:
model.fit(X_train, y_train, epochs=250, batch_size=15, validation_data=(X_test, y_test), verbose=1, shuffle=False)
loss, accuracy = model.evaluate(X_test, y_test, verbose=1)
print(f'Accuracy: {accuracy}')
print(f'Loss: {loss}')

Epoch 1/250
Epoch 2/250
Epoch 3/250
Epoch 4/250
Epoch 5/250
Epoch 6/250
Epoch 7/250
Epoch 8/250
Epoch 9/250
Epoch 10/250
Epoch 11/250
Epoch 12/250
Epoch 13/250
Epoch 14/250
Epoch 15/250
Epoch 16/250
Epoch 17/250
Epoch 18/250
Epoch 19/250
Epoch 20/250
Epoch 21/250
Epoch 22/250
Epoch 23/250
Epoch 24/250
Epoch 25/250
Epoch 26/250
Epoch 27/250
Epoch 28/250
Epoch 29/250
Epoch 30/250
Epoch 31/250
Epoch 32/250
Epoch 33/250
Epoch 34/250
Epoch 35/250
Epoch 36/250
Epoch 37/250
Epoch 38/250
Epoch 39/250
Epoch 40/250
Epoch 41/250
Epoch 42/250
Epoch 43/250
Epoch 44/250
Epoch 45/250
Epoch 46/250
Epoch 47/250
Epoch 48/250
Epoch 49/250
Epoch 50/250
Epoch 51/250
Epoch 52/250
Epoch 53/250
Epoch 54/250
Epoch 55/250
Epoch 56/250
Epoch 57/250
Epoch 58/250
Epoch 59/250
Epoch 60/250
Epoch 61/250
Epoch 62/250
Epoch 63/250
Epoch 64/250
Epoch 65/250
Epoch 66/250
Epoch 67/250
Epoch 68/250
Epoch 69/250
Epoch 70/250
Epoch 71/250
Epoch 72/250
Epoch 73/250
Epoch 74/250
Epoch 75/250
Epoch 76/250
Epoch 77/250
Epoch 78