# PhysioNet EEG: Train All Baseline Models

This notebook trains 7 different models on the PhysioNet Motor Imagery dataset:
1. FBCSP (Filter Bank Common Spatial Patterns)
2. CNN-SAE (CNN with Spatial Attention)
3. EEGNet
4. ACS-SE-CNN (Adaptive Channel Selection SE-CNN)
5. G-CARM (Graph Channel Active Reasoning Module)
6. Baseline EEG-ARNN (without adaptive gating)
7. Adaptive Gating EEG-ARNN (our proposed method)

**Expected Runtime**: 10-12 hours on Kaggle GPU

**Input**: `/kaggle/input/physionet-preprocessed/derived/` (preprocessed EEG data)

**Output**: 
- Trained models in `models/` folder
- Results CSV files in `results/` folder

## Configuration

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import mne
from sklearn.model_selection import StratifiedKFold
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
import pickle
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

mne.set_log_level('ERROR')

In [None]:
# Configuration
CONFIG = {
    'data_path': '/kaggle/input/eeg-preprocessed-data/derived',  # Change this for local testing
    'output_dir': './',
    'results_dir': './results',
    'models_dir': './models',
    'figures_dir': './figures',
    
    'n_folds': 3,
    'random_seed': 42,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    
    # Training hyperparameters
    'batch_size': 64,
    'epochs': 20,
    'learning_rate': 0.001,
    'weight_decay': 1e-4,
    'patience': 10,
    
    # Data parameters
    'n_channels': 64,
    'n_classes': 2,
    'sfreq': 128,
    'tmin': 0.0,
    'tmax': 4.0,
    'n_timepoints': 513,  # 4 seconds at 128 Hz + 1
    
    # FBCSP parameters
    'fbcsp_bands': [(4, 8), (8, 12), (12, 16), (16, 20), (20, 24), (24, 28), (28, 32), (32, 36), (36, 40)],
    'fbcsp_n_components': 4,
}

# Create output directories
os.makedirs(CONFIG['results_dir'], exist_ok=True)
os.makedirs(CONFIG['models_dir'], exist_ok=True)
os.makedirs(CONFIG['figures_dir'], exist_ok=True)

# Set random seeds
np.random.seed(CONFIG['random_seed'])
torch.manual_seed(CONFIG['random_seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed(CONFIG['random_seed'])

print(f"Device: {CONFIG['device']}")
print(f"Data path: {CONFIG['data_path']}")

## Data Loading Utilities

In [None]:
def load_physionet_data(data_path, subject_ids=None):
    """
    Load preprocessed PhysioNet data from derived folder.
    
    Returns:
        X: numpy array (n_trials, n_channels, n_timepoints)
        y: numpy array (n_trials,) with labels 0=T1 (left fist), 1=T2 (right fist)
        subject_labels: numpy array (n_trials,) with subject IDs
    """
    all_X = []
    all_y = []
    all_subjects = []
    
    # Get list of subjects
    if subject_ids is None:
        files = [f for f in os.listdir(data_path) if f.endswith('.fif')]
        subject_ids = sorted(list(set([int(f.split('_')[0][1:]) for f in files])))
    
    print(f"Loading data from {len(subject_ids)} subjects...")
    
    for subject_id in tqdm(subject_ids):
        # Load runs for this subject (runs 3, 4, 7, 8, 11, 12)
        subject_runs = []
        for run_id in [3, 4, 7, 8, 11, 12]:
            filename = f"S{subject_id:03d}_R{run_id:02d}.fif"
            filepath = os.path.join(data_path, filename)
            
            if not os.path.exists(filepath):
                continue
            
            try:
                epochs = mne.read_epochs(filepath, preload=True, verbose=False)
                subject_runs.append(epochs)
            except Exception as e:
                print(f"Error loading {filename}: {e}")
                continue
        
        if len(subject_runs) == 0:
            continue
        
        # Concatenate all runs for this subject
        epochs = mne.concatenate_epochs(subject_runs)
        
        # Extract data
        X = epochs.get_data()  # (n_trials, n_channels, n_timepoints)
        
        # Extract labels - CRITICAL FIX
        # MNE uses internal event codes that must be mapped to our 0/1 labels
        event_ids = epochs.event_id
        valid_event_ids = {'T1': 1, 'T2': 2}  # MNE annotation codes
        
        # Create mapping from MNE event codes to our labels (T1=0, T2=1)
        event_name_to_label = {}
        if 'T1' in event_ids:
            event_name_to_label['T1'] = 0
        if 'T2' in event_ids:
            event_name_to_label['T2'] = 1
        
        # Map MNE event codes to our labels
        event_code_to_label = {}
        for name, label in event_name_to_label.items():
            if name in valid_event_ids:
                mne_code = valid_event_ids[name]
                event_code_to_label[mne_code] = label
        
        # Get labels for each trial
        y = np.array([event_code_to_label.get(epochs.events[i, -1], -1) 
                     for i in range(len(epochs))])
        
        # Filter out any trials with unknown labels
        valid_mask = y != -1
        X = X[valid_mask]
        y = y[valid_mask]
        
        if len(X) == 0:
            continue
        
        all_X.append(X)
        all_y.append(y)
        all_subjects.append(np.full(len(y), subject_id))
    
    # Concatenate all subjects
    X = np.concatenate(all_X, axis=0)
    y = np.concatenate(all_y, axis=0)
    subject_labels = np.concatenate(all_subjects, axis=0)
    
    print(f"Loaded {len(X)} trials from {len(np.unique(subject_labels))} subjects")
    print(f"Data shape: {X.shape}")
    print(f"Label distribution: {np.bincount(y)}")
    
    return X, y, subject_labels


class EEGDataset(Dataset):
    """PyTorch dataset for EEG data."""
    def __init__(self, X, y):
        self.X = torch.FloatTensor(X)
        self.y = torch.LongTensor(y)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

## Model Architectures

In [None]:
# Model 1: FBCSP
class FBCSP:
    """Filter Bank Common Spatial Patterns with LDA classifier."""
    def __init__(self, freq_bands, n_components=4, sfreq=128):
        self.freq_bands = freq_bands
        self.n_components = n_components
        self.sfreq = sfreq
        self.csp_list = []
        self.classifier = None
        
    def fit(self, X, y):
        """X: (n_trials, n_channels, n_timepoints), y: (n_trials,)"""
        from mne.decoding import CSP
        
        all_features = []
        
        for low, high in self.freq_bands:
            # Filter data
            X_filtered = self._bandpass_filter(X, low, high)
            
            # Apply CSP
            csp = CSP(n_components=self.n_components, reg=None, log=True, norm_trace=False)
            features = csp.fit_transform(X_filtered, y)
            
            self.csp_list.append(csp)
            all_features.append(features)
        
        # Concatenate features from all bands
        all_features = np.concatenate(all_features, axis=1)
        
        # Train LDA classifier
        self.classifier = LinearDiscriminantAnalysis()
        self.classifier.fit(all_features, y)
        
        return self
    
    def predict(self, X):
        all_features = []
        
        for idx, (low, high) in enumerate(self.freq_bands):
            X_filtered = self._bandpass_filter(X, low, high)
            features = self.csp_list[idx].transform(X_filtered)
            all_features.append(features)
        
        all_features = np.concatenate(all_features, axis=1)
        return self.classifier.predict(all_features)
    
    def score(self, X, y):
        predictions = self.predict(X)
        return np.mean(predictions == y)
    
    def _bandpass_filter(self, X, low, high):
        """Apply bandpass filter to data."""
        from scipy.signal import butter, filtfilt
        
        nyq = self.sfreq / 2
        low_norm = low / nyq
        high_norm = high / nyq
        
        b, a = butter(4, [low_norm, high_norm], btype='band')
        
        X_filtered = np.zeros_like(X)
        for i in range(X.shape[0]):
            for j in range(X.shape[1]):
                X_filtered[i, j, :] = filtfilt(b, a, X[i, j, :])
        
        return X_filtered

In [None]:
# Model 2: CNN-SAE (CNN with Spatial Attention)
class SpatialAttention(nn.Module):
    def __init__(self, n_channels):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(n_channels, n_channels // 4),
            nn.ReLU(),
            nn.Linear(n_channels // 4, n_channels),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        # x: (batch, channels, time)
        pooled = torch.mean(x, dim=2)  # (batch, channels)
        weights = self.attention(pooled)  # (batch, channels)
        return x * weights.unsqueeze(2)  # (batch, channels, time)

class CNNSAE(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513):
        super().__init__()
        
        self.spatial_attention = SpatialAttention(n_channels)
        
        self.conv1 = nn.Conv1d(n_channels, 64, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(64)
        self.pool1 = nn.MaxPool1d(2)
        
        self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(128)
        self.pool2 = nn.MaxPool1d(2)
        
        self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(256)
        self.pool3 = nn.MaxPool1d(2)
        
        self.dropout = nn.Dropout(0.5)
        
        # Calculate flattened size
        test_input = torch.zeros(1, n_channels, n_timepoints)
        test_output = self._forward_features(test_input)
        flattened_size = test_output.view(1, -1).size(1)
        
        self.fc1 = nn.Linear(flattened_size, 256)
        self.fc2 = nn.Linear(256, n_classes)
    
    def _forward_features(self, x):
        x = self.spatial_attention(x)
        
        x = self.pool1(torch.relu(self.bn1(self.conv1(x))))
        x = self.pool2(torch.relu(self.bn2(self.conv2(x))))
        x = self.pool3(torch.relu(self.bn3(self.conv3(x))))
        
        return x
    
    def forward(self, x):
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [None]:
# Model 3: EEGNet
class EEGNet(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513, F1=8, D=2, F2=16):
        super().__init__()
        
        # Temporal convolution
        self.conv1 = nn.Conv2d(1, F1, (1, 64), padding=(0, 32), bias=False)
        self.bn1 = nn.BatchNorm2d(F1)
        
        # Depthwise convolution
        self.conv2 = nn.Conv2d(F1, F1 * D, (n_channels, 1), groups=F1, bias=False)
        self.bn2 = nn.BatchNorm2d(F1 * D)
        self.pool1 = nn.AvgPool2d((1, 4))
        self.dropout1 = nn.Dropout(0.5)
        
        # Separable convolution
        self.conv3 = nn.Conv2d(F1 * D, F2, (1, 16), padding=(0, 8), bias=False)
        self.bn3 = nn.BatchNorm2d(F2)
        self.pool2 = nn.AvgPool2d((1, 8))
        self.dropout2 = nn.Dropout(0.5)
        
        # Calculate flattened size
        test_input = torch.zeros(1, 1, n_channels, n_timepoints)
        test_output = self._forward_features(test_input)
        flattened_size = test_output.view(1, -1).size(1)
        
        self.fc = nn.Linear(flattened_size, n_classes)
    
    def _forward_features(self, x):
        x = self.bn1(self.conv1(x))
        x = self.dropout1(self.pool1(torch.elu(self.bn2(self.conv2(x)))))
        x = self.dropout2(self.pool2(torch.elu(self.bn3(self.conv3(x)))))
        return x
    
    def forward(self, x):
        # Input: (batch, channels, time)
        x = x.unsqueeze(1)  # (batch, 1, channels, time)
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [None]:
# Model 4: ACS-SE-CNN (Adaptive Channel Selection with Squeeze-Excitation CNN)
class SEBlock(nn.Module):
    """Squeeze-and-Excitation block for channel attention."""
    def __init__(self, channels, reduction=4):
        super().__init__()
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.fc2 = nn.Linear(channels // reduction, channels)
    
    def forward(self, x):
        # x: (batch, channels, time)
        squeeze = torch.mean(x, dim=2)  # (batch, channels)
        excitation = torch.relu(self.fc1(squeeze))
        excitation = torch.sigmoid(self.fc2(excitation))  # (batch, channels)
        return x * excitation.unsqueeze(2)

class ACSECNN(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513):
        super().__init__()
        
        # Channel selection module
        self.channel_attention = nn.Sequential(
            nn.Linear(n_timepoints, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
        # SE blocks
        self.se1 = SEBlock(n_channels)
        self.se2 = SEBlock(128)
        self.se3 = SEBlock(256)
        
        # Convolutional layers
        self.conv1 = nn.Conv1d(n_channels, 128, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(128)
        self.pool1 = nn.MaxPool1d(2)
        
        self.conv2 = nn.Conv1d(128, 256, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(256)
        self.pool2 = nn.MaxPool1d(2)
        
        self.conv3 = nn.Conv1d(256, 512, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(512)
        self.pool3 = nn.MaxPool1d(2)
        
        self.dropout = nn.Dropout(0.5)
        
        # Calculate flattened size
        test_input = torch.zeros(1, n_channels, n_timepoints)
        test_output = self._forward_features(test_input)
        flattened_size = test_output.view(1, -1).size(1)
        
        self.fc1 = nn.Linear(flattened_size, 256)
        self.fc2 = nn.Linear(256, n_classes)
        
        self.channel_weights = None
    
    def _forward_features(self, x):
        # Adaptive channel selection
        channel_weights = []
        for i in range(x.size(1)):
            w = self.channel_attention(x[:, i, :])  # (batch, 1)
            channel_weights.append(w)
        channel_weights = torch.cat(channel_weights, dim=1)  # (batch, n_channels)
        self.channel_weights = channel_weights.detach()
        
        x = x * channel_weights.unsqueeze(2)  # (batch, n_channels, time)
        
        # SE-CNN
        x = self.se1(x)
        x = self.pool1(torch.relu(self.bn1(self.conv1(x))))
        
        x = self.se2(x)
        x = self.pool2(torch.relu(self.bn2(self.conv2(x))))
        
        x = self.se3(x)
        x = self.pool3(torch.relu(self.bn3(self.conv3(x))))
        
        return x
    
    def forward(self, x):
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [None]:
# Model 5: G-CARM (Graph Channel Active Reasoning Module)
class CARMBlock(nn.Module):
    """Channel Active Reasoning Module with graph convolution."""
    def __init__(self, n_channels):
        super().__init__()
        self.n_channels = n_channels
        
        # Learnable adjacency matrix
        self.A = nn.Parameter(torch.randn(n_channels, n_channels) * 0.01)
        
        # Layer normalization
        self.norm = nn.LayerNorm(n_channels)
    
    def forward(self, x):
        # x: (batch, channels, time)
        batch_size, n_channels, n_time = x.shape
        
        # Normalize adjacency matrix
        A_norm = torch.softmax(self.A, dim=1)
        
        # Apply graph convolution
        x_reshaped = x.permute(0, 2, 1)  # (batch, time, channels)
        x_graph = torch.matmul(x_reshaped, A_norm.t())  # (batch, time, channels)
        x_graph = x_graph.permute(0, 2, 1)  # (batch, channels, time)
        
        return x_graph
    
    def get_adjacency_matrix(self):
        """Return normalized adjacency matrix for channel selection."""
        return torch.softmax(self.A, dim=1).detach()

class GCARM(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513):
        super().__init__()
        
        # CARM blocks
        self.carm1 = CARMBlock(n_channels)
        self.carm2 = CARMBlock(n_channels)
        
        # Convolutional layers
        self.conv1 = nn.Conv1d(n_channels, 128, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(128)
        self.pool1 = nn.MaxPool1d(2)
        
        self.conv2 = nn.Conv1d(128, 256, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(256)
        self.pool2 = nn.MaxPool1d(2)
        
        self.conv3 = nn.Conv1d(256, 512, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(512)
        self.pool3 = nn.MaxPool1d(2)
        
        self.dropout = nn.Dropout(0.5)
        
        # Calculate flattened size
        test_input = torch.zeros(1, n_channels, n_timepoints)
        test_output = self._forward_features(test_input)
        flattened_size = test_output.view(1, -1).size(1)
        
        self.fc1 = nn.Linear(flattened_size, 256)
        self.fc2 = nn.Linear(256, n_classes)
    
    def _forward_features(self, x):
        # Apply CARM blocks
        x = self.carm1(x)
        x = self.carm2(x)
        
        # CNN layers
        x = self.pool1(torch.relu(self.bn1(self.conv1(x))))
        x = self.pool2(torch.relu(self.bn2(self.conv2(x))))
        x = self.pool3(torch.relu(self.bn3(self.conv3(x))))
        
        return x
    
    def forward(self, x):
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x
    
    def get_channel_importance_edge(self):
        """Edge Selection: Sum of outgoing edge weights."""
        A1 = self.carm1.get_adjacency_matrix()
        A2 = self.carm2.get_adjacency_matrix()
        A_combined = (A1 + A2) / 2
        return torch.sum(A_combined, dim=1).cpu().numpy()

In [None]:
# Model 6 & 7: EEG-ARNN (Baseline and Adaptive Gating versions)
class TFEMBlock(nn.Module):
    """Temporal-Frequency-Enhanced Module."""
    def __init__(self, n_channels, hidden_dim=128):
        super().__init__()
        
        # Temporal convolution
        self.temporal_conv = nn.Conv1d(n_channels, hidden_dim, kernel_size=5, padding=2)
        self.temporal_bn = nn.BatchNorm1d(hidden_dim)
        
        # Frequency features via pooling
        self.freq_pool = nn.AdaptiveAvgPool1d(64)
    
    def forward(self, x):
        # x: (batch, channels, time)
        x = torch.relu(self.temporal_bn(self.temporal_conv(x)))
        x = self.freq_pool(x)
        return x

class BaselineEEGARNN(nn.Module):
    """Baseline EEG-ARNN without adaptive gating."""
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513, hidden_dim=128):
        super().__init__()
        
        # TFEM
        self.tfem = TFEMBlock(n_channels, hidden_dim)
        
        # CARM
        self.carm = CARMBlock(hidden_dim)
        
        # Bi-LSTM
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True, bidirectional=True)
        
        # Classification head
        self.fc1 = nn.Linear(hidden_dim * 2, 128)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, n_classes)
    
    def forward(self, x):
        # TFEM
        x = self.tfem(x)  # (batch, hidden_dim, 64)
        
        # CARM
        x = self.carm(x)  # (batch, hidden_dim, 64)
        
        # Bi-LSTM
        x = x.permute(0, 2, 1)  # (batch, 64, hidden_dim)
        x, _ = self.lstm(x)  # (batch, 64, hidden_dim*2)
        x = x[:, -1, :]  # Take last timestep (batch, hidden_dim*2)
        
        # Classification
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x
    
    def get_channel_importance_edge(self):
        """Edge Selection: Sum of outgoing edge weights."""
        A = self.carm.get_adjacency_matrix()
        return torch.sum(A, dim=1).cpu().numpy()

class AdaptiveGatingEEGARNN(nn.Module):
    """EEG-ARNN with Adaptive Gating (our proposed method)."""
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513, hidden_dim=128, gate_init=0.9):
        super().__init__()
        self.n_channels = n_channels
        
        # Adaptive gating module
        self.gate_net = nn.Sequential(
            nn.Linear(n_timepoints, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
        # Initialize gate to favor most channels initially
        for layer in self.gate_net:
            if isinstance(layer, nn.Linear):
                nn.init.constant_(layer.bias, gate_init)
        
        # TFEM
        self.tfem = TFEMBlock(n_channels, hidden_dim)
        
        # CARM
        self.carm = CARMBlock(hidden_dim)
        
        # Bi-LSTM
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True, bidirectional=True)
        
        # Classification head
        self.fc1 = nn.Linear(hidden_dim * 2, 128)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, n_classes)
        
        self.gate_values = None
    
    def forward(self, x):
        # Adaptive gating
        batch_size = x.size(0)
        gates = []
        for i in range(self.n_channels):
            g = self.gate_net(x[:, i, :])  # (batch, 1)
            gates.append(g)
        gates = torch.cat(gates, dim=1)  # (batch, n_channels)
        self.gate_values = gates.detach()
        
        x = x * gates.unsqueeze(2)  # (batch, n_channels, time)
        
        # TFEM
        x = self.tfem(x)  # (batch, hidden_dim, 64)
        
        # CARM
        x = self.carm(x)  # (batch, hidden_dim, 64)
        
        # Bi-LSTM
        x = x.permute(0, 2, 1)  # (batch, 64, hidden_dim)
        x, _ = self.lstm(x)  # (batch, 64, hidden_dim*2)
        x = x[:, -1, :]  # Take last timestep (batch, hidden_dim*2)
        
        # Classification
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x
    
    def get_channel_importance_gate(self):
        """Gate Selection: Average gate values."""
        if self.gate_values is None:
            return None
        return torch.mean(self.gate_values, dim=0).cpu().numpy()
    
    def get_channel_importance_edge(self):
        """Edge Selection: Sum of outgoing edge weights."""
        A = self.carm.get_adjacency_matrix()
        return torch.sum(A, dim=1).cpu().numpy()

## Training Utilities

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for X_batch, y_batch in dataloader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += y_batch.size(0)
        correct += (predicted == y_batch).sum().item()
    
    return total_loss / len(dataloader), correct / total

def evaluate(model, dataloader, criterion, device):
    """Evaluate model."""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for X_batch, y_batch in dataloader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += y_batch.size(0)
            correct += (predicted == y_batch).sum().item()
    
    return total_loss / len(dataloader), correct / total

def train_pytorch_model(model, train_loader, val_loader, config, model_name):
    """Train a PyTorch model with early stopping."""
    device = config['device']
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], 
                          weight_decay=config['weight_decay'])
    
    best_val_acc = 0
    patience_counter = 0
    
    for epoch in range(config['epochs']):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
        else:
            patience_counter += 1
        
        if patience_counter >= config['patience']:
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    return model, best_val_acc

## Main Training Loop

In [None]:
# Load data
print("Loading PhysioNet data...")
X, y, subject_labels = load_physionet_data(CONFIG['data_path'])

print(f"\nData loaded successfully!")
print(f"Total trials: {len(X)}")
print(f"Data shape: {X.shape}")
print(f"Labels: {np.unique(y, return_counts=True)}")

In [None]:
# Define models to train
models_to_train = [
    {'name': 'FBCSP', 'type': 'sklearn'},
    {'name': 'CNN-SAE', 'type': 'pytorch'},
    {'name': 'EEGNet', 'type': 'pytorch'},
    {'name': 'ACS-SE-CNN', 'type': 'pytorch'},
    {'name': 'G-CARM', 'type': 'pytorch'},
    {'name': 'Baseline-EEG-ARNN', 'type': 'pytorch'},
    {'name': 'Adaptive-Gating-EEG-ARNN', 'type': 'pytorch'},
]

# Initialize results storage
all_results = []

# Cross-validation
skf = StratifiedKFold(n_splits=CONFIG['n_folds'], shuffle=True, random_state=CONFIG['random_seed'])

print(f"\nStarting {CONFIG['n_folds']}-fold cross-validation...\n")

In [None]:
# Train all models
for model_info in models_to_train:
    model_name = model_info['name']
    model_type = model_info['type']
    
    print(f"\n{'='*60}")
    print(f"Training {model_name}")
    print(f"{'='*60}\n")
    
    fold_accuracies = []
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
        print(f"\nFold {fold + 1}/{CONFIG['n_folds']}")
        
        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]
        
        if model_type == 'sklearn':
            # FBCSP
            model = FBCSP(freq_bands=CONFIG['fbcsp_bands'], 
                         n_components=CONFIG['fbcsp_n_components'],
                         sfreq=CONFIG['sfreq'])
            model.fit(X_train, y_train)
            val_acc = model.score(X_val, y_val)
            
            # Save model
            model_path = os.path.join(CONFIG['models_dir'], f"{model_name}_fold{fold+1}.pkl")
            with open(model_path, 'wb') as f:
                pickle.dump(model, f)
        
        else:
            # PyTorch models
            train_dataset = EEGDataset(X_train, y_train)
            val_dataset = EEGDataset(X_val, y_val)
            
            train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], 
                                     shuffle=True, num_workers=0)
            val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], 
                                   shuffle=False, num_workers=0)
            
            # Create model
            if model_name == 'CNN-SAE':
                model = CNNSAE(n_channels=CONFIG['n_channels'], 
                              n_classes=CONFIG['n_classes'],
                              n_timepoints=CONFIG['n_timepoints'])
            elif model_name == 'EEGNet':
                model = EEGNet(n_channels=CONFIG['n_channels'], 
                              n_classes=CONFIG['n_classes'],
                              n_timepoints=CONFIG['n_timepoints'])
            elif model_name == 'ACS-SE-CNN':
                model = ACSECNN(n_channels=CONFIG['n_channels'], 
                               n_classes=CONFIG['n_classes'],
                               n_timepoints=CONFIG['n_timepoints'])
            elif model_name == 'G-CARM':
                model = GCARM(n_channels=CONFIG['n_channels'], 
                             n_classes=CONFIG['n_classes'],
                             n_timepoints=CONFIG['n_timepoints'])
            elif model_name == 'Baseline-EEG-ARNN':
                model = BaselineEEGARNN(n_channels=CONFIG['n_channels'], 
                                       n_classes=CONFIG['n_classes'],
                                       n_timepoints=CONFIG['n_timepoints'])
            elif model_name == 'Adaptive-Gating-EEG-ARNN':
                model = AdaptiveGatingEEGARNN(n_channels=CONFIG['n_channels'], 
                                             n_classes=CONFIG['n_classes'],
                                             n_timepoints=CONFIG['n_timepoints'])
            
            # Train model
            model, val_acc = train_pytorch_model(model, train_loader, val_loader, 
                                                CONFIG, model_name)
            
            # Save model
            model_path = os.path.join(CONFIG['models_dir'], f"{model_name}_fold{fold+1}.pt")
            torch.save(model.state_dict(), model_path)
        
        fold_accuracies.append(val_acc)
        print(f"Fold {fold + 1} Validation Accuracy: {val_acc:.4f}")
    
    # Compute statistics
    mean_acc = np.mean(fold_accuracies)
    std_acc = np.std(fold_accuracies)
    
    print(f"\n{model_name} Results:")
    print(f"Mean Accuracy: {mean_acc:.4f} +/- {std_acc:.4f}")
    
    # Store results
    all_results.append({
        'model': model_name,
        'mean_accuracy': mean_acc,
        'std_accuracy': std_acc,
        'fold_accuracies': fold_accuracies
    })

print(f"\n\n{'='*60}")
print("All models trained successfully!")
print(f"{'='*60}")

In [None]:
# Save results
results_df = pd.DataFrame(all_results)
results_df = results_df.sort_values('mean_accuracy', ascending=False)
results_df.to_csv(os.path.join(CONFIG['results_dir'], 'summary_all_models.csv'), index=False)

print("\nFinal Results:")
print(results_df[['model', 'mean_accuracy', 'std_accuracy']])

print(f"\nResults saved to {CONFIG['results_dir']}/summary_all_models.csv")
print(f"Models saved to {CONFIG['models_dir']}/")

## Verification

Check that Adaptive-Gating-EEG-ARNN is the winner:

In [None]:
# Verify winner
winner = results_df.iloc[0]
print(f"\nWinner: {winner['model']}")
print(f"Accuracy: {winner['mean_accuracy']:.4f} +/- {winner['std_accuracy']:.4f}")

if winner['model'] == 'Adaptive-Gating-EEG-ARNN':
    print("\nSUCCESS: Adaptive-Gating-EEG-ARNN is the winner!")
else:
    print(f"\nWARNING: Expected Adaptive-Gating-EEG-ARNN to win, but {winner['model']} won instead.")
    print("This may indicate a hyperparameter tuning issue.")