# PhysioNet EEG: Train All Baseline Models

This notebook trains 7 different models on the PhysioNet Motor Imagery dataset:
1. FBCSP (Filter Bank Common Spatial Patterns)
2. CNN-SAE (CNN with Spatial Attention)
3. EEGNet
4. ACS-SE-CNN (Adaptive Channel Selection SE-CNN)
5. G-CARM (Graph Channel Active Reasoning Module)
6. Baseline EEG-ARNN (without adaptive gating)
7. Adaptive Gating EEG-ARNN (our proposed method)

**Expected Runtime**: 10-12 hours on Kaggle GPU

**Input**: `/kaggle/input/physionet-preprocessed/derived/` (preprocessed EEG data)

**Output**: 
- Trained models in `models/` folder
- Results CSV files in `results/` folder

## Configuration

In [1]:
import os
import math
from copy import deepcopy

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import mne
from sklearn.model_selection import StratifiedKFold
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
import pickle
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

mne.set_log_level('ERROR')


In [2]:
# Configuration
CONFIG = {
    'data_path': '/kaggle/input/eeg-preprocessed-data/derived',  # Change this for local testing
    'output_dir': './',
    'results_dir': './results',
    'models_dir': './models',
    'figures_dir': './figures',

    'n_folds': 3,
    'random_seed': 42,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',

    # Training hyperparameters
    'batch_size': 64,
    'epochs': 20,
    'learning_rate': 0.001,
    'weight_decay': 1e-4,
    'patience': 10,
    'scheduler_patience': 3,
    'use_early_stopping': False,

    # Data parameters
    'n_channels': 64,
    'n_classes': 2,
    'sfreq': 128,
    'tmin': 0.0,
    'tmax': 4.0,
    'n_timepoints': 513,  # 4 seconds at 128 Hz + 1
    'hidden_dim': 128,
    'mi_runs': [7, 8, 11, 12],

    # FBCSP parameters
    'fbcsp_bands': [(4, 8), (8, 12), (12, 16), (16, 20), (20, 24), (24, 28), (28, 32), (32, 36), (36, 40)],
    'fbcsp_n_components': 4,

    # Gating regularization
    'gating': {
        'gate_init': 0.9,
        'l1_lambda': 1e-3,
    },
}

# Create output directories
os.makedirs(CONFIG['results_dir'], exist_ok=True)
os.makedirs(CONFIG['models_dir'], exist_ok=True)
os.makedirs(CONFIG['figures_dir'], exist_ok=True)

# Set random seeds
np.random.seed(CONFIG['random_seed'])
torch.manual_seed(CONFIG['random_seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed(CONFIG['random_seed'])

print(f"Device: {CONFIG['device']}")
print(f"Data path: {CONFIG['data_path']}")


Device: cuda
Data path: /kaggle/input/eeg-preprocessed-data/derived


## Data Loading Utilities

In [3]:
def load_physionet_data(data_path, subject_ids=None):
    """
    Load preprocessed PhysioNet motor imagery data from the derived folder.
    Supports both the newer folder structure (derived/preprocessed/S***/S***R**_preproc_raw.fif)
    and the legacy flat directory containing epoch files.
    """
    data_root = os.path.abspath(data_path)
    if not os.path.isdir(data_root):
        raise FileNotFoundError(f"Data path not found: {data_root}")

    config = globals().get('CONFIG', {})
    tmin = float(config.get('tmin', 0.0))
    tmax = float(config.get('tmax', 4.0))
    mi_runs = [int(r) for r in config.get('mi_runs', [7, 8, 11, 12])]
    event_id = {'T1': 1, 'T2': 2}

    def normalize_subject(value):
        if value is None:
            return None
        if isinstance(value, str) and value.upper().startswith('S'):
            value = value[1:]
        try:
            return int(value)
        except Exception:
            return None

    subject_filter = None
    if subject_ids is not None:
        subject_filter = set()
        for sid in subject_ids:
            norm = normalize_subject(sid)
            if norm is not None:
                subject_filter.add(norm)
        if not subject_filter:
            subject_filter = None

    def aggregate_results(blocks_X, blocks_y, blocks_subjects):
        X = np.concatenate(blocks_X, axis=0)
        y = np.concatenate(blocks_y, axis=0)
        subjects = np.concatenate(blocks_subjects, axis=0)
        print(f"Loaded {len(X)} trials from {len(np.unique(subjects))} subjects")
        print(f"Data shape: {X.shape}")
        print(f"Label distribution: {np.bincount(y)}")
        return X, y, subjects

    subject_root = data_root
    preprocessed_dir = os.path.join(data_root, 'preprocessed')
    if os.path.isdir(preprocessed_dir):
        subject_root = preprocessed_dir
    subject_dirs = [d for d in sorted(os.listdir(subject_root))
                    if os.path.isdir(os.path.join(subject_root, d)) and d.upper().startswith('S')]

    all_X, all_y, all_subjects = [], [], []
    if subject_dirs:
        print(f"Detected {len(subject_dirs)} preprocessed subject folders under {subject_root}")
        label_map = {event_id['T1']: 0, event_id['T2']: 1}
        for subject_dir in subject_dirs:
            subject_numeric = normalize_subject(subject_dir)
            if subject_filter and subject_numeric not in subject_filter:
                continue
            subject_path = os.path.join(subject_root, subject_dir)
            for run_id in mi_runs:
                candidate_names = [
                    f"{subject_dir}R{run_id:02d}_preproc_raw.fif",
                    f"{subject_dir}R{run_id:02d}_raw.fif",
                    f"{subject_dir}R{run_id:02d}.fif",
                    f"{subject_dir}_R{run_id:02d}.fif",
                ]
                run_path = None
                for name in candidate_names:
                    candidate = os.path.join(subject_path, name)
                    if os.path.exists(candidate):
                        run_path = candidate
                        break
                if run_path is None:
                    continue
                try:
                    raw = mne.io.read_raw_fif(run_path, preload=True, verbose=False)
                except Exception as e:
                    print(f"Error loading {run_path}: {e}")
                    continue
                picks = mne.pick_types(raw.info, eeg=True, meg=False, stim=False, eog=False)
                if len(picks) == 0:
                    continue
                try:
                    events, _ = mne.events_from_annotations(raw, event_id=event_id)
                except Exception as e:
                    print(f"Error parsing annotations for {run_path}: {e}")
                    continue
                if len(events) == 0:
                    continue
                try:
                    epochs = mne.Epochs(raw, events, event_id=event_id, tmin=tmin, tmax=tmax,
                                        baseline=None, preload=True, picks=picks, verbose=False)
                except Exception as e:
                    print(f"Error epoching {run_path}: {e}")
                    continue
                data = epochs.get_data()
                labels = epochs.events[:, 2]
                mapped = np.array([label_map.get(lbl, -1) for lbl in labels])
                valid_mask = mapped >= 0
                if not np.any(valid_mask):
                    continue
                all_X.append(data[valid_mask])
                all_y.append(mapped[valid_mask])
                subj_label = subject_numeric if subject_numeric is not None else -1
                all_subjects.append(np.full(np.sum(valid_mask), subj_label))
        if all_X:
            return aggregate_results(all_X, all_y, all_subjects)
        print("No data loaded from preprocessed folders, falling back to legacy format...")

    # Legacy format fallback (flat directory with epoch files)
    legacy_files = [f for f in os.listdir(data_root) if f.endswith('.fif')]
    if not legacy_files:
        raise ValueError(
            "No valid PhysioNet files found. Ensure the derived folder contains either "
            "preprocessed subject subfolders or .fif epoch files."
        )
    if subject_filter:
        filtered = []
        for fname in legacy_files:
            parts = fname.split('_')
            if not parts:
                continue
            subj = normalize_subject(parts[0])
            if subj is not None and subj in subject_filter:
                filtered.append(fname)
        legacy_files = filtered
        if not legacy_files:
            raise ValueError("No files matched the requested subject IDs in legacy format.")

    print(f"Found {len(legacy_files)} legacy epoch files. Loading...")
    for fname in legacy_files:
        filepath = os.path.join(data_root, fname)
        try:
            epochs = mne.read_epochs(filepath, preload=True, verbose=False)
        except Exception as e:
            print(f"Error loading {filepath}: {e}")
            continue
        current_event_id = epochs.event_id
        if not current_event_id:
            continue
        label_lookup = {}
        if 'T1' in current_event_id:
            label_lookup[current_event_id['T1']] = 0
        if 'T2' in current_event_id:
            label_lookup[current_event_id['T2']] = 1
        if not label_lookup:
            continue
        labels = np.array([label_lookup.get(epochs.events[i, -1], -1) for i in range(len(epochs))])
        valid = labels >= 0
        if not np.any(valid):
            continue
        data = epochs.get_data()[valid]
        labels = labels[valid]
        subj = normalize_subject(fname.split('_')[0])
        subj_arr = np.full(len(labels), subj if subj is not None else -1)
        all_X.append(data)
        all_y.append(labels)
        all_subjects.append(subj_arr)
    if not all_X:
        raise ValueError("No valid trials were loaded from the provided PhysioNet files.")
    return aggregate_results(all_X, all_y, all_subjects)


## Model Architectures

In [4]:
# Model 1: FBCSP
class FBCSP:
    """Filter Bank Common Spatial Patterns with LDA classifier."""
    def __init__(self, freq_bands, n_components=4, sfreq=128):
        self.freq_bands = freq_bands
        self.n_components = n_components
        self.sfreq = sfreq
        self.csp_list = []
        self.classifier = None
        
    def fit(self, X, y):
        """X: (n_trials, n_channels, n_timepoints), y: (n_trials,)"""
        from mne.decoding import CSP
        
        all_features = []
        
        for low, high in self.freq_bands:
            # Filter data
            X_filtered = self._bandpass_filter(X, low, high)
            
            # Apply CSP
            csp = CSP(n_components=self.n_components, reg=None, log=True, norm_trace=False)
            features = csp.fit_transform(X_filtered, y)
            
            self.csp_list.append(csp)
            all_features.append(features)
        
        # Concatenate features from all bands
        all_features = np.concatenate(all_features, axis=1)
        
        # Train LDA classifier
        self.classifier = LinearDiscriminantAnalysis()
        self.classifier.fit(all_features, y)
        
        return self
    
    def predict(self, X):
        all_features = []
        
        for idx, (low, high) in enumerate(self.freq_bands):
            X_filtered = self._bandpass_filter(X, low, high)
            features = self.csp_list[idx].transform(X_filtered)
            all_features.append(features)
        
        all_features = np.concatenate(all_features, axis=1)
        return self.classifier.predict(all_features)
    
    def score(self, X, y):
        predictions = self.predict(X)
        return np.mean(predictions == y)
    
    def _bandpass_filter(self, X, low, high):
        """Apply bandpass filter to data."""
        from scipy.signal import butter, filtfilt
        
        nyq = self.sfreq / 2
        low_norm = low / nyq
        high_norm = high / nyq
        
        b, a = butter(4, [low_norm, high_norm], btype='band')
        
        X_filtered = np.zeros_like(X)
        for i in range(X.shape[0]):
            for j in range(X.shape[1]):
                X_filtered[i, j, :] = filtfilt(b, a, X[i, j, :])
        
        return X_filtered

In [5]:
# Model 2: CNN-SAE (CNN with Spatial Attention)
class SpatialAttention(nn.Module):
    def __init__(self, n_channels):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(n_channels, n_channels // 4),
            nn.ReLU(),
            nn.Linear(n_channels // 4, n_channels),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        # x: (batch, channels, time)
        pooled = torch.mean(x, dim=2)  # (batch, channels)
        weights = self.attention(pooled)  # (batch, channels)
        return x * weights.unsqueeze(2)  # (batch, channels, time)

class CNNSAE(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513):
        super().__init__()
        
        self.spatial_attention = SpatialAttention(n_channels)
        
        self.conv1 = nn.Conv1d(n_channels, 64, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(64)
        self.pool1 = nn.MaxPool1d(2)
        
        self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(128)
        self.pool2 = nn.MaxPool1d(2)
        
        self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(256)
        self.pool3 = nn.MaxPool1d(2)
        
        self.dropout = nn.Dropout(0.5)
        
        # Calculate flattened size
        test_input = torch.zeros(1, n_channels, n_timepoints)
        test_output = self._forward_features(test_input)
        flattened_size = test_output.view(1, -1).size(1)
        
        self.fc1 = nn.Linear(flattened_size, 256)
        self.fc2 = nn.Linear(256, n_classes)
    
    def _forward_features(self, x):
        x = self.spatial_attention(x)
        
        x = self.pool1(torch.relu(self.bn1(self.conv1(x))))
        x = self.pool2(torch.relu(self.bn2(self.conv2(x))))
        x = self.pool3(torch.relu(self.bn3(self.conv3(x))))
        
        return x
    
    def forward(self, x):
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

**Updated EEGNet**

In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class EEGNet(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513, F1=8, D=2, F2=16):
        super().__init__()

        # Temporal convolution
        self.conv1 = nn.Conv2d(1, F1, (1, 64), padding=(0, 32), bias=False)
        self.bn1 = nn.BatchNorm2d(F1)

        # Depthwise convolution
        self.conv2 = nn.Conv2d(F1, F1 * D, (n_channels, 1), groups=F1, bias=False)
        self.bn2 = nn.BatchNorm2d(F1 * D)
        self.pool1 = nn.AvgPool2d((1, 4))
        self.dropout1 = nn.Dropout(0.5)

        # Separable convolution
        self.conv3 = nn.Conv2d(F1 * D, F2, (1, 16), padding=(0, 8), bias=False)
        self.bn3 = nn.BatchNorm2d(F2)
        self.pool2 = nn.AvgPool2d((1, 8))
        self.dropout2 = nn.Dropout(0.5)

        # Calculate flattened size
        test_input = torch.zeros(1, 1, n_channels, n_timepoints)
        test_output = self._forward_features(test_input)
        flattened_size = test_output.view(1, -1).size(1)

        self.fc = nn.Linear(flattened_size, n_classes)

    def _forward_features(self, x):
        x = self.bn1(self.conv1(x))
        # REPLACED torch.elu(...) with F.elu(...)
        x = self.dropout1(self.pool1(F.elu(self.bn2(self.conv2(x)))))
        x = self.dropout2(self.pool2(F.elu(self.bn3(self.conv3(x)))))
        return x

    def forward(self, x):
        # Input: (batch, channels, time)
        x = x.unsqueeze(1)  # (batch, 1, channels, time)
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


In [30]:
m = EEGNet(n_channels=64, n_classes=2, n_timepoints=513)
print("Instantiated EEGNet OK. Params:", sum(p.numel() for p in m.parameters()))


Instantiated EEGNet OK. Params: 6226


In [7]:
# Model 4: ACS-SE-CNN (Adaptive Channel Selection with Squeeze-Excitation CNN)
class SEBlock(nn.Module):
    """Squeeze-and-Excitation block for channel attention."""
    def __init__(self, channels, reduction=4):
        super().__init__()
        hidden = max(1, channels // reduction)
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.fc2 = nn.Linear(channels // reduction, channels)
    
    def forward(self, x):
        # x: (batch, channels, time)xx
        squeeze = torch.mean(x, dim=2)  # (batch, channels)
        excitation = torch.relu(self.fc1(squeeze))
        excitation = torch.sigmoid(self.fc2(excitation))  # (batch, channels)
        return x * excitation.unsqueeze(2)

class ACSECNN(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513):
        super().__init__()
        
        # Channel selection module
        self.channel_attention = nn.Sequential(
            nn.Linear(n_timepoints, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
        # SE blocks
        self.se1 = SEBlock(n_channels)
        self.se2 = SEBlock(128)
        self.se3 = SEBlock(256)
        
        # Convolutional layers
        self.conv1 = nn.Conv1d(n_channels, 128, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(128)
        self.pool1 = nn.MaxPool1d(2)
        
        self.conv2 = nn.Conv1d(128, 256, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(256)
        self.pool2 = nn.MaxPool1d(2)
        
        self.conv3 = nn.Conv1d(256, 512, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(512)
        self.pool3 = nn.MaxPool1d(2)
        
        self.dropout = nn.Dropout(0.5)
        
        # Calculate flattened size
        test_input = torch.zeros(1, n_channels, n_timepoints)
        test_output = self._forward_features(test_input)
        flattened_size = test_output.view(1, -1).size(1)
        
        self.fc1 = nn.Linear(flattened_size, 256)
        self.fc2 = nn.Linear(256, n_classes)
        
        self.channel_weights = None
    
    def _forward_features(self, x):
        # Adaptive channel selection
        channel_weights = []
        for i in range(x.size(1)):
            w = self.channel_attention(x[:, i, :])  # (batch, 1)
            channel_weights.append(w)
        channel_weights = torch.cat(channel_weights, dim=1)  # (batch, n_channels)
        self.channel_weights = channel_weights.detach()
        
        x = x * channel_weights.unsqueeze(2)  # (batch, n_channels, time)
        
        # SE-CNN
        x = self.se1(x)
        x = self.pool1(torch.relu(self.bn1(self.conv1(x))))
        
        x = self.se2(x)
        x = self.pool2(torch.relu(self.bn2(self.conv2(x))))
        
        x = self.se3(x)
        x = self.pool3(torch.relu(self.bn3(self.conv3(x))))
        
        return x
    
    def forward(self, x):
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [8]:
# Model 5: G-CARM (Graph Channel Active Reasoning Module)
class CARMBlock(nn.Module):
    """Channel Active Reasoning Module with graph convolution."""
    def __init__(self, n_channels):
        super().__init__()
        self.n_channels = n_channels
        
        # Learnable adjacency matrix
        self.A = nn.Parameter(torch.randn(n_channels, n_channels) * 0.01)
        
        # Layer normalization
        self.norm = nn.LayerNorm(n_channels)
    
    def forward(self, x):
        # x: (batch, channels, time)
        batch_size, n_channels, n_time = x.shape
        
        # Normalize adjacency matrix
        A_norm = torch.softmax(self.A, dim=1)
        
        # Apply graph convolution
        x_reshaped = x.permute(0, 2, 1)  # (batch, time, channels)
        x_graph = torch.matmul(x_reshaped, A_norm.t())  # (batch, time, channels)
        x_graph = x_graph.permute(0, 2, 1)  # (batch, channels, time)
        
        return x_graph
    
    def get_adjacency_matrix(self):
        """Return normalized adjacency matrix for channel selection."""
        return torch.softmax(self.A, dim=1).detach()

class GCARM(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513):
        super().__init__()
        
        # CARM blocks
        self.carm1 = CARMBlock(n_channels)
        self.carm2 = CARMBlock(n_channels)
        
        # Convolutional layers
        self.conv1 = nn.Conv1d(n_channels, 128, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(128)
        self.pool1 = nn.MaxPool1d(2)
        
        self.conv2 = nn.Conv1d(128, 256, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(256)
        self.pool2 = nn.MaxPool1d(2)
        
        self.conv3 = nn.Conv1d(256, 512, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(512)
        self.pool3 = nn.MaxPool1d(2)
        
        self.dropout = nn.Dropout(0.5)
        
        # Calculate flattened size
        test_input = torch.zeros(1, n_channels, n_timepoints)
        test_output = self._forward_features(test_input)
        flattened_size = test_output.view(1, -1).size(1)
        
        self.fc1 = nn.Linear(flattened_size, 256)
        self.fc2 = nn.Linear(256, n_classes)
    
    def _forward_features(self, x):
        # Apply CARM blocks
        x = self.carm1(x)
        x = self.carm2(x)
        
        # CNN layers
        x = self.pool1(torch.relu(self.bn1(self.conv1(x))))
        x = self.pool2(torch.relu(self.bn2(self.conv2(x))))
        x = self.pool3(torch.relu(self.bn3(self.conv3(x))))
        
        return x
    
    def forward(self, x):
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x
    
    def get_channel_importance_edge(self):
        """Edge Selection: Sum of outgoing edge weights."""
        A1 = self.carm1.get_adjacency_matrix()
        A2 = self.carm2.get_adjacency_matrix()
        A_combined = (A1 + A2) / 2
        return torch.sum(A_combined, dim=1).cpu().numpy()

In [9]:
# Model 6 & 7: EEG-ARNN (Baseline and Adaptive Gating versions)
class GraphConvLayer(nn.Module):
    """Graph convolution with learned symmetric adjacency."""
    def __init__(self, num_channels, hidden_dim):
        super().__init__()
        self.num_channels = num_channels
        self.hidden_dim = hidden_dim
        self.A = nn.Parameter(torch.randn(num_channels, num_channels) * 0.01)
        self.theta = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.bn = nn.BatchNorm2d(hidden_dim)
        self.act = nn.ELU()

    def forward(self, x):
        B, H, C, T = x.shape
        A = torch.sigmoid(self.A)
        A = 0.5 * (A + A.t())
        I = torch.eye(C, device=A.device)
        A_hat = A + I
        D = torch.diag(torch.pow(A_hat.sum(1).clamp_min(1e-6), -0.5))
        A_norm = D @ A_hat @ D

        x_perm = x.permute(0, 3, 2, 1).contiguous().view(B * T, C, H)
        x_g = A_norm @ x_perm
        x_g = self.theta(x_g)
        x_g = x_g.view(B, T, C, H).permute(0, 3, 2, 1)
        x_out = self.bn(x_g)
        return self.act(x_out)

    def get_adjacency(self):
        with torch.no_grad():
            A = torch.sigmoid(self.A)
            A = 0.5 * (A + A.t())
            return A.cpu().numpy()


class TemporalConv(nn.Module):
    """Temporal convolution operating independently per channel."""
    def __init__(self, in_channels, out_channels, kernel_size=16, pool=True):
        super().__init__()
        self.pool = pool
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(1, kernel_size),
                              padding=(0, kernel_size // 2), bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.ELU()
        self.pool_layer = nn.AvgPool2d(kernel_size=(1, 2)) if pool else None

    def forward(self, x):
        x = self.act(self.bn(self.conv(x)))
        if self.pool_layer is not None:
            x = self.pool_layer(x)
        return x


class BaselineEEGARNN(nn.Module):
    """Baseline EEG-ARNN with temporal conv + adaptive graph reasoning."""
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513, hidden_dim=128):
        super().__init__()
        self.n_channels = n_channels
        self.hidden_dim = hidden_dim
        self.use_gate_regularizer = False
        self.gate_penalty_tensor = None
        self.latest_gate_values = None

        self.t1 = TemporalConv(1, hidden_dim, 16, pool=False)
        self.g1 = GraphConvLayer(n_channels, hidden_dim)
        self.t2 = TemporalConv(hidden_dim, hidden_dim, 16, pool=True)
        self.g2 = GraphConvLayer(n_channels, hidden_dim)
        self.t3 = TemporalConv(hidden_dim, hidden_dim, 16, pool=True)
        self.g3 = GraphConvLayer(n_channels, hidden_dim)

        with torch.no_grad():
            dummy = torch.zeros(1, n_channels, n_timepoints)
            feat = self._forward_features(self._prepare_input(dummy))
            self.feature_dim = feat.view(1, -1).size(1)

        self.fc1 = nn.Linear(self.feature_dim, 256)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, n_classes)

    def _prepare_input(self, x):
        if x.dim() == 3:
            x = x.unsqueeze(1)
        return x

    def _forward_features(self, x):
        x = self.g1(self.t1(x))
        x = self.g2(self.t2(x))
        x = self.g3(self.t3(x))
        return x

    def _forward_from_prepared(self, x):
        features = self._forward_features(x)
        x = features.view(features.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

    def forward(self, x):
        prepared = self._prepare_input(x)
        self.gate_penalty_tensor = None
        self.latest_gate_values = None
        return self._forward_from_prepared(prepared)

    def get_final_adjacency(self):
        return self.g3.get_adjacency()

    def get_channel_importance_edge(self):
        adjacency = self.get_final_adjacency()
        return np.sum(adjacency, axis=1)


class AdaptiveGatingEEGARNN(BaselineEEGARNN):
    """EEG-ARNN with adaptive data-dependent channel gates."""
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513, hidden_dim=128, gate_init=0.9):
        super().__init__(n_channels, n_classes, n_timepoints, hidden_dim)
        self.use_gate_regularizer = True
        self.gate_net = nn.Sequential(
            nn.Linear(n_channels * 2, n_channels),
            nn.ReLU(),
            nn.Linear(n_channels, n_channels),
            nn.Sigmoid()
        )
        init_value = float(np.clip(gate_init, 1e-3, 1 - 1e-3))
        init_bias = math.log(init_value / (1.0 - init_value))
        with torch.no_grad():
            self.gate_net[-2].bias.fill_(init_bias)
        self.latest_gate_values = None

    def compute_gates(self, x):
        x_s = x.squeeze(1)
        ch_mean = x_s.mean(dim=2)
        ch_std = x_s.std(dim=2)
        stats = torch.cat([ch_mean, ch_std], dim=1)
        return self.gate_net(stats)

    def forward(self, x):
        prepared = self._prepare_input(x)
        gates = self.compute_gates(prepared)
        self.gate_penalty_tensor = gates
        self.latest_gate_values = gates.detach()
        gated = prepared * gates.view(gates.size(0), 1, gates.size(1), 1)
        return self._forward_from_prepared(gated)

    def get_channel_importance_gate(self):
        if self.latest_gate_values is None:
            return None
        return self.latest_gate_values.mean(dim=0).cpu().numpy()


## Training Utilities

In [10]:
def train_epoch(model, dataloader, criterion, optimizer, device, l1_lambda=0.0):
    """Train for one epoch with optional gating regularization."""
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for X_batch, y_batch in dataloader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)

        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)

        gate_penalty = getattr(model, 'gate_penalty_tensor', None)
        if l1_lambda > 0 and gate_penalty is not None:
            loss = loss + l1_lambda * gate_penalty.abs().mean()

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += y_batch.size(0)
        correct += (predicted == y_batch).sum().item()

    denom = max(1, len(dataloader))
    return total_loss / denom, correct / max(1, total)


def evaluate(model, dataloader, criterion, device):
    """Evaluate model."""
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for X_batch, y_batch in dataloader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)

            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += y_batch.size(0)
            correct += (predicted == y_batch).sum().item()

    denom = max(1, len(dataloader))
    return total_loss / denom, correct / max(1, total)


def train_pytorch_model(model, train_loader, val_loader, config, model_name=''):
    """Train a PyTorch model with scheduler + best checkpoint tracking."""
    device = config['device']
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], 
                          weight_decay=config['weight_decay'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=config.get('scheduler_patience', 3), verbose=False
    )

    l1_lambda = config.get('gating', {}).get('l1_lambda', 0.0) if getattr(model, 'use_gate_regularizer', False) else 0.0
    use_early_stopping = config.get('use_early_stopping', False) and config.get('patience') is not None
    max_patience = config.get('patience', 0)
    patience_counter = 0

    best_state = deepcopy(model.state_dict())
    best_val_acc = 0.0
    best_val_loss = float('inf')

    for epoch in range(config['epochs']):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device, l1_lambda)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)
        scheduler.step(val_loss)

        improved = val_acc > best_val_acc or (val_acc == best_val_acc and val_loss < best_val_loss)
        if improved:
            best_state = deepcopy(model.state_dict())
            best_val_acc = val_acc
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1

        prefix = model_name if model_name else 'Model'
        print(f"[{prefix}] Epoch {epoch + 1}/{config['epochs']} - Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

        if use_early_stopping and patience_counter >= max_patience:
            print(f"Early stopping triggered for {prefix} at epoch {epoch + 1}")
            break

    model.load_state_dict(best_state)
    return best_state, best_val_acc


## Main Training Loop

In [11]:
# Load data
print("Loading PhysioNet data...")
X, y, subject_labels = load_physionet_data(CONFIG['data_path'])

print(f"\nData loaded successfully!")
print(f"Total trials: {len(X)}")
print(f"Data shape: {X.shape}")
print(f"Labels: {np.unique(y, return_counts=True)}")

Loading PhysioNet data...
Detected 51 preprocessed subject folders under /kaggle/input/eeg-preprocessed-data/derived/preprocessed
Loaded 2966 trials from 51 subjects
Data shape: (2966, 64, 513)
Label distribution: [1489 1477]

Data loaded successfully!
Total trials: 2966
Data shape: (2966, 64, 513)
Labels: (array([0, 1]), array([1489, 1477]))


In [12]:
# Define models to train
models_to_train = [
    {'name': 'FBCSP', 'type': 'sklearn'},
    {'name': 'CNN-SAE', 'type': 'pytorch'},
    {'name': 'EEGNet', 'type': 'pytorch'},
    {'name': 'ACS-SE-CNN', 'type': 'pytorch'},
    {'name': 'G-CARM', 'type': 'pytorch'},
    {'name': 'Baseline-EEG-ARNN', 'type': 'pytorch'},
    {'name': 'Adaptive-Gating-EEG-ARNN', 'type': 'pytorch'},
]

# Initialize results storage
all_results = []

# Cross-validation
skf = StratifiedKFold(n_splits=CONFIG['n_folds'], shuffle=True, random_state=CONFIG['random_seed'])

print(f"\nStarting {CONFIG['n_folds']}-fold cross-validation...\n")


Starting 3-fold cross-validation...



In [16]:
from torch.utils.data import Dataset

class EEGDataset(Dataset):
    def __init__(self, X, y):
        # (samples, channels, timepoints)
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


In [18]:
X.shape 


(2966, 64, 513)

In [19]:
try:
    EEGDataset
    print("EEGDataset is already defined.")
except NameError:
    print("EEGDataset is NOT defined.")


EEGDataset is already defined.


In [21]:
import os,fnmatch,sys
root = os.getcwd()   # change if your code is in another dir

bad = ['torch.F','torch.tanh(', 'torch.sigmoid(', 'torch.relu(']  # we only need torch.elu really
matches=[]
for path,dirs,files in os.walk(root):
    for f in files:
        if f.endswith('.py') or f.endswith('.ipynb'):
            p=os.path.join(path,f)
            try:
                with open(p,'r',errors='ignore') as fh:
                    txt=fh.read()
                for token in bad:
                    if token in txt:
                        matches.append((p, token))
            except:
                pass

if not matches:
    print("No suspicious tokens found in files under", root)
else:
    for p,t in matches:
        print("Found",t,"in",p)


No suspicious tokens found in files under /kaggle/working


In [22]:
try:
    F
    print("✔ F is loaded and available.")
except NameError:
    print("❌ F is NOT loaded. You must run the import cell.")


In [24]:
from torch.utils.data import Dataset
class EEGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


In [25]:
try:
    F
    print("✔ F is loaded and available.")
except NameError:
    print(" F is NOT loaded. You must run the import cell.")


✔ F is loaded and available.


In [32]:
import torch, gc
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
print("Cache cleared. Free mem (MiB):", torch.cuda.get_device_properties(0).total_memory//(1024**2) - torch.cuda.memory_allocated(0)//(1024**2))


Cache cleared. Free mem (MiB): 1286


In [31]:

# Train all models
for model_info in models_to_train:
    model_name = model_info['name']
    model_type = model_info['type']

    print(f"{'='*60}")
    print(f"Training {model_name}")
    print(f"{'='*60}")

    fold_accuracies = []

    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
        print(f"Fold {fold + 1}/{CONFIG['n_folds']}")

        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]

        if model_type == 'sklearn':
            model = FBCSP(freq_bands=CONFIG['fbcsp_bands'],
                          n_components=CONFIG['fbcsp_n_components'],
                          sfreq=CONFIG['sfreq'])
            model.fit(X_train, y_train)
            val_acc = model.score(X_val, y_val)

            model_path = os.path.join(CONFIG['models_dir'], f"{model_name}_fold{fold+1}.pkl")
            with open(model_path, 'wb') as f:
                pickle.dump(model, f)
        else:
            train_dataset = EEGDataset(X_train, y_train)
            val_dataset = EEGDataset(X_val, y_val)

            train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'],
                                      shuffle=True, num_workers=0)
            val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'],
                                    shuffle=False, num_workers=0)

            base_kwargs = {
                'n_channels': CONFIG['n_channels'],
                'n_classes': CONFIG['n_classes'],
                'n_timepoints': CONFIG['n_timepoints'],
            }

            if model_name == 'CNN-SAE':
                model = CNNSAE(**base_kwargs)
            elif model_name == 'EEGNet':
                model = EEGNet(**base_kwargs)
            elif model_name == 'ACS-SE-CNN':
                model = ACSECNN(**base_kwargs)
            elif model_name == 'G-CARM':
                model = GCARM(**base_kwargs)
            elif model_name == 'Baseline-EEG-ARNN':
                model = BaselineEEGARNN(hidden_dim=CONFIG['hidden_dim'], **base_kwargs)
            elif model_name == 'Adaptive-Gating-EEG-ARNN':
                model = AdaptiveGatingEEGARNN(hidden_dim=CONFIG['hidden_dim'],
                                              gate_init=CONFIG['gating']['gate_init'],
                                              **base_kwargs)
            else:
                raise ValueError(f"Unknown model: {model_name}")

            best_state, val_acc = train_pytorch_model(model, train_loader, val_loader,
                                                      CONFIG, model_name)
            model_path = os.path.join(CONFIG['models_dir'], f"{model_name}_fold{fold+1}.pt")
            torch.save(best_state, model_path)

        fold_accuracies.append(val_acc)
        print(f"Fold {fold + 1} Validation Accuracy: {val_acc:.4f}")

    mean_acc = np.mean(fold_accuracies)
    std_acc = np.std(fold_accuracies)

    print(f"{model_name} Results:")
    print(f"Mean Accuracy: {mean_acc:.4f} +/- {std_acc:.4f}")

    all_results.append({
        'model': model_name,
        'mean_accuracy': mean_acc,
        'std_accuracy': std_acc,
        'fold_accuracies': fold_accuracies
    })

print(f"{'='*60}")
print("All models trained successfully!")
print(f"{'='*60}")


Training FBCSP
Fold 1/3
Fold 1 Validation Accuracy: 0.6208
Fold 2/3
Fold 2 Validation Accuracy: 0.6077
Fold 3/3
Fold 3 Validation Accuracy: 0.6144
FBCSP Results:
Mean Accuracy: 0.6143 +/- 0.0054
Training CNN-SAE
Fold 1/3
[CNN-SAE] Epoch 1/20 - Train Acc: 0.6151 | Val Acc: 0.5025
[CNN-SAE] Epoch 2/20 - Train Acc: 0.7951 | Val Acc: 0.5025
[CNN-SAE] Epoch 3/20 - Train Acc: 0.8134 | Val Acc: 0.5905
[CNN-SAE] Epoch 4/20 - Train Acc: 0.8417 | Val Acc: 0.5025
[CNN-SAE] Epoch 5/20 - Train Acc: 0.8574 | Val Acc: 0.5025
[CNN-SAE] Epoch 6/20 - Train Acc: 0.8604 | Val Acc: 0.5025
[CNN-SAE] Epoch 7/20 - Train Acc: 0.8614 | Val Acc: 0.5025
[CNN-SAE] Epoch 8/20 - Train Acc: 0.8761 | Val Acc: 0.5025
[CNN-SAE] Epoch 9/20 - Train Acc: 0.8938 | Val Acc: 0.4975
[CNN-SAE] Epoch 10/20 - Train Acc: 0.8973 | Val Acc: 0.5025
[CNN-SAE] Epoch 11/20 - Train Acc: 0.8938 | Val Acc: 0.4975
[CNN-SAE] Epoch 12/20 - Train Acc: 0.9155 | Val Acc: 0.4975
[CNN-SAE] Epoch 13/20 - Train Acc: 0.9191 | Val Acc: 0.5025
[CNN-SAE

OutOfMemoryError: CUDA out of memory. Tried to allocate 514.00 MiB. GPU 0 has a total capacity of 15.89 GiB of which 365.12 MiB is free. Process 2582 has 15.53 GiB memory in use. Of the allocated memory 15.13 GiB is allocated by PyTorch, and 96.68 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [33]:
import os
print(os.listdir(CONFIG['models_dir']))


['G-CARM_fold1.pt', 'FBCSP_fold2.pkl', 'G-CARM_fold2.pt', 'FBCSP_fold1.pkl', 'CNN-SAE_fold3.pt', 'ACS-SE-CNN_fold3.pt', 'ACS-SE-CNN_fold1.pt', 'FBCSP_fold3.pkl', 'EEGNet_fold1.pt', 'G-CARM_fold3.pt', 'EEGNet_fold3.pt', 'CNN-SAE_fold1.pt', 'ACS-SE-CNN_fold2.pt', 'CNN-SAE_fold2.pt', 'EEGNet_fold2.pt']


In [38]:
import torch, gc
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
print("Cache cleared. Free mem (MiB):", torch.cuda.get_device_properties(0).total_memory//(1024**2) - torch.cuda.memory_allocated(0)//(1024**2))


Cache cleared. Free mem (MiB): 16217


In [39]:
remaining_models = [
    {'name': 'Baseline-EEG-ARNN', 'type': 'pytorch'},
    {'name': 'Adaptive-Gating-EEG-ARNN', 'type': 'pytorch'},
]

for model_info in remaining_models:
    model_name = model_info['name']
    model_type = model_info['type']

    print("="*60)
    print(f"Resuming: Training {model_name}")
    print("="*60)

    fold_accuracies = []

    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
        print(f"Fold {fold + 1}/{CONFIG['n_folds']}")

        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]

        train_dataset = EEGDataset(X_train, y_train)
        val_dataset = EEGDataset(X_val, y_val)

        # ↓↓↓ reduce batch size for safety ↓↓↓
        safe_bs = max(4, CONFIG['batch_size'] // 2)

        train_loader = DataLoader(train_dataset, batch_size=safe_bs,
                                  shuffle=True, num_workers=0)
        val_loader = DataLoader(val_dataset, batch_size=safe_bs,
                                shuffle=False, num_workers=0)

        base_kwargs = {
            'n_channels': CONFIG['n_channels'],
            'n_classes': CONFIG['n_classes'],
            'n_timepoints': CONFIG['n_timepoints'],
        }

        if model_name == 'Baseline-EEG-ARNN':
            model = BaselineEEGARNN(hidden_dim=CONFIG['hidden_dim'], **base_kwargs)
        elif model_name == 'Adaptive-Gating-EEG-ARNN':
            model = AdaptiveGatingEEGARNN(
                hidden_dim=CONFIG['hidden_dim'],
                gate_init=CONFIG['gating']['gate_init'],
                **base_kwargs
            )
        else:
            raise ValueError("Unexpected model name during resume.")

        best_state, val_acc = train_pytorch_model(model, train_loader, val_loader,
                                                  CONFIG, model_name)
        model_path = os.path.join(CONFIG['models_dir'], f"{model_name}_fold{fold+1}.pt")
        torch.save(best_state, model_path)

        print(f"Fold {fold + 1} Accuracy: {val_acc:.4f}")
        fold_accuracies.append(val_acc)

    print(f"\n{model_name} Results:")
    print(f"Mean: {np.mean(fold_accuracies):.4f}  Std: {np.std(fold_accuracies):.4f}")


Resuming: Training Baseline-EEG-ARNN
Fold 1/3
[Baseline-EEG-ARNN] Epoch 1/20 - Train Acc: 0.4987 | Val Acc: 0.5025
[Baseline-EEG-ARNN] Epoch 2/20 - Train Acc: 0.5215 | Val Acc: 0.5025
[Baseline-EEG-ARNN] Epoch 3/20 - Train Acc: 0.5109 | Val Acc: 0.5106
[Baseline-EEG-ARNN] Epoch 4/20 - Train Acc: 0.5210 | Val Acc: 0.5025
[Baseline-EEG-ARNN] Epoch 5/20 - Train Acc: 0.5301 | Val Acc: 0.4975
[Baseline-EEG-ARNN] Epoch 6/20 - Train Acc: 0.5119 | Val Acc: 0.5015
[Baseline-EEG-ARNN] Epoch 7/20 - Train Acc: 0.4997 | Val Acc: 0.5025
[Baseline-EEG-ARNN] Epoch 8/20 - Train Acc: 0.5099 | Val Acc: 0.5025
[Baseline-EEG-ARNN] Epoch 9/20 - Train Acc: 0.4932 | Val Acc: 0.4874
[Baseline-EEG-ARNN] Epoch 10/20 - Train Acc: 0.5099 | Val Acc: 0.4975
[Baseline-EEG-ARNN] Epoch 11/20 - Train Acc: 0.4952 | Val Acc: 0.5086
[Baseline-EEG-ARNN] Epoch 12/20 - Train Acc: 0.5392 | Val Acc: 0.5248
[Baseline-EEG-ARNN] Epoch 13/20 - Train Acc: 0.5210 | Val Acc: 0.4944
[Baseline-EEG-ARNN] Epoch 14/20 - Train Acc: 0.5367 |

In [40]:
# Save results
results_df = pd.DataFrame(all_results)
results_df = results_df.sort_values('mean_accuracy', ascending=False)
results_df.to_csv(os.path.join(CONFIG['results_dir'], 'summary_all_models.csv'), index=False)

print("\nFinal Results:")
print(results_df[['model', 'mean_accuracy', 'std_accuracy']])

print(f"\nResults saved to {CONFIG['results_dir']}/summary_all_models.csv")
print(f"Models saved to {CONFIG['models_dir']}/")


Final Results:
        model  mean_accuracy  std_accuracy
7      EEGNet       0.840188      0.002996
2     CNN-SAE       0.697269      0.127215
6     CNN-SAE       0.637240      0.041979
0       FBCSP       0.614295      0.005367
1       FBCSP       0.614295      0.005367
3       FBCSP       0.614295      0.005367
5       FBCSP       0.614295      0.005367
4     CNN-SAE       0.561377      0.041023
8  ACS-SE-CNN       0.545856      0.032388
9      G-CARM       0.507753      0.011471

Results saved to ./results/summary_all_models.csv
Models saved to ./models/


## Verification

Check that Adaptive-Gating-EEG-ARNN is the winner:

In [41]:
# Verify winner
winner = results_df.iloc[0]
print(f"\nWinner: {winner['model']}")
print(f"Accuracy: {winner['mean_accuracy']:.4f} +/- {winner['std_accuracy']:.4f}")

if winner['model'] == 'Adaptive-Gating-EEG-ARNN':
    print("\nSUCCESS: Adaptive-Gating-EEG-ARNN is the winner!")
else:
    print(f"\nWARNING: Expected Adaptive-Gating-EEG-ARNN to win, but {winner['model']} won instead.")
    print("This may indicate a hyperparameter tuning issue.")


Winner: EEGNet
Accuracy: 0.8402 +/- 0.0030

This may indicate a hyperparameter tuning issue.
