# PhysioNet EEG: Complete Pipeline - Training, Channel Selection & Results Generation

This unified notebook contains the complete experimental pipeline:
1. Train all 7 baseline models with improved hyperparameters
2. Evaluate channel selection methods (Edge, Aggregation, Gate)
3. Generate all paper-ready results, tables, and figures

**Improved Hyperparameters for Faster Convergence:**
- Epochs: 25 (with early stopping patience=5)
- Learning rate: 0.002 (2x original)
- Adaptive LR scheduling with ReduceLROnPlateau
- Early stopping enabled for faster training

**Expected Runtime**: ~18-20 hours on Kaggle GPU

**Estimated Breakdown:**
- Initial training (7 models × 3 folds × ~15 min): ~5 hours
- Channel selection (2 models × 3 methods × 5 k-values × 3 folds × ~15 min): ~11 hours
- Retention (6 k-values × 3 folds × ~15 min): ~4.5 hours

**Input**: `/kaggle/input/eeg-preprocessed-data/derived/`

**Output**:
- `models/` - Trained model checkpoints
- `results/` - CSV files with all results
- `figures/` - Publication-ready figures

## 1. Configuration and Imports

In [None]:
import os
import gc
import math
import json
from copy import deepcopy

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import mne
from sklearn.model_selection import StratifiedKFold
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.preprocessing import StandardScaler
import pickle
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

mne.set_log_level('ERROR')

print("All imports successful!")

In [None]:
# Configuration with improved hyperparameters
CONFIG = {
    'data_path': '/kaggle/input/eeg-preprocessed-data/derived',
    'output_dir': './',
    'results_dir': './results',
    'models_dir': './models',
    'figures_dir': './figures',

    'n_folds': 3,
    'random_seed': 42,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',

    # IMPROVED Training hyperparameters for faster convergence
    'batch_size': 64,
    'epochs': 25,  # Balanced: faster than 50, better convergence than 20
    'learning_rate': 0.002,  # 2x original for faster convergence
    'weight_decay': 1e-4,
    'patience': 5,  # Early stopping patience
    'scheduler_patience': 2,  # LR reduction patience
    'scheduler_factor': 0.5,
    'use_early_stopping': True,  # Enabled
    'min_lr': 1e-6,

    # Data parameters
    'n_channels': 64,
    'n_classes': 2,
    'sfreq': 128,
    'tmin': 0.0,
    'tmax': 4.0,
    'n_timepoints': 513,
    'hidden_dim': 128,
    'mi_runs': [7, 8, 11, 12],

    # FBCSP parameters
    'fbcsp_bands': [(4, 8), (8, 12), (12, 16), (16, 20), (20, 24), (24, 28), (28, 32), (32, 36), (36, 40)],
    'fbcsp_n_components': 4,

    # Gating regularization
    'gating': {
        'gate_init': 0.9,
        'l1_lambda': 1e-3,
    },
    
    # Channel selection k values (REDUCED FOR FASTER RUNTIME)
    'k_values': [10, 20, 30, 40, 50],  # 5 values instead of more
    'retention_k_values': [10, 15, 20, 25, 30, 35],  # 6 values for retention curve
}

# Create output directories
os.makedirs(CONFIG['results_dir'], exist_ok=True)
os.makedirs(CONFIG['models_dir'], exist_ok=True)
os.makedirs(CONFIG['figures_dir'], exist_ok=True)

# Set random seeds
np.random.seed(CONFIG['random_seed'])
torch.manual_seed(CONFIG['random_seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed(CONFIG['random_seed'])
    torch.backends.cudnn.deterministic = True

print(f"Device: {CONFIG['device']}")
print(f"Data path: {CONFIG['data_path']}")
print(f"Epochs: {CONFIG['epochs']}")
print(f"Learning rate: {CONFIG['learning_rate']}")
print(f"Early stopping: {CONFIG['use_early_stopping']} (patience={CONFIG['patience']})")
print(f"\nChannel selection k-values: {CONFIG['k_values']}")
print(f"Retention analysis k-values: {CONFIG['retention_k_values']}")

# Calculate estimated runtime
baseline_runs = 7 * 3  # 7 models × 3 folds
cs_runs = 2 * 3 * len(CONFIG['k_values']) * 3  # 2 models × 3 methods × 5 k-values × 3 folds
retention_runs = len(CONFIG['retention_k_values']) * 3  # 6 k-values × 3 folds
total_runs = baseline_runs + cs_runs + retention_runs

print(f"\nEstimated training runs:")
print(f"  - Baseline models: {baseline_runs}")
print(f"  - Channel selection: {cs_runs}")
print(f"  - Retention analysis: {retention_runs}")
print(f"  - TOTAL: {total_runs} training runs")
print(f"\nEstimated runtime (15 min/run): ~{total_runs * 15 / 60:.1f} hours")

## 2. Data Loading Utilities

In [None]:
def load_physionet_data(data_path, subject_ids=None):
    """
    Load preprocessed PhysioNet motor imagery data.
    Supports both preprocessed folder structure and legacy flat directory.
    """
    data_root = os.path.abspath(data_path)
    if not os.path.isdir(data_root):
        raise FileNotFoundError(f"Data path not found: {data_root}")

    config = globals().get('CONFIG', {})
    tmin = float(config.get('tmin', 0.0))
    tmax = float(config.get('tmax', 4.0))
    mi_runs = [int(r) for r in config.get('mi_runs', [7, 8, 11, 12])]
    event_id = {'T1': 1, 'T2': 2}

    def normalize_subject(value):
        if value is None:
            return None
        if isinstance(value, str) and value.upper().startswith('S'):
            value = value[1:]
        try:
            return int(value)
        except Exception:
            return None

    subject_filter = None
    if subject_ids is not None:
        subject_filter = set()
        for sid in subject_ids:
            norm = normalize_subject(sid)
            if norm is not None:
                subject_filter.add(norm)
        if not subject_filter:
            subject_filter = None

    def aggregate_results(blocks_X, blocks_y, blocks_subjects):
        X = np.concatenate(blocks_X, axis=0)
        y = np.concatenate(blocks_y, axis=0)
        subjects = np.concatenate(blocks_subjects, axis=0)
        print(f"Loaded {len(X)} trials from {len(np.unique(subjects))} subjects")
        print(f"Data shape: {X.shape}")
        print(f"Label distribution: {np.bincount(y)}")
        return X, y, subjects

    subject_root = data_root
    preprocessed_dir = os.path.join(data_root, 'preprocessed')
    if os.path.isdir(preprocessed_dir):
        subject_root = preprocessed_dir
    
    subject_dirs = [d for d in sorted(os.listdir(subject_root))
                    if os.path.isdir(os.path.join(subject_root, d)) and d.upper().startswith('S')]

    all_X, all_y, all_subjects = [], [], []
    
    if subject_dirs:
        print(f"Detected {len(subject_dirs)} preprocessed subject folders")
        label_map = {event_id['T1']: 0, event_id['T2']: 1}
        
        for subject_dir in subject_dirs:
            subject_numeric = normalize_subject(subject_dir)
            if subject_filter and subject_numeric not in subject_filter:
                continue
            
            subject_path = os.path.join(subject_root, subject_dir)
            
            for run_id in mi_runs:
                candidate_names = [
                    f"{subject_dir}R{run_id:02d}_preproc_raw.fif",
                    f"{subject_dir}R{run_id:02d}_raw.fif",
                    f"{subject_dir}R{run_id:02d}.fif",
                    f"{subject_dir}_R{run_id:02d}.fif",
                ]
                
                run_path = None
                for name in candidate_names:
                    candidate = os.path.join(subject_path, name)
                    if os.path.exists(candidate):
                        run_path = candidate
                        break
                
                if run_path is None:
                    continue
                
                try:
                    raw = mne.io.read_raw_fif(run_path, preload=True, verbose=False)
                except Exception as e:
                    continue
                
                picks = mne.pick_types(raw.info, eeg=True, meg=False, stim=False, eog=False)
                if len(picks) == 0:
                    continue
                
                try:
                    events, _ = mne.events_from_annotations(raw, event_id=event_id)
                except Exception:
                    continue
                
                if len(events) == 0:
                    continue
                
                try:
                    epochs = mne.Epochs(raw, events, event_id=event_id, tmin=tmin, tmax=tmax,
                                        baseline=None, preload=True, picks=picks, verbose=False)
                except Exception:
                    continue
                
                data = epochs.get_data()
                labels = epochs.events[:, 2]
                mapped = np.array([label_map.get(lbl, -1) for lbl in labels])
                valid_mask = mapped >= 0
                
                if not np.any(valid_mask):
                    continue
                
                all_X.append(data[valid_mask])
                all_y.append(mapped[valid_mask])
                subj_label = subject_numeric if subject_numeric is not None else -1
                all_subjects.append(np.full(np.sum(valid_mask), subj_label))
        
        if all_X:
            return aggregate_results(all_X, all_y, all_subjects)
    
    # Legacy format fallback
    legacy_files = [f for f in os.listdir(data_root) if f.endswith('.fif')]
    if not legacy_files:
        raise ValueError("No valid PhysioNet files found")
    
    print(f"Found {len(legacy_files)} legacy epoch files. Loading...")
    for fname in legacy_files:
        filepath = os.path.join(data_root, fname)
        try:
            epochs = mne.read_epochs(filepath, preload=True, verbose=False)
        except Exception:
            continue
        
        current_event_id = epochs.event_id
        if not current_event_id:
            continue
        
        label_lookup = {}
        if 'T1' in current_event_id:
            label_lookup[current_event_id['T1']] = 0
        if 'T2' in current_event_id:
            label_lookup[current_event_id['T2']] = 1
        
        if not label_lookup:
            continue
        
        labels = np.array([label_lookup.get(epochs.events[i, -1], -1) for i in range(len(epochs))])
        valid = labels >= 0
        
        if not np.any(valid):
            continue
        
        data = epochs.get_data()[valid]
        labels = labels[valid]
        subj = normalize_subject(fname.split('_')[0])
        subj_arr = np.full(len(labels), subj if subj is not None else -1)
        
        all_X.append(data)
        all_y.append(labels)
        all_subjects.append(subj_arr)
    
    if not all_X:
        raise ValueError("No valid trials were loaded")
    
    return aggregate_results(all_X, all_y, all_subjects)


class EEGDataset(Dataset):
    """PyTorch Dataset for EEG data."""
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

## 3. Model Architectures

In [None]:
# Model 1: FBCSP
class FBCSP:
    """Filter Bank Common Spatial Patterns with LDA classifier."""
    def __init__(self, freq_bands, n_components=4, sfreq=128):
        self.freq_bands = freq_bands
        self.n_components = n_components
        self.sfreq = sfreq
        self.csp_list = []
        self.classifier = None
    
    def fit(self, X, y):
        from mne.decoding import CSP
        
        all_features = []
        
        for low, high in self.freq_bands:
            X_filtered = self._bandpass_filter(X, low, high)
            csp = CSP(n_components=self.n_components, reg=None, log=True, norm_trace=False)
            features = csp.fit_transform(X_filtered, y)
            self.csp_list.append(csp)
            all_features.append(features)
        
        all_features = np.concatenate(all_features, axis=1)
        self.classifier = LinearDiscriminantAnalysis()
        self.classifier.fit(all_features, y)
        
        return self
    
    def predict(self, X):
        all_features = []
        
        for idx, (low, high) in enumerate(self.freq_bands):
            X_filtered = self._bandpass_filter(X, low, high)
            features = self.csp_list[idx].transform(X_filtered)
            all_features.append(features)
        
        all_features = np.concatenate(all_features, axis=1)
        return self.classifier.predict(all_features)
    
    def score(self, X, y):
        predictions = self.predict(X)
        return np.mean(predictions == y)
    
    def _bandpass_filter(self, X, low, high):
        from scipy.signal import butter, filtfilt
        
        nyq = self.sfreq / 2
        low_norm = low / nyq
        high_norm = high / nyq
        b, a = butter(4, [low_norm, high_norm], btype='band')
        
        X_filtered = np.zeros_like(X)
        for i in range(X.shape[0]):
            for j in range(X.shape[1]):
                X_filtered[i, j, :] = filtfilt(b, a, X[i, j, :])
        
        return X_filtered

In [None]:
# Model 2: CNN-SAE
class SpatialAttention(nn.Module):
    def __init__(self, n_channels):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(n_channels, n_channels // 4),
            nn.ReLU(),
            nn.Linear(n_channels // 4, n_channels),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        pooled = torch.mean(x, dim=2)
        weights = self.attention(pooled)
        return x * weights.unsqueeze(2)


class CNNSAE(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513):
        super().__init__()
        
        self.spatial_attention = SpatialAttention(n_channels)
        
        self.conv1 = nn.Conv1d(n_channels, 64, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(64)
        self.pool1 = nn.MaxPool1d(2)
        
        self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(128)
        self.pool2 = nn.MaxPool1d(2)
        
        self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(256)
        self.pool3 = nn.MaxPool1d(2)
        
        self.dropout = nn.Dropout(0.5)
        
        test_input = torch.zeros(1, n_channels, n_timepoints)
        test_output = self._forward_features(test_input)
        flattened_size = test_output.view(1, -1).size(1)
        
        self.fc1 = nn.Linear(flattened_size, 256)
        self.fc2 = nn.Linear(256, n_classes)
    
    def _forward_features(self, x):
        x = self.spatial_attention(x)
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        return x
    
    def forward(self, x):
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [None]:
# Model 3: EEGNet
class EEGNet(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513, F1=8, D=2, F2=16):
        super().__init__()

        self.conv1 = nn.Conv2d(1, F1, (1, 64), padding=(0, 32), bias=False)
        self.bn1 = nn.BatchNorm2d(F1)

        self.conv2 = nn.Conv2d(F1, F1 * D, (n_channels, 1), groups=F1, bias=False)
        self.bn2 = nn.BatchNorm2d(F1 * D)
        self.pool1 = nn.AvgPool2d((1, 4))
        self.dropout1 = nn.Dropout(0.5)

        self.conv3 = nn.Conv2d(F1 * D, F2, (1, 16), padding=(0, 8), bias=False)
        self.bn3 = nn.BatchNorm2d(F2)
        self.pool2 = nn.AvgPool2d((1, 8))
        self.dropout2 = nn.Dropout(0.5)

        test_input = torch.zeros(1, 1, n_channels, n_timepoints)
        test_output = self._forward_features(test_input)
        flattened_size = test_output.view(1, -1).size(1)

        self.fc = nn.Linear(flattened_size, n_classes)

    def _forward_features(self, x):
        x = self.bn1(self.conv1(x))
        x = self.dropout1(self.pool1(F.elu(self.bn2(self.conv2(x)))))
        x = self.dropout2(self.pool2(F.elu(self.bn3(self.conv3(x)))))
        return x

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [None]:
# Model 4: ACS-SE-CNN
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=4):
        super().__init__()
        hidden = max(1, channels // reduction)
        self.fc1 = nn.Linear(channels, hidden)
        self.fc2 = nn.Linear(hidden, channels)
    
    def forward(self, x):
        squeeze = torch.mean(x, dim=2)
        excitation = F.relu(self.fc1(squeeze))
        excitation = torch.sigmoid(self.fc2(excitation))
        return x * excitation.unsqueeze(2)


class ACSECNN(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513):
        super().__init__()
        
        self.channel_attention = nn.Sequential(
            nn.Linear(n_timepoints, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
        self.se1 = SEBlock(n_channels)
        self.se2 = SEBlock(128)
        self.se3 = SEBlock(256)
        
        self.conv1 = nn.Conv1d(n_channels, 128, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(128)
        self.pool1 = nn.MaxPool1d(2)
        
        self.conv2 = nn.Conv1d(128, 256, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(256)
        self.pool2 = nn.MaxPool1d(2)
        
        self.conv3 = nn.Conv1d(256, 512, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(512)
        self.pool3 = nn.MaxPool1d(2)
        
        self.dropout = nn.Dropout(0.5)
        
        test_input = torch.zeros(1, n_channels, n_timepoints)
        test_output = self._forward_features(test_input)
        flattened_size = test_output.view(1, -1).size(1)
        
        self.fc1 = nn.Linear(flattened_size, 256)
        self.fc2 = nn.Linear(256, n_classes)
        
        self.channel_weights = None
    
    def _forward_features(self, x):
        channel_weights = []
        for i in range(x.size(1)):
            w = self.channel_attention(x[:, i, :])
            channel_weights.append(w)
        channel_weights = torch.cat(channel_weights, dim=1)
        self.channel_weights = channel_weights.detach()
        
        x = x * channel_weights.unsqueeze(2)
        
        x = self.se1(x)
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        
        x = self.se2(x)
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        
        x = self.se3(x)
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        
        return x
    
    def forward(self, x):
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [None]:
# Model 5: G-CARM
class CARMBlock(nn.Module):
    def __init__(self, n_channels):
        super().__init__()
        self.n_channels = n_channels
        self.A = nn.Parameter(torch.randn(n_channels, n_channels) * 0.01)
        self.norm = nn.LayerNorm(n_channels)
    
    def forward(self, x):
        batch_size, n_channels, n_time = x.shape
        A_norm = torch.softmax(self.A, dim=1)
        x_reshaped = x.permute(0, 2, 1)
        x_graph = torch.matmul(x_reshaped, A_norm.t())
        x_graph = x_graph.permute(0, 2, 1)
        return x_graph
    
    def get_adjacency_matrix(self):
        return torch.softmax(self.A, dim=1).detach()


class GCARM(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513):
        super().__init__()
        
        self.carm1 = CARMBlock(n_channels)
        self.carm2 = CARMBlock(n_channels)
        
        self.conv1 = nn.Conv1d(n_channels, 128, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(128)
        self.pool1 = nn.MaxPool1d(2)
        
        self.conv2 = nn.Conv1d(128, 256, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(256)
        self.pool2 = nn.MaxPool1d(2)
        
        self.conv3 = nn.Conv1d(256, 512, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(512)
        self.pool3 = nn.MaxPool1d(2)
        
        self.dropout = nn.Dropout(0.5)
        
        test_input = torch.zeros(1, n_channels, n_timepoints)
        test_output = self._forward_features(test_input)
        flattened_size = test_output.view(1, -1).size(1)
        
        self.fc1 = nn.Linear(flattened_size, 256)
        self.fc2 = nn.Linear(256, n_classes)
    
    def _forward_features(self, x):
        x = self.carm1(x)
        x = self.carm2(x)
        
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        
        return x
    
    def forward(self, x):
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x
    
    def get_channel_importance_edge(self):
        A1 = self.carm1.get_adjacency_matrix()
        A2 = self.carm2.get_adjacency_matrix()
        A_combined = (A1 + A2) / 2
        return torch.sum(A_combined, dim=1).cpu().numpy()

In [None]:
# Models 6 & 7: Baseline EEG-ARNN and Adaptive Gating EEG-ARNN
class GraphConvLayer(nn.Module):
    def __init__(self, num_channels, hidden_dim):
        super().__init__()
        self.num_channels = num_channels
        self.hidden_dim = hidden_dim
        self.A = nn.Parameter(torch.randn(num_channels, num_channels) * 0.01)
        self.theta = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.bn = nn.BatchNorm2d(hidden_dim)
        self.act = nn.ELU()

    def forward(self, x):
        B, H, C, T = x.shape
        A = torch.sigmoid(self.A)
        A = 0.5 * (A + A.t())
        I = torch.eye(C, device=A.device)
        A_hat = A + I
        D = torch.diag(torch.pow(A_hat.sum(1).clamp_min(1e-6), -0.5))
        A_norm = D @ A_hat @ D

        x_perm = x.permute(0, 3, 2, 1).contiguous().view(B * T, C, H)
        x_g = A_norm @ x_perm
        x_g = self.theta(x_g)
        x_g = x_g.view(B, T, C, H).permute(0, 3, 2, 1)
        x_out = self.bn(x_g)
        return self.act(x_out)

    def get_adjacency(self):
        with torch.no_grad():
            A = torch.sigmoid(self.A)
            A = 0.5 * (A + A.t())
            return A.cpu().numpy()


class TemporalConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=16, pool=True):
        super().__init__()
        self.pool = pool
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(1, kernel_size),
                              padding=(0, kernel_size // 2), bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.ELU()
        self.pool_layer = nn.AvgPool2d(kernel_size=(1, 2)) if pool else None

    def forward(self, x):
        x = self.act(self.bn(self.conv(x)))
        if self.pool_layer is not None:
            x = self.pool_layer(x)
        return x


class BaselineEEGARNN(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513, hidden_dim=128):
        super().__init__()
        self.n_channels = n_channels
        self.hidden_dim = hidden_dim
        self.use_gate_regularizer = False
        self.gate_penalty_tensor = None
        self.latest_gate_values = None

        self.t1 = TemporalConv(1, hidden_dim, 16, pool=False)
        self.g1 = GraphConvLayer(n_channels, hidden_dim)
        self.t2 = TemporalConv(hidden_dim, hidden_dim, 16, pool=True)
        self.g2 = GraphConvLayer(n_channels, hidden_dim)
        self.t3 = TemporalConv(hidden_dim, hidden_dim, 16, pool=True)
        self.g3 = GraphConvLayer(n_channels, hidden_dim)

        with torch.no_grad():
            dummy = torch.zeros(1, n_channels, n_timepoints)
            feat = self._forward_features(self._prepare_input(dummy))
            self.feature_dim = feat.view(1, -1).size(1)

        self.fc1 = nn.Linear(self.feature_dim, 256)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, n_classes)

    def _prepare_input(self, x):
        if x.dim() == 3:
            x = x.unsqueeze(1)
        return x

    def _forward_features(self, x):
        x = self.g1(self.t1(x))
        x = self.g2(self.t2(x))
        x = self.g3(self.t3(x))
        return x

    def _forward_from_prepared(self, x):
        features = self._forward_features(x)
        x = features.view(features.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

    def forward(self, x):
        prepared = self._prepare_input(x)
        self.gate_penalty_tensor = None
        self.latest_gate_values = None
        return self._forward_from_prepared(prepared)

    def get_final_adjacency(self):
        return self.g3.get_adjacency()

    def get_channel_importance_edge(self):
        adjacency = self.get_final_adjacency()
        return np.sum(adjacency, axis=1)


class AdaptiveGatingEEGARNN(BaselineEEGARNN):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513, hidden_dim=128, gate_init=0.9):
        super().__init__(n_channels, n_classes, n_timepoints, hidden_dim)
        self.use_gate_regularizer = True
        
        self.gate_net = nn.Sequential(
            nn.Linear(n_channels * 2, n_channels),
            nn.ReLU(),
            nn.Linear(n_channels, n_channels),
            nn.Sigmoid()
        )
        
        init_value = float(np.clip(gate_init, 1e-3, 1 - 1e-3))
        init_bias = math.log(init_value / (1.0 - init_value))
        with torch.no_grad():
            self.gate_net[-2].bias.fill_(init_bias)
        
        self.latest_gate_values = None

    def compute_gates(self, x):
        x_s = x.squeeze(1)
        ch_mean = x_s.mean(dim=2)
        ch_std = x_s.std(dim=2)
        stats = torch.cat([ch_mean, ch_std], dim=1)
        return self.gate_net(stats)

    def forward(self, x):
        prepared = self._prepare_input(x)
        gates = self.compute_gates(prepared)
        self.gate_penalty_tensor = gates
        self.latest_gate_values = gates.detach()
        gated = prepared * gates.view(gates.size(0), 1, gates.size(1), 1)
        return self._forward_from_prepared(gated)

    def get_channel_importance_gate(self):
        if self.latest_gate_values is None:
            return None
        return self.latest_gate_values.mean(dim=0).cpu().numpy()

## 4. Training Utilities

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device, l1_lambda=0.0):
    """Train for one epoch."""
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for X_batch, y_batch in dataloader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)

        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)

        gate_penalty = getattr(model, 'gate_penalty_tensor', None)
        if l1_lambda > 0 and gate_penalty is not None:
            loss = loss + l1_lambda * gate_penalty.abs().mean()

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += y_batch.size(0)
        correct += (predicted == y_batch).sum().item()

    denom = max(1, len(dataloader))
    return total_loss / denom, correct / max(1, total)


def evaluate(model, dataloader, criterion, device):
    """Evaluate model."""
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for X_batch, y_batch in dataloader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)

            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += y_batch.size(0)
            correct += (predicted == y_batch).sum().item()

    denom = max(1, len(dataloader))
    return total_loss / denom, correct / max(1, total)


def train_pytorch_model(model, train_loader, val_loader, config, model_name=''):
    """Train a PyTorch model with improved scheduler and early stopping."""
    device = config['device']
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], 
                          weight_decay=config['weight_decay'])
    
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=config.get('scheduler_factor', 0.5), 
        patience=config.get('scheduler_patience', 3), 
        min_lr=config.get('min_lr', 1e-6),
        verbose=False
    )

    l1_lambda = config.get('gating', {}).get('l1_lambda', 0.0) if getattr(model, 'use_gate_regularizer', False) else 0.0
    use_early_stopping = config.get('use_early_stopping', False)
    max_patience = config.get('patience', 10)
    patience_counter = 0

    best_state = deepcopy(model.state_dict())
    best_val_acc = 0.0
    best_val_loss = float('inf')

    for epoch in range(config['epochs']):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device, l1_lambda)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)
        scheduler.step(val_loss)

        improved = val_acc > best_val_acc or (val_acc == best_val_acc and val_loss < best_val_loss)
        if improved:
            best_state = deepcopy(model.state_dict())
            best_val_acc = val_acc
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1

        prefix = model_name if model_name else 'Model'
        if epoch % 5 == 0 or improved:
            print(f"[{prefix}] Epoch {epoch + 1}/{config['epochs']} - "
                  f"Train: {train_acc:.4f} | Val: {val_acc:.4f} | Best: {best_val_acc:.4f}")

        if use_early_stopping and patience_counter >= max_patience:
            print(f"Early stopping triggered for {prefix} at epoch {epoch + 1}")
            break

    model.load_state_dict(best_state)
    return best_state, best_val_acc

## 5. Load Data

In [None]:
print("Loading PhysioNet data...")
X, y, subject_labels = load_physionet_data(CONFIG['data_path'])

print(f"\nData loaded successfully!")
print(f"Total trials: {len(X)}")
print(f"Data shape: {X.shape}")
print(f"Labels: {np.unique(y, return_counts=True)}")

## 6. Train All Models

In [None]:
# Define models to train
models_to_train = [
    {'name': 'FBCSP', 'type': 'sklearn'},
    {'name': 'CNN-SAE', 'type': 'pytorch'},
    {'name': 'EEGNet', 'type': 'pytorch'},
    {'name': 'ACS-SE-CNN', 'type': 'pytorch'},
    {'name': 'G-CARM', 'type': 'pytorch'},
    {'name': 'Baseline-EEG-ARNN', 'type': 'pytorch'},
    {'name': 'Adaptive-Gating-EEG-ARNN', 'type': 'pytorch'},
]

all_results = []
skf = StratifiedKFold(n_splits=CONFIG['n_folds'], shuffle=True, random_state=CONFIG['random_seed'])

print(f"Starting {CONFIG['n_folds']}-fold cross-validation...\n")

In [None]:
# Train all models
for model_info in models_to_train:
    model_name = model_info['name']
    model_type = model_info['type']

    print(f"\n{'='*60}")
    print(f"Training {model_name}")
    print(f"{'='*60}\n")

    fold_accuracies = []

    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
        print(f"\nFold {fold + 1}/{CONFIG['n_folds']}")

        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]

        if model_type == 'sklearn':
            model = FBCSP(freq_bands=CONFIG['fbcsp_bands'],
                          n_components=CONFIG['fbcsp_n_components'],
                          sfreq=CONFIG['sfreq'])
            model.fit(X_train, y_train)
            val_acc = model.score(X_val, y_val)

            model_path = os.path.join(CONFIG['models_dir'], f"{model_name}_fold{fold+1}.pkl")
            with open(model_path, 'wb') as f:
                pickle.dump(model, f)
        else:
            train_dataset = EEGDataset(X_train, y_train)
            val_dataset = EEGDataset(X_val, y_val)

            train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'],
                                      shuffle=True, num_workers=0)
            val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'],
                                    shuffle=False, num_workers=0)

            base_kwargs = {
                'n_channels': CONFIG['n_channels'],
                'n_classes': CONFIG['n_classes'],
                'n_timepoints': CONFIG['n_timepoints'],
            }

            if model_name == 'CNN-SAE':
                model = CNNSAE(**base_kwargs)
            elif model_name == 'EEGNet':
                model = EEGNet(**base_kwargs)
            elif model_name == 'ACS-SE-CNN':
                model = ACSECNN(**base_kwargs)
            elif model_name == 'G-CARM':
                model = GCARM(**base_kwargs)
            elif model_name == 'Baseline-EEG-ARNN':
                model = BaselineEEGARNN(hidden_dim=CONFIG['hidden_dim'], **base_kwargs)
            elif model_name == 'Adaptive-Gating-EEG-ARNN':
                model = AdaptiveGatingEEGARNN(hidden_dim=CONFIG['hidden_dim'],
                                              gate_init=CONFIG['gating']['gate_init'],
                                              **base_kwargs)
            else:
                raise ValueError(f"Unknown model: {model_name}")

            best_state, val_acc = train_pytorch_model(model, train_loader, val_loader,
                                                      CONFIG, model_name)
            
            model_path = os.path.join(CONFIG['models_dir'], f"{model_name}_fold{fold+1}.pt")
            torch.save(best_state, model_path)
            
            # Clean up GPU memory
            del model
            torch.cuda.empty_cache()
            gc.collect()

        fold_accuracies.append(val_acc)
        print(f"Fold {fold + 1} Validation Accuracy: {val_acc:.4f}")

    mean_acc = np.mean(fold_accuracies)
    std_acc = np.std(fold_accuracies)

    print(f"\n{model_name} Results:")
    print(f"Mean Accuracy: {mean_acc:.4f} +/- {std_acc:.4f}")

    all_results.append({
        'model': model_name,
        'mean_accuracy': mean_acc,
        'std_accuracy': std_acc,
        'fold_accuracies': fold_accuracies
    })

print(f"\n{'='*60}")
print("All models trained successfully!")
print(f"{'='*60}")

In [None]:
# Save initial results
results_df = pd.DataFrame(all_results)
results_df = results_df.sort_values('mean_accuracy', ascending=False)
results_df.to_csv(os.path.join(CONFIG['results_dir'], 'summary_all_models.csv'), index=False)

print("\nInitial Results:")
print(results_df[['model', 'mean_accuracy', 'std_accuracy']])

winner = results_df.iloc[0]
print(f"\nBest Model: {winner['model']}")
print(f"Accuracy: {winner['mean_accuracy']:.4f} +/- {winner['std_accuracy']:.4f}")

## 7. Channel Selection Evaluation

In [None]:
# Channel selection utilities
def get_channel_importance_aggregation(model, dataloader, device):
    """Aggregation Selection (AS) using averaged feature activations."""
    model.eval()
    channel_stats = []

    with torch.no_grad():
        for X_batch, _ in dataloader:
            X_batch = X_batch.to(device)
            prepared = model._prepare_input(X_batch)
            features = model._forward_features(prepared)
            activations = torch.mean(torch.abs(features), dim=(1, 3))
            channel_stats.append(activations.cpu())

    if not channel_stats:
        return np.zeros(model.n_channels)
    stacked = torch.cat(channel_stats, dim=0)
    return stacked.mean(dim=0).numpy()


def compute_gate_importance(model, dataloader, device):
    """Average adaptive gate values across the entire dataset."""
    model.eval()
    gate_batches = []

    with torch.no_grad():
        for X_batch, _ in dataloader:
            X_batch = X_batch.to(device)
            _ = model(X_batch)
            latest = getattr(model, 'latest_gate_values', None)
            if latest is not None:
                gate_batches.append(latest.cpu())

    if not gate_batches:
        return np.ones(model.n_channels) / model.n_channels
    stacked = torch.cat(gate_batches, dim=0)
    return stacked.mean(dim=0).numpy()


def select_top_k_channels(importance_scores, k):
    """Select top k channels based on importance scores."""
    top_k_indices = np.argsort(importance_scores)[-k:]
    return sorted(top_k_indices)


def apply_channel_selection(X, selected_channels):
    """Apply channel selection to data."""
    return X[:, selected_channels, :]

In [None]:
# Models to evaluate for channel selection
models_to_evaluate = [
    {'name': 'Baseline-EEG-ARNN', 'methods': ['edge', 'aggregation']},
    {'name': 'Adaptive-Gating-EEG-ARNN', 'methods': ['edge', 'aggregation', 'gate']},
]

channel_selection_results = []

def build_model(model_name, n_channels):
    """Build model with specified number of channels."""
    kwargs = {
        'n_channels': n_channels,
        'n_classes': CONFIG['n_classes'],
        'n_timepoints': CONFIG['n_timepoints'],
    }
    if model_name == 'Baseline-EEG-ARNN':
        return BaselineEEGARNN(hidden_dim=CONFIG['hidden_dim'], **kwargs)
    return AdaptiveGatingEEGARNN(hidden_dim=CONFIG['hidden_dim'],
                                 gate_init=CONFIG['gating']['gate_init'],
                                 **kwargs)

In [None]:
# Channel selection evaluation loop
print(f"\n{'='*60}")
print("CHANNEL SELECTION EVALUATION")
print(f"{'='*60}\n")

for model_info in models_to_evaluate:
    model_name = model_info['name']
    selection_methods = model_info['methods']

    print(f"\nEvaluating {model_name}\n{'-'*60}")

    for method in selection_methods:
        print(f"\nMethod: {method.upper()}")

        for k in CONFIG['k_values']:
            print(f"\n  k={k} channels:", end=' ')
            fold_accuracies = []

            for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
                X_train, X_val = X[train_idx], X[val_idx]
                y_train, y_val = y[train_idx], y[val_idx]

                # Load trained model
                model = build_model(model_name, CONFIG['n_channels'])
                model_path = os.path.join(CONFIG['models_dir'], f"{model_name}_fold{fold+1}.pt")
                state_dict = torch.load(model_path, map_location=CONFIG['device'])
                model.load_state_dict(state_dict)
                model = model.to(CONFIG['device'])
                model.eval()

                # Compute importance scores
                if method == 'edge':
                    importance_scores = model.get_channel_importance_edge()
                elif method == 'aggregation':
                    train_dataset = EEGDataset(X_train, y_train)
                    train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'],
                                              shuffle=False, num_workers=0)
                    importance_scores = get_channel_importance_aggregation(model, train_loader,
                                                                            CONFIG['device'])
                else:  # gate
                    train_dataset = EEGDataset(X_train, y_train)
                    train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'],
                                              shuffle=False, num_workers=0)
                    importance_scores = compute_gate_importance(model, train_loader, CONFIG['device'])

                # Select channels
                selected_channels = select_top_k_channels(importance_scores, k)
                X_train_selected = apply_channel_selection(X_train, selected_channels)
                X_val_selected = apply_channel_selection(X_val, selected_channels)

                # Train new model with selected channels
                new_model = build_model(model_name, k)
                train_dataset = EEGDataset(X_train_selected, y_train)
                val_dataset = EEGDataset(X_val_selected, y_val)

                train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'],
                                          shuffle=True, num_workers=0)
                val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'],
                                        shuffle=False, num_workers=0)

                best_state, val_acc = train_pytorch_model(new_model, train_loader, val_loader,
                                                          CONFIG, f"{model_name}-{method}-k{k}")
                fold_accuracies.append(val_acc)
                
                # Clean up
                del model, new_model
                torch.cuda.empty_cache()
                gc.collect()

            mean_acc = np.mean(fold_accuracies)
            std_acc = np.std(fold_accuracies)
            print(f"{mean_acc:.4f} +/- {std_acc:.4f}")

            channel_selection_results.append({
                'model': model_name,
                'method': method,
                'k': k,
                'mean_accuracy': mean_acc,
                'std_accuracy': std_acc,
                'fold_accuracies': fold_accuracies
            })

print(f"\n{'='*60}")
print("Channel selection evaluation complete!")
print(f"{'='*60}")

In [None]:
# Save channel selection results
cs_df = pd.DataFrame(channel_selection_results)
cs_df.to_csv(os.path.join(CONFIG['results_dir'], 'channel_selection_results.csv'), index=False)

print("\nChannel Selection Results:")
print(cs_df[['model', 'method', 'k', 'mean_accuracy', 'std_accuracy']])

# Find best method for each model
print("\n" + "="*60)
print("Best Channel Selection Methods:")
print("="*60)

for model_name in ['Baseline-EEG-ARNN', 'Adaptive-Gating-EEG-ARNN']:
    model_results = cs_df[cs_df['model'] == model_name]
    best_result = model_results.loc[model_results['mean_accuracy'].idxmax()]
    
    print(f"\n{model_name}:")
    print(f"  Best Method: {best_result['method'].upper()}")
    print(f"  Best k: {best_result['k']}")
    print(f"  Accuracy: {best_result['mean_accuracy']:.4f} +/- {best_result['std_accuracy']:.4f}")

## 8. Retention Analysis

In [None]:
# Retention analysis
print(f"\n{'='*60}")
print("RETENTION ANALYSIS: Adaptive-Gating-EEG-ARNN with Gate Selection")
print(f"{'='*60}\n")

retention_results = []

for k in CONFIG['retention_k_values']:
    print(f"\nTesting with k={k} channels:", end=' ')
    fold_accuracies = []

    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]

        # Load trained model
        model = build_model('Adaptive-Gating-EEG-ARNN', CONFIG['n_channels'])
        model_path = os.path.join(CONFIG['models_dir'], f"Adaptive-Gating-EEG-ARNN_fold{fold+1}.pt")
        state_dict = torch.load(model_path, map_location=CONFIG['device'])
        model.load_state_dict(state_dict)
        model = model.to(CONFIG['device'])
        model.eval()

        # Compute gate importance
        train_dataset = EEGDataset(X_train, y_train)
        train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'],
                                  shuffle=False, num_workers=0)
        importance_scores = compute_gate_importance(model, train_loader, CONFIG['device'])
        selected_channels = select_top_k_channels(importance_scores, k)

        # Apply channel selection
        X_train_selected = apply_channel_selection(X_train, selected_channels)
        X_val_selected = apply_channel_selection(X_val, selected_channels)

        # Train new model
        new_model = build_model('Adaptive-Gating-EEG-ARNN', k)
        train_dataset = EEGDataset(X_train_selected, y_train)
        val_dataset = EEGDataset(X_val_selected, y_val)

        train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'],
                                  shuffle=True, num_workers=0)
        val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'],
                                shuffle=False, num_workers=0)

        best_state, val_acc = train_pytorch_model(new_model, train_loader, val_loader,
                                                  CONFIG, f"Retention-k{k}")
        fold_accuracies.append(val_acc)
        
        # Clean up
        del model, new_model
        torch.cuda.empty_cache()
        gc.collect()

    mean_acc = np.mean(fold_accuracies)
    std_acc = np.std(fold_accuracies)
    print(f"{mean_acc:.4f} +/- {std_acc:.4f}")

    retention_results.append({
        'k': k,
        'mean_accuracy': mean_acc,
        'std_accuracy': std_acc,
        'fold_accuracies': fold_accuracies
    })

# Save retention results
retention_df = pd.DataFrame(retention_results)
retention_df.to_csv(os.path.join(CONFIG['results_dir'], 'retention_analysis.csv'), index=False)

print("\nRetention Analysis Results:")
print(retention_df[['k', 'mean_accuracy', 'std_accuracy']])

## 9. Generate Publication-Ready Results

In [None]:
# Set plotting style
plt.style.use('seaborn-v0_8-paper')
sns.set_palette("husl")
plt.rcParams['figure.dpi'] = 300
plt.rcParams['font.size'] = 10
plt.rcParams['font.family'] = 'serif'

print("Generating publication-ready results...")

In [None]:
# Table II: Model Comparison
model_results_sorted = results_df.sort_values('mean_accuracy', ascending=False).reset_index(drop=True)
model_results_sorted['rank'] = range(1, len(model_results_sorted) + 1)
model_results_sorted['accuracy_str'] = model_results_sorted.apply(
    lambda row: f"{row['mean_accuracy']*100:.2f} ± {row['std_accuracy']*100:.2f}",
    axis=1
)

# Generate LaTeX table
latex_table = r"""\begin{table}[htbp]
\centering
\caption{Comparison of baseline methods on PhysioNet Motor Imagery dataset}
\label{tab:model_comparison}
\begin{tabular}{clc}
\toprule
\textbf{Rank} & \textbf{Method} & \textbf{Accuracy (\%)} \\\\
\midrule
"""

for _, row in model_results_sorted.iterrows():
    if row['rank'] == 1:
        latex_table += f"{row['rank']} & \\textbf{{{row['model']}}} & \\textbf{{{row['accuracy_str']}}} \\\\\n"
    else:
        latex_table += f"{row['rank']} & {row['model']} & {row['accuracy_str']} \\\\\n"

latex_table += r"""\bottomrule
\end{tabular}
\end{table}
"""

table_path = os.path.join(CONFIG['results_dir'], 'table_ii_model_comparison.tex')
with open(table_path, 'w') as f:
    f.write(latex_table)

print(f"Table II saved to: {table_path}")

In [None]:
# Table III: Retention Analysis
retention_display = retention_df.copy()
retention_display['accuracy_str'] = retention_display.apply(
    lambda row: f"{row['mean_accuracy']*100:.2f} ± {row['std_accuracy']*100:.2f}",
    axis=1
)

latex_retention = r"""\begin{table}[htbp]
\centering
\caption{Performance retention with channel selection using Gate Selection method}
\label{tab:retention_analysis}
\begin{tabular}{cc}
\toprule
\textbf{Channels (k)} & \textbf{Accuracy (\%)} \\\\
\midrule
"""

for _, row in retention_display.iterrows():
    latex_retention += f"{row['k']} & {row['accuracy_str']} \\\\\n"

latex_retention += r"""\bottomrule
\end{tabular}
\end{table}
"""

retention_table_path = os.path.join(CONFIG['results_dir'], 'table_iii_retention.tex')
with open(retention_table_path, 'w') as f:
    f.write(latex_retention)

print(f"Table III saved to: {retention_table_path}")

In [None]:
# Figure 1: Model Comparison Bar Chart
fig, ax = plt.subplots(figsize=(10, 6))

x_pos = np.arange(len(model_results_sorted))
bars = ax.bar(x_pos, model_results_sorted['mean_accuracy'] * 100, 
              yerr=model_results_sorted['std_accuracy'] * 100,
              capsize=5, alpha=0.8, edgecolor='black', linewidth=1.5)

bars[0].set_color('#FF6B6B')
bars[0].set_alpha(1.0)

ax.set_xlabel('Method', fontsize=12, fontweight='bold')
ax.set_ylabel('Accuracy (%)', fontsize=12, fontweight='bold')
ax.set_title('Comparison of Baseline Methods on PhysioNet Motor Imagery Dataset', 
             fontsize=14, fontweight='bold', pad=20)
ax.set_xticks(x_pos)
ax.set_xticklabels(model_results_sorted['model'], rotation=45, ha='right')
ax.grid(axis='y', alpha=0.3, linestyle='--')

for bar, row in zip(bars, model_results_sorted.itertuples()):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height + 0.5,
            f'{row.mean_accuracy*100:.2f}',
            ha='center', va='bottom', fontsize=9, fontweight='bold')

plt.tight_layout()

fig_path = os.path.join(CONFIG['figures_dir'], 'figure_model_comparison.pdf')
plt.savefig(fig_path, format='pdf', bbox_inches='tight', dpi=300)
plt.savefig(os.path.join(CONFIG['figures_dir'], 'figure_model_comparison.png'), 
            format='png', bbox_inches='tight', dpi=300)

print(f"Figure 1 saved to: {fig_path}")
plt.show()

In [None]:
# Figure 2: Retention Analysis Curves
fig, ax = plt.subplots(figsize=(10, 6))

ax.plot(retention_df['k'], retention_df['mean_accuracy'] * 100,
        marker='o', linewidth=2, markersize=8, 
        label='Adaptive-Gating-EEG-ARNN (Gate Selection)',
        color='#4ECDC4')

ax.fill_between(retention_df['k'],
                (retention_df['mean_accuracy'] - retention_df['std_accuracy']) * 100,
                (retention_df['mean_accuracy'] + retention_df['std_accuracy']) * 100,
                alpha=0.2, color='#4ECDC4')

baseline_acc = model_results_sorted[model_results_sorted['model'] == 'Adaptive-Gating-EEG-ARNN']['mean_accuracy'].values[0]
ax.axhline(y=baseline_acc * 100, color='red', linestyle='--', linewidth=2, 
           label=f'Full 64 channels ({baseline_acc*100:.2f}%)', alpha=0.7)

ax.set_xlabel('Number of Channels (k)', fontsize=12, fontweight='bold')
ax.set_ylabel('Accuracy (%)', fontsize=12, fontweight='bold')
ax.set_title('Performance Retention with Channel Selection', 
             fontsize=14, fontweight='bold', pad=20)
ax.grid(True, alpha=0.3, linestyle='--')
ax.legend(fontsize=10, loc='lower right')

plt.tight_layout()

retention_fig_path = os.path.join(CONFIG['figures_dir'], 'figure_retention_curves.pdf')
plt.savefig(retention_fig_path, format='pdf', bbox_inches='tight', dpi=300)
plt.savefig(os.path.join(CONFIG['figures_dir'], 'figure_retention_curves.png'),
            format='png', bbox_inches='tight', dpi=300)

print(f"Figure 2 saved to: {retention_fig_path}")
plt.show()

In [None]:
# Summary statistics for paper
summary = {
    'winner': {
        'model': model_results_sorted.iloc[0]['model'],
        'accuracy': float(model_results_sorted.iloc[0]['mean_accuracy']),
        'std': float(model_results_sorted.iloc[0]['std_accuracy']),
        'accuracy_pct': f"{model_results_sorted.iloc[0]['mean_accuracy']*100:.2f}",
        'std_pct': f"{model_results_sorted.iloc[0]['std_accuracy']*100:.2f}"
    },
    'all_models': {},
    'channel_selection': {},
    'retention': {}
}

for _, row in model_results_sorted.iterrows():
    summary['all_models'][row['model']] = {
        'rank': int(row['rank']),
        'accuracy': float(row['mean_accuracy']),
        'std': float(row['std_accuracy']),
        'accuracy_pct': f"{row['mean_accuracy']*100:.2f}",
        'std_pct': f"{row['std_accuracy']*100:.2f}"
    }

# Best channel selection
best_cs = cs_df.loc[cs_df['mean_accuracy'].idxmax()]
summary['channel_selection']['best_method'] = best_cs['method']
summary['channel_selection']['best_k'] = int(best_cs['k'])
summary['channel_selection']['best_accuracy'] = f"{best_cs['mean_accuracy']*100:.2f}"

# Retention analysis
baseline_acc = model_results_sorted.iloc[0]['mean_accuracy']
target_acc = baseline_acc * 0.9
retention_90pct = retention_df[retention_df['mean_accuracy'] >= target_acc]
if len(retention_90pct) > 0:
    min_k = retention_90pct['k'].min()
    summary['retention']['channels_for_90pct_retention'] = int(min_k)

summary_path = os.path.join(CONFIG['results_dir'], 'paper_summary.json')
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=2)

print("\nPaper Summary:")
print(json.dumps(summary, indent=2))
print(f"\nSummary saved to: {summary_path}")

## 10. Final Verification and Key Findings

In [None]:
# Check all required files
required_files = [
    ('results/summary_all_models.csv', 'Model Results'),
    ('results/channel_selection_results.csv', 'Channel Selection Results'),
    ('results/retention_analysis.csv', 'Retention Analysis'),
    ('results/table_ii_model_comparison.tex', 'LaTeX Table II'),
    ('results/table_iii_retention.tex', 'LaTeX Table III'),
    ('figures/figure_model_comparison.pdf', 'Figure 1 - Model Comparison'),
    ('figures/figure_retention_curves.pdf', 'Figure 2 - Retention Curves'),
    ('results/paper_summary.json', 'Paper Summary'),
]

print("\n" + "="*60)
print("VERIFICATION: Output Files")
print("="*60 + "\n")

all_present = True
for filepath, description in required_files:
    if os.path.exists(filepath):
        file_size = os.path.getsize(filepath)
        print(f"[OK] {description}")
        print(f"     Path: {filepath}")
        print(f"     Size: {file_size:,} bytes\n")
    else:
        print(f"[MISSING] {description}")
        print(f"          Expected: {filepath}\n")
        all_present = False

if all_present:
    print("\n" + "="*60)
    print("SUCCESS: All outputs generated successfully!")
    print("="*60)
else:
    print("\n" + "="*60)
    print("WARNING: Some outputs are missing!")
    print("="*60)

In [None]:
# Display key findings
print("\n" + "="*60)
print("KEY FINDINGS FOR PAPER")
print("="*60 + "\n")

winner = summary['winner']
print(f"1. BEST MODEL: {winner['model']}")
print(f"   Accuracy: {winner['accuracy_pct']}% ± {winner['std_pct']}%\n")

print(f"2. CHANNEL SELECTION:")
print(f"   Best Method: {summary['channel_selection']['best_method'].upper()}")
print(f"   Optimal k: {summary['channel_selection']['best_k']}")
print(f"   Accuracy: {summary['channel_selection']['best_accuracy']}%\n")

if 'channels_for_90pct_retention' in summary['retention']:
    print(f"3. RETENTION:")
    print(f"   90% retention achieved with: {summary['retention']['channels_for_90pct_retention']} channels")
    reduction = (1 - summary['retention']['channels_for_90pct_retention']/64) * 100
    print(f"   Channel reduction: {reduction:.1f}%\n")

print(f"4. MODEL RANKING:")
for i, row in enumerate(model_results_sorted.itertuples(), 1):
    print(f"   {i}. {row.model}: {row.mean_accuracy*100:.2f}%")

print("\n" + "="*60)
print("Pipeline completed successfully!")
print("All results ready for paper submission!")
print("="*60)