# PhysioNet EEG: Channel Selection Evaluation

This notebook evaluates three channel selection methods:
1. **Edge Selection (ES)**: Based on sum of outgoing edge weights from CARM adjacency matrix
2. **Aggregation Selection (AS)**: Based on feature aggregation after CARM layers
3. **Gate Selection (GS)**: Based on adaptive gate values (only for Adaptive-Gating-EEG-ARNN)

**Expected Runtime**: 8-10 hours on Kaggle GPU

**Input**: 
- `/kaggle/input/physionet-preprocessed/derived/` (preprocessed EEG data)
- `models/` folder (trained models from notebook 01)

**Output**: 
- `channel_selection_results.csv` - Results for all methods and k values
- `retention_analysis.csv` - Performance vs number of channels retained

## Configuration and Imports

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import mne
from sklearn.model_selection import StratifiedKFold
import pickle
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

mne.set_log_level('ERROR')

In [None]:
# Configuration
CONFIG = {
    'data_path': '/kaggle/input/eeg-preprocessed-data/derived',
    'models_dir': './models',
    'results_dir': './results',
    
    'n_folds': 3,
    'random_seed': 42,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    
    'batch_size': 64,
    'epochs': 20,
    'learning_rate': 0.001,
    'weight_decay': 1e-4,
    'patience': 10,
    
    'n_channels': 64,
    'n_classes': 2,
    'sfreq': 128,
    'n_timepoints': 513,
    
    # Channel selection k values
    'k_values': [5, 10, 15, 20, 25, 30, 35],
}

np.random.seed(CONFIG['random_seed'])
torch.manual_seed(CONFIG['random_seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed(CONFIG['random_seed'])

print(f"Device: {CONFIG['device']}")

## Copy Model Architectures and Data Loading from Notebook 01

In [None]:
# Data loading utilities (same as notebook 01)
def load_physionet_data(data_path, subject_ids=None):
    all_X = []
    all_y = []
    all_subjects = []
    
    if subject_ids is None:
        files = [f for f in os.listdir(data_path) if f.endswith('.fif')]
        subject_ids = sorted(list(set([int(f.split('_')[0][1:]) for f in files])))
    
    print(f"Loading data from {len(subject_ids)} subjects...")
    
    for subject_id in tqdm(subject_ids):
        subject_runs = []
        for run_id in [3, 4, 7, 8, 11, 12]:
            filename = f"S{subject_id:03d}_R{run_id:02d}.fif"
            filepath = os.path.join(data_path, filename)
            
            if not os.path.exists(filepath):
                continue
            
            try:
                epochs = mne.read_epochs(filepath, preload=True, verbose=False)
                subject_runs.append(epochs)
            except Exception as e:
                continue
        
        if len(subject_runs) == 0:
            continue
        
        epochs = mne.concatenate_epochs(subject_runs)
        X = epochs.get_data()
        
        event_ids = epochs.event_id
        valid_event_ids = {'T1': 1, 'T2': 2}
        
        event_name_to_label = {}
        if 'T1' in event_ids:
            event_name_to_label['T1'] = 0
        if 'T2' in event_ids:
            event_name_to_label['T2'] = 1
        
        event_code_to_label = {}
        for name, label in event_name_to_label.items():
            if name in valid_event_ids:
                mne_code = valid_event_ids[name]
                event_code_to_label[mne_code] = label
        
        y = np.array([event_code_to_label.get(epochs.events[i, -1], -1) 
                     for i in range(len(epochs))])
        
        valid_mask = y != -1
        X = X[valid_mask]
        y = y[valid_mask]
        
        if len(X) == 0:
            continue
        
        all_X.append(X)
        all_y.append(y)
        all_subjects.append(np.full(len(y), subject_id))
    
    X = np.concatenate(all_X, axis=0)
    y = np.concatenate(all_y, axis=0)
    subject_labels = np.concatenate(all_subjects, axis=0)
    
    print(f"Loaded {len(X)} trials from {len(np.unique(subject_labels))} subjects")
    
    return X, y, subject_labels

class EEGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.FloatTensor(X)
        self.y = torch.LongTensor(y)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [None]:
# Model architectures (same as notebook 01 - only need the ones with channel selection)
class CARMBlock(nn.Module):
    def __init__(self, n_channels):
        super().__init__()
        self.n_channels = n_channels
        self.A = nn.Parameter(torch.randn(n_channels, n_channels) * 0.01)
        self.norm = nn.LayerNorm(n_channels)
    
    def forward(self, x):
        batch_size, n_channels, n_time = x.shape
        A_norm = torch.softmax(self.A, dim=1)
        x_reshaped = x.permute(0, 2, 1)
        x_graph = torch.matmul(x_reshaped, A_norm.t())
        x_graph = x_graph.permute(0, 2, 1)
        return x_graph
    
    def get_adjacency_matrix(self):
        return torch.softmax(self.A, dim=1).detach()

class TFEMBlock(nn.Module):
    def __init__(self, n_channels, hidden_dim=128):
        super().__init__()
        self.temporal_conv = nn.Conv1d(n_channels, hidden_dim, kernel_size=5, padding=2)
        self.temporal_bn = nn.BatchNorm1d(hidden_dim)
        self.freq_pool = nn.AdaptiveAvgPool1d(64)
    
    def forward(self, x):
        x = torch.relu(self.temporal_bn(self.temporal_conv(x)))
        x = self.freq_pool(x)
        return x

class BaselineEEGARNN(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513, hidden_dim=128):
        super().__init__()
        self.tfem = TFEMBlock(n_channels, hidden_dim)
        self.carm = CARMBlock(hidden_dim)
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.fc1 = nn.Linear(hidden_dim * 2, 128)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, n_classes)
    
    def forward(self, x):
        x = self.tfem(x)
        x = self.carm(x)
        x = x.permute(0, 2, 1)
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x
    
    def get_channel_importance_edge(self):
        A = self.carm.get_adjacency_matrix()
        return torch.sum(A, dim=1).cpu().numpy()

class AdaptiveGatingEEGARNN(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513, hidden_dim=128, gate_init=0.9):
        super().__init__()
        self.n_channels = n_channels
        
        self.gate_net = nn.Sequential(
            nn.Linear(n_timepoints, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
        for layer in self.gate_net:
            if isinstance(layer, nn.Linear):
                nn.init.constant_(layer.bias, gate_init)
        
        self.tfem = TFEMBlock(n_channels, hidden_dim)
        self.carm = CARMBlock(hidden_dim)
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.fc1 = nn.Linear(hidden_dim * 2, 128)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, n_classes)
        
        self.gate_values = None
    
    def forward(self, x):
        batch_size = x.size(0)
        gates = []
        for i in range(self.n_channels):
            g = self.gate_net(x[:, i, :])
            gates.append(g)
        gates = torch.cat(gates, dim=1)
        self.gate_values = gates.detach()
        
        x = x * gates.unsqueeze(2)
        x = self.tfem(x)
        x = self.carm(x)
        x = x.permute(0, 2, 1)
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x
    
    def get_channel_importance_gate(self):
        if self.gate_values is None:
            return None
        return torch.mean(self.gate_values, dim=0).cpu().numpy()
    
    def get_channel_importance_edge(self):
        A = self.carm.get_adjacency_matrix()
        return torch.sum(A, dim=1).cpu().numpy()

In [None]:
# Training utilities
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for X_batch, y_batch in dataloader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += y_batch.size(0)
        correct += (predicted == y_batch).sum().item()
    
    return total_loss / len(dataloader), correct / total

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for X_batch, y_batch in dataloader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += y_batch.size(0)
            correct += (predicted == y_batch).sum().item()
    
    return total_loss / len(dataloader), correct / total

def train_pytorch_model(model, train_loader, val_loader, config):
    device = config['device']
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], 
                          weight_decay=config['weight_decay'])
    
    best_val_acc = 0
    patience_counter = 0
    
    for epoch in range(config['epochs']):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
        else:
            patience_counter += 1
        
        if patience_counter >= config['patience']:
            break
    
    return model, best_val_acc

## Channel Selection Functions

In [None]:
def get_channel_importance_aggregation(model, dataloader, device):
    """
    Aggregation Selection (AS): Compute channel importance based on
    feature aggregation after CARM layers.
    """
    model.eval()
    all_activations = []
    
    with torch.no_grad():
        for X_batch, _ in dataloader:
            X_batch = X_batch.to(device)
            
            # Forward pass through TFEM and CARM
            x = model.tfem(X_batch)
            x = model.carm(x)
            
            # Aggregate over time dimension
            activations = torch.mean(torch.abs(x), dim=2)  # (batch, channels)
            all_activations.append(activations.cpu())
    
    # Average across all batches
    all_activations = torch.cat(all_activations, dim=0)
    channel_importance = torch.mean(all_activations, dim=0).numpy()
    
    return channel_importance

def select_top_k_channels(importance_scores, k):
    """Select top k channels based on importance scores."""
    top_k_indices = np.argsort(importance_scores)[-k:]
    return sorted(top_k_indices)

def apply_channel_selection(X, selected_channels):
    """Apply channel selection to data."""
    return X[:, selected_channels, :]

## Load Data

In [None]:
print("Loading PhysioNet data...")
X, y, subject_labels = load_physionet_data(CONFIG['data_path'])
print(f"Data loaded: {X.shape}")

## Channel Selection Experiments

In [None]:
# Models to evaluate
models_to_evaluate = [
    {'name': 'Baseline-EEG-ARNN', 'methods': ['edge', 'aggregation']},
    {'name': 'Adaptive-Gating-EEG-ARNN', 'methods': ['edge', 'aggregation', 'gate']},
]

# Results storage
channel_selection_results = []
retention_results = []

skf = StratifiedKFold(n_splits=CONFIG['n_folds'], shuffle=True, random_state=CONFIG['random_seed'])

In [None]:
# Main evaluation loop
for model_info in models_to_evaluate:
    model_name = model_info['name']
    selection_methods = model_info['methods']
    
    print(f"\n{'='*60}")
    print(f"Evaluating {model_name}")
    print(f"{'='*60}\n")
    
    for method in selection_methods:
        print(f"\n{'-'*60}")
        print(f"Channel Selection Method: {method.upper()}")
        print(f"{'-'*60}\n")
        
        for k in CONFIG['k_values']:
            print(f"\nEvaluating with k={k} channels...")
            
            fold_accuracies = []
            
            for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
                print(f"Fold {fold + 1}/{CONFIG['n_folds']}", end=' ')
                
                X_train, X_val = X[train_idx], X[val_idx]
                y_train, y_val = y[train_idx], y[val_idx]
                
                # Load pre-trained model from notebook 01
                if model_name == 'Baseline-EEG-ARNN':
                    model = BaselineEEGARNN(n_channels=CONFIG['n_channels'],
                                           n_classes=CONFIG['n_classes'],
                                           n_timepoints=CONFIG['n_timepoints'])
                else:
                    model = AdaptiveGatingEEGARNN(n_channels=CONFIG['n_channels'],
                                                 n_classes=CONFIG['n_classes'],
                                                 n_timepoints=CONFIG['n_timepoints'])
                
                model_path = os.path.join(CONFIG['models_dir'], f"{model_name}_fold{fold+1}.pt")
                model.load_state_dict(torch.load(model_path))
                model = model.to(CONFIG['device'])
                model.eval()
                
                # Get channel importance scores
                if method == 'edge':
                    importance_scores = model.get_channel_importance_edge()
                elif method == 'aggregation':
                    train_dataset = EEGDataset(X_train, y_train)
                    train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'],
                                            shuffle=False, num_workers=0)
                    importance_scores = get_channel_importance_aggregation(model, train_loader,
                                                                          CONFIG['device'])
                elif method == 'gate':
                    # Run a forward pass to populate gate values
                    train_dataset = EEGDataset(X_train, y_train)
                    train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'],
                                            shuffle=False, num_workers=0)
                    with torch.no_grad():
                        for X_batch, _ in train_loader:
                            X_batch = X_batch.to(CONFIG['device'])
                            _ = model(X_batch)
                            break
                    importance_scores = model.get_channel_importance_gate()
                
                # Select top k channels
                selected_channels = select_top_k_channels(importance_scores, k)
                
                # Apply channel selection
                X_train_selected = apply_channel_selection(X_train, selected_channels)
                X_val_selected = apply_channel_selection(X_val, selected_channels)
                
                # Retrain model with selected channels
                if model_name == 'Baseline-EEG-ARNN':
                    new_model = BaselineEEGARNN(n_channels=k,
                                               n_classes=CONFIG['n_classes'],
                                               n_timepoints=CONFIG['n_timepoints'])
                else:
                    new_model = AdaptiveGatingEEGARNN(n_channels=k,
                                                     n_classes=CONFIG['n_classes'],
                                                     n_timepoints=CONFIG['n_timepoints'])
                
                train_dataset = EEGDataset(X_train_selected, y_train)
                val_dataset = EEGDataset(X_val_selected, y_val)
                
                train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'],
                                        shuffle=True, num_workers=0)
                val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'],
                                       shuffle=False, num_workers=0)
                
                new_model, val_acc = train_pytorch_model(new_model, train_loader, val_loader, CONFIG)
                
                fold_accuracies.append(val_acc)
                print(f"Acc: {val_acc:.4f}")
            
            # Compute statistics
            mean_acc = np.mean(fold_accuracies)
            std_acc = np.std(fold_accuracies)
            
            print(f"k={k}: {mean_acc:.4f} +/- {std_acc:.4f}")
            
            # Store results
            channel_selection_results.append({
                'model': model_name,
                'method': method,
                'k': k,
                'mean_accuracy': mean_acc,
                'std_accuracy': std_acc,
                'fold_accuracies': fold_accuracies
            })

print(f"\n\n{'='*60}")
print("Channel selection evaluation complete!")
print(f"{'='*60}")

## Retention Analysis

Evaluate performance as we gradually reduce the number of channels:

In [None]:
# Retention analysis: Test with different numbers of channels
retention_k_values = [5, 10, 15, 20, 25, 30, 35]

print(f"\n{'='*60}")
print("Retention Analysis: Adaptive-Gating-EEG-ARNN with Gate Selection")
print(f"{'='*60}\n")

for k in retention_k_values:
    print(f"\nTesting with k={k} channels...")
    
    fold_accuracies = []
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]
        
        # Load pre-trained model
        model = AdaptiveGatingEEGARNN(n_channels=CONFIG['n_channels'],
                                     n_classes=CONFIG['n_classes'],
                                     n_timepoints=CONFIG['n_timepoints'])
        model_path = os.path.join(CONFIG['models_dir'], f"Adaptive-Gating-EEG-ARNN_fold{fold+1}.pt")
        model.load_state_dict(torch.load(model_path))
        model = model.to(CONFIG['device'])
        model.eval()
        
        # Get gate-based importance
        train_dataset = EEGDataset(X_train, y_train)
        train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'],
                                shuffle=False, num_workers=0)
        with torch.no_grad():
            for X_batch, _ in train_loader:
                X_batch = X_batch.to(CONFIG['device'])
                _ = model(X_batch)
                break
        importance_scores = model.get_channel_importance_gate()
        
        # Select top k channels
        selected_channels = select_top_k_channels(importance_scores, k)
        
        # Apply selection and retrain
        X_train_selected = apply_channel_selection(X_train, selected_channels)
        X_val_selected = apply_channel_selection(X_val, selected_channels)
        
        new_model = AdaptiveGatingEEGARNN(n_channels=k,
                                         n_classes=CONFIG['n_classes'],
                                         n_timepoints=CONFIG['n_timepoints'])
        
        train_dataset = EEGDataset(X_train_selected, y_train)
        val_dataset = EEGDataset(X_val_selected, y_val)
        
        train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'],
                                shuffle=True, num_workers=0)
        val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'],
                               shuffle=False, num_workers=0)
        
        new_model, val_acc = train_pytorch_model(new_model, train_loader, val_loader, CONFIG)
        fold_accuracies.append(val_acc)
    
    mean_acc = np.mean(fold_accuracies)
    std_acc = np.std(fold_accuracies)
    
    print(f"k={k}: {mean_acc:.4f} +/- {std_acc:.4f}")
    
    retention_results.append({
        'k': k,
        'mean_accuracy': mean_acc,
        'std_accuracy': std_acc,
        'fold_accuracies': fold_accuracies
    })

## Save Results

In [None]:
# Save channel selection results
cs_df = pd.DataFrame(channel_selection_results)
cs_df.to_csv(os.path.join(CONFIG['results_dir'], 'channel_selection_results.csv'), index=False)

print("\nChannel Selection Results:")
print(cs_df[['model', 'method', 'k', 'mean_accuracy', 'std_accuracy']])

# Save retention analysis results
retention_df = pd.DataFrame(retention_results)
retention_df.to_csv(os.path.join(CONFIG['results_dir'], 'retention_analysis.csv'), index=False)

print("\nRetention Analysis Results:")
print(retention_df[['k', 'mean_accuracy', 'std_accuracy']])

print(f"\nResults saved to {CONFIG['results_dir']}/")

## Summary

Find the best channel selection method:

In [None]:
# Find best method for each model
print("\nBest Channel Selection Methods:")
print("="*60)

for model_name in ['Baseline-EEG-ARNN', 'Adaptive-Gating-EEG-ARNN']:
    model_results = cs_df[cs_df['model'] == model_name]
    best_result = model_results.loc[model_results['mean_accuracy'].idxmax()]
    
    print(f"\n{model_name}:")
    print(f"  Best Method: {best_result['method'].upper()}")
    print(f"  Best k: {best_result['k']}")
    print(f"  Accuracy: {best_result['mean_accuracy']:.4f} +/- {best_result['std_accuracy']:.4f}")