<a href="https://colab.research.google.com/github/takeisika/group-89-eeg-depression/blob/main/eeg_pred_model_dev_share.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!unzip /content/drive/MyDrive/EEG_128channels_resting_lanzhou_2015.zip -d /content/data

Archive:  /content/drive/MyDrive/EEG_128channels_resting_lanzhou_2015.zip
   creating: /content/data/EEG_128channels_resting_lanzhou_2015/
  inflating: /content/data/EEG_128channels_resting_lanzhou_2015/02010013rest 20150703 1333..mat  
  inflating: /content/data/EEG_128channels_resting_lanzhou_2015/02010012rest 20150626 1026..mat  
  inflating: /content/data/EEG_128channels_resting_lanzhou_2015/02020022rest 20150707 1452..mat  
  inflating: /content/data/EEG_128channels_resting_lanzhou_2015/Multivariate Pattern Analysis of EEG-Based Functional Connectivity A Study on the Identification of Depression.pdf  
  inflating: /content/data/EEG_128channels_resting_lanzhou_2015/02010019rest 20150716 1440..mat  
  inflating: /content/data/EEG_128channels_resting_lanzhou_2015/02030020_rest 20151230 1416.mat  
  inflating: /content/data/EEG_128channels_resting_lanzhou_2015/02010005rest 20150507 0907..mat  
  inflating: /content/data/EEG_128channels_resting_lanzhou_2015/02010022restnew 20150724 14.

# ① Get all .mat files

In [None]:
import os
import glob
FOLDER = "/content/data"
mat_files = glob.glob(os.path.join(FOLDER, "**/*.mat"), recursive=True)
print(f"Found {len(mat_files)} .mat files.")

Found 53 .mat files.


# ② Get subj_id & its label

In [None]:
import os
import re
def get_subj_id_and_label(path):
    filename = os.path.basename(path)
    match = re.search(r"(02\d+)", filename)  # 02010002_... or 02020008_... or 02030002_...
    subj_id = match.group(1)
    label = 1 if subj_id.startswith("0201") else 0  # 0201...→depressed (1), others→not depressed(0)
    return subj_id, label

# ③ Preprocessing

In [None]:
import numpy as np
from scipy import signal
import random

SEED = 42
np.random.seed(SEED)
random.seed(SEED)

SAMPLING_RATE = 250
LOW_CUTOFF = 1.0
HIGH_CUTOFF = 45.0
NOTCH = 50.0
TRIM_SEC = 30

WIN_SEC = 2.0
STEP_SEC = WIN_SEC / 2

def butter_bandpass_filter(data, sampling_rate, low_cutoff, high_cutoff, filter_order=4):
    nyquist_freq = 0.5 * sampling_rate
    low_norm = low_cutoff / nyquist_freq
    high_norm = high_cutoff / nyquist_freq
    b_coeffs, a_coeffs = signal.butter(filter_order, [low_norm, high_norm], btype='band')
    filtered_data = signal.filtfilt(b_coeffs, a_coeffs, data, axis=-1)
    return filtered_data

def notch_filter(data, sampling_rate, notch, quality_factor=30.0):
    nyquist_freq = 0.5 * sampling_rate
    notch_norm = notch / nyquist_freq
    b_coeffs, a_coeffs = signal.iirnotch(w0=notch_norm, Q=quality_factor)
    filtered_data = signal.filtfilt(b_coeffs, a_coeffs, data, axis=-1)
    return filtered_data

def preprocess_data(data, sampling_rate, low_cutoff=LOW_CUTOFF, high_cutoff=HIGH_CUTOFF, notch=NOTCH, trim_sec=TRIM_SEC):
    processed_data = data - data.mean(axis=0, keepdims=True)
    processed_data = butter_bandpass_filter(processed_data, sampling_rate, low_cutoff, high_cutoff)
    processed_data = notch_filter(processed_data, sampling_rate, notch)
    start_sample = int(trim_sec * sampling_rate)
    end_sample = processed_data.shape[1] - int(trim_sec * sampling_rate)
    end_sample = max(end_sample, start_sample + 1)
    return processed_data[:, start_sample:end_sample]

def slide_wins(preprocessed_data, sampling_rate, win_sec=WIN_SEC, step_sec=STEP_SEC):
    win_cnts = int(win_sec * sampling_rate)
    step_cnts = int(step_sec * sampling_rate)

    wins = []
    win_idxs = []
    for start_idx in range(0, preprocessed_data.shape[1] - win_cnts + 1, step_cnts):
        end_idx = start_idx + win_cnts
        wins.append(preprocessed_data[:, start_idx:end_idx])
        win_idxs.append((start_idx, end_idx))

    if wins:
        wins_array = np.stack(wins, axis=0)
    else:
        wins_array = np.empty((0, preprocessed_data.shape[0], win_cnts))

    return wins_array, win_idxs

def discard_noisy_wins(wins_array, z_score_threshold=7.0):
    if len(wins_array) == 0:
        return wins_array
    mean = wins_array.mean(axis=(1, 2), keepdims=True)
    std = wins_array.std(axis=(1, 2), keepdims=True) + 1e-6
    z_scores = (wins_array - mean) / std
    is_not_noisy_win_tf = np.max(np.abs(z_scores), axis=(1, 2)) < z_score_threshold
    return wins_array[is_not_noisy_win_tf]

# ④ Feature Extraction (Welch's Method)

In [None]:
FREQ_BANDS = [
    (1, 4),    # Delta (δ)
    (4, 8),    # Theta (θ)
    (8, 13),   # Alpha (α)
    (13, 30),  # Beta (β)
    (30, 45)   # Gamma (γ)
]

def welch_bandpowers(wins_array, sampling_rate=SAMPLING_RATE, freq_bands=FREQ_BANDS):
    if len(wins_array) == 0:
        return np.empty((0, 0), dtype=np.float32)

    win_cnts, ch_cnts, sample_cnts = wins_array.shape
    samples_per_seg = 256
    features = []

    for idx in range(win_cnts):
        curr_win = wins_array[idx]
        freqs, psds = signal.welch(curr_win, fs=sampling_rate, nperseg=samples_per_seg, axis=-1)
        integrated_psd_in_all_freqs = np.trapezoid(psds, freqs, axis=-1) + 1e-12

        curr_win_bands = []
        for (low_freq, high_freq) in freq_bands:
            is_in_this_freq_band_tf = (freqs >= low_freq) & (freqs < high_freq)
            integrated_psd_in_this_freq_band = np.trapezoid(psds[:, is_in_this_freq_band_tf], freqs[is_in_this_freq_band_tf], axis=-1)
            relative_psd_in_this_freq_band_in_perc = integrated_psd_in_this_freq_band / integrated_psd_in_all_freqs
            curr_win_bands.append(relative_psd_in_this_freq_band_in_perc)

        win_features = np.stack(curr_win_bands, axis=-1)
        features.append(win_features)

    features_array = np.stack(features, axis=0)
    features_array = features_array.reshape(win_cnts, -1).astype(np.float32)

    return features_array

# ⑤ Load Target Data from .mat Files

In [None]:
from scipy.io import loadmat
import numpy as np

def load_mat(path, verbose=True):
    mat = loadmat(path, squeeze_me=True, struct_as_record=False)
    keys_wo__ = []
    for keyname in mat.keys():
        if not keyname.startswith("__"):
            keys_wo__.append(keyname)
    target_key = keys_wo__[0]
    target_data = np.asarray(mat[target_key], dtype=np.float32)

    if target_data.shape[0] == 129:
        target_data = target_data[:128, :]

    return target_data