In [1]:
# -*- coding: utf-8 -*-
# -*- authors : Vincent Roduit -*-
# -*- date : 2025-04-24 -*-
# -*- Last revision: 2025-05-02 by janzgraggen -*-
# -*- python version : 3.10.4 -*-
# -*- Description: Notebook that summarizeses the main results-*-

# <center> EE-452: Network Machine Learning </center>
## <center> Ecole Polytechnique Fédérale de Lausanne </center>
### <center>Graph-based EEG Analysis </center>
---

In [18]:
#import libraries
import pandas as pd
from pathlib import Path
import sys

from torch_geometric.utils import to_networkx
import matplotlib.pyplot as plt
import networkx as nx
from seiz_eeg.dataset import EEGDataset

#import modules
import constants
from transform_func import *
from dataloader import load_data
from utils import *

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
data_path = "../data"

DATA_ROOT = Path(data_path)

clips_tr = pd.read_parquet(DATA_ROOT / "train/segments.parquet")

# You can change the signal_transform, or remove it completely
dataset_tr = EEGDataset(
    clips_tr,
    signals_root=DATA_ROOT / "train",
    signal_transform=None,
    prefetch=True,  # If your compute does not allow it, you can use `prefetch=False`
)

In [43]:
import numpy as np
import pywt
from torch_geometric.data import Data
import torch

# Configuration
FS = 250                # Sampling frequency
WIN_SEC = 2.0           # Window length in seconds
OVERLAP = 0.5           # 50% overlap
WAVELET = 'db4'         # Mother wavelet
DWT_LEVEL = 5           # A5, D5, D4, D3, D2 → 5 bands
EPS = 1e-6              # Small value to avoid division by zero
def safe_corrcoef(data):
    std = np.std(data, axis=1, keepdims=True)
    std[std == 0] = 1e-10  # avoid division by zero
    normed = (data - data.mean(axis=1, keepdims=True)) / std
    return np.dot(normed, normed.T) / data.shape[1]

# 1. Sliding windows
def sliding_windows(arr, win_len, step):
    n_ch, n_total = arr.shape
    for start in range(0, n_total - win_len + 1, step):
        yield arr[:, start:start + win_len]

# 2. Band-wise statistics
def features_one_channel(sig, fs=FS, wavelet=WAVELET, level=DWT_LEVEL):
    # Time-domain
    xavg = np.mean(sig)
    xarv = np.mean(np.abs(sig))
    xstd = np.std(sig)
    xp2p = np.ptp(sig)
    xcross = np.sum(sig[:-1] * sig[1:] < 0) / len(sig)
    F_time = [xavg, xarv, xstd, xp2p, xcross]

    # FFT → PSD
    fft = np.fft.rfft(sig)
    freqs = np.fft.rfftfreq(len(sig), d=1/fs)
    psd = (np.abs(fft) ** 2) / len(sig)

    # Truncate both to 128 points
    psd = psd[:128]
    freqs = freqs[:128]
    band_limits = [(0.5,4), (4,7), (7,15), (15,31), (31,fs//2)]
    F_psdband = [np.mean(psd[(freqs >= fmin) & (freqs < fmax)]) for fmin, fmax in band_limits]

    # CWT
    scales = np.arange(2, 129)
    cwtmat, _ = pywt.cwt(sig, scales, 'morl')  # use 'morl' for continuous wavelet
    energy = np.sum(np.abs(cwtmat) ** 2, axis=1)
    F_cwt = energy
    F_cwtband = []
    F_psd = psd
    for fmin, fmax in band_limits:
        idx = (fs * 0.8125 / scales >= fmin) & (fs * 0.8125 / scales < fmax)
        F_cwtband.append(np.mean(energy[idx]) if np.any(idx) else 0)

    return np.concatenate([F_time, F_psd, F_psdband, F_cwt, F_cwtband])  # 5 + 128 + 5 + 127 + 5 = 270

# 3. Window features
def features_one_window(win):
    return np.stack([features_one_channel(ch) for ch in win], axis=0)  # (channels, 270)

# 4. Pearson edge construction (from full signal or window)
def compute_edge_index(data, threshold=0.8):
    C = data.shape[0]
    corr = safe_corrcoef(data)
    edge_index = []

    for i in range(C):
        for j in range(C):
            if i != j and abs(corr[i, j]) >= threshold:
                edge_index.append([i, j])

    # Add self-loops
    edge_index += [[i, i] for i in range(C)]
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()  # shape: [2, num_edges]
    return edge_index

# 5. Full pipeline: (19, 3000) → list of torch_geometric Data objects (11 windows)
def extract_graph_sequence(data):
    win_len = int(WIN_SEC * FS)      # 500
    step    = int(win_len * (1 - OVERLAP))  # 250

    # Static edge index for the whole signal
    edge_index = compute_edge_index(data)

    graph_list = []
    for win in sliding_windows(data, win_len, step):
        x = features_one_window(win)  # (19, 270)
        x = torch.tensor(x, dtype=torch.float)
        graph = Data(x=x, edge_index=edge_index)
        graph_list.append(graph)

    return graph_list  # len = num timesteps (e.g., 11), each item is a Data object

In [46]:
data = dataset_tr[0][0].T

graph_sequence = extract_graph_sequence(data)
print(f"{len(graph_sequence)} timesteps, shape of x: {graph_sequence[0].x.shape}")

11 timesteps, shape of x: torch.Size([19, 270])


In [47]:
graph_sequence

[Data(x=[19, 270], edge_index=[2, 49]),
 Data(x=[19, 270], edge_index=[2, 49]),
 Data(x=[19, 270], edge_index=[2, 49]),
 Data(x=[19, 270], edge_index=[2, 49]),
 Data(x=[19, 270], edge_index=[2, 49]),
 Data(x=[19, 270], edge_index=[2, 49]),
 Data(x=[19, 270], edge_index=[2, 49]),
 Data(x=[19, 270], edge_index=[2, 49]),
 Data(x=[19, 270], edge_index=[2, 49]),
 Data(x=[19, 270], edge_index=[2, 49]),
 Data(x=[19, 270], edge_index=[2, 49])]