In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import signal
from seiz_eeg.dataset import EEGDataset
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.loader import DataLoader  
from tqdm import tqdm
import torch
from sklearn.model_selection import train_test_split
from torch_geometric.nn import GINEConv
from torch_geometric.utils import to_undirected
from torch_geometric.utils import add_self_loops
from torch.nn import Linear, Sequential, BatchNorm1d, ReLU, Dropout
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GINConv
from torch_geometric.nn import global_mean_pool, global_add_pool, global_max_pool
from sklearn.metrics import f1_score
from tqdm import tqdm
import matplotlib.pyplot as plt
import pywt
from scipy.signal import stft
from scipy.signal import welch
from sklearn.metrics import mutual_info_score
from torch_geometric.utils import dropout_edge
from sklearn.model_selection import GroupShuffleSplit
from torch_geometric.data import Data
from scipy.signal import resample, correlate
from torch_geometric.loader import DataLoader
from torch.utils.data import Dataset, Subset, random_split
from torch_geometric.data import Data



### Preprocessing Methods

In [None]:
bp_filter = signal.butter(4, (0.5, 30), btype="bandpass", output="sos", fs=250)


def time_filtering(x: np.ndarray) -> np.ndarray:
    """Filter signal in the time domain"""
    return signal.sosfiltfilt(bp_filter, x, axis=0).copy()


def fft_filtering(x: np.ndarray) -> np.ndarray:
    """Compute FFT and only keep"""
    x = np.abs(np.fft.fft(x, axis=0))
    x = np.log(np.where(x > 1e-8, x, 1e-8))

    win_len = x.shape[0]
    # Only frequencies b/w 0.5 and 30Hz
    return x[int(0.5 * win_len // 250) : 30 * win_len // 250]


def stft_transform(x: np.ndarray, fs=250, nperseg=128, noverlap=64) -> np.ndarray:
    '''Divides the signal into overlapping windows and applies FFT to each segment'''
    features = []
    for ch in x.T:
        f, t, Zxx = signal.stft(ch, fs=fs, nperseg=nperseg, noverlap=noverlap)
        power = np.abs(Zxx) ** 2
        log_power = np.log1p(np.mean(power, axis=1))  # Mean over time
        features.append(log_power)
    return np.stack(features) 

def wavelet_energy(x, wavelet='db4', level=4):
    '''Extracts features from multichannel EEG data using the Discrete Wavelet Transform'''
    result = []
    for ch in x.T:
        coeffs = pywt.wavedec(ch, wavelet, level=level)
        energies = [np.log1p(np.sum(np.square(c))) for c in coeffs]
        result.append(energies)
    return np.stack(result)

def bandpower(x, fs=250):
    '''Estimates the signal power within standard EEG frequency bands (e.g., delta, theta, alpha, beta)'''
    bands = [(0.5, 4), (4, 8), (8, 12), (12, 30)]
    result = []
    for ch in x.T:
        f, Pxx = welch(ch, fs=fs, nperseg=256)
        bandpowers = [np.log1p(np.trapz(Pxx[(f >= low) & (f < high)], f[(f >= low) & (f < high)])) for low, high in bands]
        result.append(bandpowers)
    return np.stack(result)

def combined_transform(x):
    """Concatenate wavelet + bandpower + STFT per channel"""
    x = signal.sosfiltfilt(bp_filter, x, axis=0)  

    wvlt = wavelet_energy(x)    
    bp = bandpower(x)           
    stft = stft_transform(x)    

    assert wvlt.shape[0] == bp.shape[0] == stft.shape[0]

    out = np.concatenate([wvlt, bp, stft], axis=1)  # along features
    return out  

def normalize_features(feat: np.ndarray, axis=0, eps=1e-8) -> np.ndarray:
    """Z-score normalization per channel or feature"""
    mean = feat.mean(axis=axis, keepdims=True)
    std = feat.std(axis=axis, keepdims=True) + eps
    return (feat - mean) / std

def normalized_combined_transform(x):
    """Concatenate wavelet + bandpower + STFT per channel"""
    x = signal.sosfiltfilt(bp_filter, x, axis=0) 

    # Extract features
    wvlt = wavelet_energy(x)       
    bp   = bandpower(x)            
    stft = stft_transform(x)       

    # Normalize each separately
    wvlt = normalize_features(wvlt, axis=1)  
    bp   = normalize_features(bp, axis=1)
    stft = normalize_features(stft, axis=1)

    # Combine features along feature axis
    out = np.concatenate([wvlt, bp, stft], axis=1)
    return out


chosen_transform = combined_transform

### Train Validation Split based on Patient ID

In [None]:
clips_tr = pd.read_parquet("train/segments.parquet")

def extract_patient_id(path):
    fname = path.split("/")[-1].split(".")[0]
    return fname.split("_")[0]

clips_tr["patient_id"] = clips_tr["signals_path"].apply(extract_patient_id)

gss = GroupShuffleSplit(n_splits=1, test_size=0.1, random_state=42)
train_idx, val_idx = next(
    gss.split(clips_tr, clips_tr["label"], clips_tr["patient_id"])
)

train_clips = clips_tr.iloc[train_idx].reset_index(drop=True)
val_clips   = clips_tr.iloc[val_idx].reset_index(drop=True)

print("Patients total:", clips_tr["patient_id"].nunique())
print("Train patients:", train_clips["patient_id"].nunique())
print("Val   patients:", val_clips["patient_id"].nunique())

dataset_tr = EEGDataset(
    train_clips,
    signals_root= "train",
    signal_transform=chosen_transform,
    prefetch=True,  # If your compute does not allow it, you can use prefetch=False
)

dataset_val = EEGDataset(
    val_clips,
    signals_root= "train",
    signal_transform=chosen_transform,
    prefetch=True,  # If your compute does not allow it, you can use prefetch=False
)

### Train Validation Split based on Segment

In [None]:
# You can change the signal_transform, or remove it completely
dataset_tr = EEGDataset(
    clips_tr,
    signals_root="train",
    signal_transform=chosen_transform,
    prefetch=False,  # If your compute does not allow it, you can use `prefetch=False`
)

### Graph Construction Methods
a) Distance Based Graph, edge weights are computed by applying a thresholded Gaussian kernel

In [None]:
from torch_geometric.data import Data
import pandas as pd
import numpy as np
import pandas as pd
import torch
import pandas as pd
import torch

dist_df = pd.read_csv(
    "distances_3d.csv"
)

electrodes = sorted(set(dist_df['from']) | set(dist_df['to']))
node2idx = {node: i for i, node in enumerate(electrodes)}

edges = []
weights = []

sigma = dist_df['distance'].std() 
kappa = 0.9

for f, t, d in zip(dist_df['from'], dist_df['to'], dist_df['distance']):
    if d < kappa and d != 0.0:
        u, v = node2idx[f], node2idx[t]
        weight = np.exp(-np.square(d) / (sigma ** 2))
        
        edges.append([u, v])
        edges.append([v, u])
        weights.append(weight)
        weights.append(weight)

edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
edge_attr = torch.tensor(weights, dtype=torch.float)
edge_index, edge_attr = to_undirected(edge_index, edge_attr)
edge_index, edge_attr = add_self_loops(edge_index, edge_attr=edge_attr,
                                       fill_value=1.0)

print(f"Edge index shape: {edge_index.shape}")      
print("First 5 edges:\n", edge_index[:, :5])
print("First 5 edge weights:\n", edge_attr[:5])

for x_np, y_np in tqdm(dataset_tr, desc="Converting dataset"):
    print("Original x_np shape:", x_np.shape)  # likely [354, 19]
    x = torch.tensor(x_np.T, dtype=torch.float)  # we expect [19, 354]
    print("Transposed x shape:", x.shape)
    print("y shape:", torch.tensor([y_np], dtype=torch.float).shape)
    ...
    break  

new_dataset = []
for x_np, y_np in tqdm(dataset_tr, desc="Converting dataset"):
    x = torch.tensor(x_np.T, dtype=torch.float)  
    y = torch.tensor([y_np], dtype=torch.float)
    graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr.unsqueeze(-1), y=y)
    new_dataset.append(graph)

dataset_tr_dist = new_dataset

new_dataset = []
for x_np, y_np in tqdm(dataset_val, desc="Converting dataset"):
    x = torch.tensor(x_np.T, dtype=torch.float)  
    y = torch.tensor([y_np], dtype=torch.float)
    graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr.unsqueeze(-1), y=y)
    new_dataset.append(graph)

dataset_val_dist = new_dataset

loader_tr = DataLoader(dataset_tr_dist, batch_size=2**8, shuffle=True)
loader_val = DataLoader(dataset_val_dist, batch_size=2**8, shuffle=False)

### Graph Construction Methods
b) Distance Based Graph, edge weights are computed by 1/d

In [None]:
dist_df = pd.read_csv(
    "distances_3d.csv"
)

electrodes = sorted(set(dist_df['from']) | set(dist_df['to']))
node2idx = {node: i for i, node in enumerate(electrodes)}

# Step 3: Build edge_index and edge_attr
edges = []
weights = []

for f, t, d in zip(dist_df['from'], dist_df['to'], dist_df['distance']):
    if d > 0:
        u, v = node2idx[f], node2idx[t]
        weight = 1.0 / d
        # Add both directions for undirected graph
        edges.append([u, v])
        edges.append([v, u])
        weights.append(weight)
        weights.append(weight)

edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
edge_attr = torch.tensor(weights, dtype=torch.float)

print(f"Edge index shape: {edge_index.shape}")       # Should be [2, num_edges * 2]
print("First 5 edges:\n", edge_index[:, :5])
print("First 5 edge weights:\n", edge_attr[:5])

### Graph Construction Methods
c) Distance Based Graph, based on k-nearest neighbor

In [None]:
dist_df = pd.read_csv(
    "distances_3d.csv"
)

electrodes = sorted(set(dist_df['from']) | set(dist_df['to']))
node2idx = {node: i for i, node in enumerate(electrodes)}

k = 5
edges_knn = []

for node in dist_df['from'].unique():
    sub = dist_df[dist_df['from'] == node].sort_values(by='distance').head(k)
    for _, row in sub.iterrows():
        u, v = node2idx[row['from']], node2idx[row['to']]
        edges_knn.append([u, v])
        edges_knn.append([v, u])  # undirected

edge_index_knn = torch.tensor(edges, dtype=torch.long).t().contiguous()

### Graph Construction Methods
d) Distance Based Graph, based on mutual information

In [None]:
def compute_mi_matrix(X, bins=16):
    """
    Compute pairwise mutual information between EEG channels.
    X: shape [channels, time]
    """
    n = X.shape[0]
    mi_matrix = np.zeros((n, n))

    # Discretize signals per channel
    X_disc = np.floor((X - X.min(axis=1, keepdims=True)) /
                      (X.max(axis=1, keepdims=True) - X.min(axis=1, keepdims=True) + 1e-8) * (bins - 1)).astype(int)

    for i in range(n):
        for j in range(i, n):
            mi = mutual_info_score(X_disc[i], X_disc[j])
            mi_matrix[i, j] = mi
            mi_matrix[j, i] = mi  # symmetric

    # Normalize the MI matrix
    for i in range(n):
        for j in range(n):
            mi_matrix[i, j] /= min(mi_matrix[i, i], mi_matrix[j, j]) + 1e-8

    return mi_matrix

def average_mi_over_dataset(dataset, n_segments=100):
    x0, _ = dataset[0]
    n_channels = x0.shape[0]
    mi_sum = np.zeros((n_channels, n_channels))

    for i in range(min(n_segments, len(dataset))):
        x_np, _ = dataset[i]
        mi_sum += compute_mi_matrix(x_np)

    return mi_sum / n_segments


def mi_to_graph(mi_matrix, k=4):
    edge_list = []
    edge_weights = []

    for i in range(mi_matrix.shape[0]):
        top_k = np.argsort(mi_matrix[i])[-k:]
        for j in top_k:
            if i != j:
                edge_list.append([i, j])
                edge_list.append([j, i])
                edge_weights.append(mi_matrix[i, j])
                edge_weights.append(mi_matrix[i, j])  # symmetric

    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_weights, dtype=torch.float)
    return edge_index, edge_attr

# Compute MI matrix
mi_matrix = average_mi_over_dataset(dataset_tr, n_segments=100)

# Build graph with top-8 MI neighbors per node
edge_index, edge_attr = mi_to_graph(mi_matrix, k=19)

# Final check
print("Edge index shape:", edge_index.shape)
print("Edge attr shape:", edge_attr.shape)

### Graph Construction Methods
e) Correlation Based Graph
Based on the work in [here](https://github.com/tsy935/eeg-gnn-ssl), constructing graph by calculating the normalized cross-correlation between the EEG channels, instead of constructing inplace, since we have less data, we constructed the graph at the start.

In [None]:
def comp_xcorr(x, y, mode="valid", normalize=True):
    """
    Compute cross-correlation between 2 1D signals x, y
    Args:
        x: 1D array
        y: 1D array
        mode: 'valid', 'full' or 'same',
            refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.correlate.html
        normalize: If True, will normalize cross-correlation
    Returns:
        xcorr: cross-correlation of x and y
    """
    xcorr = correlate(x, y, mode=mode)
    # the below normalization code refers to matlab xcorr function
    cxx0 = np.sum(np.absolute(x) ** 2)
    cyy0 = np.sum(np.absolute(y) ** 2)
    if normalize and (cxx0 != 0) and (cyy0 != 0):
        scale = (cxx0 * cyy0) ** 0.5
        xcorr /= scale
    return xcorr
    
def keep_topk(adj_mat, top_k=3, directed=True):
    """ "
    Helper function to sparsen the adjacency matrix by keeping top-k neighbors
    for each node.
    Args:
        adj_mat: adjacency matrix, shape (num_nodes, num_nodes)
        top_k: int
        directed: whether or not a directed graph
    Returns:
        adj_mat: sparse adjacency matrix, directed graph
    """
    # Set values that are not of top-k neighbors to 0:
    adj_mat_noSelfEdge = adj_mat.copy()
    for i in range(adj_mat_noSelfEdge.shape[0]):
        adj_mat_noSelfEdge[i, i] = 0

    top_k_idx = (-adj_mat_noSelfEdge).argsort(axis=-1)[:, :top_k]

    mask = np.eye(adj_mat.shape[0], dtype=bool)
    for i in range(0, top_k_idx.shape[0]):
        for j in range(0, top_k_idx.shape[1]):
            mask[i, top_k_idx[i, j]] = 1
            if not directed:
                mask[top_k_idx[i, j], i] = 1  # symmetric

    adj_mat = mask * adj_mat
    return adj_mat

def ccor_edge_attr(dataset_tr, with_label):

    num_sensors = 19
    new_dataset = []
    labels = []

    for x in dataset_tr:

        if with_label:
            clip, label = x
            clip = clip.T
            labels.append(label)
        else:
            clip, idx = x
            clip = clip.T
            labels.append(idx)
            
        if clip.ndim == 2:
            clip = clip[:, :, None]
        adj_mat = np.eye(num_sensors, num_sensors, dtype=np.float32)
        
        clip = np.transpose(clip, (1, 0, 2))
        
        clip = clip.reshape((num_sensors, -1))
        # print(clip.shape)
        for i in range(0, num_sensors):
            for j in range(i + 1, num_sensors):
                xcorr = comp_xcorr(
                    clip[i, :], clip[j, :], mode='valid', normalize=True)
                adj_mat[i, j] = xcorr
                adj_mat[j, i] = xcorr
        adj_mat = abs(adj_mat)
        W = keep_topk(adj_mat, top_k=3, directed=True)
        src, dst, w = [], [], []
        for i in range(W.shape[0]):
            for j in range(W.shape[1]):
                if W[i, j] != 0:
                    src.append(i); dst.append(j); w.append(W[i, j])
    
        edge_index = torch.tensor([src, dst], dtype=torch.long)
        edge_attr  = torch.tensor(w, dtype=torch.float).unsqueeze(-1)
        data = torch.tensor(clip, dtype=torch.float)
        if with_label:
            y = torch.tensor([label], dtype=torch.float)
            graph = Data(x=data, edge_index=edge_index, edge_attr=edge_attr, y=y)
        else:
            graph = Data(x=data, edge_index=edge_index, edge_attr=edge_attr, idx=idx)
        new_dataset.append(graph)
    
    return new_dataset, labels

dataset_tr_ccor, labels = ccor_edge_attr(dataset_tr, True) 
dataset_val_ccor, labels = ccor_edge_attr(dataset_val, True) 

loader_tr = DataLoader(dataset_tr_ccor, batch_size=2**8, shuffle=True)
loader_val = DataLoader(dataset_val_ccor, batch_size=2**8, shuffle=False)

### Graph Construction Methods
f) Undirected graph based on 10-20

In [None]:
edges = [
    (0, 10), (0, 2), (0, 16), (0, 1),       # FP1 → F7, F3, FZ, FP2
    (1, 0), (1, 16), (1, 3), (1, 11),       # FP2 → FP1, FZ, F4, F8
    (2, 0), (2, 10), (2, 4), (2, 16),       # F3 → FP1, F7, C3, FZ
    (3, 1), (3, 16), (3, 5), (3, 11),       # F4 → FP2, FZ, C4, F8
    (4, 2), (4, 12), (4, 6), (4, 17),       # C3 → F3, T3, P3, CZ
    (5, 3), (5, 17), (5, 7), (5, 13),       # C4 → F4, CZ, P4, T4
    (6, 4), (6, 14), (6, 8), (6, 18),       # P3 → C3, T5, O1, PZ
    (7, 5), (7, 18), (7, 9), (7, 15),       # P4 → C4, PZ, O2, T6
    (8, 14), (8, 6), (8, 18), (8, 9),       # O1 → T5, P3, PZ, O2
    (9, 8), (9, 18), (9, 7), (9, 15),       # O2 → O1, PZ, P4, T6
    (10, 0), (10, 2), (10, 12),             # F7 → FP1, F3, T3
    (11, 1), (11, 3), (11, 13),             # F8 → FP2, F4, T4
    (12, 10), (12, 4), (12, 14),            # T3 → F7, C3, T5
    (13, 11), (13, 5), (13, 15),            # T4 → F8, C4, T6
    (14, 12), (14, 6), (14, 8),             # T5 → T3, P3, O1
    (15, 13), (15, 7), (15, 9),             # T6 → T4, P4, O2
    (16, 0), (16, 1), (16, 2), (16, 3), (16, 17),  # FZ → FP1, FP2, F3, F4, CZ
    (17, 4), (17, 5), (17, 16), (17, 18),   # CZ → C3, C4, FZ, PZ
    (18, 6), (18, 7), (18, 8), (18, 9), (18, 17)  # PZ → P3, P4, O1, O2, CZ
]

class EEGGraphDataset(torch.utils.data.Dataset):
    def __init__(self, eeg_dataset, edge_index, use_supernode=False, return_id=False):
        self.eeg_dataset = eeg_dataset
        self.edge_index = edge_index
        self.use_supernode = use_supernode
        self.return_id = return_id

    def __len__(self):
        return len(self.eeg_dataset)

    def __getitem__(self, idx):
        if self.return_id:
            x_np, sample_id = self.eeg_dataset[idx]
        else:
            x_np, y = self.eeg_dataset[idx]

        x = torch.tensor(x_np, dtype=torch.float)

        if self.use_supernode:
            x_super = x.mean(dim=0, keepdim=True)
            x = torch.cat([x, x_super], dim=0)
            supernode_edges = [(i, 19) for i in range(19)]
            full_edges = self.edge_index + supernode_edges
            edge_index = torch.tensor(full_edges, dtype=torch.long).T
        else:
            edge_index = torch.tensor(self.edge_index, dtype=torch.long).T

        data = Data(x=x, edge_index=edge_index)

        if self.return_id:
            return data, sample_id
        else:
            y = torch.tensor([y], dtype=torch.long)
            data.y = y
            return data


edge_index = torch.tensor(edges, dtype=torch.long).T
graph_dataset = EEGGraphDataset(dataset_tr, edges, use_supernode=False)

# Create Graph

In [None]:
def create_graph_data(x_batch, y_batch, edge_index, edge_attr):
    graphs = []

    for x, y in zip(x_batch, y_batch):
        x_tensor = torch.tensor(x, dtype=torch.float)        # [num_nodes, num_features]
        y_tensor = torch.tensor([y], dtype=torch.float)       # [1] or [1,1] for BCE/MSE
        graph = Data(x=x_tensor, edge_index=edge_index, edge_attr=edge_attr, y=y_tensor)
        graphs.append(graph)

    return graphs

### MODELS
a) GIN
- Architecture with two GIN layers followed by concatenating add, mean and max pooling and a classification head.
- To include edge features into learning GINEConv is used instead of GINConv.
- Added random dropout of edges in training mode to help model generalize better.

In [None]:
class GIN_with_edge(torch.nn.Module):
    def __init__(self, input_dim=19, hidden_dim=64, dropout=0.2, edge_dim = 1):
        super().__init__()
        
        self.conv1 = GINEConv(
            Sequential(
                Linear(input_dim,  hidden_dim),
                BatchNorm1d(hidden_dim), ReLU(),
                #Dropout(0.2),
                Linear(hidden_dim, hidden_dim), ReLU()
            ),
            edge_dim=edge_dim,
        )
        
        self.conv2 = GINEConv(
            Sequential(
                Linear(hidden_dim, hidden_dim),
                BatchNorm1d(hidden_dim), ReLU(),
                # Dropout(0.2),
                Linear(hidden_dim, hidden_dim), ReLU()
            ),
            edge_dim=edge_dim,
        )

        self.lin1 = Linear(hidden_dim*3, hidden_dim)
        self.lin2 = Linear(hidden_dim, 1)
        self.dropout = dropout

    def forward(self, x, edge_index, edge_attr, batch):

        # Randomly drop edges
        if self.training:       
            E = edge_index.size(1)
            p = 0.2
            mask = torch.rand(E, device=edge_index.device) >= p
            edge_index = edge_index[:, mask]
            edge_attr = edge_attr[mask]
            
        h1 = self.conv1(x, edge_index, edge_attr)
        h2 = self.conv2(h1, edge_index, edge_attr)

        h_sum  = global_add_pool(h2, batch)
        h_mean = global_mean_pool(h2, batch)
        h_max  = global_max_pool(h2, batch)
        h = torch.cat([h_sum, h_mean, h_max], dim=1)
        
        h = self.lin1(h).relu()
        h = F.dropout(h, p=0.2, training=self.training)
        h = self.lin2(h).squeeze(1)

        return h

### MODELS
b) GCN Graph Convolutional Network

In [None]:
class EEG_GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        self.classifier = nn.Linear(out_channels, 1)


    def forward(self, x, edge_index, edge_weight, batch):
        x = self.conv1(x, edge_index, edge_weight=edge_weight)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_weight=edge_weight)
        x = global_mean_pool(x, batch)  # Graph-level representation
        x = self.classifier(x)
        return x  # [batch_size, 1]

### MODELS
c) GCN v2 Graph Convolutional Network

In [None]:
class EEG_GCN_v2(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.bn1 = BatchNorm(hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        self.bn2 = BatchNorm(out_channels)
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(out_channels, 1)

    def forward(self, x, edge_index, edge_weight, batch):
        x = self.conv1(x, edge_index, edge_weight=edge_weight)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)

        x = self.conv2(x, edge_index, edge_weight=edge_weight)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.dropout(x)

        x = global_mean_pool(x, batch)
        return self.classifier(x)

### MODELS
d) Graph SAGE SAmple and aggreGatE

In [None]:
class EEG_SAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)
        self.classifier = nn.Linear(out_channels, 1)


    def forward(self, x, edge_index, edge_weight, batch):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)  # Graph-level representation
        x = self.classifier(x)
        return x  # [batch_size, 1]

### MODELS
e) Graph SAGE v2 SAmple and aggreGatE

In [None]:
class EEG_SAGE_v2(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.bn1 = BatchNorm(hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)
        self.bn2 = BatchNorm(out_channels)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(out_channels, 1)

    def forward(self, x, edge_index, edge_weight, batch):
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)

        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.dropout(x)

        x = global_mean_pool(x, batch)
        return self.classifier(x)

### MODELS
f) STGNN Spatiotemporal Graph Neural Network

In [None]:
class STGNN_EEG(nn.Module):
    def __init__(self, time_steps, temporal_out=32, gcn_hidden=64):
        super().__init__()

        # Temporal encoder: Conv1D over [T] dimension of each node
        self.temporal = nn.Sequential(
            nn.Conv1d(1, 16, kernel_size=7, stride=2, padding=3),  # [B*19, 1, T] → [B*19, 16, T']
            nn.ReLU(),
            nn.Conv1d(16, temporal_out, kernel_size=5, stride=2, padding=2),  # [B*19, 32, T'']
            nn.ReLU()
        )

        self.gcn1 = GCNConv(temporal_out, gcn_hidden)
        self.gcn2 = GCNConv(gcn_hidden, gcn_hidden)

        self.classifier = nn.Sequential(
            nn.Linear(2 * gcn_hidden, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1)  # Binary classification
        )

    def forward(self, data):
        x, edge_index, edge_weight, batch = data.x, data.edge_index, data.edge_attr, data.batch
        # Input x shape: [B*19, T] → reshape to [B*19, 1, T]
        x = x.unsqueeze(1)
        x = self.temporal(x)  # [B*19, C, T']
        x = x.mean(dim=2)     # Temporal average → [B*19, temporal_out]

        x = self.gcn1(x, edge_index, edge_weight=edge_weight)
        x = F.relu(x)
        x = self.gcn2(x, edge_index, edge_weight=edge_weight)
        x = F.relu(x)

        # Spatio-temporal pooling
        x_mean = global_mean_pool(x, batch)
        x_max = global_max_pool(x, batch)
        x = torch.cat([x_mean, x_max], dim=1)

        return self.classifier(x)  # [B, 1]

### MODELS
g) ChebNet Chebyshev Graph Convolution

In [None]:
class EEG_ChebNet(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, K=3):
        super().__init__()
        self.conv1 = ChebConv(in_channels, hidden_channels, K=K)
        self.conv2 = ChebConv(hidden_channels, hidden_channels, K=K)
        self.classifier = nn.Linear(hidden_channels, out_channels)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x, edge_index, edge_weight, batch):
        x = self.conv1(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.dropout(x)

        x = self.conv2(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.dropout(x)

        x = global_mean_pool(x, batch)
        return self.classifier(x)

### MODELS
h) ChebNet Attention Extended version with a global attention mechanism for pooling

In [None]:
class EEG_ChebNet_Attn(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, K=3):
        super().__init__()
        self.conv1 = ChebConv(in_channels, hidden_channels, K=K)
        self.conv2 = ChebConv(hidden_channels, hidden_channels, K=K)

        # Global attention pooling with learnable gate
        self.attn_pool = GlobalAttention(gate_nn=nn.Sequential(
            nn.Linear(hidden_channels, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        ))

        self.classifier = nn.Linear(hidden_channels, out_channels)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x, edge_index, edge_weight, batch):
        x = self.conv1(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.dropout(x)

        x = self.conv2(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.dropout(x)

        x = self.attn_pool(x, batch)
        return self.classifier(x)


### MODELS
i) ChebNet Residual Extended version having residual connections

In [None]:
class EEG_ChebNet_Res(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, K=3):
        super().__init__()
        self.conv1 = ChebConv(in_channels, hidden_channels, K=K)
        self.bn1 = nn.BatchNorm1d(hidden_channels)
        self.conv2 = ChebConv(hidden_channels, hidden_channels, K=K)
        self.bn2 = nn.BatchNorm1d(hidden_channels)
        self.classifier = nn.Linear(hidden_channels, out_channels)
        self.dropout = nn.Dropout(0.3)

        # Project input to match hidden_channels for residual path (if needed)
        self.res_proj = nn.Linear(in_channels, hidden_channels) if in_channels != hidden_channels else nn.Identity()

    def forward(self, x, edge_index, edge_weight, batch):
        res = self.res_proj(x)  # Project for residual
        x = self.conv1(x, edge_index, edge_weight)
        x = self.bn1(x)
        x = F.relu(x + res)  # Residual connection
        x = self.dropout(x)

        res2 = x  # second residual
        x = self.conv2(x, edge_index, edge_weight)
        x = self.bn2(x)
        x = F.relu(x + res2)  # Second residual
        x = self.dropout(x)

        x = global_mean_pool(x, batch)
        return self.classifier(x)

### MODELS
j) ChebNet Res Attn

In [None]:
class EEG_ChebNet_ResAttn(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, K=3):
        super().__init__()
        self.conv1 = ChebConv(in_channels, hidden_channels, K=K)
        self.bn1 = BatchNorm(hidden_channels)
        self.res1 = nn.Linear(in_channels, hidden_channels) if in_channels != hidden_channels else nn.Identity()

        self.conv2 = ChebConv(hidden_channels, hidden_channels, K=K)
        self.bn2 = BatchNorm(hidden_channels)

        self.attn_pool = GlobalAttention(gate_nn=nn.Sequential(
            nn.Linear(hidden_channels, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        ))

        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index, edge_weight, batch):
        # First residual block
        res1 = self.res1(x)
        x = self.conv1(x, edge_index, edge_weight)
        x = self.bn1(x)
        x = F.relu(x + res1)
        x = self.dropout(x)

        # Second residual block
        res2 = x
        x = self.conv2(x, edge_index, edge_weight)
        x = self.bn2(x)
        x = F.relu(x + res2)
        x = self.dropout(x)

        # Attention-based global pooling
        x = self.attn_pool(x, batch)
        return self.classifier(x)

### MODELS
k) GAT

In [None]:
class EEGGAT(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim=1, heads=(4, 4), dropout=0.2, use_concat=True):
        super().__init__()
        self.use_concat = use_concat

        self.gat1 = GATConv(in_dim, hidden_dim, heads=heads[0], concat=use_concat, dropout=dropout)
        gat1_out_dim = hidden_dim * heads[0] 

        self.gat2 = GATConv(gat1_out_dim, hidden_dim, heads=heads[1], concat=use_concat, dropout=dropout)
        gat2_out_dim = hidden_dim * heads[1] 

        self.head = torch.nn.Linear(gat2_out_dim, hidden_dim)
        self.classifier = torch.nn.Linear(hidden_dim, out_dim)

    def forward(self, x, edge_index, batch):
        x = self.gat1(x, edge_index)
        x = F.elu(x)

        x = self.gat2(x, edge_index)
        x = F.elu(x)

        x = global_mean_pool(x, batch)

        x = F.relu(self.head(x))
        x = F.dropout(x, p=0.3, training=self.training)

        out = self.classifier(x)      # Graph-level output
        return out.squeeze(1)           # Shape: (batch_size,)
    

### MODELS
l) GAT with supernode

In [None]:
class EEGGAT_superpool(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim=1, heads=(4, 4), dropout=0.2, use_concat=True):
        super().__init__()
        self.use_concat = use_concat

        self.gat1 = GATConv(in_dim, hidden_dim, heads=heads[0], concat=use_concat, dropout=dropout)
        gat1_out_dim = hidden_dim * heads[0] 

        self.gat2 = GATConv(gat1_out_dim, hidden_dim, heads=heads[1], concat=use_concat, dropout=dropout)
        gat2_out_dim = hidden_dim * heads[1] 

        self.head = torch.nn.Linear(gat2_out_dim, hidden_dim)
        self.classifier = torch.nn.Linear(hidden_dim, out_dim)

    def forward(self, x, edge_index, batch):
        x = self.gat1(x, edge_index)
        x = F.elu(x)

        x = self.gat2(x, edge_index)
        x = F.elu(x)

        # Get supernode from each graph (assumed at index 19)
        num_graphs = batch.max().item() + 1
        supernode_indices = (torch.arange(num_graphs, device=x.device) * 20) + 19
        x_super = x[supernode_indices]

        x_super = F.relu(self.head(x_super))
        x_super = F.dropout(x_super, p=0.3, training=self.training)

        out = self.classifier(x_super)
        
        return out.squeeze(1)           # Shape: (batch_size,)