In [16]:
import os
import numpy as np
import pandas as pd
import mne

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Conv2D, BatchNormalization, Activation, MaxPooling2D,
    Reshape, Bidirectional, LSTM, Dense, Dropout, Permute, TimeDistributed
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.regularizers import l2
from sklearn.utils.class_weight import compute_class_weight
import numpy as np

In [None]:


metadata = pd.read_csv('Dataset/participants.tsv', sep='\t')
print(metadata.head())

group_mapping = {'A': 0, 'F': 1, 'C': 2}
metadata['label'] = metadata['Group'].map(group_mapping)
print(metadata[['participant_id', 'Group', 'label']].head())

subject_labels = dict(zip(metadata['participant_id'], metadata['label']))

derivatives_path = os.path.join('Dataset', 'derivatives')
subject_folders = [os.path.join(derivatives_path, d) for d in os.listdir(derivatives_path) if d.startswith('sub-')]

all_features = []
all_epoch_subject_ids = []

freq_bands = {
    "delta": (0.5, 4),
    "theta": (4, 8),
    "alpha": (8, 13),
    "beta": (13, 25),
    "gamma": (25, 45),
}

for subject_folder in subject_folders:
    eeg_folder = os.path.join(subject_folder, 'eeg')
    set_files = [f for f in os.listdir(eeg_folder) if f.endswith('.set')]
    if not set_files:

        continue
    
    set_file_path = os.path.join(eeg_folder, set_files[0])
    print("Loading:", set_file_path)
    
    raw = mne.io.read_raw_eeglab(set_file_path, preload=True)
    raw.filter(0.5, 45, fir_design='firwin')
    
    # Create fixed-length epochs (2 seconds duration, 1 second overlap)
    epochs = mne.make_fixed_length_epochs(raw, duration=2.0, overlap=1, preload=True)
   
    psd = epochs.compute_psd(method="welch", fmin=0.5, fmax=45)
    psds, freqs = psd.get_data(return_freqs=True)
    
    band_power = {}
    for band, (fmin, fmax) in freq_bands.items():
        idx = np.logical_and(freqs >= fmin, freqs <= fmax)
        band_power[band] = psds[:, :, idx].mean(axis=-1)
    
    bp_abs = np.stack(list(band_power.values()), axis=-1)
    total_power = bp_abs.sum(axis=-1, keepdims=True)
    rbp_relative = bp_abs / total_power

    
    features = rbp_relative.reshape(rbp_relative.shape[0],
                                    rbp_relative.shape[1],
                                    rbp_relative.shape[2], 1)
    
    all_features.append(features)
    subject_id = os.path.basename(subject_folder)
    num_epochs = features.shape[0]
    all_epoch_subject_ids.extend([subject_id] * num_epochs)

X = np.concatenate(all_features, axis=0)

n_channels, n_freqs, n_times = X.shape[1:]


y = np.array([subject_labels[pid] for pid in all_epoch_subject_ids])

X_reshaped = X.reshape(X.shape[0], -1)

X_train, X_test, y_train, y_test = train_test_split(X_reshaped, y, test_size=0.2, random_state=42)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

X_train = X_train.reshape(X_train.shape[0], n_channels, n_freqs, n_times)
X_test = X_test.reshape(X_test.shape[0], n_channels, n_freqs, n_times)


<h3> Prepare for LSTM </h3>
<p><strong>CNN Output:</strong> (batch_size, reduced_channels, reduced_freqs, filters)</p>

<p><strong>LSTM Input:</strong> (batch_size, timesteps, features)</p>

<p><strong>Permute Dimensions:</strong></p>
<ul>
  <li>The code swaps axes (batch, channels, freqs, filters)  ->  (batch, freqs, channels, filters)</li>
</ul>

<p><strong>Combine Features:</strong></p>
<ul>
  <li>The last two dimensions (channels and filters) are combined into a single "features" dimension for each timestep. </li>
  <li>New shape: (batch, reduced_freqs, reduced_channels × filters)</li>
</ul>


In [None]:

def create_cnn_bilstm_model(input_shape, nb_classes, lstm_units=64, l2_reg=1e-4, dropout_rate=0.5):
    # input_shape: (n_channels, n_freqs, 1)
    n_channels, n_freqs, _ = input_shape
    inputs = Input(shape=input_shape)

    # Block 1
    x = Conv2D(32, (3, 3), padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(l2_reg))(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = MaxPooling2D(pool_size=(2, 1))(x) # Reduces n_channels, keeps n_freqs

    # Block 2
    x = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(l2_reg))(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = MaxPooling2D(pool_size=(2, 2))(x)

    # Preparation for LSTM

    cnn_shape = tf.keras.backend.int_shape(x)

    x = Permute((2, 1, 3))(x) 

    reshape_target = (cnn_shape[2], cnn_shape[1] * cnn_shape[3]) 
    x = Reshape(target_shape=reshape_target)(x)


    #BiLSTM Layer
    # Use return_sequences=False as we want the final output for classification

    x = Bidirectional(LSTM(lstm_units, return_sequences=False, dropout=dropout_rate*0.5, recurrent_dropout=dropout_rate*0.5))(x)

    #Classification 
    x = Dropout(dropout_rate)(x)
    x = Dense(128, activation='relu', kernel_initializer='he_normal', kernel_regularizer=l2(l2_reg))(x)
    x = Dropout(dropout_rate)(x)
    outputs = Dense(nb_classes, activation='softmax')(x)

    model = Model(inputs, outputs)
    return model

In [None]:
input_shape = (X_train.shape[1], X_train.shape[2], X_train.shape[3])  # (channels, freq_bands, 1)

# Get number of classes from training data
nb_classes = len(np.unique(y_train))

# Define hyperparameters
LSTM_UNITS = 64 
L2_REG = 1e-4 
DROPOUT_RATE = 0.4

model = create_cnn_bilstm_model(
    input_shape,
    nb_classes,
    lstm_units=LSTM_UNITS,
    l2_reg=L2_REG,
    dropout_rate=DROPOUT_RATE
)

model.compile(
    optimizer=Adam(learning_rate=0.001),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
class_weight_dict = dict(enumerate(class_weights))

callbacks = [
    EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True, verbose=1),
    ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=7, min_lr=1e-6, verbose=1)
]

# Train model
history = model.fit(
    X_train, y_train,
    validation_data=(X_test, y_test),
    epochs=100,
    batch_size=128,
    class_weight=class_weight_dict,
    callbacks=callbacks,
    verbose=1 
)


test_loss, test_acc = model.evaluate(X_test, y_test, verbose=0)
print(f"Test Accuracy: {test_acc:.4f} ({test_acc:.2%})")
print(f"Test Loss: {test_loss:.4f}")


model_filename = 'alzheimer_eeg_cnn_bilstm_model.h5'
model.save(model_filename)
