<a href="https://colab.research.google.com/github/takeisika/group-89-eeg-depression/blob/main/eeg_pred_model_dev_share.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!unzip /content/drive/MyDrive/EEG_128channels_resting_lanzhou_2015.zip -d /content/data

Archive:  /content/drive/MyDrive/EEG_128channels_resting_lanzhou_2015.zip
   creating: /content/data/EEG_128channels_resting_lanzhou_2015/
  inflating: /content/data/EEG_128channels_resting_lanzhou_2015/02010013rest 20150703 1333..mat  
  inflating: /content/data/EEG_128channels_resting_lanzhou_2015/02010012rest 20150626 1026..mat  
  inflating: /content/data/EEG_128channels_resting_lanzhou_2015/02020022rest 20150707 1452..mat  
  inflating: /content/data/EEG_128channels_resting_lanzhou_2015/Multivariate Pattern Analysis of EEG-Based Functional Connectivity A Study on the Identification of Depression.pdf  
  inflating: /content/data/EEG_128channels_resting_lanzhou_2015/02010019rest 20150716 1440..mat  
  inflating: /content/data/EEG_128channels_resting_lanzhou_2015/02030020_rest 20151230 1416.mat  
  inflating: /content/data/EEG_128channels_resting_lanzhou_2015/02010005rest 20150507 0907..mat  
  inflating: /content/data/EEG_128channels_resting_lanzhou_2015/02010022restnew 20150724 14.

# ❶ Get all .mat files

In [None]:
import os
import glob
FOLDER = "/content/data"
mat_files = glob.glob(os.path.join(FOLDER, "**/*.mat"), recursive=True)
print(f"Found {len(mat_files)} .mat files.")

Found 53 .mat files.


# ❷ Get subj_id & its label

In [None]:
import os
import re
def get_subj_id_and_label(path):
    filename = os.path.basename(path)
    match = re.search(r"(02\d+)", filename)  # 02010002_... or 02020008_... or 02030002_...
    subj_id = match.group(1)
    label = 1 if subj_id.startswith("0201") else 0  # 0201...→depressed (1), others→not depressed(0)
    return subj_id, label

# ❸ Preprocessing

In [None]:
import numpy as np
from scipy import signal
import random

SEED = 42
np.random.seed(SEED)
random.seed(SEED)

SAMPLING_RATE = 250
LOW_CUTOFF = 1.0
HIGH_CUTOFF = 45.0
NOTCH = 50.0
TRIM_SEC = 30

WIN_SEC = 2.0
STEP_SEC = WIN_SEC / 2

def butter_bandpass_filter(data, sampling_rate, low_cutoff, high_cutoff, filter_order=4):
    nyquist_freq = 0.5 * sampling_rate
    low_norm = low_cutoff / nyquist_freq
    high_norm = high_cutoff / nyquist_freq
    b_coeffs, a_coeffs = signal.butter(filter_order, [low_norm, high_norm], btype='band')
    filtered_data = signal.filtfilt(b_coeffs, a_coeffs, data, axis=-1)
    return filtered_data

def notch_filter(data, sampling_rate, notch, quality_factor=30.0):
    nyquist_freq = 0.5 * sampling_rate
    notch_norm = notch / nyquist_freq
    b_coeffs, a_coeffs = signal.iirnotch(w0=notch_norm, Q=quality_factor)
    filtered_data = signal.filtfilt(b_coeffs, a_coeffs, data, axis=-1)
    return filtered_data

def preprocess_data(data, sampling_rate, low_cutoff=LOW_CUTOFF, high_cutoff=HIGH_CUTOFF, notch=NOTCH, trim_sec=TRIM_SEC):
    processed_data = data - data.mean(axis=0, keepdims=True)
    processed_data = butter_bandpass_filter(processed_data, sampling_rate, low_cutoff, high_cutoff)
    processed_data = notch_filter(processed_data, sampling_rate, notch)
    start_sample = int(trim_sec * sampling_rate)
    end_sample = processed_data.shape[1] - int(trim_sec * sampling_rate)
    end_sample = max(end_sample, start_sample + 1)
    return processed_data[:, start_sample:end_sample]

def slide_wins(preprocessed_data, sampling_rate, win_sec=WIN_SEC, step_sec=STEP_SEC):
    win_cnts = int(win_sec * sampling_rate)
    step_cnts = int(step_sec * sampling_rate)

    wins = []
    win_idxs = []
    for start_idx in range(0, preprocessed_data.shape[1] - win_cnts + 1, step_cnts):
        end_idx = start_idx + win_cnts
        wins.append(preprocessed_data[:, start_idx:end_idx])
        win_idxs.append((start_idx, end_idx))

    if wins:
        wins_array = np.stack(wins, axis=0)
    else:
        wins_array = np.empty((0, preprocessed_data.shape[0], win_cnts))

    return wins_array, win_idxs

def discard_noisy_wins(wins_array, z_score_threshold=7.0):
    if len(wins_array) == 0:
        return wins_array
    mean = wins_array.mean(axis=(1, 2), keepdims=True)
    std = wins_array.std(axis=(1, 2), keepdims=True) + 1e-6
    z_scores = (wins_array - mean) / std
    is_not_noisy_win_tf = np.max(np.abs(z_scores), axis=(1, 2)) < z_score_threshold
    return wins_array[is_not_noisy_win_tf]

# ❹ Feature Extraction (Welch's Method)

In [None]:
FREQ_BANDS = [
    (1, 4),    # Delta (δ)
    (4, 8),    # Theta (θ)
    (8, 13),   # Alpha (α)
    (13, 30),  # Beta (β)
    (30, 45)   # Gamma (γ)
]

def welch(wins_array, sampling_rate=SAMPLING_RATE, freq_bands=FREQ_BANDS):
    if len(wins_array) == 0:
        return np.empty((0, 0), dtype=np.float32)

    win_cnts, ch_cnts, sample_cnts = wins_array.shape
    samples_per_seg = 256
    features = []

    for idx in range(win_cnts):
        curr_win = wins_array[idx]
        freqs, psds = signal.welch(curr_win, fs=sampling_rate, nperseg=samples_per_seg, axis=-1)
        integrated_psd_in_all_freqs = np.trapz(psds, freqs, axis=-1) + 1e-12

        curr_win_bands = []
        for (low_freq, high_freq) in freq_bands:
            is_in_this_freq_band_tf = (freqs >= low_freq) & (freqs < high_freq)
            integrated_psd_in_this_freq_band = np.trapz(psds[:, is_in_this_freq_band_tf], freqs[is_in_this_freq_band_tf], axis=-1)
            relative_psd_in_this_freq_band_in_perc = integrated_psd_in_this_freq_band / integrated_psd_in_all_freqs
            curr_win_bands.append(relative_psd_in_this_freq_band_in_perc)

        win_features = np.stack(curr_win_bands, axis=-1)
        features.append(win_features)

    features_array = np.stack(features, axis=0)
    features_array = features_array.reshape(win_cnts, -1).astype(np.float32)

    return features_array

# ❺ Load Target Data from .mat Files

In [None]:
from scipy.io import loadmat
import numpy as np

def load_mat(path, verbose=True):
    mat = loadmat(path, squeeze_me=True, struct_as_record=False)
    keys_wo__ = []
    for keyname in mat.keys():
        if not keyname.startswith("__"):
            keys_wo__.append(keyname)
    target_key = keys_wo__[0]
    target_data = np.asarray(mat[target_key], dtype=np.float32)

    if target_data.shape[0] == 129:
        target_data = target_data[:128, :]

    return target_data

# ❻ Build Dataset for ML

In [None]:
!pip -q uninstall -y tensorflow-decision-forests tensorflow-text tf-keras > /dev/null 2>&1 || true
!pip -q install "tensorflow==2.18.1" numpy scipy scikit-learn h5py > /dev/null 2>&1

In [None]:
import numpy as np

def build_dataset(files):
    X_all, y_all, subj_all = [], [], []
    skipped = []
    for file in sorted(files):
        subj_id, y = get_subj_id_and_label(file)
        data = load_mat(file)
        processed_data = preprocess_data(data, sampling_rate=SAMPLING_RATE)
        wins_array, _ = slide_wins(processed_data, sampling_rate=SAMPLING_RATE)
        wins_array = discard_noisy_wins(wins_array, z_score_threshold=7.0)
        if len(wins_array) == 0:
            skipped.append((file, "no_windows_after_noise_removal"))
            continue
        features_array = welch(wins_array)
        X_all.append(features_array)
        y_all.append(np.full((features_array.shape[0],), y, dtype=np.int64)) # For 500 windows: [1, 1, 1, ..., 1, 1] (array of length 500)
        subj_all.append(np.full((features_array.shape[0],), subj_id, dtype=object)) # For 500 windows: ["02010002", "02010002", ..., "02010002"] (array of length 500)
    return np.concatenate(X_all, 0), np.concatenate(y_all, 0), np.concatenate(subj_all, 0), skipped

X_all, y_all, subj_all, skipped = build_dataset(mat_files)
print("Windows:", X_all.shape, "Proportion of depression data:", y_all.mean(), " Unique subjects:", len(set(subj_all)))
print(f"Skipped files: {len(skipped)}")

# ❼ Split Data by Subject

In [None]:
from sklearn.model_selection import GroupShuffleSplit

trval_test = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=SEED)
trval_idxs, test_idxs = next(trval_test.split(X_all, y_all, groups=subj_all))
X_trval, y_trval, subj_trval = X_all[trval_idxs], y_all[trval_idxs], subj_all[trval_idxs]
X_test, y_test, subj_test = X_all[test_idxs], y_all[test_idxs], subj_all[test_idxs]

tr_val = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=SEED)
tr_idxs, val_idxs = next(tr_val.split(X_trval, y_trval, groups=subj_trval))

X_tr, y_tr = X_trval[tr_idxs], y_trval[tr_idxs]
X_val, y_val = X_trval[val_idxs], y_trval[val_idxs]

print(f"Train: {X_tr.shape}, Val: {X_val.shape}, Test: {X_test.shape}")

# ❽ Standardize Features

In [None]:
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler(with_mean=True, with_std=True)
X_tr  = scaler.fit_transform(X_tr)
X_val = scaler.transform(X_val)
X_test= scaler.transform(X_test)

# ❾ Build MLP

In [None]:
import tensorflow as tf
import numpy as np
from sklearn.utils import class_weight

tf.random.set_seed(SEED)

def build_mlp(input_dim):
    input_layer = tf.keras.Input(shape=(input_dim,), name="input_features")
    x = tf.keras.layers.Dense(256, activation='relu')(input_layer)
    x = tf.keras.layers.Dropout(0.5)(x)
    x = tf.keras.layers.Dense(128, activation='relu')(x)
    x = tf.keras.layers.Dropout(0.25)(x)
    output_layer = tf.keras.layers.Dense(1, activation='sigmoid')(x)

    model = tf.keras.Model(input_layer, output_layer)
    model.compile(optimizer=tf.keras.optimizers.Adam(1e-3), loss='binary_crossentropy', metrics=['accuracy', tf.keras.metrics.AUC(name='auc')])
    return model

model = build_mlp(X_tr.shape[1])
model.summary()

callbacks = [tf.keras.callbacks.EarlyStopping(monitor='val_auc', mode='max', patience=8, restore_best_weights=True), tf.keras.callbacks.ReduceLROnPlateau(monitor='val_auc', mode='max', patience=4, factor=0.5, min_lr=1e-5)]

class_weight = {0: 1.0, 1: 3.0}

model.fit(X_tr, y_tr, validation_data=(X_val, y_val), epochs=60, batch_size=256, callbacks=callbacks, class_weight=class_weight, verbose=1)

# ➓ Subject-level Accuracy

In [None]:
import numpy as np

def acc_by_subj(model, Xs, ys, subj_ids):
    win_probs = model.predict(Xs, batch_size=1024, verbose=0).ravel()
    win_probs_by_subj_dict = {}
    label_by_subj_dict = {}

    for win_prob, y, subj_id in zip(win_probs, ys, subj_ids):
        win_probs_by_subj_dict.setdefault(subj_id, []).append(win_prob)
        label_by_subj_dict[subj_id] = y

    preds_by_subj = []
    label_by_subj = []

    for subj_id in sorted(win_probs_by_subj_dict.keys()):
        label_by_subj.append(label_by_subj_dict[subj_id])
        avg_prob_by_subj = np.mean(win_probs_by_subj_dict[subj_id])
        pred_by_subj = 1 if avg_prob_by_subj >= 0.5 else 0
        preds_by_subj.append(pred_by_subj)

    return (np.array(label_by_subj) == np.array(preds_by_subj)).mean()


val_acc  = acc_by_subj(model, X_val, y_val, subj_trval[val_idxs])
test_acc = acc_by_subj(model, X_test, y_test, subj_test)
print(f"Subject-level Val Acc: {val_acc:.3f} | Test Acc: {test_acc:.3f}")

# Save Model

In [None]:
import joblib
model_save_path = '/content/eeg_depression_model.pkl'
joblib.dump(model, model_save_path)