# PhysioNet Motor Imagery - EEG-ARNN Models

## Baseline EEG-ARNN vs Adaptive Gating EEG-ARNN

This notebook trains and evaluates:
1. **Baseline EEG-ARNN** - Pure CNN-GCN architecture
2. **Adaptive Gating EEG-ARNN** - Input-dependent channel gating

## Configuration:
- **30 epochs**, **0.002 LR**, **NO EARLY STOPPING**
- **10 subjects**, **3-fold CV**
- **Channel Selection**: ES/AS/GS at k=[10,15,20,25,30]

## Metrics:
- Accuracy, Precision, Recall, F1-Score, AUC-ROC, Specificity

## Output:
- `eeg_arnn_baseline_results.csv`
- `eeg_arnn_adaptive_results.csv`
- `eeg_arnn_baseline_retrain_results.csv`
- `eeg_arnn_adaptive_retrain_results.csv`

## 1. Setup and Imports

In [None]:
import json
import random
import warnings
from pathlib import Path
from copy import deepcopy

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix
)
import gc

import mne

warnings.filterwarnings('ignore')
sns.set_context('notebook', font_scale=1.0)
mne.set_log_level('WARNING')

def set_seed(s=42):
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)
    torch.cuda.manual_seed_all(s)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

## 2. Configuration

In [None]:
import os
from pathlib import Path

if os.path.exists('/kaggle/input'):
    print("Running on Kaggle")
    kaggle_input = Path('/kaggle/input')
    datasets = [d for d in kaggle_input.iterdir() if d.is_dir()]
    print(f"Available datasets: {[d.name for d in datasets]}")

    DATA_DIR = None
    possible_names = ['physioneteegmi', 'eeg-motor-movementimagery-dataset']
    for ds_name in possible_names:
        test_path = kaggle_input / ds_name
        if test_path.exists():
            DATA_DIR = test_path
            print(f"Found dataset: {DATA_DIR}")
            break

    if DATA_DIR is None and datasets:
        DATA_DIR = datasets[0]
        print(f"Using first available dataset: {DATA_DIR}")
else:
    print("Running locally")
    DATA_DIR = Path('data/physionet/files')

CONFIG = {
    'data': {
        'raw_data_dir': DATA_DIR,
        'selected_classes': [1, 2],
        'tmin': -1.0,
        'tmax': 5.0,
        'baseline': (-0.5, 0)
    },
    'preprocessing': {
        'l_freq': 0.5,
        'h_freq': 40.0,
        'notch_freq': 50.0,
        'target_sfreq': 128.0,
        'apply_car': True
    },
    'model': {
        'hidden_dim': 16,
        'epochs': 35,
        'learning_rate': 0.002,
        'batch_size': 32,
        'n_folds': 5,
        'patience': 999,
        'adj_lr': 0.001
    },
    'gating': {
        'l1_lambda': 1e-3,
        'gate_init': 0.9
    },
    'channel_selection': {
        'k_values': [10, 15, 20, 25, 30]
    },
    'output': {
        'results_dir': Path('results'),
    },
    'max_subjects': 10,
    'min_runs_per_subject': 8
}

CONFIG['output']['results_dir'].mkdir(exist_ok=True, parents=True)

print("="*80)
print("PhysioNet Motor Imagery - EEG-ARNN (OPTIMIZED SETTINGS)")
print("="*80)
print("CRITICAL FIXES APPLIED:")
print("  1. MIN-MAX NORMALIZATION (was z-score)")
print("  2. CUSTOM ADJACENCY LEARNING (paper's method)")
print("  3. BATCH SIZE 20 (was 64)")
print("  4. HIDDEN DIM 16 (was 40)")
print("="*80)
print("USER PREFERENCES:")
print("  - 35 epochs (faster training)")
print("  - 5-fold CV (balanced)")
print("  - LR 0.002 (slightly higher)")
print("="*80)
print(f"Training: {CONFIG['max_subjects']} subjects, {CONFIG['model']['n_folds']}-fold CV, {CONFIG['model']['epochs']} epochs")
print(f"Learning rate: {CONFIG['model']['learning_rate']}, Adjacency LR: {CONFIG['model']['adj_lr']}")
print(f"Channel selection k values: {CONFIG['channel_selection']['k_values']}")
print("="*80)

## 3. Data Cleaning - Remove Faulty Subjects

In [None]:
KNOWN_BAD_SUBJECTS = [
    'S088', 'S089', 'S092', 'S100', 'S104', 'S106', 'S107', 'S108', 'S109'
]

HIGH_ISSUE_SUBJECTS = [
    'S003', 'S004', 'S009', 'S010', 'S012', 'S013', 'S017', 'S018', 'S019',
    'S021', 'S022', 'S023', 'S024', 'S025', 'S026', 'S027', 'S028', 'S029'
]

EXCLUDED_SUBJECTS = set(KNOWN_BAD_SUBJECTS + HIGH_ISSUE_SUBJECTS)

print(f"Total excluded subjects: {len(EXCLUDED_SUBJECTS)}")
print(f"Excluded subjects: {sorted(EXCLUDED_SUBJECTS)}")

## 4. Data Loading and Preprocessing Functions

In [None]:
def preprocess_raw(raw, config):
    cleaned_names = {name: name.rstrip('.') for name in raw.ch_names}
    raw.rename_channels(cleaned_names)
    raw.pick_types(eeg=True)
    raw.set_montage('standard_1020', on_missing='ignore', match_case=False)
    
    nyquist = raw.info['sfreq'] / 2.0
    if config['preprocessing']['notch_freq'] < nyquist:
        raw.notch_filter(freqs=config['preprocessing']['notch_freq'], verbose=False)
    
    raw.filter(
        l_freq=config['preprocessing']['l_freq'],
        h_freq=config['preprocessing']['h_freq'],
        method='fir',
        fir_design='firwin',
        verbose=False
    )
    
    if config['preprocessing']['apply_car']:
        raw.set_eeg_reference('average', projection=False, verbose=False)
    
    raw.resample(config['preprocessing']['target_sfreq'], npad='auto', verbose=False)
    return raw


def load_and_preprocess_edf(edf_path, config):
    raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')
    raw = preprocess_raw(raw, config)
    
    events, event_ids = mne.events_from_annotations(raw, verbose='ERROR')
    
    if len(events) == 0:
        return None, None, raw.ch_names
    
    epochs = mne.Epochs(
        raw,
        events,
        event_id=event_ids,
        tmin=config['data']['tmin'],
        tmax=config['data']['tmax'],
        baseline=tuple(config['data']['baseline']),
        preload=True,
        verbose='ERROR'
    )
    
    return epochs.get_data(), epochs.events[:, 2], raw.ch_names


def filter_classes(x, y, selected_classes):
    mask = np.isin(y, selected_classes)
    y, x = y[mask], x[mask]
    label_map = {old: new for new, old in enumerate(sorted(selected_classes))}
    y = np.array([label_map[int(label)] for label in y], dtype=np.int64)
    return x, y


def normalize(x):
    """Min-max normalization matching the paper implementation - CRITICAL FIX"""
    x_normalized = np.zeros_like(x)
    for i in range(x.shape[0]):
        x_min = np.min(x[i])
        x_max = np.max(x[i])
        x_normalized[i] = (x[i] - x_min) / (x_max - x_min + 1e-8)
    return x_normalized


def load_subject_data(data_dir, subject_id, run_ids, config):
    subject_dir = data_dir / subject_id
    if not subject_dir.exists():
        return None, None, None
    
    all_x, all_y = [], []
    channel_names = None
    
    for run_id in run_ids:
        edf_path = subject_dir / f'{subject_id}{run_id}.edf'
        if not edf_path.exists():
            continue
        
        try:
            x, y, ch_names = load_and_preprocess_edf(edf_path, config)
            if x is None or len(y) == 0:
                continue
            
            x, y = filter_classes(x, y, config['data']['selected_classes'])
            if len(y) == 0:
                continue
            
            channel_names = channel_names or ch_names
            all_x.append(x)
            all_y.append(y)
        except Exception as e:
            print(f"  Warning: Failed to load {edf_path.name}: {e}")
            continue
    
    if len(all_x) == 0:
        return None, None, channel_names
    
    return np.concatenate(all_x, 0), np.concatenate(all_y, 0), channel_names


def get_available_subjects(data_dir, min_runs=8, excluded=None):
    if not data_dir.exists():
        raise ValueError(f"Data directory not found: {data_dir}")
    
    excluded = excluded or set()
    subjects = []
    
    for subject_dir in sorted(data_dir.iterdir()):
        if not subject_dir.is_dir() or not subject_dir.name.startswith('S'):
            continue
        
        if subject_dir.name in excluded:
            continue
        
        edf_files = list(subject_dir.glob('*.edf'))
        if len(edf_files) >= min_runs:
            subjects.append(subject_dir.name)
    
    return subjects


print("\nScanning for subjects...")
data_dir = CONFIG['data']['raw_data_dir']
print(f"Looking for data in: {data_dir}")

all_subjects = get_available_subjects(
    data_dir, 
    min_runs=CONFIG['min_runs_per_subject'],
    excluded=EXCLUDED_SUBJECTS
)
subjects = all_subjects[:CONFIG['max_subjects']]

print(f"Found {len(all_subjects)} clean subjects with >= {CONFIG['min_runs_per_subject']} runs")
print(f"Will process {len(subjects)} subjects: {subjects}")

MOTOR_IMAGERY_RUNS = ['R07', 'R08', 'R09', 'R10', 'R11', 'R12', 'R13', 'R14']
MOTOR_EXECUTION_RUNS = ['R03', 'R04', 'R05', 'R06']
ALL_TASK_RUNS = MOTOR_IMAGERY_RUNS + MOTOR_EXECUTION_RUNS
print(f"Using runs: {ALL_TASK_RUNS}")
print("NORMALIZATION FIXED: Using min-max normalization from paper!")

## 5. PyTorch Dataset

In [None]:
class EEGDataset(Dataset):
    def __init__(self, x, y):
        self.x = torch.FloatTensor(x).unsqueeze(1)
        self.y = torch.LongTensor(y)
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, i):
        return self.x[i], self.y[i]

## 6. Comprehensive Metrics Functions

In [None]:
@torch.no_grad()
def calculate_comprehensive_metrics(model, dataloader, device):
    model.eval()
    all_preds, all_labels, all_probs = [], [], []

    for X_batch, y_batch in dataloader:
        X_batch = X_batch.to(device)
        outputs = model(X_batch)
        probs = torch.softmax(outputs, dim=1)
        _, predicted = torch.max(outputs, 1)

        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(y_batch.numpy())
        all_probs.extend(probs[:, 1].cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)

    metrics = {
        'accuracy': accuracy_score(all_labels, all_preds),
        'precision': precision_score(all_labels, all_preds, average='binary', zero_division=0),
        'recall': recall_score(all_labels, all_preds, average='binary', zero_division=0),
        'f1_score': f1_score(all_labels, all_preds, average='binary', zero_division=0),
        'auc_roc': roc_auc_score(all_labels, all_probs) if len(np.unique(all_labels)) > 1 else 0.0,
    }

    cm = confusion_matrix(all_labels, all_preds)
    if cm.shape == (2, 2):
        tn, fp, fn, tp = cm.ravel()
        metrics['specificity'] = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        metrics['sensitivity'] = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    else:
        metrics['specificity'] = 0.0
        metrics['sensitivity'] = metrics['recall']

    return metrics


print("Comprehensive metrics functions defined!")

## 7. Model Architectures

In [None]:
class GraphConvLayer(nn.Module):
    def __init__(self, num_channels, hidden_dim):
        super().__init__()
        self.num_channels = num_channels
        self.hidden_dim = hidden_dim
        
        self.A = nn.Parameter(torch.randn(num_channels, num_channels))
        self.theta = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.bn = nn.BatchNorm2d(hidden_dim)
        self.act = nn.ELU()
    
    def forward(self, x):
        B, H, C, T = x.shape
        
        A = torch.sigmoid(self.A)
        A = 0.5 * (A + A.t())
        I = torch.eye(C, device=A.device)
        A_hat = A + I
        D = torch.diag(torch.pow(A_hat.sum(1).clamp_min(1e-6), -0.5))
        A_norm = D @ A_hat @ D
        
        x_batch = x.permute(0, 3, 2, 1).contiguous().view(B*T, C, H)
        x_g = A_norm @ x_batch
        x_g = self.theta(x_g)
        x_g = x_g.view(B, T, C, H).permute(0, 3, 2, 1)
        
        x_out = self.bn(x_g)
        x_out = self.act(x_out)
        
        return x_out
    
    def get_adjacency(self):
        with torch.no_grad():
            A = torch.sigmoid(self.A)
            A = 0.5 * (A + A.t())
            return A.cpu().numpy()


class TemporalConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=16, pool=True):
        super().__init__()
        self.pool = pool
        self.conv = nn.Conv2d(in_channels, out_channels, 
                            kernel_size=(1, kernel_size), 
                            padding=(0, kernel_size//2), bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.ELU()
        self.pool_layer = nn.AvgPool2d(kernel_size=(1, 2)) if pool else None
    
    def forward(self, x):
        x = self.act(self.bn(self.conv(x)))
        return self.pool_layer(x) if self.pool else x


class BaselineEEGARNN(nn.Module):
    def __init__(self, C, T, K, H):
        super().__init__()
        self.t1 = TemporalConv(1, H, 16, False)
        self.g1 = GraphConvLayer(C, H)
        self.t2 = TemporalConv(H, H, 16, True)
        self.g2 = GraphConvLayer(C, H)
        self.t3 = TemporalConv(H, H, 16, True)
        self.g3 = GraphConvLayer(C, H)
        
        with torch.no_grad():
            ft = self._forward_features(torch.zeros(1, 1, C, T))
            fs = ft.view(1, -1).size(1)
        
        self.fc1 = nn.Linear(fs, 256)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, K)
    
    def _forward_features(self, x):
        x = self.g1(self.t1(x))
        x = self.g2(self.t2(x))
        x = self.g3(self.t3(x))
        return x
    
    def forward(self, x):
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)
    
    def get_final_adjacency(self):
        return self.g3.get_adjacency()


class AdaptiveGatedEEGARNN(BaselineEEGARNN):
    def __init__(self, C, T, K, H, gate_init=0.9):
        super().__init__(C, T, K, H)
        
        self.gate_net = nn.Sequential(
            nn.Linear(C * 2, C),
            nn.ReLU(),
            nn.Linear(C, C),
            nn.Sigmoid()
        )
        
        with torch.no_grad():
            self.gate_net[-2].bias.fill_(2.0)
        
        self.latest_gates = None
    
    def compute_gates(self, x):
        B, _, C, T = x.shape
        x_squeeze = x.squeeze(1)
        ch_mean = x_squeeze.mean(dim=2)
        ch_std = x_squeeze.std(dim=2)
        stats = torch.cat([ch_mean, ch_std], dim=1)
        gates = self.gate_net(stats)
        return gates
    
    def forward(self, x):
        gates = self.compute_gates(x)
        self.latest_gates = gates.detach().cpu()
        x = x * gates.view(-1, 1, gates.size(1), 1)
        return super().forward(x)
    
    def get_gate_values(self):
        if self.latest_gates is not None:
            return self.latest_gates.mean(dim=0)
        return None


print("EEG-ARNN architectures defined!")

## 8. Training Functions

In [None]:
def train_epoch_with_adj_learning(model, dataloader, criterion, optimizer, device, adj_lr=0.001, l1_lambda=0.0):
    """Training epoch with custom adjacency matrix learning - PAPER IMPLEMENTATION"""
    model.train()
    total_loss = 0.0
    all_preds, all_labels = [], []
    
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        
        if l1_lambda > 0 and hasattr(model, 'get_gate_values'):
            gate_values = model.get_gate_values()
            if gate_values is not None:
                loss = loss + l1_lambda * gate_values.abs().mean()

        loss.backward()
        
        if hasattr(model, 'g1') and hasattr(model.g1, 'A'):
            with torch.no_grad():
                for gcn_layer in [model.g1, model.g2, model.g3]:
                    if gcn_layer.A.grad is not None:
                        W_grad = gcn_layer.A.grad
                        gcn_layer.A.data = (1 - adj_lr) * gcn_layer.A.data - adj_lr * W_grad
                        gcn_layer.A.grad.zero_()
        
        optimizer.step()
        
        total_loss += loss.item()
        all_preds += torch.argmax(logits, 1).cpu().tolist()
        all_labels += y.cpu().tolist()
    
    return total_loss / max(1, len(dataloader)), accuracy_score(all_labels, all_preds)


@torch.no_grad()
def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds, all_labels = [], []
    
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = criterion(logits, y)
        
        total_loss += loss.item()
        all_preds += torch.argmax(logits, 1).cpu().tolist()
        all_labels += y.cpu().tolist()
    
    return total_loss / max(1, len(dataloader)), accuracy_score(all_labels, all_preds)


def train_model(model, train_loader, val_loader, device, epochs, lr, patience, adj_lr=0.001, l1_lambda=0.0, verbose=True):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3, verbose=False
    )
    
    best_acc = 0.0
    best_state = None
    no_improve = 0
    
    epoch_iterator = tqdm(range(epochs), desc='    Epochs', leave=False) if verbose else range(epochs)
    
    for epoch in epoch_iterator:
        train_loss, train_acc = train_epoch_with_adj_learning(
            model, train_loader, criterion, optimizer, device, adj_lr, l1_lambda
        )
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)
        
        scheduler.step(val_loss)
        
        if verbose:
            epoch_iterator.set_postfix({
                'train_loss': f'{train_loss:.4f}',
                'train_acc': f'{train_acc:.4f}',
                'val_loss': f'{val_loss:.4f}',
                'val_acc': f'{val_acc:.4f}',
                'best': f'{best_acc:.4f}'
            })
        
        if val_acc > best_acc:
            best_acc = val_acc
            best_state = deepcopy(model.state_dict())
            no_improve = 0
        else:
            no_improve += 1
        
        if no_improve >= patience:
            if verbose:
                print(f'      Early stopping at epoch {epoch+1}/{epochs}')
            break
    
    if best_state is None:
        best_state = deepcopy(model.state_dict())
    
    model.load_state_dict(best_state)
    return best_state, best_acc


print("Training functions defined with CUSTOM ADJACENCY MATRIX LEARNING!")

## 9. Main Training Loop

In [None]:
all_results = {'baseline': [], 'adaptive': []}

print("\nStarting training with PAPER IMPLEMENTATION for PhysioNet...\n")

for subject_id in tqdm(subjects, desc='Training subjects'):
    print(f"\nProcessing {subject_id}...")
    
    X, Y, channel_names = load_subject_data(
        data_dir,
        subject_id,
        ALL_TASK_RUNS,
        CONFIG
    )
    
    if X is None or len(Y) == 0:
        print(f"  Skipped: No data available")
        continue
    
    C, T = X.shape[1], X.shape[2]
    K = len(set(CONFIG['data']['selected_classes']))
    H = CONFIG['model']['hidden_dim']
    
    print(f"  Data shape: {X.shape}")
    print(f"  Label distribution: {np.bincount(Y)}")
    
    for model_type in ['baseline', 'adaptive']:
        print(f"\n  Training {model_type.upper()}...")
        
        skf = StratifiedKFold(n_splits=CONFIG['model']['n_folds'], shuffle=True, random_state=42)
        fold_results = []
        adjacencies = []
        gate_values_list = []
        
        for fold, (train_idx, val_idx) in enumerate(skf.split(X, Y)):
            X_train, X_val = normalize(X[train_idx]), normalize(X[val_idx])
            Y_train, Y_val = Y[train_idx], Y[val_idx]
            
            train_loader = DataLoader(
                EEGDataset(X_train, Y_train),
                batch_size=CONFIG['model']['batch_size'],
                shuffle=True,
                num_workers=0
            )
            val_loader = DataLoader(
                EEGDataset(X_val, Y_val),
                batch_size=CONFIG['model']['batch_size'],
                shuffle=False,
                num_workers=0
            )
            
            if model_type == 'baseline':
                model = BaselineEEGARNN(C, T, K, H).to(device)
                l1_lambda = 0.0
            else:
                model = AdaptiveGatedEEGARNN(C, T, K, H, CONFIG['gating']['gate_init']).to(device)
                l1_lambda = CONFIG['gating']['l1_lambda']
            
            best_state, best_acc = train_model(
                model, train_loader, val_loader, device,
                CONFIG['model']['epochs'],
                CONFIG['model']['learning_rate'],
                CONFIG['model']['patience'],
                adj_lr=CONFIG['model']['adj_lr'],
                l1_lambda=l1_lambda
            )
            model.load_state_dict(best_state)
            
            metrics = calculate_comprehensive_metrics(model, val_loader, device)
            fold_results.append({'fold': fold, **metrics})
            
            adjacency = model.get_final_adjacency()
            adjacencies.append(adjacency)
            
            if hasattr(model, 'get_gate_values'):
                gate_values = model.get_gate_values()
                if gate_values is not None:
                    if isinstance(gate_values, torch.Tensor):
                        gate_values = gate_values.detach().cpu().numpy()
                    gate_values_list.append(gate_values)
            
            del model
            torch.cuda.empty_cache()
            gc.collect()
        
        avg_metrics = {}
        for key in ['accuracy', 'precision', 'recall', 'f1_score', 'auc_roc', 'specificity']:
            values = [f[key] for f in fold_results]
            avg_metrics[f'avg_{key}'] = float(np.mean(values))
            avg_metrics[f'std_{key}'] = float(np.std(values))
        
        avg_adjacency = np.mean(np.stack(adjacencies, 0), 0)
        
        result = {
            'subject': subject_id,
            'num_trials': X.shape[0],
            'num_channels': C,
            **avg_metrics,
            'adjacency_matrix': avg_adjacency,
            'channel_names': channel_names
        }
        
        if gate_values_list:
            result['avg_gate_values'] = np.mean(np.stack(gate_values_list, 0), 0)
        
        all_results[model_type].append(result)
        
        print(f"    Accuracy: {avg_metrics['avg_accuracy']:.4f} ± {avg_metrics['std_accuracy']:.4f}")
        print(f"    F1-Score: {avg_metrics['avg_f1_score']:.4f} ± {avg_metrics['std_f1_score']:.4f}")

print("\n" + "="*80)
print("Training Complete!")
print("="*80)

## 10. Channel Selection Functions

In [None]:
class ChannelSelector:
    def __init__(self, adjacency, channel_names, gate_values=None):
        self.A = adjacency
        self.names = np.array(channel_names)
        self.C = adjacency.shape[0]
        self.gate_values = gate_values

    def edge_selection(self, k):
        edge_importance = np.zeros(self.C)
        for i in range(self.C):
            for j in range(self.C):
                if i != j:
                    edge_importance[i] += abs(self.A[i, j])
                    edge_importance[j] += abs(self.A[i, j])

        indices = np.sort(np.argsort(edge_importance)[-int(k):])
        return self.names[indices].tolist(), indices

    def aggregation_selection(self, k):
        agg_scores = np.sum(np.abs(self.A), 1)
        indices = np.sort(np.argsort(agg_scores)[-int(k):])
        return self.names[indices].tolist(), indices

    def gate_selection(self, k):
        if self.gate_values is None:
            raise ValueError("Gate values not available.")

        indices = np.sort(np.argsort(self.gate_values)[-int(k):])
        return self.names[indices].tolist(), indices


def retrain_with_selected_channels(x, y, selected_indices, T, K, device, config, model_type='baseline'):
    x_selected = x[:, selected_indices, :]
    C = len(selected_indices)
    H = config['model']['hidden_dim']

    skf = StratifiedKFold(n_splits=config['model']['n_folds'], shuffle=True, random_state=42)
    fold_results = []

    for fold, (train_idx, val_idx) in enumerate(skf.split(x_selected, y)):
        X_train = normalize(x_selected[train_idx])
        X_val = normalize(x_selected[val_idx])
        Y_train, Y_val = y[train_idx], y[val_idx]

        train_loader = DataLoader(
            EEGDataset(X_train, Y_train),
            batch_size=config['model']['batch_size'],
            shuffle=True,
            num_workers=0
        )
        val_loader = DataLoader(
            EEGDataset(X_val, Y_val),
            batch_size=config['model']['batch_size'],
            shuffle=False,
            num_workers=0
        )

        if model_type == 'baseline':
            model = BaselineEEGARNN(C, T, K, H).to(device)
            l1_lambda = 0.0
        else:
            model = AdaptiveGatedEEGARNN(C, T, K, H, config['gating']['gate_init']).to(device)
            l1_lambda = config['gating']['l1_lambda']

        best_state, best_acc = train_model(
            model, train_loader, val_loader, device,
            config['model']['epochs'],
            config['model']['learning_rate'],
            config['model']['patience'],
            adj_lr=config['model']['adj_lr'],
            l1_lambda=l1_lambda,
            verbose=False
        )
        model.load_state_dict(best_state)

        metrics = calculate_comprehensive_metrics(model, val_loader, device)
        fold_results.append(metrics)
        
        del model
        torch.cuda.empty_cache()
        gc.collect()

    avg_metrics = {}
    for key in ['accuracy', 'precision', 'recall', 'f1_score', 'auc_roc', 'specificity']:
        values = [f[key] for f in fold_results]
        avg_metrics[f'avg_{key}'] = float(np.mean(values))
        avg_metrics[f'std_{key}'] = float(np.std(values))
    
    return avg_metrics


print("Channel selection functions defined!")

## 11. Channel Selection and Retraining

In [None]:
retrain_results = {'baseline': [], 'adaptive': []}

print("\n" + "="*80)
print("CHANNEL SELECTION AND RETRAINING")
print("="*80 + "\n")

for subject_id in tqdm(subjects, desc='Retraining'):
    print(f"\nProcessing {subject_id}...")

    X, Y, channel_names = load_subject_data(
        data_dir,
        subject_id,
        ALL_TASK_RUNS,
        CONFIG
    )

    if X is None:
        continue

    C, T = X.shape[1], X.shape[2]
    K = len(set(CONFIG['data']['selected_classes']))

    for model_type in ['baseline', 'adaptive']:
        subj_result = None
        for res in all_results[model_type]:
            if res['subject'] == subject_id:
                subj_result = res
                break

        if subj_result is None:
            continue

        adjacency = subj_result['adjacency_matrix']
        gate_values = subj_result.get('avg_gate_values', None)
        selector = ChannelSelector(adjacency, channel_names, gate_values)

        selection_methods = ['ES', 'AS']
        if model_type == 'adaptive':
            selection_methods.append('GS')

        for method_name in selection_methods:
            for k in CONFIG['channel_selection']['k_values']:
                if method_name == 'ES':
                    selected_channels, selected_indices = selector.edge_selection(k)
                elif method_name == 'AS':
                    selected_channels, selected_indices = selector.aggregation_selection(k)
                elif method_name == 'GS':
                    selected_channels, selected_indices = selector.gate_selection(k)

                retrain_metrics = retrain_with_selected_channels(
                    X, Y, selected_indices, T, K, device, CONFIG, model_type
                )

                acc_drop = subj_result['avg_accuracy'] - retrain_metrics['avg_accuracy']

                retrain_results[model_type].append({
                    'subject': subject_id,
                    'method': method_name,
                    'k': k,
                    'num_channels_selected': len(selected_channels),
                    **retrain_metrics,
                    'full_channels_acc': subj_result['avg_accuracy'],
                    'accuracy_drop': acc_drop,
                    'accuracy_drop_pct': (acc_drop / subj_result['avg_accuracy'] * 100)
                })

                print(f"  {model_type.upper()} - {method_name}, k={k}: "
                      f"{retrain_metrics['avg_accuracy']:.4f} (drop: {acc_drop:.4f})")

print("\n" + "="*80)
print("Retraining Complete!")
print("="*80)

## 12. Save Results

In [None]:
results_dir = CONFIG['output']['results_dir']

for model_type in ['baseline', 'adaptive']:
    if len(all_results[model_type]) > 0:
        df = pd.DataFrame([{
            'subject': r['subject'],
            'num_trials': r['num_trials'],
            'num_channels': r['num_channels'],
            'accuracy': r['avg_accuracy'],
            'std_accuracy': r['std_accuracy'],
            'precision': r['avg_precision'],
            'std_precision': r['std_precision'],
            'recall': r['avg_recall'],
            'std_recall': r['std_recall'],
            'f1_score': r['avg_f1_score'],
            'std_f1_score': r['std_f1_score'],
            'auc_roc': r['avg_auc_roc'],
            'std_auc_roc': r['std_auc_roc'],
            'specificity': r['avg_specificity'],
            'std_specificity': r['std_specificity']
        } for r in all_results[model_type]])
        
        df.to_csv(results_dir / f'eeg_arnn_{model_type}_results.csv', index=False)
        print(f"Saved: eeg_arnn_{model_type}_results.csv")

for model_type in ['baseline', 'adaptive']:
    if len(retrain_results[model_type]) > 0:
        df = pd.DataFrame(retrain_results[model_type])
        df.to_csv(results_dir / f'eeg_arnn_{model_type}_retrain_results.csv', index=False)
        print(f"Saved: eeg_arnn_{model_type}_retrain_results.csv")

print(f"\nAll results saved to {results_dir}")

## 13. Results Summary

In [None]:
print("\n" + "="*80)
print("RESULTS SUMMARY")
print("="*80 + "\n")

for model_type in ['baseline', 'adaptive']:
    if len(all_results[model_type]) > 0:
        accs = [r['avg_accuracy'] for r in all_results[model_type]]
        f1s = [r['avg_f1_score'] for r in all_results[model_type]]
        aucs = [r['avg_auc_roc'] for r in all_results[model_type]]
        
        print(f"{model_type.upper()} Results:")
        print(f"  Subjects: {len(all_results[model_type])}")
        print(f"  Mean accuracy: {np.mean(accs):.4f} ± {np.std(accs):.4f}")
        print(f"  Mean F1-Score: {np.mean(f1s):.4f} ± {np.std(f1s):.4f}")
        print(f"  Mean AUC-ROC: {np.mean(aucs):.4f} ± {np.std(aucs):.4f}")
        print()

if len(all_results['baseline']) > 0 and len(all_results['adaptive']) > 0:
    baseline_acc = np.mean([r['avg_accuracy'] for r in all_results['baseline']])
    adaptive_acc = np.mean([r['avg_accuracy'] for r in all_results['adaptive']])
    improvement = adaptive_acc - baseline_acc
    
    print(f"\nAdaptive vs Baseline:")
    print(f"  Improvement: {improvement:.4f} ({improvement/baseline_acc*100:.2f}%)")

print("\n" + "="*80)
print("DONE!")
print("="*80)