In [17]:
import numpy as np
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Conv1D, MaxPooling1D, Dropout, Flatten, BatchNormalization, Conv2D, DepthwiseConv2D, AveragePooling2D, Activation, SeparableConv2D, SpatialDropout1D, MaxPooling2D, Dot, Input, TimeDistributed, Bidirectional, GlobalMaxPooling1D, MaxPool2D
from tensorflow.keras.models import Model
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.utils import normalize
import mne
from mne.datasets import eegbci
from mne.io import concatenate_raws, read_raw_edf
import glob
import numpy as np
from utils import preprocess_data
from mne.preprocessing import ICA

In [18]:

'''
=========  ===================================
run        task
=========  ===================================
1          Baseline, eyes open
2          Baseline, eyes closed
3, 7, 11   Motor execution: left vs right hand
4, 8, 12   Motor imagery: left vs right hand
5, 9, 13   Motor execution: hands vs feet
6, 10, 14  Motor imagery: hands vs feet
=========  ===================================
'''
raws_train = []
raws_test = []
for ii in range(1, 60):
    subject = f'S{ii:03d}'
    files = glob.glob(f'../../files/{subject}/*.edf')
    for i in [3, 7, 11]:
        current_file = files[i]
        r = read_raw_edf(current_file, preload=True, stim_channel='auto')
        if ii < 10:
            raws_test.append(r)
        else:
            raws_train.append(r)
    
raws_train_obj = concatenate_raws(raws_train)
raw_test_obj = concatenate_raws(raws_test)

Extracting EDF parameters from /Users/owalid/42/post_intership/total-perspective-vortex/files/S001/S001R10.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from /Users/owalid/42/post_intership/total-perspective-vortex/files/S001/S001R04.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from /Users/owalid/42/post_intership/total-perspective-vortex/files/S001/S001R14.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from /Users/owalid/42/post_intership/total-perspective-vortex/files/S002/S002R04.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19679  =      0.000 ...   122.

In [19]:
def preprocess_raw(raw):
    # "Fc5.","Fc6.","Fc3.","Fc4.","Fc1.","Fc2.","C5..","C6..","C3..","C4..","C1..","C2..","Cp5.","Cp6.","Cp3.","Cp4.","Cp1.","Cp2."
    # raw = raw.pick_channels(['Fc5.', 'Fc6.', 'Fc3.', 'Fc4.', 'Fc1.', 'Fc2.', 'C5..', 'C6..', 'C3..', 'C4..', 'C1..', 'C2..', 'Cp5.', 'Cp6.', 'Cp3.', 'Cp4.', 'Cp1.', 'Cp2.'])

    # filters
    notch_freq = 60
    raw.notch_filter(notch_freq, fir_design='firwin')

    low_cutoff = 8
    high_cutoff = 40
    raw.filter(low_cutoff, high_cutoff, fir_design='firwin')

    events, event_dict = mne.events_from_annotations(raw)
    print(raw.info)
    print(event_dict)
    picks = mne.pick_types(raw.info, meg=True, eeg=True, stim=False, eog=False, exclude='bads')

    # event_id = {'T1': 2, 'T2': 3, 'T3': 4, 'T4': 5}
    event_id = {'T1': 1, 'T2': 2}
    events, event_dict = mne.events_from_annotations(raw, event_id=event_id)
    tmin = -0.5  # Time before event in seconds
    tmax = 4.5  # Time after event in seconds
    epochs = mne.Epochs(raw, events, event_dict, tmin, tmax, proj=True, picks=picks, baseline=None, preload=True)

    # raw, events, event_dict, picks, epochs
    return raw, events, event_dict, picks, epochs

In [20]:
raw_train = raws_train_obj.copy()
raw_test = raw_test_obj.copy()

raw_train, events_train, event_dict_train, picks_train, epochs_train = preprocess_raw(raw_train)
raw_test, events_test, event_dict_test, picks_test, epochs_test = preprocess_raw(raw_test)

Setting up band-stop filter from 59 - 61 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 59.35
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 59.10 Hz)
- Upper passband edge: 60.65 Hz
- Upper transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 60.90 Hz)
- Filter length: 1057 samples (6.606 sec)



[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.3s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.4s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  64 out of  64 | elapsed:    5.7s finished


Filtering raw data in 150 contiguous segments
Setting up band-pass filter from 8 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 265 samples (1.656 sec)



[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  64 out of  64 | elapsed:    0.1s finished


Used Annotations descriptions: ['T0', 'T1', 'T2']
<Info | 7 non-empty values
 bads: []
 ch_names: Fc5., Fc3., Fc1., Fcz., Fc2., Fc4., Fc6., C5.., C3.., C1.., ...
 chs: 64 EEG
 custom_ref_applied: False
 highpass: 8.0 Hz
 lowpass: 40.0 Hz
 meas_date: 2009-08-12 16:15:00 UTC
 nchan: 64
 projs: []
 sfreq: 160.0 Hz
>
{'T0': 1, 'T1': 2, 'T2': 3}
Used Annotations descriptions: ['T1', 'T2']
Not setting metadata
1995 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 1995 events and 801 original time points ...
105 bad epochs dropped
Setting up band-stop filter from 59 - 61 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 59.35
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 59.10 Hz)
- Upper passband edge: 60.65

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.1s remaining:    0.0s


Filtering raw data in 27 contiguous segments
Setting up band-pass filter from 8 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 265 samples (1.656 sec)



[Parallel(n_jobs=1)]: Done  64 out of  64 | elapsed:    1.0s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  64 out of  64 | elapsed:    0.1s finished


Used Annotations descriptions: ['T0', 'T1', 'T2']
<Info | 7 non-empty values
 bads: []
 ch_names: Fc5., Fc3., Fc1., Fcz., Fc2., Fc4., Fc6., C5.., C3.., C1.., ...
 chs: 64 EEG
 custom_ref_applied: False
 highpass: 8.0 Hz
 lowpass: 40.0 Hz
 meas_date: 2009-08-12 16:15:00 UTC
 nchan: 64
 projs: []
 sfreq: 160.0 Hz
>
{'T0': 1, 'T1': 2, 'T2': 3}
Used Annotations descriptions: ['T1', 'T2']
Not setting metadata
345 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 345 events and 801 original time points ...
15 bad epochs dropped


In [21]:
def getX_y(epochs, n_channels, input_window_size):
    X = epochs.get_data()
    X = X.transpose(0,2,1)
    X = X.reshape(X.shape[0], input_window_size, n_channels, 1)
    X = normalize(X, axis=1, order=0)

    y = epochs.events[:, -1] - 2
    y = to_categorical(y)

    return X, y

In [22]:
_, n_channels, input_window_size = epochs_train.get_data().shape
X_train, y_train = getX_y(epochs_train, n_channels, input_window_size)
X_test, y_test = getX_y(epochs_test, n_channels, input_window_size)

In [23]:
input_shape = (input_window_size, n_channels, 1)

In [24]:
print(X_train.shape, X_test.shape, input_shape)

(1890, 801, 64, 1) (330, 801, 64, 1) (801, 64, 1)


In [25]:
model = Sequential()

    # keras.layers.Conv2D(filters=25, kernel_size=(15,1),dilation_rate=(2, 1),strides=(1, 1), padding='valid', activation='elu', input_shape=(data_lenght,n_ch,1)),
model.add(Conv2D(25, (15, 1), strides=(1, 1), padding='valid', activation='elu', input_shape=input_shape))
#conv_pool_block_1
model.add(Conv2D(filters=25, kernel_size=(15,1),dilation_rate=(2, 1),strides=(1, 1), padding='valid', activation='elu', input_shape=(input_window_size,n_channels,1)))
model.add(BatchNormalization())
model.add(Dropout(0.3))

model.add(Conv2D(filters=25, kernel_size=(1,n_channels),strides=(1, 1), padding='valid', activation='elu'))
model.add(BatchNormalization())
model.add(Dropout(0.3))
model.add(MaxPool2D(pool_size=(3,1)))

#conv_pool_block_2
model.add(Conv2D(filters=50, kernel_size=(10,1),strides=(1, 1), padding='valid', activation='elu'))
model.add(BatchNormalization())
model.add(Dropout(0.3))
model.add(MaxPool2D(pool_size=(3,1)))

#conv_pool_block_3
model.add(Conv2D(filters=100, kernel_size=(10,1),strides=(1, 1), padding='valid', activation='elu'))
model.add(BatchNormalization())
model.add(Dropout(0.3))
model.add(MaxPool2D(pool_size=(3,1)))

#conv_pool_block_4
model.add(Conv2D(filters=200, kernel_size=(10,1),strides=(1, 1), padding='valid', activation='elu'))
model.add(BatchNormalization())
model.add(Dropout(0.3))
model.add(MaxPool2D(pool_size=(3,1)))

#classification Layer
model.add(Flatten())
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))

# Take a look at the model summary
model.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_6 (Conv2D)           (None, 787, 64, 25)       400       
                                                                 
 conv2d_7 (Conv2D)           (None, 759, 64, 25)       9400      
                                                                 
 batch_normalization_6 (Batc  (None, 759, 64, 25)      100       
 hNormalization)                                                 
                                                                 
 dropout_6 (Dropout)         (None, 759, 64, 25)       0         
                                                                 
 conv2d_8 (Conv2D)           (None, 759, 1, 25)        40025     
                                                                 
 batch_normalization_7 (Batc  (None, 759, 1, 25)       100       
 hNormalization)                                      

In [26]:
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X_train, y_train, epochs=5, batch_size=15, validation_data=(X_test, y_test), verbose=1, shuffle=False)
loss, accuracy = model.evaluate(X_test, y_test, verbose=1)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [27]:
print(f'Accuracy: {accuracy}')
print(f'Loss: {loss}')

Accuracy: 1.0
Loss: 9.474502462580858e-07
