# PhysioNet Motor Imagery - SparseGCN-CARM

## Improved Method for Channel Selection

**Previous Results:**
- Baseline: 86.49% accuracy, 0.74% drop @ k=20
- Gated CARM: 87.20% accuracy, 3.08% drop @ k=20
- CARMv2: 86.83% accuracy, 0.71% drop @ k=20 (best selection)

**SparseGCN-CARM Target:**
- Accuracy: > 88.5%
- Channel selection: < 0.5% drop @ k=20
- Auto channel pruning: 64 -> 30-40 channels

**Key Innovations:**
1. Progressive channel pruning during training
2. Multi-head attention for explicit channel importance
3. Multi-scale temporal convolution [8, 16, 32]
4. Feature-adaptive GCN from CARMv2 (keeps what works)

## 1. Setup and Imports

In [1]:
import json
import random
import warnings
from pathlib import Path
from copy import deepcopy

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score

import mne

warnings.filterwarnings('ignore')
sns.set_context('notebook', font_scale=1.0)
mne.set_log_level('WARNING')

def set_seed(s=42):
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)
    torch.cuda.manual_seed_all(s)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

Device: cuda


## 2. Configuration

In [2]:
# Auto-detect Kaggle environment
import os
if os.path.exists('/kaggle/input'):
    print("Running on Kaggle")
    kaggle_input = Path('/kaggle/input')
    datasets = [d for d in kaggle_input.iterdir() if d.is_dir()]
    print(f"Available datasets: {[d.name for d in datasets]}")

    DATA_DIR = None
    possible_names = ['physioneteegmi', 'eeg-motor-movementimagery-dataset']
    for ds_name in possible_names:
        test_path = kaggle_input / ds_name
        if test_path.exists():
            DATA_DIR = test_path
            print(f"Found dataset: {DATA_DIR}")
            break

    if DATA_DIR is None and datasets:
        DATA_DIR = datasets[0]
        print(f"Using first available dataset: {DATA_DIR}")
else:
    print("Running locally")
    DATA_DIR = Path('data/physionet/files')

CONFIG = {
    'data': {
        'raw_data_dir': DATA_DIR,
        'selected_classes': [1, 2],
        'tmin': -1.0,
        'tmax': 5.0,
        'baseline': (-0.5, 0)
    },
    'preprocessing': {
        'l_freq': 0.5,
        'h_freq': 40.0,
        'notch_freq': 50.0,
        'target_sfreq': 128.0,
        'apply_car': True
    },
    'model': {
        'hidden_dim': 40,
        'epochs': 40,
        'learning_rate': 1e-3,
        'batch_size': 32,
        'n_folds': 3,
        'patience': 10
    },
    'sparsegcn': {
        'topk_k': 8,
        'lambda_feat': 0.3,
        'hop_alpha': 0.5,
        'edge_dropout': 0.1,
        'use_pairnorm': True,
        'use_residual': True,
        'use_channel_attention': True,
        'attention_heads': 4,
        'prune_enabled': True,
        'prune_start_epoch': 10,
        'prune_every': 2,
        'prune_ratio': 0.05,
        'min_channels': 20,
        'temporal_scales': [8, 16, 32],
        'channel_importance_loss': 1e-3
    },
    'channel_selection': {
        'k_values': [10, 15, 20, 25]
    },
    'output': {
        'results_dir': Path('results'),
    },
    'max_subjects': 20,
    'min_runs_per_subject': 8
}

CONFIG['output']['results_dir'].mkdir(exist_ok=True, parents=True)
print("\nConfiguration loaded!")
print(f"Training: {CONFIG['max_subjects']} subjects, {CONFIG['model']['n_folds']}-fold CV")
print(f"Progressive pruning: Start epoch {CONFIG['sparsegcn']['prune_start_epoch']}")

Running on Kaggle
Available datasets: ['physioneteegmi']
Found dataset: /kaggle/input/physioneteegmi

Configuration loaded!
Training: 20 subjects, 3-fold CV
Progressive pruning: Start epoch 10


## 3. Data Cleaning - Remove Faulty Subjects

Based on quality analysis, exclude subjects with:
- Less than 10 good runs
- Good run ratio < 70%
- Known problematic subjects

In [3]:
# Known faulty subjects from data cleaning analysis
KNOWN_BAD_SUBJECTS = [
    'S088', 'S089', 'S092', 'S100', 'S104', 'S106', 'S107', 'S108', 'S109'
]

# Additional subjects with high clipping or amplitude issues
HIGH_ISSUE_SUBJECTS = [
    'S003', 'S004', 'S009', 'S010', 'S012', 'S013', 'S017', 'S018', 'S019',
    'S021', 'S022', 'S023', 'S024', 'S025', 'S026', 'S027', 'S028', 'S029'
]

EXCLUDED_SUBJECTS = set(KNOWN_BAD_SUBJECTS + HIGH_ISSUE_SUBJECTS)

print(f"Total excluded subjects: {len(EXCLUDED_SUBJECTS)}")
print(f"Excluded: {sorted(EXCLUDED_SUBJECTS)}")

Total excluded subjects: 27
Excluded: ['S003', 'S004', 'S009', 'S010', 'S012', 'S013', 'S017', 'S018', 'S019', 'S021', 'S022', 'S023', 'S024', 'S025', 'S026', 'S027', 'S028', 'S029', 'S088', 'S089', 'S092', 'S100', 'S104', 'S106', 'S107', 'S108', 'S109']


## 4. Data Loading and Preprocessing

In [4]:
def preprocess_raw(raw, config):
    cleaned_names = {name: name.rstrip('.') for name in raw.ch_names}
    raw.rename_channels(cleaned_names)
    raw.pick_types(eeg=True)
    raw.set_montage('standard_1020', on_missing='ignore', match_case=False)

    nyquist = raw.info['sfreq'] / 2.0
    if config['preprocessing']['notch_freq'] < nyquist:
        raw.notch_filter(freqs=config['preprocessing']['notch_freq'], verbose=False)

    raw.filter(
        l_freq=config['preprocessing']['l_freq'],
        h_freq=config['preprocessing']['h_freq'],
        method='fir',
        fir_design='firwin',
        verbose=False
    )

    if config['preprocessing']['apply_car']:
        raw.set_eeg_reference('average', projection=False, verbose=False)

    raw.resample(config['preprocessing']['target_sfreq'], npad='auto', verbose=False)
    return raw


def load_and_preprocess_edf(edf_path, config):
    raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')
    raw = preprocess_raw(raw, config)

    try:
        events = mne.find_events(raw, verbose='ERROR')
        event_ids = {f'T{i}': i for i in np.unique(events[:, 2])}
        assert len(events) > 0
    except Exception:
        events, event_ids = mne.events_from_annotations(raw, verbose='ERROR')

    if len(events) == 0:
        return None, None, raw.ch_names

    epochs = mne.Epochs(
        raw,
        events,
        event_id=event_ids,
        tmin=config['data']['tmin'],
        tmax=config['data']['tmax'],
        baseline=tuple(config['data']['baseline']),
        preload=True,
        verbose='ERROR'
    )

    return epochs.get_data(), epochs.events[:, 2], raw.ch_names


def filter_classes(x, y, selected_classes):
    mask = np.isin(y, selected_classes)
    y, x = y[mask], x[mask]
    label_map = {old: new for new, old in enumerate(sorted(selected_classes))}
    y = np.array([label_map[int(label)] for label in y], dtype=np.int64)
    return x, y


def normalize(x):
    mu = x.mean(axis=(0, 2), keepdims=True)
    sd = x.std(axis=(0, 2), keepdims=True) + 1e-8
    return (x - mu) / sd


def load_subject_data(data_dir, subject_id, run_ids, config):
    subject_dir = data_dir / subject_id
    if not subject_dir.exists():
        return None, None, None

    all_x, all_y = [], []
    channel_names = None

    for run_id in run_ids:
        edf_path = subject_dir / f'{subject_id}{run_id}.edf'
        if not edf_path.exists():
            continue

        try:
            x, y, ch_names = load_and_preprocess_edf(edf_path, config)
            if x is None or len(y) == 0:
                continue

            x, y = filter_classes(x, y, config['data']['selected_classes'])
            if len(y) == 0:
                continue

            channel_names = channel_names or ch_names
            all_x.append(x)
            all_y.append(y)
        except Exception as e:
            print(f"  Warning: Failed to load {edf_path.name}: {e}")
            continue

    if len(all_x) == 0:
        return None, None, channel_names

    return np.concatenate(all_x, 0), np.concatenate(all_y, 0), channel_names


def get_available_subjects(data_dir, min_runs=8, excluded=None):
    if not data_dir.exists():
        raise ValueError(f"Data directory not found: {data_dir}")

    excluded = excluded or set()
    subjects = []

    for subject_dir in sorted(data_dir.iterdir()):
        if not subject_dir.is_dir() or not subject_dir.name.startswith('S'):
            continue

        if subject_dir.name in excluded:
            continue

        edf_files = list(subject_dir.glob('*.edf'))
        if len(edf_files) >= min_runs:
            subjects.append(subject_dir.name)

    return subjects


# Scan for available subjects
print("\nScanning for subjects...")
data_dir = CONFIG['data']['raw_data_dir']
print(f"Looking for data in: {data_dir}")

all_subjects = get_available_subjects(
    data_dir,
    min_runs=CONFIG['min_runs_per_subject'],
    excluded=EXCLUDED_SUBJECTS
)
subjects = all_subjects[:CONFIG['max_subjects']]

print(f"Found {len(all_subjects)} clean subjects")
print(f"Will process {len(subjects)} subjects: {subjects}")

# Define which runs to use
MOTOR_IMAGERY_RUNS = ['R07', 'R08', 'R09', 'R10', 'R11', 'R12', 'R13', 'R14']
MOTOR_EXECUTION_RUNS = ['R03', 'R04', 'R05', 'R06']
ALL_TASK_RUNS = MOTOR_IMAGERY_RUNS + MOTOR_EXECUTION_RUNS
print(f"Using runs: {ALL_TASK_RUNS}")


Scanning for subjects...
Looking for data in: /kaggle/input/physioneteegmi
Found 82 clean subjects
Will process 20 subjects: ['S001', 'S002', 'S005', 'S006', 'S007', 'S008', 'S011', 'S014', 'S015', 'S016', 'S020', 'S030', 'S031', 'S032', 'S033', 'S034', 'S035', 'S036', 'S037', 'S038']
Using runs: ['R07', 'R08', 'R09', 'R10', 'R11', 'R12', 'R13', 'R14', 'R03', 'R04', 'R05', 'R06']


## 5. PyTorch Dataset

In [5]:
class EEGDataset(Dataset):
    def __init__(self, x, y):
        self.x = torch.FloatTensor(x).unsqueeze(1)
        self.y = torch.LongTensor(y)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, i):
        return self.x[i], self.y[i]

## 6. SparseGCN-CARM Architecture

### 6.1 Channel Attention Module
Multi-head attention to learn channel importance explicitly

In [6]:
class ChannelAttention(nn.Module):
    def __init__(self, num_channels, num_heads=4):
        super().__init__()
        self.num_channels = num_channels
        self.num_heads = num_heads
        self.head_dim = max(1, num_channels // num_heads)

        # Channel embedding
        self.channel_embed = nn.Parameter(torch.randn(num_channels, self.head_dim * num_heads))
        nn.init.xavier_uniform_(self.channel_embed)

        # Attention weights
        self.query = nn.Linear(self.head_dim * num_heads, self.head_dim * num_heads)
        self.key = nn.Linear(self.head_dim * num_heads, self.head_dim * num_heads)

        # Final projection to importance scores
        self.importance_proj = nn.Sequential(
            nn.Linear(self.head_dim * num_heads, 1),
            nn.Sigmoid()
        )

    def forward(self, x=None):
        embed = self.channel_embed  # (C, D)

        Q = self.query(embed)  # (C, D)
        K = self.key(embed)    # (C, D)

        # Multi-head attention
        Q = Q.view(self.num_channels, self.num_heads, self.head_dim)
        K = K.view(self.num_channels, self.num_heads, self.head_dim)

        # Attention scores
        attn = torch.einsum('chd,khd->chk', Q, K) / np.sqrt(self.head_dim)
        attn = F.softmax(attn, dim=-1)

        # Aggregate across heads
        attn_pooled = attn.mean(dim=1)  # (C, C)
        channel_scores = attn_pooled.sum(dim=1)  # (C,)
        channel_scores = channel_scores / channel_scores.sum()

        # Also use learned projection
        importance = self.importance_proj(embed).squeeze(-1) if embed.dim() > 1 else self.importance_proj(embed)

        # Combine both signals
        final_importance = 0.5 * channel_scores + 0.5 * importance
        final_importance = final_importance / (final_importance.sum() + 1e-8)

        return final_importance, attn



### 6.2 Multi-Scale Temporal Convolution

In [7]:
class MultiScaleTemporalConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_sizes=[8, 16, 32], pool=True):
        super().__init__()
        self.pool = pool
        k_num = len(kernel_sizes)
        base = out_channels // k_num
        rem = out_channels - base * k_num
        out_list = [base + (1 if i < rem else 0) for i in range(k_num)]

        branches = []
        for i, k in enumerate(kernel_sizes):
            oc = out_list[i]
            branches.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, oc, kernel_size=(1, k), padding=(0, k//2), bias=False),
                    nn.BatchNorm2d(oc),
                    nn.ELU()
                )
            )
        self.branches = nn.ModuleList(branches)
        self.pool_layer = nn.AvgPool2d(kernel_size=(1, 2)) if pool else None

    def forward(self, x):
        branch_outputs = [branch(x) for branch in self.branches]
        x = torch.cat(branch_outputs, dim=1)
        return self.pool_layer(x) if self.pool else x


def pairnorm(x, node_dim=2, eps=1e-6):
    m = x.mean(dim=node_dim, keepdim=True)
    xc = x - m
    v = (xc * xc).mean(dim=node_dim, keepdim=True)
    return xc / torch.sqrt(v + eps)


def build_feat_topk_adj(x, k, active_channels=None):
    B, H, C, T = x.shape

    if active_channels is not None:
        mask = active_channels.view(C, 1).to(x.device)
        x_masked = x * mask.view(1, 1, C, 1)
    else:
        x_masked = x

    E = x_masked.permute(2, 1, 0, 3).contiguous().view(C, H, B*T).mean(2)
    En = F.normalize(E, p=2, dim=1)
    S = (En @ En.t()).clamp_min(0.0)
    k = max(1, min(int(k), C))
    vals, idx = torch.topk(S, k, dim=1)
    M = torch.zeros_like(S)
    M.scatter_(1, idx, 1.0)
    A = S * M
    A = torch.softmax(A, 1)
    A = 0.5 * (A + A.t())
    return A



### 6.3 Feature-Adaptive GCN Layer (from CARMv2)

In [8]:
class AdaptiveGCNLayer(nn.Module):
    def __init__(self, C, H, topk_k=8, lambda_feat=0.3, hop_alpha=0.5, edge_dropout=0.1,
                 use_pairnorm=True, use_residual=True):
        super().__init__()
        self.C = C
        self.H = H
        self.k = topk_k
        self.lf = lambda_feat
        self.ha = hop_alpha
        self.ed = edge_dropout
        self.pn = use_pairnorm
        self.res = use_residual

        self.W = nn.Parameter(torch.empty(C, C))
        nn.init.xavier_uniform_(self.W)

        self.th = nn.Linear(H, H, bias=False)
        self.bn = nn.BatchNorm2d(H)
        self.act = nn.ELU()
        self.last = None

    def _learned(self, dev, active_channels=None):
        A = torch.sigmoid(self.W)
        A = 0.5 * (A + A.t())

        if active_channels is not None:
            mask = active_channels.view(self.C, 1).to(dev)
            A = A * mask * mask.t()

        I = torch.eye(self.C, device=dev, dtype=A.dtype)
        At = A + I
        d = torch.pow(At.sum(1).clamp_min(1e-6), -0.5)
        D = torch.diag(d)
        return D @ At @ D

    def forward(self, x, active_channels=None):
        B, H, C, T = x.shape

        Al = self._learned(x.device, active_channels)
        A2 = Al @ Al
        Ah = (1 - self.ha) * Al + self.ha * A2
        Af = build_feat_topk_adj(x, self.k, active_channels)
        A = (1 - self.lf) * Ah + self.lf * Af

        if self.training and self.ed > 0:
            M = (torch.rand_like(A) > self.ed).float()
            A = 0.5 * ((A * M) + (A * M).t())
            A = A + torch.eye(C, device=A.device, dtype=A.dtype)

        d = torch.pow(A.sum(1).clamp_min(1e-6), -0.5)
        D = torch.diag(d)
        A = D @ A @ D

        xb = x.permute(0, 3, 2, 1).contiguous().view(B*T, C, H)
        xg = A @ xb
        xg = self.th(xg)
        xg = xg.view(B, T, C, H).permute(0, 3, 2, 1)

        if self.res:
            if active_channels is not None:
                x_res = x * active_channels.view(1, 1, C, 1)
            else:
                x_res = x
            out = xg + x_res
        else:
            out = xg
        out = pairnorm(out, 2) if self.pn else out
        out = self.bn(out)
        out = self.act(out)

        self.last = {'learned': Al.detach().cpu().numpy()}
        return out



### 6.4 SparseGCN-CARM Main Model

In [9]:
class SparseGCNCARMModel(nn.Module):
    def __init__(self, C, T, K, H, config):
        super().__init__()
        self.C = C
        self.config = config

        # Channel attention
        if config.get('use_channel_attention', True):
            self.channel_attention = ChannelAttention(C, config.get('attention_heads', 4))
        else:
            self.channel_attention = None

        # Multi-scale temporal + adaptive GCN layers
        scales = config.get('temporal_scales', [8, 16, 32])
        self.t1 = MultiScaleTemporalConv(1, H, scales, False)
        self.g1 = AdaptiveGCNLayer(C, H, config.get('topk_k', 8), config.get('lambda_feat', 0.3),
                                   config.get('hop_alpha', 0.5), config.get('edge_dropout', 0.1),
                                   config.get('use_pairnorm', True), config.get('use_residual', True))
        self.t2 = MultiScaleTemporalConv(H, H, scales, True)
        self.g2 = AdaptiveGCNLayer(C, H, config.get('topk_k', 8), config.get('lambda_feat', 0.3),
                                   config.get('hop_alpha', 0.5), config.get('edge_dropout', 0.1),
                                   config.get('use_pairnorm', True), config.get('use_residual', True))
        self.t3 = MultiScaleTemporalConv(H, H, scales, True)
        self.g3 = AdaptiveGCNLayer(C, H, config.get('topk_k', 8), config.get('lambda_feat', 0.3),
                                   config.get('hop_alpha', 0.5), config.get('edge_dropout', 0.1),
                                   config.get('use_pairnorm', True), config.get('use_residual', True))

        with torch.no_grad():
            ft = self._forward_features(torch.zeros(1, 1, C, T), None)
            fs = ft.view(1, -1).size(1)

        self.fc1 = nn.Linear(fs, 256)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, K)

        # Active channels mask
        self.register_buffer('active_channels', torch.ones(C))
        self.channel_importance = None

    def _forward_features(self, x, active_channels):
        x = self.g1(self.t1(x), active_channels)
        x = self.g2(self.t2(x), active_channels)
        x = self.g3(self.t3(x), active_channels)
        return x

    def forward(self, x):
        # Get channel importance
        if self.channel_attention is not None:
            importance, attn = self.channel_attention(x)
            self.channel_importance = importance

            # Apply soft channel gating during training
            if self.training:
                x = x * importance.view(1, 1, self.C, 1)

        # Forward through network
        x = self._forward_features(x, self.active_channels)

        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

    def prune_channels(self, prune_ratio, min_channels=20):
        if self.channel_importance is None:
            return 0

        num_active = int(self.active_channels.sum().item())
        num_to_prune = max(1, int(num_active * prune_ratio))

        if num_active - num_to_prune < min_channels:
            return 0

        importance = self.channel_importance.detach().cpu()
        active_importance = importance * self.active_channels.cpu()

        _, sorted_indices = torch.sort(active_importance)

        pruned = 0
        for idx in sorted_indices:
            if self.active_channels[idx] > 0:
                self.active_channels[idx] = 0
                pruned += 1
                if pruned >= num_to_prune:
                    break

        return pruned

    def get_channel_importance(self):
        if self.channel_importance is not None:
            return self.channel_importance.detach().cpu().numpy()
        return None

    def get_active_channels_mask(self):
        return self.active_channels.cpu().numpy()

    def get_final_adjacency(self):
        return self.g3.last.get('learned', None) if self.g3.last else None


# Training functions

## 7. Training with Progressive Pruning

In [10]:
def train_epoch_sparse(model, dataloader, criterion, optimizer, device, config):
    model.train()
    total_loss = 0.0
    all_preds, all_labels = [], []

    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()

        logits = model(x)
        loss = criterion(logits, y)

        # Add channel importance regularization
        if model.channel_importance is not None:
            importance_loss = config.get('channel_importance_loss', 1e-3) * model.channel_importance.abs().mean()
            loss = loss + importance_loss

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        all_preds += torch.argmax(logits, 1).cpu().tolist()
        all_labels += y.cpu().tolist()

    return total_loss / max(1, len(dataloader)), accuracy_score(all_labels, all_preds)


@torch.no_grad()
def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds, all_labels = [], []

    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = criterion(logits, y)

        total_loss += loss.item()
        all_preds += torch.argmax(logits, 1).cpu().tolist()
        all_labels += y.cpu().tolist()

    return total_loss / max(1, len(dataloader)), accuracy_score(all_labels, all_preds)


def train_with_pruning(model, train_loader, val_loader, device, epochs, lr, patience, config):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3, verbose=False
    )

    best_acc = 0.0
    best_state = None
    no_improve = 0

    prune_enabled = config.get('prune_enabled', True)
    prune_start = config.get('prune_start_epoch', 10)
    prune_every = config.get('prune_every', 2)
    prune_ratio = config.get('prune_ratio', 0.05)
    min_channels = config.get('min_channels', 20)

    for epoch in range(epochs):
        train_loss, train_acc = train_epoch_sparse(model, train_loader, criterion, optimizer, device, config)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)

        scheduler.step(val_loss)

        # Progressive pruning
        if prune_enabled and epoch >= prune_start and (epoch - prune_start) % prune_every == 0:
            num_pruned = model.prune_channels(prune_ratio, min_channels)
            if num_pruned > 0:
                active = int(model.active_channels.sum().item())
                print(f"    Epoch {epoch+1}: Pruned {num_pruned} channels, {active} remaining")

        if val_acc > best_acc:
            best_acc = val_acc
            best_state = deepcopy(model.state_dict())
            no_improve = 0
        else:
            no_improve += 1

        if no_improve >= patience:
            break

    if best_state is None:
        best_state = deepcopy(model.state_dict())

    model.load_state_dict(best_state)
    return best_state, best_acc



## 8. Improved Channel Selector

In [11]:
class ImprovedChannelSelector:
    def __init__(self, adjacency, importance_scores, channel_names):
        self.A = adjacency
        self.importance = importance_scores
        self.names = np.array(channel_names)
        self.C = adjacency.shape[0]

    def importance_selection(self, k):
        k = min(int(k), self.C)
        indices = np.argsort(self.importance)[-k:]
        indices = np.sort(indices)
        return self.names[indices].tolist(), indices

    def hybrid_selection(self, k):
        connectivity = np.sum(np.abs(self.A), 1)
        combined_score = 0.7 * self.importance + 0.3 * (connectivity / connectivity.max())

        k = min(int(k), self.C)
        indices = np.argsort(combined_score)[-k:]
        indices = np.sort(indices)
        return self.names[indices].tolist(), indices

    def edge_selection(self, k):
        edges = []
        for i in range(self.C):
            for j in range(i+1, self.C):
                edges.append((i, j, abs(self.A[i, j]) + abs(self.A[j, i])))
        edges.sort(key=lambda t: t[2], reverse=True)
        top_edges = edges[:int(k)]
        indices = sorted(set([i for i, _, _ in top_edges] + [j for _, j, _ in top_edges]))
        return self.names[indices].tolist(), np.array(indices)



## 9. Main Training Loop

In [12]:
all_results = []

for subject_id in tqdm(subjects, desc='Training SparseGCN'):
    print(f"\nProcessing {subject_id}...")

    X, Y, channel_names = load_subject_data(data_dir, subject_id, ALL_TASK_RUNS, CONFIG)
    if X is None:
        print(f"  Skipped (no data)")
        continue

    C, T = X.shape[1], X.shape[2]
    K = len(set(CONFIG['data']['selected_classes']))
    print(f"  Data shape: {X.shape}, Classes: {K}")

    skf = StratifiedKFold(n_splits=CONFIG['model']['n_folds'], shuffle=True, random_state=42)
    fold_results = []
    adjacencies = []
    importance_scores = []

    for fold, (train_idx, val_idx) in enumerate(skf.split(X, Y)):
        X_train, X_val = normalize(X[train_idx]), normalize(X[val_idx])
        Y_train, Y_val = Y[train_idx], Y[val_idx]

        train_loader = DataLoader(
            EEGDataset(X_train, Y_train),
            batch_size=CONFIG['model']['batch_size'],
            shuffle=True,
            num_workers=0
        )
        val_loader = DataLoader(
            EEGDataset(X_val, Y_val),
            batch_size=CONFIG['model']['batch_size'],
            shuffle=False,
            num_workers=0
        )

        model = SparseGCNCARMModel(C, T, K, CONFIG['model']['hidden_dim'], CONFIG['sparsegcn']).to(device)

        best_state, best_acc = train_with_pruning(
            model, train_loader, val_loader, device,
            CONFIG['model']['epochs'],
            CONFIG['model']['learning_rate'],
            CONFIG['model']['patience'],
            CONFIG['sparsegcn']
        )

        model.load_state_dict(best_state)
        _, accuracy = evaluate(model, val_loader, nn.CrossEntropyLoss(), device)

        adjacency = model.get_final_adjacency()
        importance = model.get_channel_importance()
        active_mask = model.get_active_channels_mask()

        adjacencies.append(adjacency)
        importance_scores.append(importance)

        fold_results.append({
            'fold': fold,
            'val_acc': accuracy,
            'num_active': int(active_mask.sum())
        })

        print(f"  Fold {fold+1}: {accuracy:.4f} ({int(active_mask.sum())}/{C} channels)")

    avg_acc = np.mean([f['val_acc'] for f in fold_results])
    avg_active = np.mean([f['num_active'] for f in fold_results])

    all_results.append({
        'subject': subject_id,
        'accuracy': avg_acc,
        'avg_active_channels': avg_active,
        'adjacency_matrix': np.mean(adjacencies, 0),
        'channel_importance': np.mean(importance_scores, 0),
        'channel_names': channel_names
    })

    print(f"  Final: {avg_acc:.4f} ({avg_active:.0f} active channels)")

print("\nTraining complete!")

Training SparseGCN:   0%|          | 0/20 [00:00<?, ?it/s]


Processing S001...
  Data shape: (252, 64, 769), Classes: 2
    Epoch 11: Pruned 3 channels, 61 remaining
    Epoch 13: Pruned 3 channels, 58 remaining
  Fold 1: 0.9286 (64/64 channels)
    Epoch 11: Pruned 3 channels, 61 remaining
    Epoch 13: Pruned 3 channels, 58 remaining
  Fold 2: 0.9167 (64/64 channels)
    Epoch 11: Pruned 3 channels, 61 remaining
    Epoch 13: Pruned 3 channels, 58 remaining
    Epoch 15: Pruned 2 channels, 56 remaining
    Epoch 17: Pruned 2 channels, 54 remaining
    Epoch 19: Pruned 2 channels, 52 remaining
  Fold 3: 0.9167 (64/64 channels)
  Final: 0.9206 (64 active channels)

Processing S002...
  Data shape: (252, 64, 769), Classes: 2
    Epoch 11: Pruned 3 channels, 61 remaining
    Epoch 13: Pruned 3 channels, 58 remaining
    Epoch 15: Pruned 2 channels, 56 remaining
    Epoch 17: Pruned 2 channels, 54 remaining
  Fold 1: 0.7976 (64/64 channels)
    Epoch 11: Pruned 3 channels, 61 remaining
    Epoch 13: Pruned 3 channels, 58 remaining
    Epoch 15: P

## 10. Results and Comparison

In [13]:
results_df = pd.DataFrame(all_results)

print("SparseGCN-CARM Results:")
print(f"Mean accuracy: {results_df['accuracy'].mean():.4f}")
print(f"Std accuracy: {results_df['accuracy'].std():.4f}")
print(f"Avg active channels: {results_df['avg_active_channels'].mean():.1f}")

# Display results
print("\nPer-subject results:")
print(results_df[['subject', 'accuracy', 'avg_active_channels']])

SparseGCN-CARM Results:
Mean accuracy: 0.8792
Std accuracy: 0.0567
Avg active channels: 62.8

Per-subject results:
   subject  accuracy  avg_active_channels
0     S001  0.920635            64.000000
1     S002  0.813492            64.000000
2     S005  0.857143            64.000000
3     S006  0.873016            60.333333
4     S007  0.940476            62.000000
5     S008  0.944444            64.000000
6     S011  0.857143            64.000000
7     S014  0.821429            59.333333
8     S015  0.876984            63.000000
9     S016  0.761905            63.000000
10    S020  0.968254            64.000000
11    S030  0.884921            64.000000
12    S031  0.884921            64.000000
13    S032  0.956349            62.000000
14    S033  0.916667            62.000000
15    S034  0.873016            64.000000
16    S035  0.940476            58.666667
17    S036  0.873016            63.000000
18    S037  0.809524            63.000000
19    S038  0.809524            63.000000


## 11. Channel Selection Experiments

In [None]:
k_values = CONFIG['channel_selection']['k_values']
selection_results = []

for result in all_results:
    subject_id = result['subject']
    full_acc = result['accuracy']
    adjacency = result['adjacency_matrix']
    importance = result['channel_importance']
    channel_names = result['channel_names']

    # Load subject data again
    X, Y, _ = load_subject_data(data_dir, subject_id, ALL_TASK_RUNS, CONFIG)
    if X is None:
        continue

    C, T = X.shape[1], X.shape[2]
    K = len(set(CONFIG['data']['selected_classes']))

    selector = ImprovedChannelSelector(adjacency, importance, channel_names)

    for k in k_values:
        # Try different selection methods
        for method_name, method_func in [
            ('IMP', selector.importance_selection),
            ('HYB', selector.hybrid_selection),
            ('ES', selector.edge_selection)
        ]:
            selected_names, selected_indices = method_func(k)

            # Retrain with selected channels
            X_subset = X[:, selected_indices, :]

            skf = StratifiedKFold(n_splits=CONFIG['model']['n_folds'], shuffle=True, random_state=42)
            fold_accs = []

            for fold, (train_idx, val_idx) in enumerate(skf.split(X_subset, Y)):
                X_train = normalize(X_subset[train_idx])
                X_val = normalize(X_subset[val_idx])
                Y_train, Y_val = Y[train_idx], Y[val_idx]

                train_loader = DataLoader(
                    EEGDataset(X_train, Y_train),
                    batch_size=CONFIG['model']['batch_size'],
                    shuffle=True,
                    num_workers=0
                )
                val_loader = DataLoader(
                    EEGDataset(X_val, Y_val),
                    batch_size=CONFIG['model']['batch_size'],
                    shuffle=False,
                    num_workers=0
                )

                # Use smaller model for selected channels
                model = SparseGCNCARMModel(
                    len(selected_indices), T, K,
                    CONFIG['model']['hidden_dim'],
                    {**CONFIG['sparsegcn'], 'prune_enabled': False}
                ).to(device)

                best_state, best_acc = train_with_pruning(
                    model, train_loader, val_loader, device,
                    CONFIG['model']['epochs'],
                    CONFIG['model']['learning_rate'],
                    CONFIG['model']['patience'],
                    {**CONFIG['sparsegcn'], 'prune_enabled': False}
                )

                model.load_state_dict(best_state)
                _, accuracy = evaluate(model, val_loader, nn.CrossEntropyLoss(), device)
                fold_accs.append(accuracy)

            subset_acc = np.mean(fold_accs)
            drop = full_acc - subset_acc

            selection_results.append({
                'subject': subject_id,
                'method': method_name,
                'k': k,
                'full_acc': full_acc,
                'subset_acc': subset_acc,
                'drop': drop,
                'channels': ','.join(selected_names)
            })

            print(f"{subject_id} {method_name} k={k}: {subset_acc:.4f} (drop: {drop:.4f})")

selection_df = pd.DataFrame(selection_results)
print("\nChannel selection complete!")

S001 IMP k=10: 0.8770 (drop: 0.0437)
S001 HYB k=10: 0.9246 (drop: -0.0040)
S001 ES k=10: 0.9048 (drop: 0.0159)
S001 IMP k=15: 0.9048 (drop: 0.0159)
S001 HYB k=15: 0.9206 (drop: 0.0000)
S001 ES k=15: 0.9127 (drop: 0.0079)
S001 IMP k=20: 0.9206 (drop: 0.0000)
S001 HYB k=20: 0.9008 (drop: 0.0198)
S001 ES k=20: 0.9167 (drop: 0.0040)
S001 IMP k=25: 0.9048 (drop: 0.0159)
S001 HYB k=25: 0.9008 (drop: 0.0198)
S001 ES k=25: 0.9048 (drop: 0.0159)
S002 IMP k=10: 0.7579 (drop: 0.0556)
S002 HYB k=10: 0.7183 (drop: 0.0952)
S002 ES k=10: 0.7183 (drop: 0.0952)
S002 IMP k=15: 0.7619 (drop: 0.0516)
S002 HYB k=15: 0.8016 (drop: 0.0119)
S002 ES k=15: 0.7857 (drop: 0.0278)
S002 IMP k=20: 0.7778 (drop: 0.0357)
S002 HYB k=20: 0.7897 (drop: 0.0238)
S002 ES k=20: 0.8095 (drop: 0.0040)
S002 IMP k=25: 0.7857 (drop: 0.0278)
S002 HYB k=25: 0.8056 (drop: 0.0079)
S002 ES k=25: 0.8214 (drop: -0.0079)
S005 IMP k=10: 0.8294 (drop: 0.0278)
S005 HYB k=10: 0.8214 (drop: 0.0357)
S005 ES k=10: 0.7698 (drop: 0.0873)
S005 IMP

## 12. Summary and Visualization

In [1]:
# Summarize by method and k
summary = selection_df.groupby(['method', 'k']).agg({
    'drop': ['mean', 'std']
}).reset_index()

print("\nChannel Selection Summary:")
print(summary)

# Plot comparison
fig, ax = plt.subplots(1, 1, figsize=(10, 6))

for method in ['IMP', 'HYB', 'ES']:
    method_data = selection_df[selection_df['method'] == method]
    grouped = method_data.groupby('k')['drop'].mean()
    ax.plot(grouped.index, grouped.values, marker='o', label=method)

ax.set_xlabel('Number of Selected Channels (k)')
ax.set_ylabel('Accuracy Drop')
ax.set_title('SparseGCN-CARM: Channel Selection Performance')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"\nBest selection method at k=20: {selection_df[selection_df['k']==20].groupby('method')['drop'].mean().idxmin()}")
print(f"Mean drop at k=20: {selection_df[selection_df['k']==20]['drop'].mean():.4f}")

NameError: name 'selection_df' is not defined

## 13. Save Results

In [None]:
results_df.to_csv('results/sparsegcn_results.csv', index=False)
selection_df.to_csv('results/sparsegcn_selection.csv', index=False)
print("Results saved!")
print(f"\nFinal SparseGCN-CARM Performance:")
print(f"- Mean accuracy: {results_df['accuracy'].mean():.4f}")
print(f"- Mean active channels: {results_df['avg_active_channels'].mean():.1f}")
print(f"- Best selection drop @ k=20: {selection_df[selection_df['k']==20]['drop'].min():.4f}")