# PhysioNet EEG: Channel Selection Evaluation

This notebook evaluates three channel selection methods:
1. **Edge Selection (ES)**: Based on sum of outgoing edge weights from CARM adjacency matrix
2. **Aggregation Selection (AS)**: Based on feature aggregation after CARM layers
3. **Gate Selection (GS)**: Based on adaptive gate values (only for Adaptive-Gating-EEG-ARNN)

**Expected Runtime**: 8-10 hours on Kaggle GPU

**Input**: 
- `/kaggle/input/physionet-preprocessed/derived/` (preprocessed EEG data)
- `models/` folder (trained models from notebook 01)

**Output**: 
- `channel_selection_results.csv` - Results for all methods and k values
- `retention_analysis.csv` - Performance vs number of channels retained

## Configuration and Imports

In [None]:
import os
import math
from copy import deepcopy

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import mne
from sklearn.model_selection import StratifiedKFold
import pickle
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

mne.set_log_level('ERROR')


In [None]:
# Configuration
CONFIG = {
    'data_path': '/kaggle/input/eeg-preprocessed-data/derived',  # Change this for local testing
    'output_dir': './',
    'results_dir': './results',
    'models_dir': './models',
    'figures_dir': './figures',

    'n_folds': 3,
    'random_seed': 42,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',

    # Training hyperparameters
    'batch_size': 64,
    'epochs': 20,
    'learning_rate': 0.001,
    'weight_decay': 1e-4,
    'patience': 10,
    'scheduler_patience': 3,
    'use_early_stopping': False,

    # Data parameters
    'n_channels': 64,
    'n_classes': 2,
    'sfreq': 128,
    'tmin': 0.0,
    'tmax': 4.0,
    'n_timepoints': 513,  # 4 seconds at 128 Hz + 1
    'hidden_dim': 128,
    'mi_runs': [7, 8, 11, 12],

    # FBCSP parameters
    'fbcsp_bands': [(4, 8), (8, 12), (12, 16), (16, 20), (20, 24), (24, 28), (28, 32), (32, 36), (36, 40)],
    'fbcsp_n_components': 4,

    # Gating regularization
    'gating': {
        'gate_init': 0.9,
        'l1_lambda': 1e-3,
    },
}

# Create output directories
os.makedirs(CONFIG['results_dir'], exist_ok=True)
os.makedirs(CONFIG['models_dir'], exist_ok=True)
os.makedirs(CONFIG['figures_dir'], exist_ok=True)

# Set random seeds
np.random.seed(CONFIG['random_seed'])
torch.manual_seed(CONFIG['random_seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed(CONFIG['random_seed'])

print(f"Device: {CONFIG['device']}")
print(f"Data path: {CONFIG['data_path']}")


## Copy Model Architectures and Data Loading from Notebook 01

In [None]:
def load_physionet_data(data_path, subject_ids=None):
    """
    Load preprocessed PhysioNet motor imagery data from the derived folder.
    Supports both the newer folder structure (derived/preprocessed/S***/S***R**_preproc_raw.fif)
    and the legacy flat directory containing epoch files.
    """
    data_root = os.path.abspath(data_path)
    if not os.path.isdir(data_root):
        raise FileNotFoundError(f"Data path not found: {data_root}")

    config = globals().get('CONFIG', {})
    tmin = float(config.get('tmin', 0.0))
    tmax = float(config.get('tmax', 4.0))
    mi_runs = [int(r) for r in config.get('mi_runs', [7, 8, 11, 12])]
    event_id = {'T1': 1, 'T2': 2}

    def normalize_subject(value):
        if value is None:
            return None
        if isinstance(value, str) and value.upper().startswith('S'):
            value = value[1:]
        try:
            return int(value)
        except Exception:
            return None

    subject_filter = None
    if subject_ids is not None:
        subject_filter = set()
        for sid in subject_ids:
            norm = normalize_subject(sid)
            if norm is not None:
                subject_filter.add(norm)
        if not subject_filter:
            subject_filter = None

    def aggregate_results(blocks_X, blocks_y, blocks_subjects):
        X = np.concatenate(blocks_X, axis=0)
        y = np.concatenate(blocks_y, axis=0)
        subjects = np.concatenate(blocks_subjects, axis=0)
        print(f"Loaded {len(X)} trials from {len(np.unique(subjects))} subjects")
        print(f"Data shape: {X.shape}")
        print(f"Label distribution: {np.bincount(y)}")
        return X, y, subjects

    subject_root = data_root
    preprocessed_dir = os.path.join(data_root, 'preprocessed')
    if os.path.isdir(preprocessed_dir):
        subject_root = preprocessed_dir
    subject_dirs = [d for d in sorted(os.listdir(subject_root))
                    if os.path.isdir(os.path.join(subject_root, d)) and d.upper().startswith('S')]

    all_X, all_y, all_subjects = [], [], []
    if subject_dirs:
        print(f"Detected {len(subject_dirs)} preprocessed subject folders under {subject_root}")
        label_map = {event_id['T1']: 0, event_id['T2']: 1}
        for subject_dir in subject_dirs:
            subject_numeric = normalize_subject(subject_dir)
            if subject_filter and subject_numeric not in subject_filter:
                continue
            subject_path = os.path.join(subject_root, subject_dir)
            for run_id in mi_runs:
                candidate_names = [
                    f"{subject_dir}R{run_id:02d}_preproc_raw.fif",
                    f"{subject_dir}R{run_id:02d}_raw.fif",
                    f"{subject_dir}R{run_id:02d}.fif",
                    f"{subject_dir}_R{run_id:02d}.fif",
                ]
                run_path = None
                for name in candidate_names:
                    candidate = os.path.join(subject_path, name)
                    if os.path.exists(candidate):
                        run_path = candidate
                        break
                if run_path is None:
                    continue
                try:
                    raw = mne.io.read_raw_fif(run_path, preload=True, verbose=False)
                except Exception as e:
                    print(f"Error loading {run_path}: {e}")
                    continue
                picks = mne.pick_types(raw.info, eeg=True, meg=False, stim=False, eog=False)
                if len(picks) == 0:
                    continue
                try:
                    events, _ = mne.events_from_annotations(raw, event_id=event_id)
                except Exception as e:
                    print(f"Error parsing annotations for {run_path}: {e}")
                    continue
                if len(events) == 0:
                    continue
                try:
                    epochs = mne.Epochs(raw, events, event_id=event_id, tmin=tmin, tmax=tmax,
                                        baseline=None, preload=True, picks=picks, verbose=False)
                except Exception as e:
                    print(f"Error epoching {run_path}: {e}")
                    continue
                data = epochs.get_data()
                labels = epochs.events[:, 2]
                mapped = np.array([label_map.get(lbl, -1) for lbl in labels])
                valid_mask = mapped >= 0
                if not np.any(valid_mask):
                    continue
                all_X.append(data[valid_mask])
                all_y.append(mapped[valid_mask])
                subj_label = subject_numeric if subject_numeric is not None else -1
                all_subjects.append(np.full(np.sum(valid_mask), subj_label))
        if all_X:
            return aggregate_results(all_X, all_y, all_subjects)
        print("No data loaded from preprocessed folders, falling back to legacy format...")

    # Legacy format fallback (flat directory with epoch files)
    legacy_files = [f for f in os.listdir(data_root) if f.endswith('.fif')]
    if not legacy_files:
        raise ValueError(
            "No valid PhysioNet files found. Ensure the derived folder contains either "
            "preprocessed subject subfolders or .fif epoch files."
        )
    if subject_filter:
        filtered = []
        for fname in legacy_files:
            parts = fname.split('_')
            if not parts:
                continue
            subj = normalize_subject(parts[0])
            if subj is not None and subj in subject_filter:
                filtered.append(fname)
        legacy_files = filtered
        if not legacy_files:
            raise ValueError("No files matched the requested subject IDs in legacy format.")

    print(f"Found {len(legacy_files)} legacy epoch files. Loading...")
    for fname in legacy_files:
        filepath = os.path.join(data_root, fname)
        try:
            epochs = mne.read_epochs(filepath, preload=True, verbose=False)
        except Exception as e:
            print(f"Error loading {filepath}: {e}")
            continue
        current_event_id = epochs.event_id
        if not current_event_id:
            continue
        label_lookup = {}
        if 'T1' in current_event_id:
            label_lookup[current_event_id['T1']] = 0
        if 'T2' in current_event_id:
            label_lookup[current_event_id['T2']] = 1
        if not label_lookup:
            continue
        labels = np.array([label_lookup.get(epochs.events[i, -1], -1) for i in range(len(epochs))])
        valid = labels >= 0
        if not np.any(valid):
            continue
        data = epochs.get_data()[valid]
        labels = labels[valid]
        subj = normalize_subject(fname.split('_')[0])
        subj_arr = np.full(len(labels), subj if subj is not None else -1)
        all_X.append(data)
        all_y.append(labels)
        all_subjects.append(subj_arr)
    if not all_X:
        raise ValueError("No valid trials were loaded from the provided PhysioNet files.")
    return aggregate_results(all_X, all_y, all_subjects)


In [None]:
# EEG-ARNN Baseline + Adaptive Gating
class GraphConvLayer(nn.Module):
    """Graph convolution with learnable adjacency."""
    def __init__(self, num_channels, hidden_dim):
        super().__init__()
        self.num_channels = num_channels
        self.hidden_dim = hidden_dim
        self.A = nn.Parameter(torch.randn(num_channels, num_channels) * 0.01)
        self.theta = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.bn = nn.BatchNorm2d(hidden_dim)
        self.act = nn.ELU()

    def forward(self, x):
        B, H, C, T = x.shape
        A = torch.sigmoid(self.A)
        A = 0.5 * (A + A.t())
        I = torch.eye(C, device=A.device)
        A_hat = A + I
        D = torch.diag(torch.pow(A_hat.sum(1).clamp_min(1e-6), -0.5))
        A_norm = D @ A_hat @ D

        x_perm = x.permute(0, 3, 2, 1).contiguous().view(B * T, C, H)
        x_g = A_norm @ x_perm
        x_g = self.theta(x_g)
        x_g = x_g.view(B, T, C, H).permute(0, 3, 2, 1)
        x_out = self.bn(x_g)
        return self.act(x_out)

    def get_adjacency(self):
        with torch.no_grad():
            A = torch.sigmoid(self.A)
            A = 0.5 * (A + A.t())
            return A.cpu().numpy()


class TemporalConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=16, pool=True):
        super().__init__()
        self.pool = pool
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(1, kernel_size),
                              padding=(0, kernel_size // 2), bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.ELU()
        self.pool_layer = nn.AvgPool2d(kernel_size=(1, 2)) if pool else None

    def forward(self, x):
        x = self.act(self.bn(self.conv(x)))
        if self.pool_layer is not None:
            x = self.pool_layer(x)
        return x


class BaselineEEGARNN(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513, hidden_dim=128):
        super().__init__()
        self.n_channels = n_channels
        self.hidden_dim = hidden_dim
        self.use_gate_regularizer = False
        self.gate_penalty_tensor = None
        self.latest_gate_values = None

        self.t1 = TemporalConv(1, hidden_dim, 16, pool=False)
        self.g1 = GraphConvLayer(n_channels, hidden_dim)
        self.t2 = TemporalConv(hidden_dim, hidden_dim, 16, pool=True)
        self.g2 = GraphConvLayer(n_channels, hidden_dim)
        self.t3 = TemporalConv(hidden_dim, hidden_dim, 16, pool=True)
        self.g3 = GraphConvLayer(n_channels, hidden_dim)

        with torch.no_grad():
            dummy = torch.zeros(1, n_channels, n_timepoints)
            feat = self._forward_features(self._prepare_input(dummy))
            self.feature_dim = feat.view(1, -1).size(1)

        self.fc1 = nn.Linear(self.feature_dim, 256)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, n_classes)

    def _prepare_input(self, x):
        if x.dim() == 3:
            x = x.unsqueeze(1)
        return x

    def _forward_features(self, x):
        x = self.g1(self.t1(x))
        x = self.g2(self.t2(x))
        x = self.g3(self.t3(x))
        return x

    def _forward_from_prepared(self, x):
        features = self._forward_features(x)
        x = features.view(features.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

    def forward(self, x):
        prepared = self._prepare_input(x)
        self.gate_penalty_tensor = None
        self.latest_gate_values = None
        return self._forward_from_prepared(prepared)

    def get_final_adjacency(self):
        return self.g3.get_adjacency()

    def get_channel_importance_edge(self):
        adjacency = self.get_final_adjacency()
        return np.sum(adjacency, axis=1)


class AdaptiveGatingEEGARNN(BaselineEEGARNN):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513, hidden_dim=128, gate_init=0.9):
        super().__init__(n_channels, n_classes, n_timepoints, hidden_dim)
        self.use_gate_regularizer = True
        self.gate_net = nn.Sequential(
            nn.Linear(n_channels * 2, n_channels),
            nn.ReLU(),
            nn.Linear(n_channels, n_channels),
            nn.Sigmoid()
        )
        init_value = float(np.clip(gate_init, 1e-3, 1 - 1e-3))
        init_bias = math.log(init_value / (1.0 - init_value))
        with torch.no_grad():
            self.gate_net[-2].bias.fill_(init_bias)
        self.latest_gate_values = None

    def compute_gates(self, x):
        x_s = x.squeeze(1)
        ch_mean = x_s.mean(dim=2)
        ch_std = x_s.std(dim=2)
        stats = torch.cat([ch_mean, ch_std], dim=1)
        return self.gate_net(stats)

    def forward(self, x):
        prepared = self._prepare_input(x)
        gates = self.compute_gates(prepared)
        self.gate_penalty_tensor = gates
        self.latest_gate_values = gates.detach()
        gated = prepared * gates.view(gates.size(0), 1, gates.size(1), 1)
        return self._forward_from_prepared(gated)

    def get_channel_importance_gate(self):
        if self.latest_gate_values is None:
            return None
        return self.latest_gate_values.mean(dim=0).cpu().numpy()


In [None]:
# Training utilities
def train_epoch(model, dataloader, criterion, optimizer, device, l1_lambda=0.0):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for X_batch, y_batch in dataloader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)

        gate_penalty = getattr(model, 'gate_penalty_tensor', None)
        if l1_lambda > 0 and gate_penalty is not None:
            loss = loss + l1_lambda * gate_penalty.abs().mean()

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += y_batch.size(0)
        correct += (predicted == y_batch).sum().item()

    denom = max(1, len(dataloader))
    return total_loss / denom, correct / max(1, total)


def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for X_batch, y_batch in dataloader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += y_batch.size(0)
            correct += (predicted == y_batch).sum().item()

    denom = max(1, len(dataloader))
    return total_loss / denom, correct / max(1, total)


def train_pytorch_model(model, train_loader, val_loader, config, model_name=''):
    device = config['device']
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], 
                          weight_decay=config['weight_decay'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=config.get('scheduler_patience', 3), verbose=False
    )

    l1_lambda = config.get('gating', {}).get('l1_lambda', 0.0) if getattr(model, 'use_gate_regularizer', False) else 0.0
    use_early_stopping = config.get('use_early_stopping', False) and config.get('patience') is not None
    max_patience = config.get('patience', 0)
    patience_counter = 0

    best_state = deepcopy(model.state_dict())
    best_val_acc = 0.0
    best_val_loss = float('inf')

    for epoch in range(config['epochs']):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device, l1_lambda)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)
        scheduler.step(val_loss)

        improved = val_acc > best_val_acc or (val_acc == best_val_acc and val_loss < best_val_loss)
        if improved:
            best_state = deepcopy(model.state_dict())
            best_val_acc = val_acc
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1

        prefix = model_name if model_name else 'Model'
        print(f"[{prefix}] Epoch {epoch + 1}/{config['epochs']} - Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

        if use_early_stopping and patience_counter >= max_patience:
            print(f"Early stopping triggered for {prefix} at epoch {epoch + 1}")
            break

    model.load_state_dict(best_state)
    return best_state, best_val_acc


## Channel Selection Functions

In [None]:
def get_channel_importance_aggregation(model, dataloader, device):
    """Aggregation Selection (AS) using averaged feature activations."""
    model.eval()
    channel_stats = []

    with torch.no_grad():
        for X_batch, _ in dataloader:
            X_batch = X_batch.to(device)
            prepared = model._prepare_input(X_batch)
            features = model._forward_features(prepared)
            activations = torch.mean(torch.abs(features), dim=(1, 3))
            channel_stats.append(activations.cpu())

    if not channel_stats:
        return np.zeros(model.n_channels)
    stacked = torch.cat(channel_stats, dim=0)
    return stacked.mean(dim=0).numpy()


def compute_gate_importance(model, dataloader, device):
    """Average adaptive gate values across the entire dataset."""
    model.eval()
    gate_batches = []

    with torch.no_grad():
        for X_batch, _ in dataloader:
            X_batch = X_batch.to(device)
            _ = model(X_batch)
            latest = getattr(model, 'latest_gate_values', None)
            if latest is not None:
                gate_batches.append(latest.cpu())

    if not gate_batches:
        return np.ones(model.n_channels) / model.n_channels
    stacked = torch.cat(gate_batches, dim=0)
    return stacked.mean(dim=0).numpy()


def select_top_k_channels(importance_scores, k):
    top_k_indices = np.argsort(importance_scores)[-k:]
    return sorted(top_k_indices)


def apply_channel_selection(X, selected_channels):
    return X[:, selected_channels, :]


## Load Data

In [None]:
print("Loading PhysioNet data...")
X, y, subject_labels = load_physionet_data(CONFIG['data_path'])
print(f"Data loaded: {X.shape}")

## Channel Selection Experiments

In [None]:
# Models to evaluate
models_to_evaluate = [
    {'name': 'Baseline-EEG-ARNN', 'methods': ['edge', 'aggregation']},
    {'name': 'Adaptive-Gating-EEG-ARNN', 'methods': ['edge', 'aggregation', 'gate']},
]

# Results storage
channel_selection_results = []
retention_results = []

skf = StratifiedKFold(n_splits=CONFIG['n_folds'], shuffle=True, random_state=CONFIG['random_seed'])

In [None]:
# Main evaluation loop
base_kwargs = {
    'n_channels': CONFIG['n_channels'],
    'n_classes': CONFIG['n_classes'],
    'n_timepoints': CONFIG['n_timepoints'],
}

def build_model(model_name, n_channels):
    kwargs = dict(base_kwargs)
    kwargs['n_channels'] = n_channels
    if model_name == 'Baseline-EEG-ARNN':
        return BaselineEEGARNN(hidden_dim=CONFIG['hidden_dim'], **kwargs)
    return AdaptiveGatingEEGARNN(hidden_dim=CONFIG['hidden_dim'],
                                 gate_init=CONFIG['gating']['gate_init'],
                                 **kwargs)

for model_info in models_to_evaluate:
    model_name = model_info['name']
    selection_methods = model_info['methods']

    print(f"
{'='*60}")
    print(f"Evaluating {model_name}")
    print(f"{'='*60}
")

    for method in selection_methods:
        print(f"
{'-'*60}")
        print(f"Channel Selection Method: {method.upper()}")
        print(f"{'-'*60}
")

        for k in CONFIG['k_values']:
            print(f"
Evaluating with k={k} channels...")
            fold_accuracies = []

            for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
                print(f"Fold {fold + 1}/{CONFIG['n_folds']}", end=' ')

                X_train, X_val = X[train_idx], X[val_idx]
                y_train, y_val = y[train_idx], y[val_idx]

                model = build_model(model_name, CONFIG['n_channels'])
                model_path = os.path.join(CONFIG['models_dir'], f"{model_name}_fold{fold+1}.pt")
                state_dict = torch.load(model_path, map_location=CONFIG['device'])
                model.load_state_dict(state_dict)
                model = model.to(CONFIG['device'])
                model.eval()

                if method == 'edge':
                    importance_scores = model.get_channel_importance_edge()
                elif method == 'aggregation':
                    train_dataset = EEGDataset(X_train, y_train)
                    train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'],
                                              shuffle=False, num_workers=0)
                    importance_scores = get_channel_importance_aggregation(model, train_loader,
                                                                            CONFIG['device'])
                else:
                    train_dataset = EEGDataset(X_train, y_train)
                    train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'],
                                              shuffle=False, num_workers=0)
                    importance_scores = compute_gate_importance(model, train_loader, CONFIG['device'])

                selected_channels = select_top_k_channels(importance_scores, k)
                X_train_selected = apply_channel_selection(X_train, selected_channels)
                X_val_selected = apply_channel_selection(X_val, selected_channels)

                new_model = build_model(model_name, k)
                train_dataset = EEGDataset(X_train_selected, y_train)
                val_dataset = EEGDataset(X_val_selected, y_val)

                train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'],
                                          shuffle=True, num_workers=0)
                val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'],
                                        shuffle=False, num_workers=0)

                best_state, val_acc = train_pytorch_model(new_model, train_loader, val_loader,
                                                          CONFIG, f"{model_name}-{method}-k{k}")
                fold_accuracies.append(val_acc)
                print(f"Acc: {val_acc:.4f}")

            mean_acc = np.mean(fold_accuracies)
            std_acc = np.std(fold_accuracies)

            print(f"k={k}: {mean_acc:.4f} +/- {std_acc:.4f}")
            channel_selection_results.append({
                'model': model_name,
                'method': method,
                'k': k,
                'mean_accuracy': mean_acc,
                'std_accuracy': std_acc,
                'fold_accuracies': fold_accuracies
            })

print(f"

{'='*60}")
print("Channel selection evaluation complete!")
print(f"{'='*60}")


## Retention Analysis

Evaluate performance as we gradually reduce the number of channels:

In [None]:
# Retention analysis: Adaptive gating with different channel budgets
retention_k_values = [5, 10, 15, 20, 25, 30, 35]

print(f"
{'='*60}")
print("Retention Analysis: Adaptive-Gating-EEG-ARNN with Gate Selection")
print(f"{'='*60}
")

for k in retention_k_values:
    print(f"
Testing with k={k} channels...")
    fold_accuracies = []

    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]

        model = build_model('Adaptive-Gating-EEG-ARNN', CONFIG['n_channels'])
        model_path = os.path.join(CONFIG['models_dir'], f"Adaptive-Gating-EEG-ARNN_fold{fold+1}.pt")
        state_dict = torch.load(model_path, map_location=CONFIG['device'])
        model.load_state_dict(state_dict)
        model = model.to(CONFIG['device'])
        model.eval()

        train_dataset = EEGDataset(X_train, y_train)
        train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'],
                                  shuffle=False, num_workers=0)
        importance_scores = compute_gate_importance(model, train_loader, CONFIG['device'])
        selected_channels = select_top_k_channels(importance_scores, k)

        X_train_selected = apply_channel_selection(X_train, selected_channels)
        X_val_selected = apply_channel_selection(X_val, selected_channels)

        new_model = build_model('Adaptive-Gating-EEG-ARNN', k)
        train_dataset = EEGDataset(X_train_selected, y_train)
        val_dataset = EEGDataset(X_val_selected, y_val)

        train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'],
                                  shuffle=True, num_workers=0)
        val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'],
                                shuffle=False, num_workers=0)

        best_state, val_acc = train_pytorch_model(new_model, train_loader, val_loader,
                                                  CONFIG, f"Retention-k{k}")
        fold_accuracies.append(val_acc)
        print(f"Fold {fold + 1}/{CONFIG['n_folds']} Acc: {val_acc:.4f}")

    mean_acc = np.mean(fold_accuracies)
    std_acc = np.std(fold_accuracies)

    print(f"k={k}: {mean_acc:.4f} +/- {std_acc:.4f}")
    retention_results.append({
        'k': k,
        'mean_accuracy': mean_acc,
        'std_accuracy': std_acc,
        'fold_accuracies': fold_accuracies
    })


## Save Results

In [None]:
# Save channel selection results
cs_df = pd.DataFrame(channel_selection_results)
cs_df.to_csv(os.path.join(CONFIG['results_dir'], 'channel_selection_results.csv'), index=False)

print("\nChannel Selection Results:")
print(cs_df[['model', 'method', 'k', 'mean_accuracy', 'std_accuracy']])

# Save retention analysis results
retention_df = pd.DataFrame(retention_results)
retention_df.to_csv(os.path.join(CONFIG['results_dir'], 'retention_analysis.csv'), index=False)

print("\nRetention Analysis Results:")
print(retention_df[['k', 'mean_accuracy', 'std_accuracy']])

print(f"\nResults saved to {CONFIG['results_dir']}/")

## Summary

Find the best channel selection method:

In [None]:
# Find best method for each model
print("\nBest Channel Selection Methods:")
print("="*60)

for model_name in ['Baseline-EEG-ARNN', 'Adaptive-Gating-EEG-ARNN']:
    model_results = cs_df[cs_df['model'] == model_name]
    best_result = model_results.loc[model_results['mean_accuracy'].idxmax()]
    
    print(f"\n{model_name}:")
    print(f"  Best Method: {best_result['method'].upper()}")
    print(f"  Best k: {best_result['k']}")
    print(f"  Accuracy: {best_result['mean_accuracy']:.4f} +/- {best_result['std_accuracy']:.4f}")