# EEG Signal Classification using GRU Models

In this notebook, we explore the application of Gated Recurrent Unit (GRU) models to classify EEG signals. We will work with two types of datasets: INTRA and CROSS. Our objective is to build, train, and evaluate standard and advanced GRU models, the latter incorporating attention mechanisms, to determine their efficacy in EEG signal classification tasks:)))



In [1]:
# Importing Necessary Libraries
import h5py
import numpy as np
import torch
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import GRU, Dense, Dropout
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.model_selection import KFold
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import Layer
import tensorflow.keras.backend as K
from tensorflow.keras.layers import GRU, Bidirectional, Dense, Dropout, TimeDistributed


data_dir = Path("./data/")
intra_dir = data_dir / "Intra"
cross_dir = data_dir / "Cross"

## Data loading & pre-processing

In [2]:
def load_h5(path):
    with h5py.File(path, 'r') as f:
        keys = list(f.keys())
        assert len(keys) == 1, "Only one key per file is expected"
        matrix = f[keys[0]][()]
    return matrix

def load_labels(path: Path) -> np.ndarray:
    *task, subject_identifier, chunk = path.stem.split("_")
    if "rest" in task:
        y = 0
    elif 'math' in task:
        y = 1
    elif 'working' in task:
        y = 2
    elif 'motor' in task:
        y = 3
    else:
        assert False, 'unknown task'
    return np.array([y, int(subject_identifier), int(chunk)])

In [3]:
def downsample(data, old_freq, new_freq):
    # Calculate the downsampling factor
    downsample_factor = int(np.round(old_freq / new_freq))
    # Ensure that timesteps are divisible by the downsampling factor
    data = data[:,:,:data.shape[2]//downsample_factor*downsample_factor]
    # Reshape
    reshaped_data = data.reshape(data.shape[0], data.shape[1], -1, downsample_factor)
    # Take the mean along the last axis
    downsampled_data = reshaped_data.mean(axis=-1)
    return downsampled_data

def z_score_normalize(data):
    # Convert to PyTorch tensor
    data_tensor = torch.tensor(data, dtype=torch.float32)
    # Calculate mean and std along the timesteps
    mean = torch.mean(data_tensor, dim=2, keepdim=True)
    std = torch.std(data_tensor, dim=2, keepdim=True)
    # Perform z-score norm
    normalized_data = (data_tensor - mean) / std
    return normalized_data

### INTRA data

In [4]:
## INTRA Data Loading and Preprocessing

intra_train_glob = list((intra_dir / "train").glob("*.h5"))
intra_test_glob = list((intra_dir / "test").glob("*.h5"))
intra_train_X = np.stack([load_h5(path) for path in intra_train_glob])
intra_train_labels = np.array([load_labels(path)[0] for path in intra_train_glob])
intra_test_X = np.stack([load_h5(path) for path in intra_test_glob])
intra_test_labels = np.array([load_labels(path)[0] for path in intra_test_glob])

In [5]:
intra_train_X_downsampled = downsample(intra_train_X, 2034, 125)
intra_train_X_norm = z_score_normalize(intra_train_X_downsampled)

intra_test_X_downsampled = downsample(intra_test_X, 2034, 125)
intra_test_X_norm = z_score_normalize(intra_test_X_downsampled)

In [6]:
intra_train_X_preprocessed = intra_train_X_norm.numpy()
intra_train_X_preprocessed.shape
intra_test_X_preprocessed = intra_test_X_norm.numpy()
intra_test_X_preprocessed.shape

(8, 248, 2226)

In [7]:
num_classes = 4  
intra_train_labels_one_hot = to_categorical(intra_train_labels, num_classes)
intra_test_labels_one_hot = to_categorical(intra_test_labels, num_classes)

### CROSS data

In [8]:
cross_train_glob = list((cross_dir / "train").glob("*.h5")) + list((cross_dir / "test1").glob("*.h5")) + list((cross_dir / "test2").glob("*.h5"))
cross_test_glob = list((cross_dir / "test3").glob("*.h5"))

In [9]:
cross_train_X = np.stack([load_h5(path) for path in cross_train_glob])
cross_train_labels = np.array([load_labels(path)[0] for path in cross_train_glob])

cross_test_X = np.stack([load_h5(path) for path in cross_test_glob])
cross_test_labels = np.array([load_labels(path)[0] for path in cross_test_glob])

In [10]:
# Preprocess Cross data
cross_train_X_downsampled = downsample(cross_train_X, 2034, 125)
cross_train_X_norm = z_score_normalize(cross_train_X_downsampled)

cross_test_X_downsampled = downsample(cross_test_X, 2034, 125)
cross_test_X_norm = z_score_normalize(cross_test_X_downsampled)

In [11]:
cross_train_X_norm = cross_train_X_norm.numpy()
cross_test_X_norm = cross_test_X_norm.numpy()

In [12]:
# Convert labels to categorical
cross_train_labels_cat = to_categorical(cross_train_labels)
cross_test_labels_cat = to_categorical(cross_test_labels)

## Model Building

We will define two GRU-based models: a standard GRU model and an advanced GRU model with attention.


In [13]:
def build_gru_model(input_shape, num_classes):
    model = Sequential()
    model.add(GRU(128, input_shape=input_shape, return_sequences=True))
    model.add(GRU(64))
    model.add(Dense(4, activation='softmax'))  # 4 classes
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model

In [14]:
class AttentionLayer(Layer):
    def __init__(self, **kwargs):
        super(AttentionLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.W = self.add_weight(name='attention_weight', 
                                 shape=(input_shape[-1], 1),
                                 initializer='random_normal',
                                 trainable=True)
        self.b = self.add_weight(name='attention_bias',
                                 shape=(input_shape[1], 1),
                                 initializer='zeros',
                                 trainable=True)
        super(AttentionLayer, self).build(input_shape)

    def call(self, x):
        # Alignment scores. Shape: [batch_size, time_steps]
        e = K.tanh(K.dot(x, self.W) + self.b)
        e = K.squeeze(e, axis=-1)

        # Softmax over alignment scores to get attention weights
        alpha = K.softmax(e)

        # Context vector is the weighted sum of the inputs
        context = x * K.expand_dims(alpha, -1)
        context = K.sum(context, axis=1)
        return context

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[-1])

In [15]:
def build_advanced_gru_model_with_attention(input_shape, num_classes=4):
    model = Sequential([
        Bidirectional(GRU(128, return_sequences=True), input_shape=input_shape),
        Dropout(0.5),
        Bidirectional(GRU(64, return_sequences=True)),  # Keep return_sequences=True
        AttentionLayer(),  # Custom attention layer
        Dense(32, activation='relu'),
        Dropout(0.5),
        Dense(num_classes, activation='softmax')
    ])
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model


## Model traning

### Intra traing with cross validation

In [16]:
# Initialize KFold
kfold = KFold(n_splits=5, shuffle=True, random_state=42)

In [17]:
fold_no = 1
for train, val in kfold.split(intra_train_X_preprocessed, intra_train_labels_one_hot):
    # Build models for each fold
    model_gru = build_gru_model(intra_train_X_preprocessed.shape[1:], 4)
    model_advanced_gru = build_advanced_gru_model_with_attention(intra_train_X_preprocessed.shape[1:], 4)
    
    # Training
    print(f'Training for fold {fold_no} ...')
    history_gru = model_gru.fit(intra_train_X_preprocessed[train], intra_train_labels_one_hot[train], epochs=10, batch_size=32, validation_data=(intra_train_X_preprocessed[val], intra_train_labels_one_hot[val]))
    history_advanced_gru = model_advanced_gru.fit(intra_train_X_preprocessed[train], intra_train_labels_one_hot[train], epochs=10, batch_size=32, validation_data=(intra_train_X_preprocessed[val], intra_train_labels_one_hot[val]))
    
    fold_no += 1

Training for fold 1 ...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Training for fold 2 ...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Training for fold 3 ...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10


Epoch 10/10
Training for fold 4 ...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Training for fold 5 ...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [18]:
# Evaluate models
performance_gru = model_gru.evaluate(intra_test_X_preprocessed, intra_test_labels_one_hot)
performance_advanced_gru = model_advanced_gru.evaluate(intra_test_X_preprocessed, intra_test_labels_one_hot)

print("Standard GRU Model Performance:", performance_gru)
print("Advanced GRU Model with Attention Performance:", performance_advanced_gru)

Standard GRU Model Performance: [1.6172165870666504, 0.375]
Advanced GRU Model with Attention Performance: [1.8650151491165161, 0.0]


### Cross training with cross-validation

In [19]:
# Parameters
num_folds = 5
kfold = KFold(n_splits=num_folds, shuffle=True)

# K-Fold Cross-validation
fold_no = 1
for train, val in kfold.split(cross_train_X_norm, cross_train_labels_cat):
    # Build models for each fold
    model_gru_cross = build_gru_model(cross_train_X_norm.shape[1:], 4)
    model_advanced_gru_cross = build_advanced_gru_model_with_attention(cross_train_X_norm.shape[1:], 4)
    
    # Training
    print(f'Training for fold {fold_no} ...')
    history_gru_cross = model_gru_cross.fit(cross_train_X_norm[train], cross_train_labels_cat[train], epochs=10, batch_size=32, validation_data=(cross_train_X_norm[val], cross_train_labels_cat[val]))
    history_advanced_gru_cross = model_advanced_gru_cross.fit(cross_train_X_norm[train], cross_train_labels_cat[train], epochs=10, batch_size=32, validation_data=(cross_train_X_norm[val], cross_train_labels_cat[val]))
    
    fold_no += 1

Training for fold 1 ...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Training for fold 2 ...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Training for fold 3 ...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


Training for fold 4 ...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Training for fold 5 ...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [20]:
# Evaluate models
performance_gru = model_gru_cross.evaluate(cross_test_X_norm, cross_test_labels_cat)
performance_advanced_gru = model_advanced_gru_cross.evaluate(cross_test_X_norm, cross_test_labels_cat)

print("Standard GRU Model Performance:", performance_gru)
print("Advanced GRU Model with Attention Performance:", performance_advanced_gru)

Standard GRU Model Performance: [3.005998373031616, 0.25]
Advanced GRU Model with Attention Performance: [1.3895056247711182, 0.5]
