# Pipeline 2: EEG-ARNN Methods - FINAL VERSION

**COMPLETE TRAINING:** 30 epochs, NO early stopping, Comprehensive metrics

## Models
1. **Baseline-EEG-ARNN** - Pure CNN-GCN architecture
2. **Adaptive-Gating-EEG-ARNN** - With data-dependent channel gating

## Experiments
### Part 1: Train Both Models
- 2-fold cross-validation (matching Pipeline 1)
- 30 epochs WITHOUT early stopping (full training)
- Comprehensive metrics: Accuracy, Precision, Recall, F1, AUC-ROC, Specificity, Sensitivity

### Part 2: Channel Selection
**Baseline**: 2 methods × 5 k-values = 10 experiments
- Edge Selection (ES)
- Aggregation Selection (AS)

**Adaptive**: 3 methods × 5 k-values = 15 experiments
- Edge Selection (ES)
- Aggregation Selection (AS)
- Gate Selection (GS)

k-values: [10, 15, 20, 25, 30]

## Configuration
- **Dataset:** `/kaggle/input/eeg-preprocessed-data/derived`
- **Epochs:** 30 (NO early stopping)
- **Cross-validation:** 2-fold
- **Learning rate:** 0.002
- **Batch size:** 64

## Expected Runtime: ~6-7 hours on Kaggle GPU
- Initial training: ~1 hour (2 models × 2 folds × 30 epochs)
- Channel selection: ~5-6 hours (25 experiments × 2 folds)

## Outputs (Matching Pipeline 1 format)
```
results/eegarnn_baseline_results.csv          - Per-fold results with ALL metrics
results/eegarnn_adaptive_results.csv          - Per-fold results with ALL metrics
results/eegarnn_baseline_summary.csv          - Summary statistics
results/eegarnn_adaptive_summary.csv          - Summary statistics
results/channel_selection_results.csv         - All selection methods
results/channel_selection_summary.csv         - Summary by method
results/training_histories.pkl                - Training curves
plots/training_curves.png                     - Training visualization
plots/model_comparison.png                    - Metrics comparison
plots/channel_selection_curves.png            - Retention curves
models/eegarnn_*.pt                           - Model checkpoints
```

## 1. Setup and Configuration

In [None]:
import os
import gc
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import mne
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix
)
import pickle
from copy import deepcopy
import warnings
warnings.filterwarnings('ignore')
mne.set_log_level('ERROR')

# Set plotting style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

print("All imports successful!")

In [None]:
# Configuration - COMPLETE TRAINING WITHOUT EARLY STOPPING
CONFIG = {
    'data_path': '/kaggle/input/eeg-preprocessed-data/derived',
    'models_dir': './models',
    'results_dir': './results',
    'plots_dir': './plots',
    
    'n_folds': 2,
    'random_seed': 42,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    
    # Training hyperparameters - NO EARLY STOPPING
    'batch_size': 64,
    'epochs': 30,  # Increased for full training
    'learning_rate': 0.0015,
    'weight_decay': 1e-4,
    'scheduler_patience': 3,
    'scheduler_factor': 0.5,
    'use_early_stopping': False,  # DISABLED - train for full epochs
    'min_lr': 1e-6,
    
    # Data parameters (matching Pipeline 1)
    'n_channels': 64,
    'n_classes': 2,
    'sfreq': 128,
    'tmin': 0.0,
    'tmax': 4.0,
    'n_timepoints': 513,
    'hidden_dim': 64,
    'mi_runs': [7, 8, 11, 12],
    
    # Gating parameters
    'gating': {
        'gate_init': 0.9,
        'l1_lambda': 1e-3,
    },
    
    # Channel selection k-values (STANDARDIZED - matching Pipeline 1)
    'k_values': [10, 15, 20, 25, 30],
}

os.makedirs(CONFIG['models_dir'], exist_ok=True)
os.makedirs(CONFIG['results_dir'], exist_ok=True)
os.makedirs(CONFIG['plots_dir'], exist_ok=True)

np.random.seed(CONFIG['random_seed'])
torch.manual_seed(CONFIG['random_seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed(CONFIG['random_seed'])
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

print(f"Device: {CONFIG['device']}")
print(f"Epochs: {CONFIG['epochs']} (NO EARLY STOPPING - FULL TRAINING)")
print(f"Folds: {CONFIG['n_folds']}")
print(f"Learning rate: {CONFIG['learning_rate']}")
print(f"K-values: {CONFIG['k_values']}")

# Runtime estimates
initial_runs = 2 * CONFIG['n_folds']  # 2 models × 2 folds
# Baseline: 2 methods, Adaptive: 3 methods
cs_runs = (2 + 3) * len(CONFIG['k_values']) * CONFIG['n_folds']  # 5 methods × 5 k × 2 folds = 50
total_runs = initial_runs + cs_runs

print(f"\nEstimated training runs:")
print(f"  Initial: {initial_runs} runs × 30 epochs")
print(f"  Channel selection: {cs_runs} runs × 30 epochs")
print(f"  TOTAL: {total_runs} runs")
print(f"\nEstimated runtime (~7 min/run): {total_runs * 7 / 60:.1f} hours")

## 2. Data Loading (EXACT MATCH with Pipeline 1)

In [None]:
def load_physionet_data(data_path):
    """Load preprocessed PhysioNet data - MATCHING PIPELINE 1."""
    data_root = os.path.abspath(data_path)
    if not os.path.isdir(data_root):
        raise FileNotFoundError(f"Data path not found: {data_root}")

    tmin, tmax = CONFIG['tmin'], CONFIG['tmax']
    mi_runs = CONFIG['mi_runs']
    event_id = {'T1': 1, 'T2': 2}
    label_map = {1: 0, 2: 1}

    # Check for preprocessed subdirectory
    preprocessed_dir = os.path.join(data_root, 'preprocessed')
    if os.path.isdir(preprocessed_dir):
        data_root = preprocessed_dir
        print(f"Using preprocessed data from: {data_root}")
    else:
        print(f"Using data from: {data_root}")
    
    subject_dirs = [d for d in sorted(os.listdir(data_root))
                    if os.path.isdir(os.path.join(data_root, d)) and d.upper().startswith('S')]

    all_X, all_y, all_subjects = [], [], []
    print(f"\nLoading data from {len(subject_dirs)} subjects...")
    
    for subject_dir in subject_dirs:
        subject_num = int(subject_dir[1:]) if len(subject_dir) > 1 else -1
        subject_path = os.path.join(data_root, subject_dir)
        
        for run_id in mi_runs:
            run_file = f"{subject_dir}R{run_id:02d}_preproc_raw.fif"
            run_path = os.path.join(subject_path, run_file)
            
            if not os.path.exists(run_path):
                continue
            
            try:
                # Load preprocessed data
                raw = mne.io.read_raw_fif(run_path, preload=True, verbose=False)
                
                # Pick EEG channels only
                picks = mne.pick_types(raw.info, eeg=True, meg=False, stim=False, eog=False)
                if len(picks) == 0:
                    continue
                
                # Extract events
                events, _ = mne.events_from_annotations(raw, event_id=event_id, verbose=False)
                if len(events) == 0:
                    continue
                
                # Create epochs
                epochs = mne.Epochs(
                    raw, events, event_id=event_id, 
                    tmin=tmin, tmax=tmax,
                    baseline=None,  # No baseline - data already preprocessed
                    preload=True, 
                    picks=picks, 
                    verbose=False
                )
                
                # Get data and labels
                data = epochs.get_data()  # (n_epochs, n_channels, n_times)
                labels = np.array([label_map.get(epochs.events[i, 2], -1) for i in range(len(epochs))])
                valid = labels >= 0
                
                if np.any(valid):
                    all_X.append(data[valid])
                    all_y.append(labels[valid])
                    all_subjects.append(np.full(np.sum(valid), subject_num))
                    
            except Exception as e:
                print(f"  Warning: Failed to load {run_file}: {e}")
                continue
    
    if len(all_X) == 0:
        raise ValueError("No data loaded! Check data path and file format.")
    
    X = np.concatenate(all_X, axis=0)
    y = np.concatenate(all_y, axis=0)
    subjects = np.concatenate(all_subjects, axis=0)
    
    print(f"\nData loaded successfully:")
    print(f"  Total trials: {len(X)}")
    print(f"  Unique subjects: {len(np.unique(subjects))}")
    print(f"  Data shape: {X.shape} (trials, channels, timepoints)")
    print(f"  Label distribution: {dict(zip(*np.unique(y, return_counts=True)))}")
    print(f"  Class balance: {np.bincount(y)[0]}/{np.bincount(y)[1]} (class 0/class 1)")
    
    return X, y, subjects


class EEGDataset(Dataset):
    """PyTorch Dataset for EEG data."""
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

## 3. Model Architectures

In [None]:
# Graph Convolution Layer
class GraphConvLayer(nn.Module):
    """Graph Convolution Layer with learned adjacency matrix."""
    def __init__(self, num_channels, hidden_dim):
        super().__init__()
        self.num_channels = num_channels
        self.hidden_dim = hidden_dim
        
        # Learnable adjacency matrix
        self.A = nn.Parameter(torch.randn(num_channels, num_channels) * 0.01)
        self.theta = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.bn = nn.BatchNorm2d(hidden_dim)
        self.act = nn.ELU()

    def forward(self, x):
        """Forward pass with graph convolution.
        
        Args:
            x: Input tensor of shape (B, H, C, T)
        
        Returns:
            Output tensor of shape (B, H, C, T)
        """
        B, H, C, T = x.shape
        
        # Normalize adjacency matrix
        A = torch.sigmoid(self.A)
        A = 0.5 * (A + A.t())  # Symmetrize
        I = torch.eye(C, device=A.device)
        A_hat = A + I
        D = torch.diag(torch.pow(A_hat.sum(1).clamp_min(1e-6), -0.5))
        A_norm = D @ A_hat @ D

        # Apply graph convolution
        x_perm = x.permute(0, 3, 2, 1).contiguous().view(B * T, C, H)
        x_g = A_norm @ x_perm
        x_g = self.theta(x_g)
        x_g = x_g.view(B, T, C, H).permute(0, 3, 2, 1)
        
        return self.act(self.bn(x_g))

    def get_adjacency(self):
        """Get the learned adjacency matrix."""
        with torch.no_grad():
            A = torch.sigmoid(self.A)
            A = 0.5 * (A + A.t())
            return A.cpu().numpy()


# Temporal Convolution
class TemporalConv(nn.Module):
    """Temporal Convolution Layer."""
    def __init__(self, in_channels, out_channels, kernel_size=16, pool=True):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, 
            kernel_size=(1, kernel_size),
            padding=(0, kernel_size // 2), 
            bias=False
        )
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.ELU()
        self.pool_layer = nn.AvgPool2d(kernel_size=(1, 2)) if pool else None

    def forward(self, x):
        x = self.act(self.bn(self.conv(x)))
        if self.pool_layer is not None:
            x = self.pool_layer(x)
        return x


print("Basic layers defined!")

In [None]:
# Baseline EEG-ARNN (without gating)
class BaselineEEGARNN(nn.Module):
    """Baseline EEG-ARNN model without gating mechanism."""
    
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513, hidden_dim=128):
        super().__init__()
        self.n_channels = n_channels
        self.hidden_dim = hidden_dim
        self.use_gate_regularizer = False

        # Temporal-Graph layers
        self.t1 = TemporalConv(1, hidden_dim, 16, pool=False)
        self.g1 = GraphConvLayer(n_channels, hidden_dim)
        self.t2 = TemporalConv(hidden_dim, hidden_dim, 16, pool=True)
        self.g2 = GraphConvLayer(n_channels, hidden_dim)
        self.t3 = TemporalConv(hidden_dim, hidden_dim, 16, pool=True)
        self.g3 = GraphConvLayer(n_channels, hidden_dim)

        # Calculate feature dimension
        with torch.no_grad():
            dummy = torch.zeros(1, n_channels, n_timepoints)
            feat = self._forward_features(self._prepare_input(dummy))
            self.feature_dim = feat.view(1, -1).size(1)

        # Classifier
        self.fc1 = nn.Linear(self.feature_dim, 256)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, n_classes)

    def _prepare_input(self, x):
        """Prepare input by adding channel dimension if needed."""
        if x.dim() == 3:
            x = x.unsqueeze(1)  # (B, C, T) -> (B, 1, C, T)
        return x

    def _forward_features(self, x):
        """Extract features through temporal-graph layers."""
        x = self.g1(self.t1(x))
        x = self.g2(self.t2(x))
        x = self.g3(self.t3(x))
        return x

    def forward(self, x):
        """Forward pass."""
        prepared = self._prepare_input(x)
        features = self._forward_features(prepared)
        x = features.view(features.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

    def get_final_adjacency(self):
        """Get the final layer's adjacency matrix."""
        return self.g3.get_adjacency()

    def get_channel_importance_edge(self):
        """Get channel importance based on edge connections."""
        adjacency = self.get_final_adjacency()
        return np.sum(np.abs(adjacency), axis=1)


print("Baseline EEG-ARNN defined!")

In [None]:
# Adaptive Gating EEG-ARNN
class AdaptiveGatingEEGARNN(BaselineEEGARNN):
    """EEG-ARNN with adaptive data-dependent channel gating.
    
    This is YOUR CONTRIBUTION - adaptive gating that learns which channels
    are important based on the input data itself.
    """
    
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513, hidden_dim=128, gate_init=0.9):
        super().__init__(n_channels, n_classes, n_timepoints, hidden_dim)
        self.use_gate_regularizer = True
        
        # Adaptive gate network - learns to weight channels based on input statistics
        self.gate_net = nn.Sequential(
            nn.Linear(n_channels * 2, n_channels),  # 2 = mean + std
            nn.ReLU(),
            nn.Linear(n_channels, n_channels),
            nn.Sigmoid()
        )
        
        # Initialize gates to start high (most channels active initially)
        init_value = float(np.clip(gate_init, 1e-3, 1 - 1e-3))
        init_bias = math.log(init_value / (1.0 - init_value))
        with torch.no_grad():
            self.gate_net[-2].bias.fill_(init_bias)
        
        self.latest_gate_values = None
        self.gate_penalty_tensor = None

    def compute_gates(self, x):
        """Compute data-dependent channel gates.
        
        Args:
            x: Input tensor of shape (B, 1, C, T)
            
        Returns:
            Gate values of shape (B, C) in range [0, 1]
        """
        x_s = x.squeeze(1)  # (B, C, T)
        ch_mean = x_s.mean(dim=2)  # (B, C)
        ch_std = x_s.std(dim=2)    # (B, C)
        stats = torch.cat([ch_mean, ch_std], dim=1)  # (B, C*2)
        return self.gate_net(stats)  # (B, C)

    def forward(self, x):
        """Forward pass with adaptive gating."""
        prepared = self._prepare_input(x)
        
        # Compute gates based on input
        gates = self.compute_gates(prepared)
        self.gate_penalty_tensor = gates  # For L1 regularization
        self.latest_gate_values = gates.detach()  # For channel selection
        
        # Apply gating
        gated = prepared * gates.view(gates.size(0), 1, gates.size(1), 1)
        
        # Continue with normal forward pass
        features = self._forward_features(gated)
        x = features.view(features.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

    def get_channel_importance_gate(self):
        """Get channel importance from gate values."""
        if self.latest_gate_values is None:
            return None
        return self.latest_gate_values.mean(dim=0).cpu().numpy()


print("Adaptive Gating EEG-ARNN defined!")

## 4. Training and Evaluation Utilities

In [None]:
def calculate_comprehensive_metrics(model, dataloader, device):
    """Calculate ALL metrics - matching Pipeline 1."""
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for X_batch, y_batch in dataloader:
            X_batch = X_batch.to(device)
            outputs = model(X_batch)
            probs = F.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(y_batch.numpy())
            all_probs.extend(probs.cpu().numpy())
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
    
    # Calculate metrics
    metrics = {
        'accuracy': accuracy_score(all_labels, all_preds),
        'precision': precision_score(all_labels, all_preds, average='binary', zero_division=0),
        'recall': recall_score(all_labels, all_preds, average='binary', zero_division=0),
        'f1_score': f1_score(all_labels, all_preds, average='binary', zero_division=0),
        'auc_roc': roc_auc_score(all_labels, all_probs[:, 1]) if len(np.unique(all_labels)) == 2 else 0.0,
    }
    
    # Confusion matrix for specificity/sensitivity
    cm = confusion_matrix(all_labels, all_preds)
    if cm.shape == (2, 2):
        tn, fp, fn, tp = cm.ravel()
        metrics['specificity'] = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        metrics['sensitivity'] = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    else:
        metrics['specificity'] = 0.0
        metrics['sensitivity'] = 0.0
    
    return metrics


def train_epoch(model, dataloader, criterion, optimizer, device, l1_lambda=0.0):
    """Train for one epoch."""
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    
    for X_batch, y_batch in dataloader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        
        # Add L1 regularization for gating
        gate_penalty = getattr(model, 'gate_penalty_tensor', None)
        if l1_lambda > 0 and gate_penalty is not None:
            loss = loss + l1_lambda * gate_penalty.abs().mean()
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += y_batch.size(0)
        correct += (predicted == y_batch).sum().item()
    
    return total_loss / len(dataloader), correct / total


def evaluate_epoch(model, dataloader, criterion, device):
    """Evaluate for one epoch."""
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    
    with torch.no_grad():
        for X_batch, y_batch in dataloader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += y_batch.size(0)
            correct += (predicted == y_batch).sum().item()
    
    return total_loss / len(dataloader), correct / total


def train_model_full(model, train_loader, val_loader, config, model_name=''):
    """Train model for FULL epochs without early stopping."""
    device = config['device']
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(
        model.parameters(), 
        lr=config['learning_rate'], 
        weight_decay=config['weight_decay']
    )
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='min', 
        factor=config['scheduler_factor'], 
        patience=config['scheduler_patience'], 
        min_lr=config['min_lr'], 
        verbose=False
    )
    
    l1_lambda = config['gating']['l1_lambda'] if getattr(model, 'use_gate_regularizer', False) else 0.0
    
    best_state = deepcopy(model.state_dict())
    best_val_acc = 0.0
    best_val_loss = float('inf')
    
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }
    
    print(f"[{model_name}] Training for {config['epochs']} epochs (NO early stopping)")
    
    for epoch in range(config['epochs']):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device, l1_lambda)
        val_loss, val_acc = evaluate_epoch(model, val_loader, criterion, device)
        scheduler.step(val_loss)
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        # Save best model
        if val_acc > best_val_acc or (val_acc == best_val_acc and val_loss < best_val_loss):
            best_state = deepcopy(model.state_dict())
            best_val_acc = val_acc
            best_val_loss = val_loss
        
        # Print progress
        if (epoch + 1) % 5 == 0 or epoch == 0:
            print(f"  Epoch {epoch+1:2d}/{config['epochs']} - "
                  f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
                  f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f} | "
                  f"Best: {best_val_acc:.4f}")
    
    # Load best model
    model.load_state_dict(best_state)
    
    return best_state, best_val_acc, history


print("Training utilities defined!")

## 5. Channel Selection Utilities

In [None]:
def get_channel_importance_aggregation(model, dataloader, device):
    """Aggregation Selection: based on feature activations."""
    model.eval()
    channel_stats = []

    with torch.no_grad():
        for X_batch, _ in dataloader:
            X_batch = X_batch.to(device)
            prepared = model._prepare_input(X_batch)
            features = model._forward_features(prepared)
            # Average over batch, hidden_dim, and time
            activations = torch.mean(torch.abs(features), dim=(1, 3))  # (B, C)
            channel_stats.append(activations.cpu())

    if not channel_stats:
        return np.zeros(model.n_channels)
    
    stacked = torch.cat(channel_stats, dim=0)  # (total_samples, C)
    return stacked.mean(dim=0).numpy()


def compute_gate_importance(model, dataloader, device):
    """Gate Selection: average gate values across dataset."""
    model.eval()
    gate_batches = []

    with torch.no_grad():
        for X_batch, _ in dataloader:
            X_batch = X_batch.to(device)
            _ = model(X_batch)  # Forward pass to compute gates
            latest = getattr(model, 'latest_gate_values', None)
            if latest is not None:
                gate_batches.append(latest.cpu())

    if not gate_batches:
        return np.ones(model.n_channels) / model.n_channels
    
    stacked = torch.cat(gate_batches, dim=0)
    return stacked.mean(dim=0).numpy()


def select_top_k_channels(importance_scores, k):
    """Select top k channels based on importance scores."""
    top_k_indices = np.argsort(importance_scores)[-k:]
    return sorted(top_k_indices)


def apply_channel_selection(X, selected_channels):
    """Apply channel selection to data."""
    return X[:, selected_channels, :]


print("Channel selection utilities defined!")

## 6. Load Data

In [None]:
print("="*80)
print("LOADING DATA")
print("="*80)

X, y, subjects = load_physionet_data(CONFIG['data_path'])

print("\nData ready for training!")

## 7. Train EEG-ARNN Models (FULL TRAINING)

In [None]:
# Setup cross-validation
skf = StratifiedKFold(n_splits=CONFIG['n_folds'], shuffle=True, random_state=CONFIG['random_seed'])

models_to_train = [
    {'name': 'Baseline-EEG-ARNN', 'class': BaselineEEGARNN},
    {'name': 'Adaptive-Gating-EEG-ARNN', 'class': AdaptiveGatingEEGARNN},
]

print("\n" + "="*80)
print("TRAINING EEG-ARNN MODELS")
print("="*80)
print(f"Configuration: {CONFIG['epochs']} epochs, NO early stopping")
print(f"Cross-validation: {CONFIG['n_folds']} folds")
print("="*80 + "\n")

In [None]:
# Training loop with comprehensive metrics
all_results = {}
all_histories = {}

for model_info in models_to_train:
    model_name = model_info['name']
    model_class = model_info['class']
    
    print(f"\n{'='*80}")
    print(f"Training: {model_name}")
    print(f"{'='*80}\n")
    
    fold_results = []
    fold_histories = []
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
        print(f"\nFold {fold + 1}/{CONFIG['n_folds']}")
        print("-" * 80)
        
        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]
        
        print(f"Train: {len(X_train)} samples, Val: {len(X_val)} samples")
        
        # Create dataloaders
        train_dataset = EEGDataset(X_train, y_train)
        val_dataset = EEGDataset(X_val, y_val)
        train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False)
        
        # Build model
        if model_class == AdaptiveGatingEEGARNN:
            model = model_class(
                n_channels=CONFIG['n_channels'],
                n_classes=CONFIG['n_classes'],
                n_timepoints=CONFIG['n_timepoints'],
                hidden_dim=CONFIG['hidden_dim'],
                gate_init=CONFIG['gating']['gate_init']
            )
        else:
            model = model_class(
                n_channels=CONFIG['n_channels'],
                n_classes=CONFIG['n_classes'],
                n_timepoints=CONFIG['n_timepoints'],
                hidden_dim=CONFIG['hidden_dim']
            )
        
        # Train model
        best_state, val_acc, history = train_model_full(
            model, train_loader, val_loader, CONFIG, 
            f"{model_name}-Fold{fold+1}"
        )
        
        # Get comprehensive metrics
        model.load_state_dict(best_state)
        model = model.to(CONFIG['device'])
        metrics = calculate_comprehensive_metrics(model, val_loader, CONFIG['device'])
        
        print(f"\nFold {fold + 1} Final Metrics:")
        print(f"  Accuracy:    {metrics['accuracy']:.4f}")
        print(f"  Precision:   {metrics['precision']:.4f}")
        print(f"  Recall:      {metrics['recall']:.4f}")
        print(f"  F1-Score:    {metrics['f1_score']:.4f}")
        print(f"  AUC-ROC:     {metrics['auc_roc']:.4f}")
        print(f"  Specificity: {metrics['specificity']:.4f}")
        print(f"  Sensitivity: {metrics['sensitivity']:.4f}")
        
        # Save model
        model_path = os.path.join(CONFIG['models_dir'], f"eegarnn_{model_name}_fold{fold+1}.pt")
        torch.save(best_state, model_path)
        print(f"\nModel saved: {model_path}")
        
        # Store results
        fold_results.append({
            'fold': fold + 1,
            'accuracy': metrics['accuracy'],
            'precision': metrics['precision'],
            'recall': metrics['recall'],
            'f1_score': metrics['f1_score'],
            'auc_roc': metrics['auc_roc'],
            'specificity': metrics['specificity'],
            'sensitivity': metrics['sensitivity']
        })
        fold_histories.append(history)
        
        # Cleanup
        del model
        torch.cuda.empty_cache()
        gc.collect()
    
    # Store results
    all_results[model_name] = fold_results
    all_histories[model_name] = fold_histories
    
    # Print model summary
    df_temp = pd.DataFrame(fold_results)
    print(f"\n{'='*80}")
    print(f"{model_name} - Summary Across Folds")
    print("="*80)
    for metric in ['accuracy', 'precision', 'recall', 'f1_score', 'auc_roc', 'specificity', 'sensitivity']:
        mean_val = df_temp[metric].mean()
        std_val = df_temp[metric].std()
        print(f"  {metric:15s}: {mean_val:.4f} ± {std_val:.4f}")
    print("="*80)

print(f"\n{'='*80}")
print("INITIAL TRAINING COMPLETE!")
print("="*80)

## 8. Save Initial Results

In [None]:
# Save detailed results (per-fold)
for model_name, fold_results in all_results.items():
    df = pd.DataFrame(fold_results)
    df['model'] = model_name
    
    # Reorder columns
    cols = ['model', 'fold', 'accuracy', 'precision', 'recall', 'f1_score', 
            'auc_roc', 'specificity', 'sensitivity']
    df = df[cols]
    
    filename = model_name.lower().replace('-', '_').replace(' ', '_')
    filepath = os.path.join(CONFIG['results_dir'], f'eegarnn_{filename}_results.csv')
    df.to_csv(filepath, index=False)
    print(f"Saved: {filepath}")

# Save summary statistics
summary_data = []
for model_name, fold_results in all_results.items():
    df_temp = pd.DataFrame(fold_results)
    summary = {'model': model_name}
    
    for metric in ['accuracy', 'precision', 'recall', 'f1_score', 'auc_roc', 'specificity', 'sensitivity']:
        summary[f'mean_{metric}'] = df_temp[metric].mean()
        summary[f'std_{metric}'] = df_temp[metric].std()
    
    summary_data.append(summary)

summary_df = pd.DataFrame(summary_data)
summary_df = summary_df.sort_values('mean_accuracy', ascending=False).reset_index(drop=True)
summary_df['rank'] = range(1, len(summary_df) + 1)

# Reorder columns
cols = ['rank', 'model'] + [col for col in summary_df.columns if col not in ['rank', 'model']]
summary_df = summary_df[cols]

filepath = os.path.join(CONFIG['results_dir'], 'eegarnn_initial_summary.csv')
summary_df.to_csv(filepath, index=False)
print(f"Saved: {filepath}")

# Save training histories
filepath = os.path.join(CONFIG['results_dir'], 'training_histories.pkl')
with open(filepath, 'wb') as f:
    pickle.dump(all_histories, f)
print(f"Saved: {filepath}")

print("\nInitial Results Summary:")
print(summary_df[['rank', 'model', 'mean_accuracy', 'mean_f1_score', 'mean_auc_roc']].to_string(index=False))

## 9. Channel Selection Experiments

In [None]:
# Channel selection configuration
cs_experiments = [
    {'model': 'Baseline-EEG-ARNN', 'methods': ['edge', 'aggregation']},
    {'model': 'Adaptive-Gating-EEG-ARNN', 'methods': ['edge', 'aggregation', 'gate']},
]

print("\n" + "="*80)
print("CHANNEL SELECTION EVALUATION")
print("="*80)
print(f"k-values: {CONFIG['k_values']}")
print(f"Baseline: 2 methods × {len(CONFIG['k_values'])} k-values = {2 * len(CONFIG['k_values'])} experiments")
print(f"Adaptive: 3 methods × {len(CONFIG['k_values'])} k-values = {3 * len(CONFIG['k_values'])} experiments")
print(f"Total: {(2 + 3) * len(CONFIG['k_values'])} experiments × {CONFIG['n_folds']} folds = {(2 + 3) * len(CONFIG['k_values']) * CONFIG['n_folds']} runs")
print("="*80 + "\n")

In [None]:
# Run channel selection experiments
channel_selection_results = []

for exp in cs_experiments:
    model_name = exp['model']
    methods = exp['methods']
    model_class = BaselineEEGARNN if 'Baseline' in model_name else AdaptiveGatingEEGARNN
    
    print(f"\n{'='*80}")
    print(f"Model: {model_name}")
    print(f"Methods: {', '.join([m.upper() for m in methods])}")
    print("="*80)
    
    for method in methods:
        print(f"\n{'-'*80}")
        print(f"Method: {method.upper()}")
        print("-"*80)
        
        for k in CONFIG['k_values']:
            print(f"\n  k={k} channels:", end=' ')
            fold_accuracies = []
            fold_metrics_list = []
            
            for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
                X_train, X_val = X[train_idx], X[val_idx]
                y_train, y_val = y[train_idx], y[val_idx]
                
                # Load trained model for channel importance
                if model_class == AdaptiveGatingEEGARNN:
                    model = model_class(
                        n_channels=CONFIG['n_channels'],
                        n_classes=CONFIG['n_classes'],
                        n_timepoints=CONFIG['n_timepoints'],
                        hidden_dim=CONFIG['hidden_dim'],
                        gate_init=CONFIG['gating']['gate_init']
                    )
                else:
                    model = model_class(
                        n_channels=CONFIG['n_channels'],
                        n_classes=CONFIG['n_classes'],
                        n_timepoints=CONFIG['n_timepoints'],
                        hidden_dim=CONFIG['hidden_dim']
                    )
                
                model_path = os.path.join(CONFIG['models_dir'], f"eegarnn_{model_name}_fold{fold+1}.pt")
                state_dict = torch.load(model_path, map_location=CONFIG['device'])
                model.load_state_dict(state_dict)
                model = model.to(CONFIG['device'])
                model.eval()
                
                # Compute channel importance
                if method == 'edge':
                    importance_scores = model.get_channel_importance_edge()
                elif method == 'aggregation':
                    train_dataset = EEGDataset(X_train, y_train)
                    train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=False)
                    importance_scores = get_channel_importance_aggregation(model, train_loader, CONFIG['device'])
                else:  # gate
                    train_dataset = EEGDataset(X_train, y_train)
                    train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=False)
                    importance_scores = compute_gate_importance(model, train_loader, CONFIG['device'])
                
                # Select channels
                selected_channels = select_top_k_channels(importance_scores, k)
                X_train_selected = apply_channel_selection(X_train, selected_channels)
                X_val_selected = apply_channel_selection(X_val, selected_channels)
                
                # Train new model with selected channels
                if model_class == AdaptiveGatingEEGARNN:
                    new_model = model_class(
                        n_channels=k,
                        n_classes=CONFIG['n_classes'],
                        n_timepoints=CONFIG['n_timepoints'],
                        hidden_dim=CONFIG['hidden_dim'],
                        gate_init=CONFIG['gating']['gate_init']
                    )
                else:
                    new_model = model_class(
                        n_channels=k,
                        n_classes=CONFIG['n_classes'],
                        n_timepoints=CONFIG['n_timepoints'],
                        hidden_dim=CONFIG['hidden_dim']
                    )
                
                train_dataset = EEGDataset(X_train_selected, y_train)
                val_dataset = EEGDataset(X_val_selected, y_val)
                train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True)
                val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False)
                
                best_state, val_acc, _ = train_model_full(
                    new_model, train_loader, val_loader, CONFIG, 
                    f"{model_name}-{method}-k{k}-F{fold+1}"
                )
                
                # Get comprehensive metrics
                new_model.load_state_dict(best_state)
                new_model = new_model.to(CONFIG['device'])
                metrics = calculate_comprehensive_metrics(new_model, val_loader, CONFIG['device'])
                
                fold_accuracies.append(metrics['accuracy'])
                fold_metrics_list.append(metrics)
                
                del model, new_model
                torch.cuda.empty_cache()
                gc.collect()
            
            # Compute mean metrics
            mean_metrics = {}
            for metric_name in ['accuracy', 'precision', 'recall', 'f1_score', 'auc_roc', 'specificity', 'sensitivity']:
                values = [m[metric_name] for m in fold_metrics_list]
                mean_metrics[f'mean_{metric_name}'] = np.mean(values)
                mean_metrics[f'std_{metric_name}'] = np.std(values)
            
            print(f"{mean_metrics['mean_accuracy']:.4f} ± {mean_metrics['std_accuracy']:.4f}")
            
            # Store results
            result = {
                'model': model_name,
                'method': method,
                'k': k,
            }
            result.update(mean_metrics)
            channel_selection_results.append(result)

print(f"\n{'='*80}")
print("CHANNEL SELECTION COMPLETE!")
print("="*80)

## 10. Save Channel Selection Results

In [None]:
# Save detailed results
cs_df = pd.DataFrame(channel_selection_results)

# Reorder columns
cols = ['model', 'method', 'k'] + [col for col in cs_df.columns if col not in ['model', 'method', 'k']]
cs_df = cs_df[cols]

filepath = os.path.join(CONFIG['results_dir'], 'channel_selection_results.csv')
cs_df.to_csv(filepath, index=False)
print(f"Saved: {filepath}")

# Create summary by method
print("\nChannel Selection Results:")
print(cs_df[['model', 'method', 'k', 'mean_accuracy', 'std_accuracy', 'mean_f1_score']].to_string(index=False))

# Best results
print("\n" + "="*80)
print("BEST CHANNEL SELECTION RESULTS")
print("="*80)

for model_name in ['Baseline-EEG-ARNN', 'Adaptive-Gating-EEG-ARNN']:
    model_data = cs_df[cs_df['model'] == model_name]
    best_row = model_data.loc[model_data['mean_accuracy'].idxmax()]
    
    print(f"\n{model_name}:")
    print(f"  Method: {best_row['method'].upper()}")
    print(f"  k: {int(best_row['k'])}")
    print(f"  Accuracy: {best_row['mean_accuracy']:.4f} ± {best_row['std_accuracy']:.4f}")
    print(f"  F1-Score: {best_row['mean_f1_score']:.4f} ± {best_row['std_f1_score']:.4f}")
    print(f"  AUC-ROC: {best_row['mean_auc_roc']:.4f} ± {best_row['std_auc_roc']:.4f}")

## 11. Visualizations

In [None]:
print("\n" + "="*80)
print("CREATING VISUALIZATIONS")
print("="*80)

# 1. Training Curves
print("\n1. Training curves...")

fig, axes = plt.subplots(2, 2, figsize=(16, 12))
fig.suptitle('Training History - Baseline vs Adaptive Gating', fontsize=16, fontweight='bold')

colors = {'Baseline-EEG-ARNN': 'blue', 'Adaptive-Gating-EEG-ARNN': 'red'}

for model_name, histories in all_histories.items():
    color = colors[model_name]
    
    # Average across folds
    train_loss = np.mean([h['train_loss'] for h in histories], axis=0)
    val_loss = np.mean([h['val_loss'] for h in histories], axis=0)
    train_acc = np.mean([h['train_acc'] for h in histories], axis=0)
    val_acc = np.mean([h['val_acc'] for h in histories], axis=0)
    
    epochs = range(1, len(train_loss) + 1)
    
    # Loss plots
    axes[0, 0].plot(epochs, train_loss, f'{color}--', alpha=0.7, linewidth=2, label=f'{model_name} (train)')
    axes[0, 0].plot(epochs, val_loss, f'{color}-', linewidth=2, label=f'{model_name} (val)')
    
    # Accuracy plots
    axes[0, 1].plot(epochs, train_acc, f'{color}--', alpha=0.7, linewidth=2, label=f'{model_name} (train)')
    axes[0, 1].plot(epochs, val_acc, f'{color}-', linewidth=2, label=f'{model_name} (val)')

axes[0, 0].set_xlabel('Epoch', fontsize=12, fontweight='bold')
axes[0, 0].set_ylabel('Loss', fontsize=12, fontweight='bold')
axes[0, 0].set_title('Training & Validation Loss', fontsize=13, fontweight='bold')
axes[0, 0].legend(fontsize=10)
axes[0, 0].grid(True, alpha=0.3)

axes[0, 1].set_xlabel('Epoch', fontsize=12, fontweight='bold')
axes[0, 1].set_ylabel('Accuracy', fontsize=12, fontweight='bold')
axes[0, 1].set_title('Training & Validation Accuracy', fontsize=13, fontweight='bold')
axes[0, 1].legend(fontsize=10)
axes[0, 1].grid(True, alpha=0.3)

# Metrics comparison
comparison_data = []
for model_name, fold_results in all_results.items():
    df_temp = pd.DataFrame(fold_results)
    comparison_data.append({
        'Model': 'Baseline' if 'Baseline' in model_name else 'Adaptive',
        'Accuracy': df_temp['accuracy'].mean(),
        'Precision': df_temp['precision'].mean(),
        'Recall': df_temp['recall'].mean(),
        'F1-Score': df_temp['f1_score'].mean(),
        'AUC-ROC': df_temp['auc_roc'].mean()
    })

df_comparison = pd.DataFrame(comparison_data)
x = np.arange(len(df_comparison))
width = 0.15

metrics = ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'AUC-ROC']
colors_bar = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']

for i, metric in enumerate(metrics):
    offset = width * (i - 2)
    axes[1, 0].bar(x + offset, df_comparison[metric], width, label=metric, color=colors_bar[i], alpha=0.8)

axes[1, 0].set_xlabel('Model', fontsize=12, fontweight='bold')
axes[1, 0].set_ylabel('Score', fontsize=12, fontweight='bold')
axes[1, 0].set_title('Model Performance Comparison', fontsize=13, fontweight='bold')
axes[1, 0].set_xticks(x)
axes[1, 0].set_xticklabels(df_comparison['Model'])
axes[1, 0].legend(fontsize=9)
axes[1, 0].grid(True, alpha=0.3, axis='y')
axes[1, 0].set_ylim([0, 1.0])

# Improvement over baseline
if len(df_comparison) == 2:
    improvements = []
    for metric in metrics:
        baseline_val = df_comparison.loc[df_comparison['Model'] == 'Baseline', metric].values[0]
        adaptive_val = df_comparison.loc[df_comparison['Model'] == 'Adaptive', metric].values[0]
        improvement = ((adaptive_val - baseline_val) / baseline_val) * 100
        improvements.append(improvement)
    
    bars = axes[1, 1].barh(metrics, improvements, color=['green' if x > 0 else 'red' for x in improvements], alpha=0.7)
    axes[1, 1].axvline(0, color='black', linewidth=0.8)
    axes[1, 1].set_xlabel('Improvement (%)', fontsize=12, fontweight='bold')
    axes[1, 1].set_title('Adaptive Gating Improvement over Baseline', fontsize=13, fontweight='bold')
    axes[1, 1].grid(True, alpha=0.3, axis='x')
    
    # Add value labels
    for bar, improvement in zip(bars, improvements):
        width = bar.get_width()
        axes[1, 1].text(width, bar.get_y() + bar.get_height()/2, 
                       f'{improvement:+.1f}%', 
                       ha='left' if width > 0 else 'right',
                       va='center', fontsize=10, fontweight='bold')
else:
    axes[1, 1].text(0.5, 0.5, 'Improvement plot\nrequires 2 models', 
                   ha='center', va='center', transform=axes[1, 1].transAxes)

plt.tight_layout()
filepath = os.path.join(CONFIG['plots_dir'], 'training_curves.png')
plt.savefig(filepath, dpi=300, bbox_inches='tight')
print(f"Saved: {filepath}")
plt.show()

In [None]:
# 2. Channel Selection Curves
print("\n2. Channel selection curves...")

fig, axes = plt.subplots(1, 2, figsize=(16, 6))
fig.suptitle('Channel Selection Performance', fontsize=16, fontweight='bold')

colors_methods = {'edge': 'blue', 'aggregation': 'green', 'gate': 'red'}
markers_methods = {'edge': 'o', 'aggregation': 's', 'gate': '^'}

for idx, model_name in enumerate(['Baseline-EEG-ARNN', 'Adaptive-Gating-EEG-ARNN']):
    ax = axes[idx]
    model_data = cs_df[cs_df['model'] == model_name]
    
    methods = model_data['method'].unique()
    
    for method in sorted(methods):
        method_data = model_data[model_data['method'] == method].sort_values('k')
        ax.errorbar(
            method_data['k'], 
            method_data['mean_accuracy'], 
            yerr=method_data['std_accuracy'],
            label=method.upper(), 
            marker=markers_methods.get(method, 'o'),
            color=colors_methods.get(method, 'gray'), 
            capsize=5, 
            linewidth=2, 
            markersize=10,
            alpha=0.8
        )
    
    # Add full-channel baseline
    full_acc = pd.DataFrame(all_results[model_name])['accuracy'].mean()
    ax.axhline(full_acc, color='black', linestyle='--', linewidth=2, alpha=0.5, label='Full (64 ch)')
    
    ax.set_xlabel('Number of Channels (k)', fontsize=12, fontweight='bold')
    ax.set_ylabel('Accuracy', fontsize=12, fontweight='bold')
    title = 'Baseline' if 'Baseline' in model_name else 'Adaptive Gating'
    ax.set_title(f'{title}\nChannel Selection Methods', fontsize=13, fontweight='bold')
    ax.legend(fontsize=10, loc='lower right')
    ax.grid(True, alpha=0.3)
    ax.set_ylim([0.5, 1.0])
    ax.set_xticks(CONFIG['k_values'])

plt.tight_layout()
filepath = os.path.join(CONFIG['plots_dir'], 'channel_selection_curves.png')
plt.savefig(filepath, dpi=300, bbox_inches='tight')
print(f"Saved: {filepath}")
plt.show()

In [None]:
# 3. Retention Analysis
print("\n3. Retention analysis...")

fig, axes = plt.subplots(1, 2, figsize=(16, 6))
fig.suptitle('Performance Retention with Channel Reduction', fontsize=16, fontweight='bold')

for idx, model_name in enumerate(['Baseline-EEG-ARNN', 'Adaptive-Gating-EEG-ARNN']):
    ax = axes[idx]
    model_data = cs_df[cs_df['model'] == model_name]
    
    # Get full-channel performance
    full_acc = pd.DataFrame(all_results[model_name])['accuracy'].mean()
    
    # For each method, plot retention percentage
    methods = model_data['method'].unique()
    
    for method in sorted(methods):
        method_data = model_data[model_data['method'] == method].sort_values('k')
        retention = (method_data['mean_accuracy'] / full_acc) * 100
        
        ax.plot(
            method_data['k'], 
            retention,
            label=method.upper(), 
            marker=markers_methods.get(method, 'o'),
            color=colors_methods.get(method, 'gray'), 
            linewidth=2, 
            markersize=10,
            alpha=0.8
        )
    
    ax.axhline(100, color='black', linestyle='--', linewidth=2, alpha=0.5, label='100% (baseline)')
    ax.axhline(90, color='orange', linestyle=':', linewidth=1.5, alpha=0.5, label='90% threshold')
    
    ax.set_xlabel('Number of Channels (k)', fontsize=12, fontweight='bold')
    ax.set_ylabel('Retention (%)', fontsize=12, fontweight='bold')
    title = 'Baseline' if 'Baseline' in model_name else 'Adaptive Gating'
    ax.set_title(f'{title}\nPerformance Retention', fontsize=13, fontweight='bold')
    ax.legend(fontsize=10, loc='lower left')
    ax.grid(True, alpha=0.3)
    ax.set_ylim([70, 105])
    ax.set_xticks(CONFIG['k_values'])

plt.tight_layout()
filepath = os.path.join(CONFIG['plots_dir'], 'retention_analysis.png')
plt.savefig(filepath, dpi=300, bbox_inches='tight')
print(f"Saved: {filepath}")
plt.show()

print("\nVisualization complete!")

## 12. Final Summary

In [None]:
print("\n" + "="*80)
print("PIPELINE 2 - COMPLETE SUMMARY")
print("="*80)

# 1. Initial Performance
print("\n1. INITIAL MODEL PERFORMANCE (Full 64 Channels):")
print("-" * 80)

for model_name, fold_results in all_results.items():
    df_temp = pd.DataFrame(fold_results)
    print(f"\n{model_name}:")
    print(f"  Accuracy:    {df_temp['accuracy'].mean():.4f} ± {df_temp['accuracy'].std():.4f}")
    print(f"  Precision:   {df_temp['precision'].mean():.4f} ± {df_temp['precision'].std():.4f}")
    print(f"  Recall:      {df_temp['recall'].mean():.4f} ± {df_temp['recall'].std():.4f}")
    print(f"  F1-Score:    {df_temp['f1_score'].mean():.4f} ± {df_temp['f1_score'].std():.4f}")
    print(f"  AUC-ROC:     {df_temp['auc_roc'].mean():.4f} ± {df_temp['auc_roc'].std():.4f}")
    print(f"  Specificity: {df_temp['specificity'].mean():.4f} ± {df_temp['specificity'].std():.4f}")
    print(f"  Sensitivity: {df_temp['sensitivity'].mean():.4f} ± {df_temp['sensitivity'].std():.4f}")

# 2. Channel Selection Results
print("\n" + "="*80)
print("2. CHANNEL SELECTION RESULTS:")
print("-" * 80)

for model_name in ['Baseline-EEG-ARNN', 'Adaptive-Gating-EEG-ARNN']:
    print(f"\n{model_name}:")
    model_data = cs_df[cs_df['model'] == model_name]
    
    for method in sorted(model_data['method'].unique()):
        print(f"\n  {method.upper()} Method:")
        method_data = model_data[model_data['method'] == method].sort_values('k')
        for _, row in method_data.iterrows():
            print(f"    k={int(row['k']):2d}: {row['mean_accuracy']:.4f} ± {row['std_accuracy']:.4f}")

# 3. Key Findings
print("\n" + "="*80)
print("3. KEY FINDINGS:")
print("-" * 80)

# Best initial model
best_initial = None
best_initial_acc = 0
for model_name, fold_results in all_results.items():
    acc = pd.DataFrame(fold_results)['accuracy'].mean()
    if acc > best_initial_acc:
        best_initial_acc = acc
        best_initial = model_name

print(f"\na) Best Initial Model:")
print(f"   {best_initial}: {best_initial_acc:.4f}")

# Best channel selection
best_cs = cs_df.loc[cs_df['mean_accuracy'].idxmax()]
print(f"\nb) Best Channel Selection:")
print(f"   Model: {best_cs['model']}")
print(f"   Method: {best_cs['method'].upper()}")
print(f"   k: {int(best_cs['k'])}")
print(f"   Accuracy: {best_cs['mean_accuracy']:.4f} ± {best_cs['std_accuracy']:.4f}")

# Retention at k=30
print(f"\nc) Performance Retention at k=30:")
for model_name in ['Baseline-EEG-ARNN', 'Adaptive-Gating-EEG-ARNN']:
    model_data = cs_df[cs_df['model'] == model_name]
    k30_data = model_data[model_data['k'] == 30]
    
    if len(k30_data) > 0:
        best_k30 = k30_data.loc[k30_data['mean_accuracy'].idxmax()]
        full_acc = pd.DataFrame(all_results[model_name])['accuracy'].mean()
        retention = (best_k30['mean_accuracy'] / full_acc) * 100
        
        print(f"\n   {model_name}:")
        print(f"     Method: {best_k30['method'].upper()}")
        print(f"     Accuracy: {best_k30['mean_accuracy']:.4f} (vs {full_acc:.4f} full)")
        print(f"     Retention: {retention:.1f}%")

# Comparison
if len(all_results) == 2:
    baseline_acc = pd.DataFrame(all_results['Baseline-EEG-ARNN'])['accuracy'].mean()
    adaptive_acc = pd.DataFrame(all_results['Adaptive-Gating-EEG-ARNN'])['accuracy'].mean()
    improvement = ((adaptive_acc - baseline_acc) / baseline_acc) * 100
    
    print(f"\nd) Adaptive Gating Improvement:")
    print(f"   Baseline: {baseline_acc:.4f}")
    print(f"   Adaptive: {adaptive_acc:.4f}")
    print(f"   Improvement: {improvement:+.2f}%")

print("\n" + "="*80)
print("PIPELINE 2 COMPLETE!")
print("="*80)

print("\nAll results saved to:")
print("  results/eegarnn_*_results.csv          - Detailed per-fold results")
print("  results/eegarnn_initial_summary.csv    - Initial training summary")
print("  results/channel_selection_results.csv  - Channel selection results")
print("  results/training_histories.pkl         - Training curves data")

print("\nAll visualizations saved to:")
print("  plots/training_curves.png              - Training history")
print("  plots/channel_selection_curves.png     - Channel selection performance")
print("  plots/retention_analysis.png           - Performance retention")

print("\nAll models saved to:")
print("  models/eegarnn_*.pt                    - Trained model checkpoints")

print("\n" + "="*80)
print("Ready for comparison with Pipeline 1!")
print("="*80)