# Pipeline 2: EEG-ARNN Methods with Channel Selection

This notebook implements the complete evaluation of EEG-ARNN methods with adaptive gating.

## Models
1. **Baseline-EEG-ARNN** - Without adaptive gating
2. **Adaptive-Gating-EEG-ARNN** - With adaptive data-dependent channel gating (YOUR CONTRIBUTION!)

## Experiments
### Part 1: Train Both Models
- 2-fold cross-validation
- 20 epochs with early stopping

### Part 2: Channel Selection (3 methods × 5 k-values)
1. **Edge Selection (ES)** - Graph adjacency-based
2. **Aggregation Selection (AS)** - Feature activation-based  
3. **Gate Selection (GS)** - Adaptive gating-based (Adaptive-Gating only)

k-values tested: [10, 20, 30, 40, 50]

### Part 3: Retention Analysis
- Tests performance degradation with channel reduction
- Uses Gate Selection (best method)
- k-values: [10, 15, 20, 25, 30, 35]

## Configuration
- **Epochs:** 20 (optimized for speed)
- **Cross-validation:** 2-fold (faster than 3-fold)
- **Learning rate:** 0.002
- **Batch size:** 64

## Expected Runtime: ~12-13 hours on Kaggle GPU
- Initial training: ~40 min (2 models × 2 folds)
- Channel selection: ~10 hours (60 experiments)
- Retention: ~2 hours (12 experiments)

## Outputs
```
results/eegarnn_baseline_results.csv        - Baseline-EEG-ARNN (2 folds)
results/eegarnn_adaptive_results.csv        - Adaptive-Gating-EEG-ARNN (2 folds)
results/channel_selection_results.csv       - All selection methods
results/retention_analysis.csv              - Retention curve (Gate-based)
results/eegarnn_complete_summary.csv        - Final summary
models/eegarnn_*.pt                         - Model checkpoints
```

## 1. Setup and Configuration

In [None]:
import os
import gc
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import mne
from sklearn.model_selection import StratifiedKFold
from copy import deepcopy
import warnings
warnings.filterwarnings('ignore')
mne.set_log_level('ERROR')

print("All imports successful!")

In [None]:
# Configuration - OPTIMIZED FOR SPEED
CONFIG = {
    'data_path': '/kaggle/input/eeg-preprocessed-data/derived',
    'models_dir': './models',
    'results_dir': './results',
    
    'n_folds': 2,  # Reduced from 3 for faster runtime
    'random_seed': 42,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    
    # Training hyperparameters
    'batch_size': 64,
    'epochs': 20,  # Optimized for speed
    'learning_rate': 0.002,
    'weight_decay': 1e-4,
    'patience': 5,
    'scheduler_patience': 2,
    'scheduler_factor': 0.5,
    'use_early_stopping': True,
    'min_lr': 1e-6,
    
    # Data parameters
    'n_channels': 64,
    'n_classes': 2,
    'sfreq': 128,
    'tmin': 0.0,
    'tmax': 4.0,
    'n_timepoints': 513,
    'hidden_dim': 128,
    'mi_runs': [7, 8, 11, 12],
    
    # Gating parameters
    'gating': {
        'gate_init': 0.9,
        'l1_lambda': 1e-3,
    },
    
    # Channel selection k-values
    'k_values': [10, 20, 30, 40, 50],
    'retention_k_values': [10, 15, 20, 25, 30, 35],
}

os.makedirs(CONFIG['models_dir'], exist_ok=True)
os.makedirs(CONFIG['results_dir'], exist_ok=True)

np.random.seed(CONFIG['random_seed'])
torch.manual_seed(CONFIG['random_seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed(CONFIG['random_seed'])

print(f"Device: {CONFIG['device']}")
print(f"Epochs: {CONFIG['epochs']}")
print(f"Folds: {CONFIG['n_folds']}")
print(f"Learning rate: {CONFIG['learning_rate']}")

# Runtime estimates
initial_runs = 2 * CONFIG['n_folds']
cs_runs = 2 * 3 * len(CONFIG['k_values']) * CONFIG['n_folds']
retention_runs = len(CONFIG['retention_k_values']) * CONFIG['n_folds']
total_runs = initial_runs + cs_runs + retention_runs

print(f"\nEstimated training runs:")
print(f"  Initial: {initial_runs}")
print(f"  Channel selection: {cs_runs}")
print(f"  Retention: {retention_runs}")
print(f"  TOTAL: {total_runs} runs")
print(f"\nEstimated runtime (~10 min/run): {total_runs * 10 / 60:.1f} hours")

## 2. Data Loading

In [None]:
def load_physionet_data(data_path):
    """Load preprocessed PhysioNet data."""
    data_root = os.path.abspath(data_path)
    if not os.path.isdir(data_root):
        raise FileNotFoundError(f"Data path not found: {data_root}")

    tmin, tmax = CONFIG['tmin'], CONFIG['tmax']
    mi_runs = CONFIG['mi_runs']
    event_id = {'T1': 1, 'T2': 2}
    label_map = {1: 0, 2: 1}

    preprocessed_dir = os.path.join(data_root, 'preprocessed')
    if os.path.isdir(preprocessed_dir):
        data_root = preprocessed_dir
    
    subject_dirs = [d for d in sorted(os.listdir(data_root))
                    if os.path.isdir(os.path.join(data_root, d)) and d.upper().startswith('S')]

    all_X, all_y, all_subjects = [], [], []
    print(f"Loading data from {len(subject_dirs)} subjects...")
    
    for subject_dir in subject_dirs:
        subject_num = int(subject_dir[1:]) if len(subject_dir) > 1 else -1
        subject_path = os.path.join(data_root, subject_dir)
        
        for run_id in mi_runs:
            run_file = f"{subject_dir}R{run_id:02d}_preproc_raw.fif"
            run_path = os.path.join(subject_path, run_file)
            
            if not os.path.exists(run_path):
                continue
            
            try:
                raw = mne.io.read_raw_fif(run_path, preload=True, verbose=False)
                picks = mne.pick_types(raw.info, eeg=True, meg=False, stim=False, eog=False)
                if len(picks) == 0:
                    continue
                
                events, _ = mne.events_from_annotations(raw, event_id=event_id)
                if len(events) == 0:
                    continue
                
                epochs = mne.Epochs(raw, events, event_id=event_id, tmin=tmin, tmax=tmax,
                                    baseline=None, preload=True, picks=picks, verbose=False)
                
                data = epochs.get_data()
                labels = np.array([label_map.get(epochs.events[i, 2], -1) for i in range(len(epochs))])
                valid = labels >= 0
                
                if np.any(valid):
                    all_X.append(data[valid])
                    all_y.append(labels[valid])
                    all_subjects.append(np.full(np.sum(valid), subject_num))
            except:
                continue
    
    X = np.concatenate(all_X, axis=0)
    y = np.concatenate(all_y, axis=0)
    subjects = np.concatenate(all_subjects, axis=0)
    
    print(f"Loaded {len(X)} trials from {len(np.unique(subjects))} subjects")
    print(f"Data shape: {X.shape}")
    print(f"Labels: {np.bincount(y)}")
    
    return X, y, subjects


class EEGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

## 3. Model Architectures

In [None]:
# Graph Convolution Layer
class GraphConvLayer(nn.Module):
    def __init__(self, num_channels, hidden_dim):
        super().__init__()
        self.A = nn.Parameter(torch.randn(num_channels, num_channels) * 0.01)
        self.theta = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.bn = nn.BatchNorm2d(hidden_dim)
        self.act = nn.ELU()

    def forward(self, x):
        B, H, C, T = x.shape
        A = torch.sigmoid(self.A)
        A = 0.5 * (A + A.t())
        I = torch.eye(C, device=A.device)
        A_hat = A + I
        D = torch.diag(torch.pow(A_hat.sum(1).clamp_min(1e-6), -0.5))
        A_norm = D @ A_hat @ D

        x_perm = x.permute(0, 3, 2, 1).contiguous().view(B * T, C, H)
        x_g = A_norm @ x_perm
        x_g = self.theta(x_g)
        x_g = x_g.view(B, T, C, H).permute(0, 3, 2, 1)
        return self.act(self.bn(x_g))

    def get_adjacency(self):
        with torch.no_grad():
            A = torch.sigmoid(self.A)
            A = 0.5 * (A + A.t())
            return A.cpu().numpy()


# Temporal Convolution
class TemporalConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=16, pool=True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(1, kernel_size),
                              padding=(0, kernel_size // 2), bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.ELU()
        self.pool_layer = nn.AvgPool2d(kernel_size=(1, 2)) if pool else None

    def forward(self, x):
        x = self.act(self.bn(self.conv(x)))
        if self.pool_layer is not None:
            x = self.pool_layer(x)
        return x

In [None]:
# Baseline EEG-ARNN (without gating)
class BaselineEEGARNN(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513, hidden_dim=128):
        super().__init__()
        self.n_channels = n_channels
        self.hidden_dim = hidden_dim
        self.use_gate_regularizer = False

        self.t1 = TemporalConv(1, hidden_dim, 16, pool=False)
        self.g1 = GraphConvLayer(n_channels, hidden_dim)
        self.t2 = TemporalConv(hidden_dim, hidden_dim, 16, pool=True)
        self.g2 = GraphConvLayer(n_channels, hidden_dim)
        self.t3 = TemporalConv(hidden_dim, hidden_dim, 16, pool=True)
        self.g3 = GraphConvLayer(n_channels, hidden_dim)

        with torch.no_grad():
            dummy = torch.zeros(1, n_channels, n_timepoints)
            feat = self._forward_features(self._prepare_input(dummy))
            self.feature_dim = feat.view(1, -1).size(1)

        self.fc1 = nn.Linear(self.feature_dim, 256)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, n_classes)

    def _prepare_input(self, x):
        if x.dim() == 3:
            x = x.unsqueeze(1)
        return x

    def _forward_features(self, x):
        x = self.g1(self.t1(x))
        x = self.g2(self.t2(x))
        x = self.g3(self.t3(x))
        return x

    def forward(self, x):
        prepared = self._prepare_input(x)
        features = self._forward_features(prepared)
        x = features.view(features.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

    def get_final_adjacency(self):
        return self.g3.get_adjacency()

    def get_channel_importance_edge(self):
        adjacency = self.get_final_adjacency()
        return np.sum(adjacency, axis=1)

In [None]:
# Adaptive Gating EEG-ARNN (YOUR CONTRIBUTION!)
class AdaptiveGatingEEGARNN(BaselineEEGARNN):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513, hidden_dim=128, gate_init=0.9):
        super().__init__(n_channels, n_classes, n_timepoints, hidden_dim)
        self.use_gate_regularizer = True
        
        # Adaptive gate network
        self.gate_net = nn.Sequential(
            nn.Linear(n_channels * 2, n_channels),
            nn.ReLU(),
            nn.Linear(n_channels, n_channels),
            nn.Sigmoid()
        )
        
        # Initialize gates to start high
        init_value = float(np.clip(gate_init, 1e-3, 1 - 1e-3))
        init_bias = math.log(init_value / (1.0 - init_value))
        with torch.no_grad():
            self.gate_net[-2].bias.fill_(init_bias)
        
        self.latest_gate_values = None
        self.gate_penalty_tensor = None

    def compute_gates(self, x):
        """Compute data-dependent channel gates."""
        x_s = x.squeeze(1)
        ch_mean = x_s.mean(dim=2)
        ch_std = x_s.std(dim=2)
        stats = torch.cat([ch_mean, ch_std], dim=1)
        return self.gate_net(stats)

    def forward(self, x):
        prepared = self._prepare_input(x)
        gates = self.compute_gates(prepared)
        self.gate_penalty_tensor = gates
        self.latest_gate_values = gates.detach()
        gated = prepared * gates.view(gates.size(0), 1, gates.size(1), 1)
        
        features = self._forward_features(gated)
        x = features.view(features.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

    def get_channel_importance_gate(self):
        """Get channel importance from gate values."""
        if self.latest_gate_values is None:
            return None
        return self.latest_gate_values.mean(dim=0).cpu().numpy()

## 4. Training Utilities

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device, l1_lambda=0.0):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    
    for X_batch, y_batch in dataloader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        
        # Add gating regularization if applicable
        gate_penalty = getattr(model, 'gate_penalty_tensor', None)
        if l1_lambda > 0 and gate_penalty is not None:
            loss = loss + l1_lambda * gate_penalty.abs().mean()
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += y_batch.size(0)
        correct += (predicted == y_batch).sum().item()
    
    return total_loss / max(1, len(dataloader)), correct / max(1, total)


def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    
    with torch.no_grad():
        for X_batch, y_batch in dataloader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += y_batch.size(0)
            correct += (predicted == y_batch).sum().item()
    
    return total_loss / max(1, len(dataloader)), correct / max(1, total)


def train_pytorch_model(model, train_loader, val_loader, config, model_name=''):
    device = config['device']
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], 
                          weight_decay=config['weight_decay'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=config['scheduler_factor'], 
        patience=config['scheduler_patience'], min_lr=config['min_lr'], verbose=False
    )
    
    l1_lambda = config['gating']['l1_lambda'] if getattr(model, 'use_gate_regularizer', False) else 0.0
    
    best_state = deepcopy(model.state_dict())
    best_val_acc = 0.0
    best_val_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(config['epochs']):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device, l1_lambda)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)
        scheduler.step(val_loss)
        
        improved = val_acc > best_val_acc or (val_acc == best_val_acc and val_loss < best_val_loss)
        if improved:
            best_state = deepcopy(model.state_dict())
            best_val_acc = val_acc
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
        
        if epoch % 5 == 0 or improved:
            print(f"[{model_name}] Epoch {epoch+1}/{config['epochs']} - "
                  f"Train: {train_acc:.4f} | Val: {val_acc:.4f} | Best: {best_val_acc:.4f}")
        
        if config['use_early_stopping'] and patience_counter >= config['patience']:
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    model.load_state_dict(best_state)
    return best_state, best_val_acc

## 5. Load Data

In [None]:
print("Loading PhysioNet data...")
X, y, subjects = load_physionet_data(CONFIG['data_path'])
print(f"\nData ready!")

## 6. Train EEG-ARNN Models

In [None]:
# Setup cross-validation
skf = StratifiedKFold(n_splits=CONFIG['n_folds'], shuffle=True, random_state=CONFIG['random_seed'])

models_to_train = [
    {'name': 'Baseline-EEG-ARNN', 'class': BaselineEEGARNN},
    {'name': 'Adaptive-Gating-EEG-ARNN', 'class': AdaptiveGatingEEGARNN},
]

print(f"\n{'='*60}")
print("TRAINING EEG-ARNN MODELS")
print(f"{'='*60}\n")

In [None]:
# Training loop
all_results = {}

for model_info in models_to_train:
    model_name = model_info['name']
    model_class = model_info['class']
    
    print(f"\n{'='*60}")
    print(f"Training {model_name}")
    print(f"{'='*60}\n")
    
    fold_results = []
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
        print(f"\nFold {fold + 1}/{CONFIG['n_folds']}")
        
        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]
        
        train_dataset = EEGDataset(X_train, y_train)
        val_dataset = EEGDataset(X_val, y_val)
        train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False)
        
        # Build model
        if model_class == AdaptiveGatingEEGARNN:
            model = model_class(
                n_channels=CONFIG['n_channels'],
                n_classes=CONFIG['n_classes'],
                n_timepoints=CONFIG['n_timepoints'],
                hidden_dim=CONFIG['hidden_dim'],
                gate_init=CONFIG['gating']['gate_init']
            )
        else:
            model = model_class(
                n_channels=CONFIG['n_channels'],
                n_classes=CONFIG['n_classes'],
                n_timepoints=CONFIG['n_timepoints'],
                hidden_dim=CONFIG['hidden_dim']
            )
        
        best_state, val_acc = train_pytorch_model(model, train_loader, val_loader, CONFIG, model_name)
        
        # Save model
        model_path = os.path.join(CONFIG['models_dir'], f"eegarnn_{model_name}_fold{fold+1}.pt")
        torch.save(best_state, model_path)
        
        fold_results.append({'fold': fold + 1, 'accuracy': val_acc})
        print(f"Fold {fold + 1} Accuracy: {val_acc:.4f}")
        
        del model
        torch.cuda.empty_cache()
        gc.collect()
    
    # Store results
    all_results[model_name] = fold_results
    mean_acc = np.mean([r['accuracy'] for r in fold_results])
    std_acc = np.std([r['accuracy'] for r in fold_results])
    
    print(f"\n{model_name} Summary:")
    print(f"Mean Accuracy: {mean_acc:.4f} +/- {std_acc:.4f}")

print(f"\n{'='*60}")
print("EEG-ARNN MODELS TRAINED!")
print(f"{'='*60}")

In [None]:
# Save initial results
for model_name, fold_results in all_results.items():
    df = pd.DataFrame(fold_results)
    df['model'] = model_name
    filename = model_name.lower().replace('-', '_').replace(' ', '_')
    df.to_csv(os.path.join(CONFIG['results_dir'], f'eegarnn_{filename}_results.csv'), index=False)
    print(f"Saved: results/eegarnn_{filename}_results.csv")

## 7. Channel Selection Utilities

In [None]:
def get_channel_importance_aggregation(model, dataloader, device):
    """Aggregation Selection: based on feature activations."""
    model.eval()
    channel_stats = []

    with torch.no_grad():
        for X_batch, _ in dataloader:
            X_batch = X_batch.to(device)
            prepared = model._prepare_input(X_batch)
            features = model._forward_features(prepared)
            activations = torch.mean(torch.abs(features), dim=(1, 3))
            channel_stats.append(activations.cpu())

    if not channel_stats:
        return np.zeros(model.n_channels)
    stacked = torch.cat(channel_stats, dim=0)
    return stacked.mean(dim=0).numpy()


def compute_gate_importance(model, dataloader, device):
    """Gate Selection: average gate values across dataset."""
    model.eval()
    gate_batches = []

    with torch.no_grad():
        for X_batch, _ in dataloader:
            X_batch = X_batch.to(device)
            _ = model(X_batch)
            latest = getattr(model, 'latest_gate_values', None)
            if latest is not None:
                gate_batches.append(latest.cpu())

    if not gate_batches:
        return np.ones(model.n_channels) / model.n_channels
    stacked = torch.cat(gate_batches, dim=0)
    return stacked.mean(dim=0).numpy()


def select_top_k_channels(importance_scores, k):
    """Select top k channels."""
    top_k_indices = np.argsort(importance_scores)[-k:]
    return sorted(top_k_indices)


def apply_channel_selection(X, selected_channels):
    """Apply channel selection to data."""
    return X[:, selected_channels, :]

## 8. Channel Selection Experiments

In [None]:
# Channel selection configuration
cs_experiments = [
    {'model': 'Baseline-EEG-ARNN', 'methods': ['edge', 'aggregation']},
    {'model': 'Adaptive-Gating-EEG-ARNN', 'methods': ['edge', 'aggregation', 'gate']},
]

channel_selection_results = []

print(f"\n{'='*60}")
print("CHANNEL SELECTION EVALUATION")
print(f"{'='*60}\n")

In [None]:
# Run channel selection experiments
for exp in cs_experiments:
    model_name = exp['model']
    methods = exp['methods']
    model_class = BaselineEEGARNN if 'Baseline' in model_name else AdaptiveGatingEEGARNN
    
    print(f"\n{'='*60}")
    print(f"{model_name}")
    print(f"{'='*60}\n")
    
    for method in methods:
        print(f"\nMethod: {method.upper()}")
        
        for k in CONFIG['k_values']:
            print(f"  k={k}:", end=' ')
            fold_accuracies = []
            
            for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
                X_train, X_val = X[train_idx], X[val_idx]
                y_train, y_val = y[train_idx], y[val_idx]
                
                # Load trained model
                if model_class == AdaptiveGatingEEGARNN:
                    model = model_class(
                        n_channels=CONFIG['n_channels'],
                        n_classes=CONFIG['n_classes'],
                        n_timepoints=CONFIG['n_timepoints'],
                        hidden_dim=CONFIG['hidden_dim'],
                        gate_init=CONFIG['gating']['gate_init']
                    )
                else:
                    model = model_class(
                        n_channels=CONFIG['n_channels'],
                        n_classes=CONFIG['n_classes'],
                        n_timepoints=CONFIG['n_timepoints'],
                        hidden_dim=CONFIG['hidden_dim']
                    )
                
                model_path = os.path.join(CONFIG['models_dir'], f"eegarnn_{model_name}_fold{fold+1}.pt")
                state_dict = torch.load(model_path, map_location=CONFIG['device'])
                model.load_state_dict(state_dict)
                model = model.to(CONFIG['device'])
                model.eval()
                
                # Compute importance
                if method == 'edge':
                    importance_scores = model.get_channel_importance_edge()
                elif method == 'aggregation':
                    train_dataset = EEGDataset(X_train, y_train)
                    train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=False)
                    importance_scores = get_channel_importance_aggregation(model, train_loader, CONFIG['device'])
                else:  # gate
                    train_dataset = EEGDataset(X_train, y_train)
                    train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=False)
                    importance_scores = compute_gate_importance(model, train_loader, CONFIG['device'])
                
                # Select channels
                selected_channels = select_top_k_channels(importance_scores, k)
                X_train_selected = apply_channel_selection(X_train, selected_channels)
                X_val_selected = apply_channel_selection(X_val, selected_channels)
                
                # Train new model
                if model_class == AdaptiveGatingEEGARNN:
                    new_model = model_class(
                        n_channels=k,
                        n_classes=CONFIG['n_classes'],
                        n_timepoints=CONFIG['n_timepoints'],
                        hidden_dim=CONFIG['hidden_dim'],
                        gate_init=CONFIG['gating']['gate_init']
                    )
                else:
                    new_model = model_class(
                        n_channels=k,
                        n_classes=CONFIG['n_classes'],
                        n_timepoints=CONFIG['n_timepoints'],
                        hidden_dim=CONFIG['hidden_dim']
                    )
                
                train_dataset = EEGDataset(X_train_selected, y_train)
                val_dataset = EEGDataset(X_val_selected, y_val)
                train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True)
                val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False)
                
                best_state, val_acc = train_pytorch_model(new_model, train_loader, val_loader, 
                                                          CONFIG, f"{model_name}-{method}-k{k}")
                fold_accuracies.append(val_acc)
                
                del model, new_model
                torch.cuda.empty_cache()
                gc.collect()
            
            mean_acc = np.mean(fold_accuracies)
            std_acc = np.std(fold_accuracies)
            print(f"{mean_acc:.4f} +/- {std_acc:.4f}")
            
            channel_selection_results.append({
                'model': model_name,
                'method': method,
                'k': k,
                'mean_accuracy': mean_acc,
                'std_accuracy': std_acc,
                'fold_accuracies': fold_accuracies
            })

print(f"\n{'='*60}")
print("CHANNEL SELECTION COMPLETE!")
print(f"{'='*60}")

In [None]:
# Save channel selection results
cs_df = pd.DataFrame(channel_selection_results)
cs_df.to_csv(os.path.join(CONFIG['results_dir'], 'channel_selection_results.csv'), index=False)

print("\nChannel Selection Results:")
print(cs_df[['model', 'method', 'k', 'mean_accuracy', 'std_accuracy']])
print("\nSaved: results/channel_selection_results.csv")

## 9. Retention Analysis

In [None]:
# Retention analysis using Gate Selection
print(f"\n{'='*60}")
print("RETENTION ANALYSIS: Adaptive-Gating-EEG-ARNN (Gate Selection)")
print(f"{'='*60}\n")

retention_results = []

for k in CONFIG['retention_k_values']:
    print(f"k={k}:", end=' ')
    fold_accuracies = []
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]
        
        # Load trained model
        model = AdaptiveGatingEEGARNN(
            n_channels=CONFIG['n_channels'],
            n_classes=CONFIG['n_classes'],
            n_timepoints=CONFIG['n_timepoints'],
            hidden_dim=CONFIG['hidden_dim'],
            gate_init=CONFIG['gating']['gate_init']
        )
        model_path = os.path.join(CONFIG['models_dir'], f"eegarnn_Adaptive-Gating-EEG-ARNN_fold{fold+1}.pt")
        state_dict = torch.load(model_path, map_location=CONFIG['device'])
        model.load_state_dict(state_dict)
        model = model.to(CONFIG['device'])
        model.eval()
        
        # Compute gate importance
        train_dataset = EEGDataset(X_train, y_train)
        train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=False)
        importance_scores = compute_gate_importance(model, train_loader, CONFIG['device'])
        selected_channels = select_top_k_channels(importance_scores, k)
        
        # Apply selection
        X_train_selected = apply_channel_selection(X_train, selected_channels)
        X_val_selected = apply_channel_selection(X_val, selected_channels)
        
        # Train new model
        new_model = AdaptiveGatingEEGARNN(
            n_channels=k,
            n_classes=CONFIG['n_classes'],
            n_timepoints=CONFIG['n_timepoints'],
            hidden_dim=CONFIG['hidden_dim'],
            gate_init=CONFIG['gating']['gate_init']
        )
        
        train_dataset = EEGDataset(X_train_selected, y_train)
        val_dataset = EEGDataset(X_val_selected, y_val)
        train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False)
        
        best_state, val_acc = train_pytorch_model(new_model, train_loader, val_loader, 
                                                  CONFIG, f"Retention-k{k}")
        fold_accuracies.append(val_acc)
        
        del model, new_model
        torch.cuda.empty_cache()
        gc.collect()
    
    mean_acc = np.mean(fold_accuracies)
    std_acc = np.std(fold_accuracies)
    print(f"{mean_acc:.4f} +/- {std_acc:.4f}")
    
    retention_results.append({
        'k': k,
        'mean_accuracy': mean_acc,
        'std_accuracy': std_acc,
        'fold_accuracies': fold_accuracies
    })

In [None]:
# Save retention results
retention_df = pd.DataFrame(retention_results)
retention_df.to_csv(os.path.join(CONFIG['results_dir'], 'retention_analysis.csv'), index=False)

print("\nRetention Analysis Results:")
print(retention_df[['k', 'mean_accuracy', 'std_accuracy']])
print("\nSaved: results/retention_analysis.csv")

## 10. Final Summary

In [None]:
# Create complete summary
print(f"\n{'='*60}")
print("COMPLETE SUMMARY")
print(f"{'='*60}\n")

# Model comparison
print("1. EEG-ARNN Model Comparison:")
for model_name, fold_results in all_results.items():
    accs = [r['accuracy'] for r in fold_results]
    print(f"   {model_name}: {np.mean(accs):.4f} +/- {np.std(accs):.4f}")

# Best channel selection
print("\n2. Best Channel Selection Method:")
best_cs = cs_df.loc[cs_df['mean_accuracy'].idxmax()]
print(f"   Model: {best_cs['model']}")
print(f"   Method: {best_cs['method'].upper()}")
print(f"   k: {best_cs['k']}")
print(f"   Accuracy: {best_cs['mean_accuracy']:.4f} +/- {best_cs['std_accuracy']:.4f}")

# Retention insights
print("\n3. Retention Analysis (Gate Selection):")
for _, row in retention_df.iterrows():
    print(f"   k={row['k']}: {row['mean_accuracy']:.4f} +/- {row['std_accuracy']:.4f}")

print(f"\n{'='*60}")
print("PIPELINE 2 COMPLETE!")
print(f"{'='*60}")
print("\nAll results saved to results/ directory")
print("Ready for comparison with Pipeline 1 baseline methods!")