# Baseline vs Adaptive Gating: Comprehensive Comparison

## Final Draft - Publication Ready

**Objective:** Compare Baseline EEG-ARNN against Adaptive Gating EEG-ARNN comprehensively

**Dataset:** PhysioNet Motor Movement/Imagery Dataset

**Configuration:**
- **Subjects:** 50 (clean data only)
- **Epochs:** 30 per fold
- **Classes:** 2 (Left Fist vs Right Fist)
- **Cross-validation:** 3-fold stratified
- **Channel Selection:** k=[10, 15, 20, 25, 30, 35, 40]

**Methods Compared:**
1. **Baseline EEG-ARNN** - Pure CNN + Graph Convolution (reference implementation)
2. **Adaptive Gating EEG-ARNN** - Input-dependent channel gating

**Comprehensive Metrics:**
- Accuracy, Precision, Recall, F1-Score
- Confusion Matrix
- ROC-AUC, PR-AUC
- Cohen's Kappa
- Training time, Convergence analysis
- Channel importance analysis
- Statistical significance tests

## 1. Setup and Imports

In [None]:
import json
import random
import warnings
import time
from pathlib import Path
from copy import deepcopy
from datetime import datetime

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, classification_report,
    roc_auc_score, roc_curve, precision_recall_curve,
    average_precision_score, cohen_kappa_score
)
from scipy import stats

import mne

warnings.filterwarnings('ignore')

# Set publication-quality plotting
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.size'] = 10

sns.set_style('whitegrid')
sns.set_context('paper', font_scale=1.2)
mne.set_log_level('WARNING')

def set_seed(s=42):
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)
    torch.cuda.manual_seed_all(s)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("="*80)
print("BASELINE VS ADAPTIVE GATING - COMPREHENSIVE COMPARISON")
print("="*80)
print(f"\nDevice: {device}")
print(f"PyTorch version: {torch.__version__}")
print(f"Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

## 2. Configuration

In [None]:
import os
from pathlib import Path

# Auto-detect environment
if os.path.exists('/kaggle/input'):
    print("Running on Kaggle")
    kaggle_input = Path('/kaggle/input')
    datasets = [d for d in kaggle_input.iterdir() if d.is_dir()]
    DATA_DIR = None
    for ds_name in ['physioneteegmi', 'eeg-motor-movementimagery-dataset']:
        test_path = kaggle_input / ds_name
        if test_path.exists():
            DATA_DIR = test_path
            break
    if DATA_DIR is None and datasets:
        DATA_DIR = datasets[0]
else:
    print("Running locally")
    DATA_DIR = Path('data/physionet/files')

print(f"Data directory: {DATA_DIR}")

# Configuration
CONFIG = {
    'experiment_name': 'baseline_vs_adaptive_final',
    'data': {
        'raw_data_dir': DATA_DIR,
        'selected_classes': [1, 2],  # T1 (left fist), T2 (right fist)
        'tmin': -1.0,
        'tmax': 5.0,
        'baseline': (-0.5, 0)
    },
    'preprocessing': {
        'l_freq': 0.5,
        'h_freq': 40.0,
        'notch_freq': 50.0,
        'target_sfreq': 128.0,
        'apply_car': True
    },
    'model': {
        'hidden_dim': 40,
        'epochs': 30,  # FINAL: 30 epochs
        'learning_rate': 1e-3,
        'batch_size': 64,
        'n_folds': 3,
        'patience': 10  # Increased for 30 epochs
    },
    'gating': {
        'l1_lambda': 1e-3,
        'gate_init': 0.9
    },
    'channel_selection': {
        'k_values': [10, 15, 20, 25, 30, 35, 40]
    },
    'output': {
        'results_dir': Path('results/baseline_vs_adaptive_final'),
    },
    'max_subjects': 50,  # FINAL: 50 subjects
    'min_runs_per_subject': 8,
    'save_models': True,
    'save_predictions': True
}

# Create output directories
CONFIG['output']['results_dir'].mkdir(exist_ok=True, parents=True)
(CONFIG['output']['results_dir'] / 'figures').mkdir(exist_ok=True)
(CONFIG['output']['results_dir'] / 'tables').mkdir(exist_ok=True)
(CONFIG['output']['results_dir'] / 'models').mkdir(exist_ok=True)

# Known bad subjects (from data cleaning)
EXCLUDED_SUBJECTS = set([
    'S003', 'S004', 'S009', 'S010', 'S012', 'S013', 'S017', 'S018', 'S019',
    'S021', 'S022', 'S023', 'S024', 'S025', 'S026', 'S027', 'S028', 'S029',
    'S088', 'S089', 'S092', 'S100', 'S104', 'S106', 'S107', 'S108', 'S109'
])

# Motor imagery and execution runs
MOTOR_IMAGERY_RUNS = ['R07', 'R08', 'R09', 'R10', 'R11', 'R12', 'R13', 'R14']
MOTOR_EXECUTION_RUNS = ['R03', 'R04', 'R05', 'R06']
ALL_TASK_RUNS = MOTOR_IMAGERY_RUNS + MOTOR_EXECUTION_RUNS

print("\n" + "="*80)
print("EXPERIMENT CONFIGURATION")
print("="*80)
print(f"Experiment: {CONFIG['experiment_name']}")
print(f"Target subjects: {CONFIG['max_subjects']}")
print(f"Epochs per fold: {CONFIG['model']['epochs']}")
print(f"Cross-validation folds: {CONFIG['model']['n_folds']}")
print(f"Channel selection k values: {CONFIG['channel_selection']['k_values']}")
print(f"Excluded subjects: {len(EXCLUDED_SUBJECTS)}")
print(f"Task runs: {len(ALL_TASK_RUNS)} ({', '.join(ALL_TASK_RUNS)})")
print(f"\nEstimated total experiments:")
print(f"  Initial training: {CONFIG['max_subjects']} subjects × 2 methods = {CONFIG['max_subjects'] * 2}")
print(f"  Retraining: {CONFIG['max_subjects']} subjects × 2 methods × {len(CONFIG['channel_selection']['k_values'])} k × 2 selections = {CONFIG['max_subjects'] * 2 * len(CONFIG['channel_selection']['k_values']) * 2}")
print(f"  Total: {CONFIG['max_subjects'] * 2 + CONFIG['max_subjects'] * 2 * len(CONFIG['channel_selection']['k_values']) * 2} experiments")

## 3. Data Loading Functions

Same preprocessing pipeline as reference implementation

In [None]:
def preprocess_raw(raw, config):
    """Apply preprocessing to raw EEG data."""
    cleaned_names = {name: name.rstrip('.') for name in raw.ch_names}
    raw.rename_channels(cleaned_names)
    raw.pick_types(eeg=True)
    raw.set_montage('standard_1020', on_missing='ignore', match_case=False)
    
    nyquist = raw.info['sfreq'] / 2.0
    if config['preprocessing']['notch_freq'] < nyquist:
        raw.notch_filter(freqs=config['preprocessing']['notch_freq'], verbose=False)
    
    raw.filter(
        l_freq=config['preprocessing']['l_freq'],
        h_freq=config['preprocessing']['h_freq'],
        method='fir',
        fir_design='firwin',
        verbose=False
    )
    
    if config['preprocessing']['apply_car']:
        raw.set_eeg_reference('average', projection=False, verbose=False)
    
    raw.resample(config['preprocessing']['target_sfreq'], npad='auto', verbose=False)
    return raw


def load_and_preprocess_edf(edf_path, config):
    """Load raw EDF file, preprocess it, and extract epochs."""
    raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')
    raw = preprocess_raw(raw, config)
    events, event_ids = mne.events_from_annotations(raw, verbose='ERROR')
    
    if len(events) == 0:
        return None, None, raw.ch_names
    
    epochs = mne.Epochs(
        raw, events, event_id=event_ids,
        tmin=config['data']['tmin'], tmax=config['data']['tmax'],
        baseline=tuple(config['data']['baseline']),
        preload=True, verbose='ERROR'
    )
    
    return epochs.get_data(), epochs.events[:, 2], raw.ch_names


def filter_classes(x, y, selected_classes):
    """Filter to keep only selected classes and remap labels."""
    mask = np.isin(y, selected_classes)
    y, x = y[mask], x[mask]
    label_map = {old: new for new, old in enumerate(sorted(selected_classes))}
    y = np.array([label_map[int(label)] for label in y], dtype=np.int64)
    return x, y


def normalize(x):
    """Z-score normalization per channel."""
    mu = x.mean(axis=(0, 2), keepdims=True)
    sd = x.std(axis=(0, 2), keepdims=True) + 1e-8
    return (x - mu) / sd


def load_subject_data(data_dir, subject_id, run_ids, config):
    """Load all runs for a subject, preprocess, and concatenate."""
    subject_dir = data_dir / subject_id
    if not subject_dir.exists():
        return None, None, None
    
    all_x, all_y = [], []
    channel_names = None
    
    for run_id in run_ids:
        edf_path = subject_dir / f'{subject_id}{run_id}.edf'
        if not edf_path.exists():
            continue
        
        try:
            x, y, ch_names = load_and_preprocess_edf(edf_path, config)
            if x is None or len(y) == 0:
                continue
            
            x, y = filter_classes(x, y, config['data']['selected_classes'])
            if len(y) == 0:
                continue
            
            channel_names = channel_names or ch_names
            all_x.append(x)
            all_y.append(y)
        except Exception as e:
            continue
    
    if len(all_x) == 0:
        return None, None, channel_names
    
    return np.concatenate(all_x, 0), np.concatenate(all_y, 0), channel_names


def get_available_subjects(data_dir, min_runs=8, excluded=None):
    """Get list of subjects with at least min_runs available, excluding bad subjects."""
    if not data_dir.exists():
        raise ValueError(f"Data directory not found: {data_dir}")
    
    excluded = excluded or set()
    subjects = []
    
    for subject_dir in sorted(data_dir.iterdir()):
        if not subject_dir.is_dir() or not subject_dir.name.startswith('S'):
            continue
        if subject_dir.name in excluded:
            continue
        edf_files = list(subject_dir.glob('*.edf'))
        if len(edf_files) >= min_runs:
            subjects.append(subject_dir.name)
    
    return subjects


# Scan for subjects
print("\nScanning for available subjects...")
all_subjects = get_available_subjects(
    CONFIG['data']['raw_data_dir'],
    min_runs=CONFIG['min_runs_per_subject'],
    excluded=EXCLUDED_SUBJECTS
)
subjects = all_subjects[:CONFIG['max_subjects']]

print(f"Found {len(all_subjects)} clean subjects with >= {CONFIG['min_runs_per_subject']} runs")
print(f"Will process: {len(subjects)} subjects")
print(f"Subjects: {', '.join(subjects[:10])}{'...' if len(subjects) > 10 else ''}")

## 4. PyTorch Dataset

In [None]:
class EEGDataset(Dataset):
    def __init__(self, x, y):
        self.x = torch.FloatTensor(x).unsqueeze(1)
        self.y = torch.LongTensor(y)
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, i):
        return self.x[i], self.y[i]

## 5. Model Architectures

### 5.1 Baseline EEG-ARNN (Reference Implementation)

In [None]:
class GraphConvLayer(nn.Module):
    """Graph Convolution Layer with learned adjacency."""
    def __init__(self, num_channels, hidden_dim):
        super().__init__()
        self.num_channels = num_channels
        self.hidden_dim = hidden_dim
        self.A = nn.Parameter(torch.randn(num_channels, num_channels))
        self.theta = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.bn = nn.BatchNorm2d(hidden_dim)
        self.act = nn.ELU()
    
    def forward(self, x):
        B, H, C, T = x.shape
        A = torch.sigmoid(self.A)
        A = 0.5 * (A + A.t())
        I = torch.eye(C, device=A.device)
        A_hat = A + I
        D = torch.diag(torch.pow(A_hat.sum(1).clamp_min(1e-6), -0.5))
        A_norm = D @ A_hat @ D
        
        x_batch = x.permute(0, 3, 2, 1).contiguous().view(B*T, C, H)
        x_g = A_norm @ x_batch
        x_g = self.theta(x_g)
        x_g = x_g.view(B, T, C, H).permute(0, 3, 2, 1)
        
        return self.act(self.bn(x_g))
    
    def get_adjacency(self):
        with torch.no_grad():
            A = torch.sigmoid(self.A)
            A = 0.5 * (A + A.t())
            return A.cpu().numpy()


class TemporalConv(nn.Module):
    """Temporal Convolution Layer."""
    def __init__(self, in_channels, out_channels, kernel_size=16, pool=True):
        super().__init__()
        self.pool = pool
        self.conv = nn.Conv2d(in_channels, out_channels, 
                            kernel_size=(1, kernel_size), 
                            padding=(0, kernel_size//2), bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.ELU()
        self.pool_layer = nn.AvgPool2d(kernel_size=(1, 2)) if pool else None
    
    def forward(self, x):
        x = self.act(self.bn(self.conv(x)))
        return self.pool_layer(x) if self.pool else x


class BaselineEEGARNN(nn.Module):
    """Baseline EEG-ARNN - Reference Implementation."""
    def __init__(self, C, T, K, H):
        super().__init__()
        self.t1 = TemporalConv(1, H, 16, False)
        self.g1 = GraphConvLayer(C, H)
        self.t2 = TemporalConv(H, H, 16, True)
        self.g2 = GraphConvLayer(C, H)
        self.t3 = TemporalConv(H, H, 16, True)
        self.g3 = GraphConvLayer(C, H)
        
        with torch.no_grad():
            ft = self._forward_features(torch.zeros(1, 1, C, T))
            fs = ft.view(1, -1).size(1)
        
        self.fc1 = nn.Linear(fs, 256)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, K)
    
    def _forward_features(self, x):
        x = self.g1(self.t1(x))
        x = self.g2(self.t2(x))
        x = self.g3(self.t3(x))
        return x
    
    def forward(self, x):
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)
    
    def get_final_adjacency(self):
        return self.g3.get_adjacency()


print("Baseline EEG-ARNN defined!")

### 5.2 Adaptive Gating EEG-ARNN

In [None]:
class AdaptiveGatedEEGARNN(BaselineEEGARNN):
    """EEG-ARNN with adaptive input-dependent channel gates."""
    def __init__(self, C, T, K, H, gate_init=0.9):
        super().__init__(C, T, K, H)
        
        # Gate network: 2-layer MLP
        self.gate_net = nn.Sequential(
            nn.Linear(C * 2, C),  # Input: mean + std per channel
            nn.ReLU(),
            nn.Linear(C, C),
            nn.Sigmoid()
        )
        
        # Initialize to start with high gate values
        with torch.no_grad():
            self.gate_net[-2].bias.fill_(2.0)
        
        self.latest_gates = None
    
    def compute_gates(self, x):
        """Compute input-dependent gates."""
        B, _, C, T = x.shape
        x_squeeze = x.squeeze(1)
        ch_mean = x_squeeze.mean(dim=2)
        ch_std = x_squeeze.std(dim=2)
        stats = torch.cat([ch_mean, ch_std], dim=1)
        gates = self.gate_net(stats)
        return gates
    
    def forward(self, x):
        gates = self.compute_gates(x)
        self.latest_gates = gates.detach().cpu()
        x = x * gates.view(-1, 1, gates.size(1), 1)
        return super().forward(x)
    
    def get_gate_values(self):
        if self.latest_gates is not None:
            return self.latest_gates.mean(dim=0)
        return None


print("Adaptive Gating EEG-ARNN defined!")

## 6. Training Functions with Comprehensive Metrics

In [None]:
def compute_metrics(y_true, y_pred, y_prob=None):
    """Compute comprehensive evaluation metrics."""
    metrics = {
        'accuracy': accuracy_score(y_true, y_pred),
        'precision': precision_score(y_true, y_pred, average='weighted', zero_division=0),
        'recall': recall_score(y_true, y_pred, average='weighted', zero_division=0),
        'f1': f1_score(y_true, y_pred, average='weighted', zero_division=0),
        'kappa': cohen_kappa_score(y_true, y_pred)
    }
    
    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    metrics['confusion_matrix'] = cm
    
    # ROC-AUC and PR-AUC if probabilities provided
    if y_prob is not None:
        try:
            if len(np.unique(y_true)) == 2:  # Binary classification
                metrics['roc_auc'] = roc_auc_score(y_true, y_prob[:, 1])
                metrics['pr_auc'] = average_precision_score(y_true, y_prob[:, 1])
            else:
                metrics['roc_auc'] = roc_auc_score(y_true, y_prob, average='weighted', multi_class='ovr')
                metrics['pr_auc'] = average_precision_score(y_true, y_prob, average='weighted')
        except:
            metrics['roc_auc'] = None
            metrics['pr_auc'] = None
    
    return metrics


def train_epoch(model, dataloader, criterion, optimizer, device, l1_lambda=0.0):
    """Train for one epoch."""
    model.train()
    total_loss = 0.0
    all_preds, all_labels, all_probs = [], [], []
    
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        
        # L1 regularization for adaptive gating
        if l1_lambda > 0 and hasattr(model, 'get_gate_values'):
            gate_values = model.get_gate_values()
            if gate_values is not None:
                loss = loss + l1_lambda * gate_values.abs().mean()
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        probs = F.softmax(logits, dim=1)
        all_probs.append(probs.detach().cpu().numpy())
        all_preds += torch.argmax(logits, 1).cpu().tolist()
        all_labels += y.cpu().tolist()
    
    all_probs = np.vstack(all_probs)
    metrics = compute_metrics(all_labels, all_preds, all_probs)
    metrics['loss'] = total_loss / max(1, len(dataloader))
    
    return metrics


@torch.no_grad()
def evaluate(model, dataloader, criterion, device):
    """Evaluate model."""
    model.eval()
    total_loss = 0.0
    all_preds, all_labels, all_probs = [], [], []
    
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = criterion(logits, y)
        
        total_loss += loss.item()
        probs = F.softmax(logits, dim=1)
        all_probs.append(probs.cpu().numpy())
        all_preds += torch.argmax(logits, 1).cpu().tolist()
        all_labels += y.cpu().tolist()
    
    all_probs = np.vstack(all_probs)
    metrics = compute_metrics(all_labels, all_preds, all_probs)
    metrics['loss'] = total_loss / max(1, len(dataloader))
    
    return metrics


def train_model(model, train_loader, val_loader, device, epochs, lr, patience, l1_lambda=0.0):
    """Train model with early stopping and return training history."""
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3, verbose=False
    )
    
    best_acc = 0.0
    best_state = None
    no_improve = 0
    history = {'train': [], 'val': []}
    start_time = time.time()
    
    for epoch in range(epochs):
        train_metrics = train_epoch(model, train_loader, criterion, optimizer, device, l1_lambda)
        val_metrics = evaluate(model, val_loader, criterion, device)
        
        history['train'].append(train_metrics)
        history['val'].append(val_metrics)
        
        scheduler.step(val_metrics['loss'])
        
        if val_metrics['accuracy'] > best_acc:
            best_acc = val_metrics['accuracy']
            best_state = deepcopy(model.state_dict())
            no_improve = 0
        else:
            no_improve += 1
        
        if no_improve >= patience:
            break
    
    training_time = time.time() - start_time
    
    if best_state is None:
        best_state = deepcopy(model.state_dict())
    
    model.load_state_dict(best_state)
    return best_state, best_acc, history, training_time


print("Training functions with comprehensive metrics defined!")

## 7. Cross-Validation with Full Metrics

In [None]:
def cross_validate_subject(x, y, channel_names, T, K, device, config, model_type='baseline'):
    """Cross-validate subject with comprehensive metrics tracking."""
    C = x.shape[1]
    skf = StratifiedKFold(n_splits=config['model']['n_folds'], shuffle=True, random_state=42)
    
    batch_size = config['model']['batch_size']
    epochs = config['model']['epochs']
    lr = config['model']['learning_rate']
    patience = config['model']['patience']
    
    folds = []
    adjacencies = []
    gate_values_list = []
    all_histories = []
    total_training_time = 0
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(x, y)):
        X_train, X_val = normalize(x[train_idx]), normalize(x[val_idx])
        Y_train, Y_val = y[train_idx], y[val_idx]
        
        train_loader = DataLoader(
            EEGDataset(X_train, Y_train),
            batch_size=batch_size, shuffle=True, num_workers=0
        )
        val_loader = DataLoader(
            EEGDataset(X_val, Y_val),
            batch_size=batch_size, shuffle=False, num_workers=0
        )
        
        # Create model
        if model_type == 'baseline':
            model = BaselineEEGARNN(C, T, K, config['model']['hidden_dim']).to(device)
            l1_lambda = 0.0
        elif model_type == 'adaptive':
            model = AdaptiveGatedEEGARNN(C, T, K, config['model']['hidden_dim'],
                                        config['gating']['gate_init']).to(device)
            l1_lambda = config['gating']['l1_lambda']
        else:
            raise ValueError(f"Unknown model type: {model_type}")
        
        # Train
        best_state, best_acc, history, train_time = train_model(
            model, train_loader, val_loader, device, epochs, lr, patience, l1_lambda
        )
        model.load_state_dict(best_state)
        
        # Final evaluation
        final_metrics = evaluate(model, val_loader, nn.CrossEntropyLoss(), device)
        final_metrics['fold'] = fold
        final_metrics['training_time'] = train_time
        final_metrics['epochs_trained'] = len(history['val'])
        
        folds.append(final_metrics)
        all_histories.append(history)
        total_training_time += train_time
        
        # Store adjacency and gates
        adjacency = model.get_final_adjacency()
        adjacencies.append(adjacency)
        
        if hasattr(model, 'get_gate_values'):
            gate_values = model.get_gate_values()
            if gate_values is not None:
                if isinstance(gate_values, torch.Tensor):
                    gate_values = gate_values.detach().cpu().numpy()
                gate_values_list.append(gate_values)
    
    # Aggregate metrics
    avg_metrics = {}
    for key in ['accuracy', 'precision', 'recall', 'f1', 'kappa', 'roc_auc', 'pr_auc']:
        values = [f[key] for f in folds if f.get(key) is not None]
        if values:
            avg_metrics[f'avg_{key}'] = float(np.mean(values))
            avg_metrics[f'std_{key}'] = float(np.std(values))
    
    avg_adjacency = np.mean(np.stack([a for a in adjacencies if a is not None], 0), 0) \
                    if any(a is not None for a in adjacencies) else None
    
    result = {
        'fold_results': folds,
        'training_histories': all_histories,
        'total_training_time': total_training_time,
        'avg_training_time_per_fold': total_training_time / len(folds),
        'adjacency_matrix': avg_adjacency,
        'channel_names': channel_names,
        **avg_metrics
    }
    
    if gate_values_list:
        result['avg_gate_values'] = np.mean(np.stack(gate_values_list, 0), 0)
    
    return result


print("Cross-validation with full metrics defined!")

## 8. Channel Selection Functions

In [None]:
class ChannelSelector:
    """Select channels from trained model using different strategies."""
    
    def __init__(self, adjacency, channel_names, gate_values=None):
        self.A = adjacency
        self.names = np.array(channel_names)
        self.C = adjacency.shape[0]
        self.gate_values = gate_values
    
    def edge_selection(self, k):
        """Edge Selection (ES)."""
        edge_importance = np.zeros(self.C)
        for i in range(self.C):
            for j in range(self.C):
                if i != j:
                    edge_importance[i] += abs(self.A[i, j])
                    edge_importance[j] += abs(self.A[i, j])
        indices = np.sort(np.argsort(edge_importance)[-int(k):])
        return self.names[indices].tolist(), indices
    
    def aggregation_selection(self, k):
        """Aggregation Selection (AS)."""
        agg_scores = np.sum(np.abs(self.A), 1)
        indices = np.sort(np.argsort(agg_scores)[-int(k):])
        return self.names[indices].tolist(), indices
    
    def gate_selection(self, k):
        """Gate Selection (GS) - for adaptive gating only."""
        if self.gate_values is None:
            raise ValueError("Gate values not available. Use ES or AS instead.")
        indices = np.sort(np.argsort(self.gate_values)[-int(k):])
        return self.names[indices].tolist(), indices


def retrain_with_selected_channels(x, y, selected_indices, T, K, device, config, model_type='baseline'):
    """Retrain model with selected channels and return metrics."""
    x_selected = x[:, selected_indices, :]
    C = len(selected_indices)
    
    skf = StratifiedKFold(n_splits=config['model']['n_folds'], shuffle=True, random_state=42)
    
    batch_size = config['model']['batch_size']
    epochs = config['model']['epochs']
    lr = config['model']['learning_rate']
    patience = config['model']['patience']
    
    fold_metrics = []
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(x_selected, y)):
        X_train = normalize(x_selected[train_idx])
        X_val = normalize(x_selected[val_idx])
        Y_train, Y_val = y[train_idx], y[val_idx]
        
        train_loader = DataLoader(
            EEGDataset(X_train, Y_train),
            batch_size=batch_size, shuffle=True, num_workers=0
        )
        val_loader = DataLoader(
            EEGDataset(X_val, Y_val),
            batch_size=batch_size, shuffle=False, num_workers=0
        )
        
        if model_type == 'baseline':
            model = BaselineEEGARNN(C, T, K, config['model']['hidden_dim']).to(device)
            l1_lambda = 0.0
        elif model_type == 'adaptive':
            model = AdaptiveGatedEEGARNN(C, T, K, config['model']['hidden_dim'],
                                        config['gating']['gate_init']).to(device)
            l1_lambda = config['gating']['l1_lambda']
        else:
            raise ValueError(f"Unknown model type: {model_type}")
        
        best_state, best_acc, history, train_time = train_model(
            model, train_loader, val_loader, device, epochs, lr, patience, l1_lambda
        )
        model.load_state_dict(best_state)
        
        metrics = evaluate(model, val_loader, nn.CrossEntropyLoss(), device)
        fold_metrics.append(metrics)
    
    # Aggregate
    avg_metrics = {}
    for key in ['accuracy', 'precision', 'recall', 'f1', 'kappa', 'roc_auc', 'pr_auc']:
        values = [m[key] for m in fold_metrics if m.get(key) is not None]
        if values:
            avg_metrics[f'avg_{key}'] = float(np.mean(values))
            avg_metrics[f'std_{key}'] = float(np.std(values))
    
    return avg_metrics


print("Channel selection functions defined!")

## 9. Main Training Loop

### Training both methods on all subjects

In [None]:
# Storage for results
all_results = {'baseline': [], 'adaptive': []}

print("\n" + "="*80)
print("MAIN TRAINING - BASELINE VS ADAPTIVE GATING")
print("="*80)
print(f"Training {len(subjects)} subjects with 2 methods...\n")

for subject_idx, subject_id in enumerate(tqdm(subjects, desc='Processing subjects')):
    print(f"\n[{subject_idx+1}/{len(subjects)}] Processing {subject_id}...")
    
    # Load subject data
    X, Y, channel_names = load_subject_data(
        CONFIG['data']['raw_data_dir'],
        subject_id,
        ALL_TASK_RUNS,
        CONFIG
    )
    
    if X is None or len(Y) == 0:
        print(f"  Skipped: No data available")
        continue
    
    C, T = X.shape[1], X.shape[2]
    K = len(set(CONFIG['data']['selected_classes']))
    
    print(f"  Data: {X.shape[0]} trials, {C} channels, {T} timepoints")
    print(f"  Class distribution: {np.bincount(Y)}")
    
    # Train both methods
    for method in ['baseline', 'adaptive']:
        print(f"\n  Training {method.upper()}...")
        result = cross_validate_subject(X, Y, channel_names, T, K, device, CONFIG, method)
        
        print(f"    Accuracy: {result['avg_accuracy']:.4f} ± {result['std_accuracy']:.4f}")
        print(f"    F1-Score: {result['avg_f1']:.4f} ± {result['std_f1']:.4f}")
        if result.get('avg_roc_auc'):
            print(f"    ROC-AUC: {result['avg_roc_auc']:.4f} ± {result['std_roc_auc']:.4f}")
        print(f"    Training time: {result['avg_training_time_per_fold']:.1f}s per fold")
        
        # Store results
        all_results[method].append({
            'subject': subject_id,
            'num_trials': X.shape[0],
            'num_channels': C,
            **{k: v for k, v in result.items() if k.startswith('avg_') or k.startswith('std_')},
            'adjacency_matrix': result['adjacency_matrix'],
            'channel_names': result['channel_names'],
            'gate_values': result.get('avg_gate_values', None),
            'training_time_per_fold': result['avg_training_time_per_fold'],
            'total_training_time': result['total_training_time'],
            'fold_results': result['fold_results'],
            'training_histories': result['training_histories']
        })

print("\n" + "="*80)
print("INITIAL TRAINING COMPLETE!")
print("="*80)

## 10. Channel Selection and Retraining

In [None]:
# Storage for retraining results
retrain_results = {'baseline': [], 'adaptive': []}

print("\n" + "="*80)
print("CHANNEL SELECTION AND RETRAINING")
print("="*80)
print(f"Testing {len(CONFIG['channel_selection']['k_values'])} different k values...\n")

for subject_idx, subject_id in enumerate(tqdm(subjects, desc='Retraining subjects')):
    print(f"\n[{subject_idx+1}/{len(subjects)}] Retraining {subject_id}...")
    
    # Load subject data
    X, Y, channel_names = load_subject_data(
        CONFIG['data']['raw_data_dir'],
        subject_id,
        ALL_TASK_RUNS,
        CONFIG
    )
    
    if X is None:
        continue
    
    C, T = X.shape[1], X.shape[2]
    K = len(set(CONFIG['data']['selected_classes']))
    
    # Process both methods
    for method in ['baseline', 'adaptive']:
        # Find result for this subject
        subj_result = None
        for res in all_results[method]:
            if res['subject'] == subject_id:
                subj_result = res
                break
        
        if subj_result is None:
            continue
        
        adjacency = subj_result['adjacency_matrix']
        gate_values = subj_result.get('gate_values', None)
        selector = ChannelSelector(adjacency, channel_names, gate_values)
        
        # Determine selection methods
        if method == 'baseline':
            selection_methods = ['ES', 'AS']
        else:
            selection_methods = ['ES', 'AS', 'GS']
        
        # Test different k values
        for sel_method in selection_methods:
            for k in CONFIG['channel_selection']['k_values']:
                # Select channels
                if sel_method == 'ES':
                    selected_channels, selected_indices = selector.edge_selection(k)
                elif sel_method == 'AS':
                    selected_channels, selected_indices = selector.aggregation_selection(k)
                elif sel_method == 'GS':
                    selected_channels, selected_indices = selector.gate_selection(k)
                
                # Retrain
                retrain_metrics = retrain_with_selected_channels(
                    X, Y, selected_indices, T, K, device, CONFIG, method
                )
                
                # Compute drops
                acc_drop = subj_result['avg_accuracy'] - retrain_metrics['avg_accuracy']
                
                retrain_results[method].append({
                    'subject': subject_id,
                    'method': sel_method,
                    'k': k,
                    'num_channels_selected': len(selected_channels),
                    'selected_channels': selected_channels,
                    **retrain_metrics,
                    'full_channels_accuracy': subj_result['avg_accuracy'],
                    'accuracy_drop': acc_drop,
                    'accuracy_drop_pct': (acc_drop / subj_result['avg_accuracy'] * 100)
                })
                
                print(f"  {method.upper()}-{sel_method}, k={k}: "
                      f"{retrain_metrics['avg_accuracy']:.4f} (drop: {acc_drop:.4f})")

print("\n" + "="*80)
print("RETRAINING COMPLETE!")
print("="*80)

## 11. Save All Results

In [None]:
import pickle

print("\nSaving results...")

# Save main results
for method in ['baseline', 'adaptive']:
    # Create DataFrame
    df_cols = ['subject', 'num_trials', 'num_channels']
    df_cols += [k for k in all_results[method][0].keys() if k.startswith('avg_') or k.startswith('std_')]
    df_cols += ['training_time_per_fold', 'total_training_time']
    
    df = pd.DataFrame([{k: res[k] for k in df_cols if k in res} for res in all_results[method]])
    df.to_csv(CONFIG['output']['results_dir'] / 'tables' / f'{method}_results.csv', index=False)
    
    # Save full results with histories
    with open(CONFIG['output']['results_dir'] / f'{method}_full_results.pkl', 'wb') as f:
        pickle.dump(all_results[method], f)

# Save retrain results
for method in ['baseline', 'adaptive']:
    df = pd.DataFrame(retrain_results[method])
    df.to_csv(CONFIG['output']['results_dir'] / 'tables' / f'{method}_retrain_results.csv', index=False)

print(f"Results saved to: {CONFIG['output']['results_dir']}")

## 12. Comprehensive Visualizations

### Figure 1: Overall Performance Comparison

In [None]:
# TO BE CONTINUED IN NEXT CELL...