In [None]:
# loading the data

In [1]:
from glob import glob
import os
import mne
import numpy as np
import pandas
import matplotlib.pyplot as plt

In [2]:
# using glob to extract all .edf files from the folder

all_file_path = glob('data/*.edf')
print(len(all_file_path))

72


In [3]:
all_file_path[:5]

['data\\Subject00_1.edf',
 'data\\Subject00_2.edf',
 'data\\Subject01_1.edf',
 'data\\Subject01_2.edf',
 'data\\Subject02_1.edf']

In [4]:
# from data set we already know there are total 36 persons 
# and each one has two .edf files subjectid_01 for data before arithmetic task
# and subjectid_02 for data during arithmetic task
# so here i assigned all the files which has 1 after '_' to before_arithmetic_task
# similarly all the files that has 2 after '_' to during arithmetic task

before_arithmetic_task = [i for i in all_file_path if '1' in i.split('_')[1]]
during_arithmetic_task = [i for i in all_file_path if '2' in i.split('_')[1]]

print(len(before_arithmetic_task), len(during_arithmetic_task))

36 36


In [5]:
# function to read the .edf file data using mne package

def read_data(file_path):
    data = mne.io.read_raw_edf(file_path, preload = True)
    data.set_eeg_reference()
    data.filter(l_freq = 0.5, h_freq = 45)
    epochs = mne.make_fixed_length_epochs(data, duration = 5, overlap = 1)
    array = epochs.get_data()
    return array

In [6]:
# i just passed a edf file to read_data function to check weather the function is working or not
# we can see what information is extracted from the .edf file

sample_data = read_data(during_arithmetic_task[0])

Extracting EDF parameters from C:\Users\prave\Desktop\eeg-during-mental-arithmetic-tasks\data\Subject00_2.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 30999  =      0.000 ...    61.998 secs...
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)
- Filter length: 3301 samples (6.602 s)

Not setting metadata
15 matching events found
No baseline correction appl

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  21 out of  21 | elapsed:    0.0s finished


In [7]:
# shape of the data that we extracted from .edf file
sample_data.shape # no of epochs, channels, length of signal

(15, 21, 2500)

In [8]:
# the following two lines are needed in order to execute the magic word 'capture'
# it function is to capture the cell output and hides it 
# used to hide the unnecessary info from printing 
# it has no connection to this project

import nest_asyncio
nest_asyncio.apply()

In [9]:
%%capture

# extract all 36 before_arithmetic_task data into before_epochs_array
before_epochs_array = [read_data(i) for i in before_arithmetic_task]

# extract all 36 during_arithmetic_task data into during_epochs_array
during_epochs_array = [read_data(i) for i in during_arithmetic_task]

In [10]:
before_epochs_array[0].shape, during_epochs_array[0].shape # no of epochs, channels, length of signal

((45, 21, 2500), (15, 21, 2500))

In [11]:
# each person data be in the following format
# person 1
#    epoch 1
#    epoch 2
#    epoch 3
#    ....
# person 2
#    epoch 1
#    epoch 2
#    epoch 3
#    ....

# so we need to create label for each epoch under each person not just one label for one person
# the following two lines of code are to create labels for each epoch under each person/each .edf file
# 0 for before arithmetic task and 1 for during arithmetic task data

before_epochs_labels = [len(i)*[0] for i in before_epochs_array]
during_epochs_labels = [len(i)*[1] for i in during_epochs_array]
len(before_epochs_labels), len(during_epochs_labels)

(36, 36)

In [12]:
# data_list is combination of before arithmetic data and during arithmetic data
data_list = before_epochs_array + during_epochs_array

# label_list is combination of before arithmetic labels and during arithmetic labels
label_list = before_epochs_labels + during_epochs_labels

In [13]:
# split the data based on subjects
# not on the basis of epochs or trails
# assign a group to each subject(edf file) and split based on groups
# because we know the data is in the follwing format
# person 1
#    epoch 1
#    epoch 2
#    epoch 3
#    ....
# person 2
#    epoch 1
#    epoch 2
#    epoch 3
#    ....
# if we randomly split it based on epochs person1 epoch1 may fall under training data
# at the same time person1 epoch2 may fall under test data
# now the model would easily predict because it trained on person1 and leads to overfitting

# in order to overcome that we assign a group value for each epoch
# like person1 group value for each epoch is 1 and person2 group value for each epoch is 2 so on
# and then split the data based on groups 
# now all the epochs on person1 may fall either on train or test not on both
group_list = [[i]*len(j) for i,j in enumerate(data_list)]
len(group_list)

72

In [14]:
# convert lists into numpy arrays
data_array = np.vstack(data_list)
label_array = np.hstack(label_list)
group_array = np.hstack(group_list)
print(data_array.shape, label_array.shape, group_array.shape)

(2132, 21, 2500) (2132,) (2132,)


In [15]:
# 2132 no of segments, 21 no of channels, 2500 is the length
# cnn expects the channel at the end so we need to change it to (2132, 2500, 21)
data_array=np.moveaxis(data_array,1,2)
data_array.shape

(2132, 2500, 21)

In [16]:
from sklearn.model_selection import GroupKFold,LeaveOneGroupOut
from sklearn.preprocessing import StandardScaler
gkf=GroupKFold()

In [17]:
# creating train and test features for one fold

accuracy=[] # to store accuracy for each fold

for train_index, val_index in gkf.split(data_array, label_array, groups=group_array):
    
    # from gkf.split we can get train, test and validation data
    train_features,train_labels=data_array[train_index],label_array[train_index]
    val_features,val_labels=data_array[val_index],label_array[val_index]
    
    # after the shape of train_features will be like (1710, 2500, 21)
    # for scalling we can only apply standard scaler on 2d data
    # so in order to apply standard we have to convert train_features into 2d array (1710*2500, 21)
    # then apply scalling
    # after scalling is done again reshape the data into its original form like (1710, 2500, 21)
    
    scaler=StandardScaler()
    
    # reshape train_features into 2d apply scalling then reshape it back into original format
    train_features = scaler.fit_transform(train_features.reshape(-1, train_features.shape[-1])).reshape(train_features.shape)
    
    # reshape val_features into 2d apply scalling then reshape it back into original format
    val_features = scaler.transform(val_features.reshape(-1, val_features.shape[-1])).reshape(val_features.shape)

    break

In [18]:
train_features.shape, val_features.shape

((1710, 2500, 21), (422, 2500, 21))

In [28]:
# constructiion of chrononet(cnn + gru)
# based on a research paper link: https://link.springer.com/chapter/10.1007/978-3-030-21642-9_8

In [19]:
from tensorflow.keras.layers import Input,Dense,concatenate,Flatten,GRU,Conv1D
from tensorflow.keras.models import Model

In [20]:
def block(input):
    conv1 = Conv1D(32, 2, strides=2,activation='relu',padding="same")(input)
    conv2 = Conv1D(32, 4, strides=2,activation='relu',padding="causal")(input)
    conv3 = Conv1D(32, 8, strides=2,activation='relu',padding="causal")(input)
    x = concatenate([conv1,conv2,conv3],axis=2)
    return x

In [21]:
input= Input(shape=(2500, 21))  # ((1710, 2500, 21), (422, 2500, 21)) required input shape
block1=block(input)
block2=block(block1)
block3=block(block2)

In [22]:
gru_out1 = GRU(32,activation='tanh',return_sequences=True)(block3)
gru_out2 = GRU(32,activation='tanh',return_sequences=True)(gru_out1)
gru_out = concatenate([gru_out1,gru_out2],axis=2)
gru_out3 = GRU(32,activation='tanh',return_sequences=True)(gru_out)
gru_out = concatenate([gru_out1,gru_out2,gru_out3])
gru_out4 = GRU(32,activation='tanh')(gru_out)

In [23]:
predictions = Dense(1,activation='sigmoid')(gru_out4)
model = Model(inputs=input, outputs=predictions)

model.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics=['accuracy'])

In [24]:
model.fit(train_features,train_labels,epochs=26,batch_size=128,validation_data=(val_features,val_labels))

Epoch 1/26
Epoch 2/26
Epoch 3/26
Epoch 4/26
Epoch 5/26
Epoch 6/26
Epoch 7/26
Epoch 8/26
Epoch 9/26
Epoch 10/26
Epoch 11/26
Epoch 12/26
Epoch 13/26
Epoch 14/26
Epoch 15/26
Epoch 16/26
Epoch 17/26
Epoch 18/26
Epoch 19/26
Epoch 20/26
Epoch 21/26
Epoch 22/26
Epoch 23/26
Epoch 24/26
Epoch 25/26
Epoch 26/26


<keras.src.callbacks.History at 0x1f53eb0f610>

In [26]:
print('accuracy',round((model.evaluate(val_features,val_labels)[1] * 100), 2))

accuracy 70.14


In [None]:
# accuracy of single fold is 70.14% using chrononet(cnn+gru)

In [None]:
# Applying K fold cross validation

In [27]:
accuracy=[] # to store accuracy for each fold

for train_index, val_index in gkf.split(data_array, label_array, groups=group_array):
    
    # from gkf.split we can get train, test and validation data
    train_features,train_labels=data_array[train_index],label_array[train_index]
    val_features,val_labels=data_array[val_index],label_array[val_index]
    
    # after the shape of train_features will be like (1710, 2500, 21)
    # for scalling we can only apply standard scaler on 2d data
    # so in order to apply standard we have to convert train_features into 2d array (1710*2500, 21)
    # then apply scalling
    # after scalling is done again reshape the data into its original form like (1710, 2500, 21)
    
    scaler=StandardScaler()
    
    # reshape train_features into 2d apply scalling then reshape it back into original format
    train_features = scaler.fit_transform(train_features.reshape(-1, train_features.shape[-1])).reshape(train_features.shape)
    
    # reshape val_features into 2d apply scalling then reshape it back into original format
    val_features = scaler.transform(val_features.reshape(-1, val_features.shape[-1])).reshape(val_features.shape)

    model.fit(train_features,train_labels,epochs=20,batch_size=128,validation_data=(val_features,val_labels))
    
    # evaluate and append the accuracy value of this batch
    accuracy.append(model.evaluate(val_features,val_labels)[1])

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20


Epoch 18/20
Epoch 19/20
Epoch 20/20
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


In [30]:
round((np.mean(accuracy)*100), 2)

89.6

In [None]:
# after cross validation
# mean accuracy of 89.6% is achieved.