# PhysioNet CARMv2 - Complete Standalone Pipeline

In [None]:
# Imports and Setup
import json, random, warnings
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
from tqdm.auto import tqdm
import mne
import matplotlib.pyplot as plt
import seaborn as sns

warnings.filterwarnings('ignore')
sns.set_context('notebook', font_scale=1.0)
mne.set_log_level('WARNING')

def seed(s=42):
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)
    torch.cuda.manual_seed_all(s)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

In [None]:
# Configuration - AUTO-DETECT KAGGLE DATASET
import os

# Auto-detect environment and find dataset
if os.path.exists('/kaggle/input'):
    print("Running on Kaggle")
    print("\nAvailable datasets:")
    
    kaggle_input = Path('/kaggle/input')
    datasets = [d for d in kaggle_input.iterdir() if d.is_dir()]
    
    for ds in datasets:
        print(f"  - {ds.name}")
    
    # Try to find PhysioNet dataset
    DATA_DIR = None
    possible_names = ['physioneteegmi']
    
    for ds_name in possible_names:
        test_path = kaggle_input / ds_name
        if test_path.exists():
            DATA_DIR = test_path
            print(f"\n✓ Found dataset: {DATA_DIR}")
            break
    
    # If not found by name, use first available dataset
    if DATA_DIR is None and datasets:
        DATA_DIR = datasets[0]
        print(f"\n⚠ Using first available dataset: {DATA_DIR}")
    
    if DATA_DIR is None:
        raise ValueError(
            "No dataset found! Please add the PhysioNet dataset:\n"
        )

print(f"\nData directory: {DATA_DIR}")
print(f"Directory exists: {DATA_DIR.exists()}")

if DATA_DIR.exists():
    # Show structure
    contents = [d.name for d in sorted(DATA_DIR.iterdir()) if d.is_dir()][:10]
    print(f"\nContents (first 10): {contents}")
    
    # Look for subject folders
    subjects_preview = [d.name for d in sorted(DATA_DIR.iterdir()) if d.is_dir() and d.name.startswith('S')][:5]
    if subjects_preview:
        print(f"Subject folders found: {subjects_preview}")
        
        # Check first subject's runs
        first_subj = DATA_DIR / subjects_preview[0]
        runs = sorted([f.name for f in first_subj.glob('*.edf')])
        print(f"Runs in {subjects_preview[0]}: {len(runs)} files")
        if runs:
            print(f"Sample runs: {runs[:5]}")
    else:
        print("\n⚠ No subject folders starting with 'S' found!")
        print("Please check if the dataset structure is correct.")

EXPERIMENT_CONFIG = {
    'data': {
        'raw_data_dir': DATA_DIR,
        'selected_classes': [1, 2],
        'tmin': -1.0,
        'tmax': 5.0,
        'baseline': (-0.5, 0)
    },
    'preprocessing': {
        'l_freq': 0.5,
        'h_freq': 40.0,
        'notch_freq': 50.0,
        'target_sfreq': 128.0,
        'apply_car': True
    },
    'model': {
        'hidden_dim': 40,
        'epochs': 30,
        'learning_rate': 1e-3,
        'batch_size': 32,
        'n_folds': 3,
        'patience': 8
    },
    'carmv2': {
        'topk_k': 8,
        'lambda_feat': 0.3,
        'hop_alpha': 0.5,
        'edge_dropout': 0.1,
        'use_pairnorm': True,
        'use_residual': True,
        'low_rank_r': 0
    },
    'output': {
        'results_dir': Path('results'),
        'results_file': 'carmv2_subject_results.csv',
        'channel_selection_file': 'carmv2_channel_selection.csv',
        'adjacency_prefix': 'carmv2_adjacency'
    },
    'max_subjects': 20,
    'min_runs_per_subject': 8
}

EXPERIMENT_CONFIG['output']['results_dir'].mkdir(exist_ok=True, parents=True)
print(f"\n✓ Configuration loaded successfully!")
print(f"✓ Training: {EXPERIMENT_CONFIG['max_subjects']} subjects, {EXPERIMENT_CONFIG['model']['n_folds']}-fold CV, {EXPERIMENT_CONFIG['model']['epochs']} epochs")

In [None]:
# Data Loading and Preprocessing Functions

def preprocess_raw(raw, config):
    """Apply preprocessing to raw EEG data."""
    cleaned_names = {name: name.rstrip('.') for name in raw.ch_names}
    raw.rename_channels(cleaned_names)
    raw.pick_types(eeg=True)
    raw.set_montage('standard_1020', on_missing='ignore', match_case=False)
    
    nyquist = raw.info['sfreq'] / 2.0
    if config['preprocessing']['notch_freq'] < nyquist:
        raw.notch_filter(freqs=config['preprocessing']['notch_freq'], verbose=False)
    
    raw.filter(
        l_freq=config['preprocessing']['l_freq'],
        h_freq=config['preprocessing']['h_freq'],
        method='fir',
        fir_design='firwin',
        verbose=False
    )
    
    if config['preprocessing']['apply_car']:
        raw.set_eeg_reference('average', projection=False, verbose=False)
    
    raw.resample(config['preprocessing']['target_sfreq'], npad='auto', verbose=False)
    return raw


def load_and_preprocess_edf(edf_path, config):
    """Load raw EDF file, preprocess it, and extract epochs."""
    raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')
    raw = preprocess_raw(raw, config)
    
    try:
        events = mne.find_events(raw, verbose='ERROR')
        event_ids = {f'T{i}': i for i in np.unique(events[:, 2])}
        assert len(events) > 0
    except Exception:
        events, event_ids = mne.events_from_annotations(raw, verbose='ERROR')
    
    if len(events) == 0:
        return None, None, raw.ch_names
    
    epochs = mne.Epochs(
        raw,
        events,
        event_id=event_ids,
        tmin=config['data']['tmin'],
        tmax=config['data']['tmax'],
        baseline=tuple(config['data']['baseline']),
        preload=True,
        verbose='ERROR'
    )
    
    return epochs.get_data(), epochs.events[:, 2], raw.ch_names


def filter_classes(x, y, selected_classes):
    """Filter to keep only selected classes and remap labels."""
    mask = np.isin(y, selected_classes)
    y, x = y[mask], x[mask]
    label_map = {old: new for new, old in enumerate(sorted(selected_classes))}
    y = np.array([label_map[int(label)] for label in y], dtype=np.int64)
    return x, y


def normalize(x):
    """Z-score normalization per channel."""
    mu = x.mean(axis=(0, 2), keepdims=True)
    sd = x.std(axis=(0, 2), keepdims=True) + 1e-8
    return (x - mu) / sd


def load_subject_data(data_dir, subject_id, run_ids, config):
    """Load all runs for a subject, preprocess, and concatenate."""
    subject_dir = data_dir / subject_id
    if not subject_dir.exists():
        return None, None, None
    
    all_x, all_y = [], []
    channel_names = None
    
    for run_id in run_ids:
        edf_path = subject_dir / f'{subject_id}{run_id}.edf'
        if not edf_path.exists():
            continue
        
        try:
            x, y, ch_names = load_and_preprocess_edf(edf_path, config)
            if x is None or len(y) == 0:
                continue
            
            x, y = filter_classes(x, y, config['data']['selected_classes'])
            if len(y) == 0:
                continue
            
            channel_names = channel_names or ch_names
            all_x.append(x)
            all_y.append(y)
        except Exception as e:
            print(f"  Warning: Failed to load {edf_path.name}: {e}")
            continue
    
    if len(all_x) == 0:
        return None, None, channel_names
    
    return np.concatenate(all_x, 0), np.concatenate(all_y, 0), channel_names


def get_available_subjects(data_dir, min_runs=8):
    """Get list of subjects with at least min_runs available."""
    if not data_dir.exists():
        raise ValueError(f"Data directory not found: {data_dir}")
    
    subjects = []
    for subject_dir in sorted(data_dir.iterdir()):
        if not subject_dir.is_dir() or not subject_dir.name.startswith('S'):
            continue
        edf_files = list(subject_dir.glob('*.edf'))
        if len(edf_files) >= min_runs:
            subjects.append(subject_dir.name)
    
    return subjects


# Scan for available subjects
print("\nScanning for subjects...")
data_dir = EXPERIMENT_CONFIG['data']['raw_data_dir']
print(f"Looking for data in: {data_dir}")

all_subjects = get_available_subjects(data_dir, min_runs=EXPERIMENT_CONFIG['min_runs_per_subject'])
subjects = all_subjects[:EXPERIMENT_CONFIG['max_subjects']]

print(f"Found {len(all_subjects)} subjects with >= {EXPERIMENT_CONFIG['min_runs_per_subject']} runs")
print(f"Will process {len(subjects)} subjects: {subjects}")

# Define which runs to use
MOTOR_IMAGERY_RUNS = ['R07', 'R08', 'R09', 'R10', 'R11', 'R12', 'R13', 'R14']
MOTOR_EXECUTION_RUNS = ['R03', 'R04', 'R05', 'R06']
ALL_TASK_RUNS = MOTOR_IMAGERY_RUNS + MOTOR_EXECUTION_RUNS
print(f"Using runs: {ALL_TASK_RUNS}")

print("\nData loading functions ready!")

In [None]:
# PyTorch Dataset
class EEGDataset(Dataset):
    def __init__(self, x, y):
        self.x = torch.FloatTensor(x).unsqueeze(1)  # Add channel dim for Conv2d
        self.y = torch.LongTensor(y)
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, i):
        return self.x[i], self.y[i]

In [None]:
# CARMv2 Model Architecture

def pairnorm(x, node_dim=2, eps=1e-6):
    m = x.mean(dim=node_dim, keepdim=True)
    xc = x - m
    v = (xc * xc).mean(dim=node_dim, keepdim=True)
    return xc / torch.sqrt(v + eps)


def build_feat_topk_adj(x, k):
    B, H, C, T = x.shape
    E = x.permute(2, 1, 0, 3).contiguous().view(C, H, B*T).mean(2)
    En = F.normalize(E, p=2, dim=1)
    S = (En @ En.t()).clamp_min(0.0)
    k = max(1, min(int(k), C))
    vals, idx = torch.topk(S, k, dim=1)
    M = torch.zeros_like(S)
    M.scatter_(1, idx, 1.0)
    A = S * M
    A = torch.softmax(A, 1)
    A = 0.5 * (A + A.t())
    return A


class CARMv2(nn.Module):
    def __init__(self, C, H, cfg):
        super().__init__()
        self.C = C
        self.H = H
        self.k = int(cfg['topk_k'])
        self.lf = float(cfg['lambda_feat'])
        self.ha = float(cfg['hop_alpha'])
        self.ed = float(cfg['edge_dropout'])
        self.pn = bool(cfg['use_pairnorm'])
        self.res = bool(cfg['use_residual'])
        r = int(cfg['low_rank_r'])
        
        if r > 0:
            self.B = nn.Parameter(torch.empty(C, r))
            nn.init.xavier_uniform_(self.B)
            self.W = None
        else:
            self.W = nn.Parameter(torch.empty(C, C))
            nn.init.xavier_uniform_(self.W)
            self.B = None
        
        self.th = nn.Linear(H, H, bias=False)
        self.bn = nn.BatchNorm2d(H)
        self.act = nn.ELU()
        self.last = None
    
    def _learned(self, dev):
        W = self.W if self.B is None else (self.B @ self.B.t())
        A = torch.sigmoid(W)
        A = 0.5 * (A + A.t())
        I = torch.eye(self.C, device=dev, dtype=A.dtype)
        At = A + I
        d = torch.pow(At.sum(1).clamp_min(1e-6), -0.5)
        D = torch.diag(d)
        return D @ At @ D
    
    def forward(self, x):
        B, H, C, T = x.shape
        Al = self._learned(x.device)
        A2 = Al @ Al
        Ah = (1 - self.ha) * Al + self.ha * A2
        Af = build_feat_topk_adj(x, self.k)
        A = (1 - self.lf) * Ah + self.lf * Af
        
        if self.training and self.ed > 0:
            M = (torch.rand_like(A) > self.ed).float()
            A = 0.5 * ((A * M) + (A * M).t())
            A = A + torch.eye(C, device=A.device, dtype=A.dtype)
        
        d = torch.pow(A.sum(1).clamp_min(1e-6), -0.5)
        D = torch.diag(d)
        A = D @ A @ D
        
        xb = x.permute(0, 3, 2, 1).contiguous().view(B*T, C, H)
        xg = A @ xb
        xg = self.th(xg)
        xg = xg.view(B, T, C, H).permute(0, 3, 2, 1)
        
        out = xg + x if self.res else xg
        out = pairnorm(out, 2) if self.pn else out
        out = self.bn(out)
        out = self.act(out)
        
        self.last = {
            'learned': Al.detach().cpu().numpy(),
            'effective': A.detach().cpu().numpy()
        }
        return out
    
    def get_adjs(self):
        return self.last or {}


class TFEM(nn.Module):
    """Temporal Feature Extraction Module"""
    def __init__(self, i, o, k=16, pool=True):
        super().__init__()
        self.pool = pool
        self.cv = nn.Conv2d(i, o, kernel_size=(1, k), padding=(0, k//2), bias=False)
        self.bn = nn.BatchNorm2d(o)
        self.act = nn.ELU()
        self.pl = nn.AvgPool2d(kernel_size=(1, 2)) if pool else None
    
    def forward(self, x):
        x = self.act(self.bn(self.cv(x)))
        return self.pl(x) if self.pool else x


class EEGARNN_CARMv2(nn.Module):
    def __init__(self, C, T, K, H, cfg):
        super().__init__()
        self.t1 = TFEM(1, H, 16, False)
        self.g1 = CARMv2(C, H, cfg)
        self.t2 = TFEM(H, H, 16, True)
        self.g2 = CARMv2(C, H, cfg)
        self.t3 = TFEM(H, H, 16, True)
        self.g3 = CARMv2(C, H, cfg)
        
        with torch.no_grad():
            ft = self._f(torch.zeros(1, 1, C, T))
            fs = ft.view(1, -1).size(1)
        
        self.fc1 = nn.Linear(fs, 256)
        self.do = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, K)
    
    def _f(self, x):
        x = self.g1(self.t1(x))
        x = self.g2(self.t2(x))
        x = self.g3(self.t3(x))
        return x
    
    def forward(self, x):
        x = self._f(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.do(x)
        return self.fc2(x)
    
    def get_final_adjs(self):
        return self.g3.get_adjs()


print("Model architecture defined!")

In [None]:
# Training Functions

def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    all_preds, all_labels = [], []
    
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        all_preds += torch.argmax(logits, 1).cpu().tolist()
        all_labels += y.cpu().tolist()
    
    return total_loss / max(1, len(dataloader)), accuracy_score(all_labels, all_preds)


@torch.no_grad()
def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds, all_labels = [], []
    
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = criterion(logits, y)
        
        total_loss += loss.item()
        all_preds += torch.argmax(logits, 1).cpu().tolist()
        all_labels += y.cpu().tolist()
    
    return total_loss / max(1, len(dataloader)), accuracy_score(all_labels, all_preds), all_preds, all_labels


def train_model(model, train_loader, val_loader, device, epochs, lr, patience):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    
    try:
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=3, verbose=False
        )
    except TypeError:
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=3
        )
    
    best_acc = 0.0
    best_state = None
    no_improve = 0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    
    for epoch in range(epochs):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc, _, _ = evaluate(model, val_loader, criterion, device)
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        try:
            scheduler.step(val_loss)
        except Exception:
            pass
        
        if val_acc > best_acc:
            best_acc = val_acc
            best_state = {k: v.detach().cpu() if hasattr(v, 'detach') else v 
                         for k, v in model.state_dict().items()}
            no_improve = 0
        else:
            no_improve += 1
        
        if no_improve >= patience:
            break
    
    if best_state is None:
        best_state = {k: v.detach().cpu() if hasattr(v, 'detach') else v 
                     for k, v in model.state_dict().items()}
    
    model.load_state_dict(best_state)
    return history, best_state


def cross_validate_subject(x, y, channel_names, T, K, device, config):
    C = x.shape[1]
    skf = StratifiedKFold(
        n_splits=int(config['model']['n_folds']),
        shuffle=True,
        random_state=42
    )
    
    batch_size = int(config['model']['batch_size'])
    epochs = int(config['model']['epochs'])
    lr = float(config['model']['learning_rate'])
    patience = int(config['model']['patience'])
    
    folds = []
    adjacencies = []
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(x, y)):
        X_train, X_val = normalize(x[train_idx]), normalize(x[val_idx])
        Y_train, Y_val = y[train_idx], y[val_idx]
        
        train_loader = DataLoader(
            EEGDataset(X_train, Y_train),
            batch_size=batch_size,
            shuffle=True,
            num_workers=0
        )
        val_loader = DataLoader(
            EEGDataset(X_val, Y_val),
            batch_size=batch_size,
            shuffle=False,
            num_workers=0
        )
        
        model = EEGARNN_CARMv2(C, T, K, config['model']['hidden_dim'], config['carmv2']).to(device)
        history, best_state = train_model(model, train_loader, val_loader, device, epochs, lr, patience)
        model.load_state_dict(best_state)
        
        _, accuracy, _, _ = evaluate(model, val_loader, nn.CrossEntropyLoss(), device)
        
        adjacency = model.get_final_adjs().get('learned', None)
        adjacencies.append(adjacency)
        folds.append({'fold': fold, 'val_acc': accuracy, 'history': history})
    
    avg_acc = float(np.mean([f['val_acc'] for f in folds]))
    std_acc = float(np.std([f['val_acc'] for f in folds]))
    avg_adjacency = np.mean(np.stack([a for a in adjacencies if a is not None], 0), 0) \
                    if any(a is not None for a in adjacencies) else None
    
    return {
        'fold_results': folds,
        'avg_accuracy': avg_acc,
        'std_accuracy': std_acc,
        'adjacency_matrix': avg_adjacency,
        'channel_names': channel_names
    }


print("Training functions ready!")

In [None]:
# Channel Selection

class ChannelSelector:
    def __init__(self, adjacency, channel_names):
        self.A = adjacency
        self.names = np.array(channel_names)
        self.C = adjacency.shape[0]
    
    def edge_selection(self, k):
        """Select channels based on edge importance."""
        edges = []
        for i in range(self.C):
            for j in range(i+1, self.C):
                edges.append((i, j, abs(self.A[i, j]) + abs(self.A[j, i])))
        
        edges.sort(key=lambda t: t[2], reverse=True)
        top_edges = edges[:int(k)]
        indices = sorted(set([i for i, _, _ in top_edges] + [j for _, j, _ in top_edges]))
        return self.names[indices].tolist(), np.array(indices)
    
    def aggregation_selection(self, k):
        """Select channels based on aggregated connectivity."""
        scores = np.sum(np.abs(self.A), 1)
        indices = np.sort(np.argsort(scores)[-int(k):])
        return self.names[indices].tolist(), indices


def viz_adjacency(adjacency, channel_names, save_path=None):
    """Visualize adjacency matrix as heatmap."""
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        adjacency,
        xticklabels=channel_names,
        yticklabels=channel_names,
        cmap='RdYlGn',
        center=0,
        square=True,
        linewidths=0.4
    )
    plt.title('CARMv2 Learned Adjacency Matrix')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=200, bbox_inches='tight')
        plt.close()
        return save_path
    else:
        plt.show()


print("Channel selection ready!")

In [None]:
# MAIN TRAINING LOOP

results_dir = EXPERIMENT_CONFIG['output']['results_dir']
all_results = []
summary_records = []

print("\nStarting training...\n")

for subject_id in tqdm(subjects, desc='Training CARMv2'):
    print(f"\nProcessing {subject_id}...")
    
    # Load subject data with preprocessing
    X, Y, channel_names = load_subject_data(
        data_dir,
        subject_id,
        ALL_TASK_RUNS,
        EXPERIMENT_CONFIG
    )
    
    if X is None or len(Y) == 0:
        print(f"  Skipped: No data available")
        continue
    
    C, T = X.shape[1], X.shape[2]
    K = len(set(EXPERIMENT_CONFIG['data']['selected_classes']))
    
    print(f"  Data shape: {X.shape} (trials={X.shape[0]}, channels={C}, timepoints={T})")
    print(f"  Label distribution: {np.bincount(Y)}")
    
    # Cross-validate
    result = cross_validate_subject(X, Y, channel_names, T, K, device, EXPERIMENT_CONFIG)
    
    print(f"  Accuracy: {result['avg_accuracy']:.4f} ± {result['std_accuracy']:.4f}")
    
    all_results.append({
        'subject': subject_id,
        'num_trials': X.shape[0],
        'num_channels': C,
        'carmv2_acc': result['avg_accuracy'],
        'carmv2_std': result['std_accuracy'],
        'adjacency_matrix': result['adjacency_matrix'],
        'channel_names': result['channel_names'],
        'fold_results': result['fold_results']
    })
    
    summary_records.append({
        'subject': subject_id,
        'num_trials': X.shape[0],
        'num_channels': C,
        'carmv2_acc': result['avg_accuracy'],
        'carmv2_std': result['std_accuracy']
    })

print("\n" + "="*50)
print("Training Complete!")
print("="*50)

In [None]:
# Save Results and Visualizations

results_df = pd.DataFrame.from_records(summary_records)
print(f"\nSubjects trained: {len(results_df)}")

if len(results_df) > 0:
    print(f"\nMean accuracy: {results_df['carmv2_acc'].mean():.4f} ± {results_df['carmv2_acc'].std():.4f}")
    print(f"Best subject: {results_df.loc[results_df['carmv2_acc'].idxmax(), 'subject']}")
    print(f"Worst subject: {results_df.loc[results_df['carmv2_acc'].idxmin(), 'subject']}")
    
    # Save results CSV
    results_path = results_dir / EXPERIMENT_CONFIG['output']['results_file']
    results_df.to_csv(results_path, index=False)
    print(f"\nSaved results to {results_path}")
    
    # Save config
    config_path = results_dir / 'experiment_config.json'
    with open(config_path, 'w') as f:
        json.dump(EXPERIMENT_CONFIG, f, indent=2, default=str)
    print(f"Saved config to {config_path}")
    
    # Get best subject's adjacency matrix
    best_idx = int(np.argmax([r['carmv2_acc'] for r in all_results]))
    best_result = all_results[best_idx]
    
    adjacency = best_result['adjacency_matrix']
    channel_names = best_result['channel_names']
    
    if adjacency is not None and channel_names is not None:
        # Save adjacency visualization
        adj_path = results_dir / f"{EXPERIMENT_CONFIG['output']['adjacency_prefix']}_{best_result['subject']}.png"
        viz_adjacency(adjacency, channel_names, adj_path)
        print(f"\nSaved adjacency matrix to {adj_path}")
        
        # Channel selection
        selector = ChannelSelector(adjacency, channel_names)
        channel_selections = []
        
        for method in ['ES', 'AS']:
            for k in [10, 15, 20]:
                if method == 'ES':
                    selected, _ = selector.edge_selection(k)
                else:
                    selected, _ = selector.aggregation_selection(k)
                
                channel_selections.append({
                    'subject': best_result['subject'],
                    'method': method,
                    'k': k,
                    'channels': selected
                })
                print(f"  Selected ({method}, k={k}): {selected}")
        
        # Save channel selections
        channel_df = pd.DataFrame(channel_selections)
        channel_path = results_dir / EXPERIMENT_CONFIG['output']['channel_selection_file']
        channel_df.to_csv(channel_path, index=False)
        print(f"\nSaved channel selections to {channel_path}")
    
    # Summary visualizations
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # Accuracy distribution
    axes[0, 0].hist(results_df['carmv2_acc'], bins=15, color='steelblue', alpha=0.8, edgecolor='black')
    axes[0, 0].set_title('Accuracy Distribution')
    axes[0, 0].set_xlabel('Accuracy')
    axes[0, 0].set_ylabel('Frequency')
    
    # Trials vs Accuracy
    axes[0, 1].scatter(results_df['num_trials'], results_df['carmv2_acc'], alpha=0.6)
    axes[0, 1].set_title('Number of Trials vs Accuracy')
    axes[0, 1].set_xlabel('Number of Trials')
    axes[0, 1].set_ylabel('Accuracy')
    
    # Top subjects
    top_n = min(10, len(results_df))
    top_subjects = results_df.nlargest(top_n, 'carmv2_acc')
    axes[1, 0].barh(range(len(top_subjects)), top_subjects['carmv2_acc'], color='forestgreen', alpha=0.7)
    axes[1, 0].set_yticks(range(len(top_subjects)))
    axes[1, 0].set_yticklabels(top_subjects['subject'])
    axes[1, 0].invert_yaxis()
    axes[1, 0].set_title(f'Top {top_n} Subjects')
    axes[1, 0].set_xlabel('Accuracy')
    
    # Ranking curve
    sorted_results = results_df.sort_values('carmv2_acc')
    axes[1, 1].plot(range(len(sorted_results)), sorted_results['carmv2_acc'], marker='o', color='coral')
    axes[1, 1].set_title('Subject Ranking')
    axes[1, 1].set_xlabel('Rank')
    axes[1, 1].set_ylabel('Accuracy')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("\nAll done! Check the 'results' folder for outputs.")
else:
    print("\nNo results to save.")