# Pipeline 2: EEG-ARNN - MEMORY OPTIMIZED VERSION

**OPTIMIZED FOR KAGGLE GPU:** No model saving, aggressive memory cleanup

## Models
1. **Baseline-EEG-ARNN** - Pure CNN-GCN architecture
2. **Adaptive-Gating-EEG-ARNN** - With data-dependent channel gating

## Memory Optimizations
- NO model checkpoints saved (only channel importance scores)
- Reduced batch size: 32 (from 64)
- Aggressive CUDA cache clearing after each fold
- Minimal history tracking
- Delete models immediately after use

## Configuration
- **Dataset:** `/kaggle/input/eeg-preprocessed-data/derived`
- **Epochs:** 30 (NO early stopping - full training)
- **Cross-validation:** 2-fold
- **Batch size:** 32 (memory optimized)
- **Learning rate:** 0.002

## Outputs
```
results/eegarnn_baseline_results.csv          - Per-fold results
results/eegarnn_adaptive_results.csv          - Per-fold results
results/eegarnn_initial_summary.csv           - Summary statistics
results/channel_selection_results.csv         - All selection methods
results/training_histories.pkl                - Training curves
plots/*.png                                    - Visualizations
```

## 1. Setup

In [None]:
import os
import gc
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import mne
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix
)
import pickle
from copy import deepcopy
import warnings
warnings.filterwarnings('ignore')
mne.set_log_level('ERROR')

plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

print("All imports successful!")

In [None]:
# MEMORY-OPTIMIZED Configuration
CONFIG = {
    'data_path': '/kaggle/input/eeg-preprocessed-data/derived',
    'results_dir': './results',
    'plots_dir': './plots',
    
    'n_folds': 2,
    'random_seed': 42,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    
    # Training hyperparameters - MEMORY OPTIMIZED
    'batch_size': 64,  # Same as Pipeline 1
    'epochs': 30,  # Full training
    'learning_rate': 0.002,
    'weight_decay': 1e-4,
    'scheduler_patience': 3,
    'scheduler_factor': 0.5,
    'min_lr': 1e-6,
    
    # Data parameters
    'n_channels': 64,
    'n_classes': 2,
    'sfreq': 128,
    'tmin': 0.0,
    'tmax': 4.0,
    'n_timepoints': 513,
    'hidden_dim': 128,
    'mi_runs': [7, 8, 11, 12],
    
    # Gating parameters
    'gating': {
        'gate_init': 0.9,
        'l1_lambda': 1e-3,
    },
    
    # Channel selection k-values
    'k_values': [10, 15, 20, 25, 30],
    
    # Memory optimization flags
    'save_models': False,  # DO NOT SAVE MODELS
    'aggressive_cleanup': True,
}

os.makedirs(CONFIG['results_dir'], exist_ok=True)
os.makedirs(CONFIG['plots_dir'], exist_ok=True)

np.random.seed(CONFIG['random_seed'])
torch.manual_seed(CONFIG['random_seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed(CONFIG['random_seed'])
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

print(f"Device: {CONFIG['device']}")
print(f"Batch size: {CONFIG['batch_size']} (memory optimized)")
print(f"Epochs: {CONFIG['epochs']} (NO early stopping)")
print(f"Save models: {CONFIG['save_models']} (memory saving)")
print(f"K-values: {CONFIG['k_values']}")

## 2. Memory Cleanup Utilities

In [None]:
def cleanup_memory():
    """Aggressive memory cleanup."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

def print_memory_usage():
    """Print current GPU memory usage."""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3
        reserved = torch.cuda.memory_reserved() / 1024**3
        print(f"GPU Memory: Allocated={allocated:.2f}GB, Reserved={reserved:.2f}GB")

print("Memory utilities defined!")
print_memory_usage()

## 3. Data Loading

In [None]:
def load_physionet_data(data_path):
    """Load preprocessed PhysioNet data - MATCHING PIPELINE 1."""
    data_root = os.path.abspath(data_path)
    if not os.path.isdir(data_root):
        raise FileNotFoundError(f"Data path not found: {data_root}")

    tmin, tmax = CONFIG['tmin'], CONFIG['tmax']
    mi_runs = CONFIG['mi_runs']
    event_id = {'T1': 1, 'T2': 2}
    label_map = {1: 0, 2: 1}

    preprocessed_dir = os.path.join(data_root, 'preprocessed')
    if os.path.isdir(preprocessed_dir):
        data_root = preprocessed_dir
        print(f"Using preprocessed data from: {data_root}")
    else:
        print(f"Using data from: {data_root}")
    
    subject_dirs = [d for d in sorted(os.listdir(data_root))
                    if os.path.isdir(os.path.join(data_root, d)) and d.upper().startswith('S')]

    all_X, all_y, all_subjects = [], [], []
    print(f"Loading data from {len(subject_dirs)} subjects...")
    
    for subject_dir in subject_dirs:
        subject_num = int(subject_dir[1:]) if len(subject_dir) > 1 else -1
        subject_path = os.path.join(data_root, subject_dir)
        
        for run_id in mi_runs:
            run_file = f"{subject_dir}R{run_id:02d}_preproc_raw.fif"
            run_path = os.path.join(subject_path, run_file)
            
            if not os.path.exists(run_path):
                continue
            
            try:
                raw = mne.io.read_raw_fif(run_path, preload=True, verbose=False)
                picks = mne.pick_types(raw.info, eeg=True, meg=False, stim=False, eog=False)
                if len(picks) == 0:
                    continue
                
                events, _ = mne.events_from_annotations(raw, event_id=event_id, verbose=False)
                if len(events) == 0:
                    continue
                
                epochs = mne.Epochs(
                    raw, events, event_id=event_id, 
                    tmin=tmin, tmax=tmax,
                    baseline=None,
                    preload=True, 
                    picks=picks, 
                    verbose=False
                )
                
                data = epochs.get_data()
                labels = np.array([label_map.get(epochs.events[i, 2], -1) for i in range(len(epochs))])
                valid = labels >= 0
                
                if np.any(valid):
                    all_X.append(data[valid])
                    all_y.append(labels[valid])
                    all_subjects.append(np.full(np.sum(valid), subject_num))
                    
            except Exception as e:
                continue
    
    if len(all_X) == 0:
        raise ValueError("No data loaded!")
    
    X = np.concatenate(all_X, axis=0)
    y = np.concatenate(all_y, axis=0)
    subjects = np.concatenate(all_subjects, axis=0)
    
    print(f"\nData loaded: {len(X)} trials from {len(np.unique(subjects))} subjects")
    print(f"Shape: {X.shape}, Labels: {np.bincount(y)}")
    
    return X, y, subjects


class EEGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

## 4. Model Architectures

In [None]:
# Graph Convolution Layer
class GraphConvLayer(nn.Module):
    def __init__(self, num_channels, hidden_dim):
        super().__init__()
        self.A = nn.Parameter(torch.randn(num_channels, num_channels) * 0.01)
        self.theta = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.bn = nn.BatchNorm2d(hidden_dim)
        self.act = nn.ELU()

    def forward(self, x):
        B, H, C, T = x.shape
        A = torch.sigmoid(self.A)
        A = 0.5 * (A + A.t())
        I = torch.eye(C, device=A.device)
        A_hat = A + I
        D = torch.diag(torch.pow(A_hat.sum(1).clamp_min(1e-6), -0.5))
        A_norm = D @ A_hat @ D
        x_perm = x.permute(0, 3, 2, 1).contiguous().view(B * T, C, H)
        x_g = A_norm @ x_perm
        x_g = self.theta(x_g)
        x_g = x_g.view(B, T, C, H).permute(0, 3, 2, 1)
        return self.act(self.bn(x_g))

    def get_adjacency(self):
        with torch.no_grad():
            A = torch.sigmoid(self.A)
            A = 0.5 * (A + A.t())
            return A.cpu().numpy()


class TemporalConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=16, pool=True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(1, kernel_size),
                              padding=(0, kernel_size // 2), bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.ELU()
        self.pool_layer = nn.AvgPool2d(kernel_size=(1, 2)) if pool else None

    def forward(self, x):
        x = self.act(self.bn(self.conv(x)))
        if self.pool_layer is not None:
            x = self.pool_layer(x)
        return x


class BaselineEEGARNN(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513, hidden_dim=128):
        super().__init__()
        self.n_channels = n_channels
        self.hidden_dim = hidden_dim
        self.use_gate_regularizer = False

        self.t1 = TemporalConv(1, hidden_dim, 16, pool=False)
        self.g1 = GraphConvLayer(n_channels, hidden_dim)
        self.t2 = TemporalConv(hidden_dim, hidden_dim, 16, pool=True)
        self.g2 = GraphConvLayer(n_channels, hidden_dim)
        self.t3 = TemporalConv(hidden_dim, hidden_dim, 16, pool=True)
        self.g3 = GraphConvLayer(n_channels, hidden_dim)

        with torch.no_grad():
            dummy = torch.zeros(1, n_channels, n_timepoints)
            feat = self._forward_features(self._prepare_input(dummy))
            self.feature_dim = feat.view(1, -1).size(1)

        self.fc1 = nn.Linear(self.feature_dim, 256)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, n_classes)

    def _prepare_input(self, x):
        if x.dim() == 3:
            x = x.unsqueeze(1)
        return x

    def _forward_features(self, x):
        x = self.g1(self.t1(x))
        x = self.g2(self.t2(x))
        x = self.g3(self.t3(x))
        return x

    def forward(self, x):
        prepared = self._prepare_input(x)
        features = self._forward_features(prepared)
        x = features.view(features.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

    def get_final_adjacency(self):
        return self.g3.get_adjacency()

    def get_channel_importance_edge(self):
        adjacency = self.get_final_adjacency()
        return np.sum(np.abs(adjacency), axis=1)


class AdaptiveGatingEEGARNN(BaselineEEGARNN):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513, hidden_dim=128, gate_init=0.9):
        super().__init__(n_channels, n_classes, n_timepoints, hidden_dim)
        self.use_gate_regularizer = True
        
        self.gate_net = nn.Sequential(
            nn.Linear(n_channels * 2, n_channels),
            nn.ReLU(),
            nn.Linear(n_channels, n_channels),
            nn.Sigmoid()
        )
        
        init_value = float(np.clip(gate_init, 1e-3, 1 - 1e-3))
        init_bias = math.log(init_value / (1.0 - init_value))
        with torch.no_grad():
            self.gate_net[-2].bias.fill_(init_bias)
        
        self.latest_gate_values = None
        self.gate_penalty_tensor = None

    def compute_gates(self, x):
        x_s = x.squeeze(1)
        ch_mean = x_s.mean(dim=2)
        ch_std = x_s.std(dim=2)
        stats = torch.cat([ch_mean, ch_std], dim=1)
        return self.gate_net(stats)

    def forward(self, x):
        prepared = self._prepare_input(x)
        gates = self.compute_gates(prepared)
        self.gate_penalty_tensor = gates
        self.latest_gate_values = gates.detach()
        gated = prepared * gates.view(gates.size(0), 1, gates.size(1), 1)
        features = self._forward_features(gated)
        x = features.view(features.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

    def get_channel_importance_gate(self):
        if self.latest_gate_values is None:
            return None
        return self.latest_gate_values.mean(dim=0).cpu().numpy()


print("Models defined!")

## 5. Training Utilities

In [None]:
def calculate_comprehensive_metrics(model, dataloader, device):
    model.eval()
    all_preds, all_labels, all_probs = [], [], []
    
    with torch.no_grad():
        for X_batch, y_batch in dataloader:
            X_batch = X_batch.to(device)
            outputs = model(X_batch)
            probs = F.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(y_batch.numpy())
            all_probs.extend(probs.cpu().numpy())
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
    
    metrics = {
        'accuracy': accuracy_score(all_labels, all_preds),
        'precision': precision_score(all_labels, all_preds, average='binary', zero_division=0),
        'recall': recall_score(all_labels, all_preds, average='binary', zero_division=0),
        'f1_score': f1_score(all_labels, all_preds, average='binary', zero_division=0),
        'auc_roc': roc_auc_score(all_labels, all_probs[:, 1]) if len(np.unique(all_labels)) == 2 else 0.0,
    }
    
    cm = confusion_matrix(all_labels, all_preds)
    if cm.shape == (2, 2):
        tn, fp, fn, tp = cm.ravel()
        metrics['specificity'] = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        metrics['sensitivity'] = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    else:
        metrics['specificity'] = 0.0
        metrics['sensitivity'] = 0.0
    
    return metrics


def train_epoch(model, dataloader, criterion, optimizer, device, l1_lambda=0.0):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    
    for X_batch, y_batch in dataloader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        
        gate_penalty = getattr(model, 'gate_penalty_tensor', None)
        if l1_lambda > 0 and gate_penalty is not None:
            loss = loss + l1_lambda * gate_penalty.abs().mean()
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += y_batch.size(0)
        correct += (predicted == y_batch).sum().item()
    
    return total_loss / len(dataloader), correct / total


def evaluate_epoch(model, dataloader, criterion, device):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    
    with torch.no_grad():
        for X_batch, y_batch in dataloader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += y_batch.size(0)
            correct += (predicted == y_batch).sum().item()
    
    return total_loss / len(dataloader), correct / total


def train_model_full(model, train_loader, val_loader, config, model_name=''):
    device = config['device']
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=config['scheduler_factor'], 
        patience=config['scheduler_patience'], min_lr=config['min_lr'], verbose=False
    )
    
    l1_lambda = config['gating']['l1_lambda'] if getattr(model, 'use_gate_regularizer', False) else 0.0
    best_state = deepcopy(model.state_dict())
    best_val_acc = 0.0
    
    # Minimal history tracking
    history = {'train_acc': [], 'val_acc': []}
    
    print(f"[{model_name}] Training {config['epochs']} epochs")
    
    for epoch in range(config['epochs']):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device, l1_lambda)
        val_loss, val_acc = evaluate_epoch(model, val_loader, criterion, device)
        scheduler.step(val_loss)
        
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        
        if val_acc > best_val_acc:
            best_state = deepcopy(model.state_dict())
            best_val_acc = val_acc
        
        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f"  Epoch {epoch+1}/{config['epochs']} - Val Acc: {val_acc:.4f} | Best: {best_val_acc:.4f}")
    
    model.load_state_dict(best_state)
    return best_state, best_val_acc, history


print("Training utilities defined!")

## 6. Channel Selection Utilities

In [None]:
def get_channel_importance_aggregation(model, dataloader, device):
    model.eval()
    channel_stats = []
    with torch.no_grad():
        for X_batch, _ in dataloader:
            X_batch = X_batch.to(device)
            prepared = model._prepare_input(X_batch)
            features = model._forward_features(prepared)
            activations = torch.mean(torch.abs(features), dim=(1, 3))
            channel_stats.append(activations.cpu())
    if not channel_stats:
        return np.zeros(model.n_channels)
    stacked = torch.cat(channel_stats, dim=0)
    return stacked.mean(dim=0).numpy()


def compute_gate_importance(model, dataloader, device):
    model.eval()
    gate_batches = []
    with torch.no_grad():
        for X_batch, _ in dataloader:
            X_batch = X_batch.to(device)
            _ = model(X_batch)
            latest = getattr(model, 'latest_gate_values', None)
            if latest is not None:
                gate_batches.append(latest.cpu())
    if not gate_batches:
        return np.ones(model.n_channels) / model.n_channels
    stacked = torch.cat(gate_batches, dim=0)
    return stacked.mean(dim=0).numpy()


def select_top_k_channels(importance_scores, k):
    top_k_indices = np.argsort(importance_scores)[-k:]
    return sorted(top_k_indices)


def apply_channel_selection(X, selected_channels):
    return X[:, selected_channels, :]


print("Channel selection utilities defined!")

## 7. Load Data

In [None]:
print("="*80)
print("LOADING DATA")
print("="*80)

X, y, subjects = load_physionet_data(CONFIG['data_path'])
cleanup_memory()
print_memory_usage()

print("\nData ready!")

## 8. Train Initial Models (NO CHECKPOINTS)

In [None]:
skf = StratifiedKFold(n_splits=CONFIG['n_folds'], shuffle=True, random_state=CONFIG['random_seed'])

models_to_train = [
    {'name': 'Baseline-EEG-ARNN', 'class': BaselineEEGARNN},
    {'name': 'Adaptive-Gating-EEG-ARNN', 'class': AdaptiveGatingEEGARNN},
]

print("\n" + "="*80)
print("TRAINING INITIAL MODELS (NO CHECKPOINTS SAVED)")
print("="*80 + "\n")

In [None]:
# Store channel importance for each model/fold
channel_importance_store = {}

all_results = {}
all_histories = {}

for model_info in models_to_train:
    model_name = model_info['name']
    model_class = model_info['class']
    
    print(f"\n{'='*80}")
    print(f"Training: {model_name}")
    print(f"{'='*80}\n")
    
    fold_results = []
    fold_histories = []
    channel_importance_store[model_name] = {}
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
        print(f"Fold {fold + 1}/{CONFIG['n_folds']}")
        cleanup_memory()
        
        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]
        
        train_dataset = EEGDataset(X_train, y_train)
        val_dataset = EEGDataset(X_val, y_val)
        train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False)
        
        if model_class == AdaptiveGatingEEGARNN:
            model = model_class(
                n_channels=CONFIG['n_channels'], n_classes=CONFIG['n_classes'],
                n_timepoints=CONFIG['n_timepoints'], hidden_dim=CONFIG['hidden_dim'],
                gate_init=CONFIG['gating']['gate_init']
            )
        else:
            model = model_class(
                n_channels=CONFIG['n_channels'], n_classes=CONFIG['n_classes'],
                n_timepoints=CONFIG['n_timepoints'], hidden_dim=CONFIG['hidden_dim']
            )
        
        best_state, val_acc, history = train_model_full(
            model, train_loader, val_loader, CONFIG, f"{model_name}-F{fold+1}"
        )
        
        model.load_state_dict(best_state)
        model = model.to(CONFIG['device'])
        metrics = calculate_comprehensive_metrics(model, val_loader, CONFIG['device'])
        
        print(f"Results: Acc={metrics['accuracy']:.4f}, F1={metrics['f1_score']:.4f}")
        
        # Store channel importance for later use (NOT full model)
        importance_info = {
            'adjacency': model.get_final_adjacency(),
            'edge_importance': model.get_channel_importance_edge(),
        }
        
        if hasattr(model, 'get_channel_importance_gate'):
            # Compute gate importance
            gate_imp = compute_gate_importance(model, train_loader, CONFIG['device'])
            importance_info['gate_importance'] = gate_imp
        
        # Compute aggregation importance
        agg_imp = get_channel_importance_aggregation(model, train_loader, CONFIG['device'])
        importance_info['aggregation_importance'] = agg_imp
        
        channel_importance_store[model_name][fold] = importance_info
        
        fold_results.append({
            'fold': fold + 1,
            'accuracy': metrics['accuracy'],
            'precision': metrics['precision'],
            'recall': metrics['recall'],
            'f1_score': metrics['f1_score'],
            'auc_roc': metrics['auc_roc'],
            'specificity': metrics['specificity'],
            'sensitivity': metrics['sensitivity']
        })
        fold_histories.append(history)
        
        # AGGRESSIVE CLEANUP
        del model, best_state, train_loader, val_loader, train_dataset, val_dataset
        cleanup_memory()
    
    all_results[model_name] = fold_results
    all_histories[model_name] = fold_histories
    
    df_temp = pd.DataFrame(fold_results)
    print(f"\n{model_name} Summary:")
    print(f"  Accuracy: {df_temp['accuracy'].mean():.4f} Â± {df_temp['accuracy'].std():.4f}")
    print(f"  F1-Score: {df_temp['f1_score'].mean():.4f} Â± {df_temp['f1_score'].std():.4f}")
    cleanup_memory()

print(f"\n{'='*80}")
print("INITIAL TRAINING COMPLETE!")
print("="*80)
print_memory_usage()

## 9. Save Initial Results

In [None]:
# Save results
for model_name, fold_results in all_results.items():
    df = pd.DataFrame(fold_results)
    df['model'] = model_name
    cols = ['model', 'fold', 'accuracy', 'precision', 'recall', 'f1_score', 
            'auc_roc', 'specificity', 'sensitivity']
    df = df[cols]
    filename = model_name.lower().replace('-', '_').replace(' ', '_')
    filepath = os.path.join(CONFIG['results_dir'], f'eegarnn_{filename}_results.csv')
    df.to_csv(filepath, index=False)
    print(f"Saved: {filepath}")

# Summary
summary_data = []
for model_name, fold_results in all_results.items():
    df_temp = pd.DataFrame(fold_results)
    summary = {'model': model_name}
    for metric in ['accuracy', 'precision', 'recall', 'f1_score', 'auc_roc', 'specificity', 'sensitivity']:
        summary[f'mean_{metric}'] = df_temp[metric].mean()
        summary[f'std_{metric}'] = df_temp[metric].std()
    summary_data.append(summary)

summary_df = pd.DataFrame(summary_data)
filepath = os.path.join(CONFIG['results_dir'], 'eegarnn_initial_summary.csv')
summary_df.to_csv(filepath, index=False)
print(f"Saved: {filepath}")

# Histories
filepath = os.path.join(CONFIG['results_dir'], 'training_histories.pkl')
with open(filepath, 'wb') as f:
    pickle.dump(all_histories, f)
print(f"Saved: {filepath}")

print("\nInitial results summary:")
print(summary_df[['model', 'mean_accuracy', 'mean_f1_score']].to_string(index=False))

## 10. Channel Selection Experiments

In [None]:
cs_experiments = [
    {'model': 'Baseline-EEG-ARNN', 'methods': ['edge', 'aggregation']},
    {'model': 'Adaptive-Gating-EEG-ARNN', 'methods': ['edge', 'aggregation', 'gate']},
]

print("\n" + "="*80)
print("CHANNEL SELECTION")
print("="*80)
print(f"k-values: {CONFIG['k_values']}")
print("="*80 + "\n")

In [None]:
channel_selection_results = []

for exp in cs_experiments:
    model_name = exp['model']
    methods = exp['methods']
    model_class = BaselineEEGARNN if 'Baseline' in model_name else AdaptiveGatingEEGARNN
    
    print(f"\n{model_name}")
    
    for method in methods:
        print(f"\n  {method.upper()}:", end='')
        
        for k in CONFIG['k_values']:
            fold_metrics_list = []
            
            for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
                cleanup_memory()
                
                X_train, X_val = X[train_idx], X[val_idx]
                y_train, y_val = y[train_idx], y[val_idx]
                
                # Get importance scores from stored data
                if method == 'edge':
                    importance_scores = channel_importance_store[model_name][fold]['edge_importance']
                elif method == 'aggregation':
                    importance_scores = channel_importance_store[model_name][fold]['aggregation_importance']
                else:  # gate
                    importance_scores = channel_importance_store[model_name][fold]['gate_importance']
                
                # Select channels
                selected_channels = select_top_k_channels(importance_scores, k)
                X_train_selected = apply_channel_selection(X_train, selected_channels)
                X_val_selected = apply_channel_selection(X_val, selected_channels)
                
                # Train new model
                if model_class == AdaptiveGatingEEGARNN:
                    new_model = model_class(
                        n_channels=k, n_classes=CONFIG['n_classes'],
                        n_timepoints=CONFIG['n_timepoints'], hidden_dim=CONFIG['hidden_dim'],
                        gate_init=CONFIG['gating']['gate_init']
                    )
                else:
                    new_model = model_class(
                        n_channels=k, n_classes=CONFIG['n_classes'],
                        n_timepoints=CONFIG['n_timepoints'], hidden_dim=CONFIG['hidden_dim']
                    )
                
                train_dataset = EEGDataset(X_train_selected, y_train)
                val_dataset = EEGDataset(X_val_selected, y_val)
                train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True)
                val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False)
                
                best_state, val_acc, _ = train_model_full(
                    new_model, train_loader, val_loader, CONFIG, f"{model_name}-{method}-k{k}-F{fold+1}"
                )
                
                new_model.load_state_dict(best_state)
                new_model = new_model.to(CONFIG['device'])
                metrics = calculate_comprehensive_metrics(new_model, val_loader, CONFIG['device'])
                fold_metrics_list.append(metrics)
                
                del new_model, best_state, train_loader, val_loader, train_dataset, val_dataset
                cleanup_memory()
            
            # Compute mean metrics
            mean_metrics = {}
            for metric_name in ['accuracy', 'precision', 'recall', 'f1_score', 'auc_roc', 'specificity', 'sensitivity']:
                values = [m[metric_name] for m in fold_metrics_list]
                mean_metrics[f'mean_{metric_name}'] = np.mean(values)
                mean_metrics[f'std_{metric_name}'] = np.std(values)
            
            print(f" k={k}:{mean_metrics['mean_accuracy']:.3f}", end='')
            
            result = {'model': model_name, 'method': method, 'k': k}
            result.update(mean_metrics)
            channel_selection_results.append(result)
        
        print()  # newline

print(f"\n{'='*80}")
print("CHANNEL SELECTION COMPLETE!")
print("="*80)
print_memory_usage()

## 11. Save Channel Selection Results

In [None]:
cs_df = pd.DataFrame(channel_selection_results)
cols = ['model', 'method', 'k'] + [col for col in cs_df.columns if col not in ['model', 'method', 'k']]
cs_df = cs_df[cols]

filepath = os.path.join(CONFIG['results_dir'], 'channel_selection_results.csv')
cs_df.to_csv(filepath, index=False)
print(f"Saved: {filepath}")

print("\nBest results:")
for model_name in ['Baseline-EEG-ARNN', 'Adaptive-Gating-EEG-ARNN']:
    model_data = cs_df[cs_df['model'] == model_name]
    best_row = model_data.loc[model_data['mean_accuracy'].idxmax()]
    print(f"{model_name}: {best_row['method'].upper()} k={int(best_row['k'])} "
          f"Acc={best_row['mean_accuracy']:.4f}")

## 12. Final Summary

In [None]:
print("\n" + "="*80)
print("PIPELINE 2 COMPLETE!")
print("="*80)

print("\nResults saved to results/")
print("Ready for comparison with Pipeline 1!")
print("="*80)