# PhysioNet Motor Imagery - Legacy Baseline Methods

## Comprehensive Evaluation of 5 Baseline Methods + Channel Selection

This notebook trains and evaluates:
1. **FBCSP** - Filter Bank Common Spatial Patterns with LDA
2. **CNN-SAE** - CNN with Spatial Attention
3. **EEGNet** - Compact convolutional network
4. **ACS-SE-CNN** - Attention + Squeeze-Excitation CNN
5. **G-CARM** - Graph-based CARM

## Channel Selection Methods:
- **FBCSP**: CSP pattern-based selection
- **G-CARM**: Edge Selection (ES) / Aggregation Selection (AS)
- **CNN-SAE, EEGNet, ACS-SE-CNN**: Gradient-based attribution

## Configuration:
- **30 epochs**, **0.002 LR**, **NO EARLY STOPPING** (for PyTorch models)
- **10 subjects**, **3-fold CV**
- **9 filter banks**, **4 CSP components** (for FBCSP)
- **Channel Selection**: k=[10,15,20,25,30]

## Metrics:
- Accuracy, Precision, Recall, F1-Score, AUC-ROC, Specificity

## Output:
- `legacy_*_results.csv` (full channel results)
- `legacy_*_retrain_results.csv` (channel selection results)

## 1. Setup and Imports

In [4]:
import json
import random
import warnings
from pathlib import Path
from copy import deepcopy

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import StratifiedKFold
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix
)
from scipy.signal import butter, filtfilt
import gc

import mne
from mne.decoding import CSP

warnings.filterwarnings('ignore')
sns.set_context('notebook', font_scale=1.0)
mne.set_log_level('WARNING')

def set_seed(s=42):
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)
    torch.cuda.manual_seed_all(s)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

Device: cuda


## 2. Configuration

In [5]:
import os
from pathlib import Path

if os.path.exists('/kaggle/input'):
    print("Running on Kaggle")
    kaggle_input = Path('/kaggle/input')
    datasets = [d for d in kaggle_input.iterdir() if d.is_dir()]
    print(f"Available datasets: {[d.name for d in datasets]}")

    DATA_DIR = None
    possible_names = ['physioneteegmi', 'eeg-motor-movementimagery-dataset']
    for ds_name in possible_names:
        test_path = kaggle_input / ds_name
        if test_path.exists():
            DATA_DIR = test_path
            print(f"Found dataset: {DATA_DIR}")
            break

    if DATA_DIR is None and datasets:
        DATA_DIR = datasets[0]
        print(f"Using first available dataset: {DATA_DIR}")
else:
    print("Running locally")
    DATA_DIR = Path('data/physionet/files')

CONFIG = {
    'data': {
        'raw_data_dir': DATA_DIR,
        'selected_classes': [1, 2],
        'tmin': -1.0,
        'tmax': 5.0,
        'baseline': (-0.5, 0)
    },
    'preprocessing': {
        'l_freq': 0.5,
        'h_freq': 40.0,
        'notch_freq': 50.0,
        'target_sfreq': 128.0,
        'apply_car': True
    },
    'model': {
        'epochs': 30,
        'learning_rate': 0.002,
        'batch_size': 64,
        'n_folds': 3,
        'patience': 999
    },
    'fbcsp': {
        'freq_bands': [
            (4, 8), (8, 12), (12, 16), (16, 20), (20, 24),
            (24, 28), (28, 32), (32, 36), (36, 40)
        ],
        'n_components': 4
    },
    'channel_selection': {
        'k_values': [10, 15, 20, 25, 30]
    },
    'output': {
        'results_dir': Path('results'),
    },
    'max_subjects': 10,
    'min_runs_per_subject': 8
}

CONFIG['output']['results_dir'].mkdir(exist_ok=True, parents=True)

print(f"\nConfiguration loaded!")
print(f"Training: {CONFIG['max_subjects']} subjects, {CONFIG['model']['n_folds']}-fold CV, {CONFIG['model']['epochs']} epochs")
print(f"Learning rate: {CONFIG['model']['learning_rate']}, No early stopping (patience={CONFIG['model']['patience']})")
print(f"FBCSP: {len(CONFIG['fbcsp']['freq_bands'])} filter banks, {CONFIG['fbcsp']['n_components']} components")
print(f"Channel selection k values: {CONFIG['channel_selection']['k_values']}")

Running on Kaggle
Available datasets: ['physioneteegmi']
Found dataset: /kaggle/input/physioneteegmi

Configuration loaded!
Training: 10 subjects, 3-fold CV, 30 epochs
Learning rate: 0.002, No early stopping (patience=999)
FBCSP: 9 filter banks, 4 components
Channel selection k values: [10, 15, 20, 25, 30]


## 3. Data Cleaning - Remove Faulty Subjects

In [6]:
KNOWN_BAD_SUBJECTS = [
    'S088', 'S089', 'S092', 'S100', 'S104', 'S106', 'S107', 'S108', 'S109'
]

HIGH_ISSUE_SUBJECTS = [
    'S003', 'S004', 'S009', 'S010', 'S012', 'S013', 'S017', 'S018', 'S019',
    'S021', 'S022', 'S023', 'S024', 'S025', 'S026', 'S027', 'S028', 'S029'
]

EXCLUDED_SUBJECTS = set(KNOWN_BAD_SUBJECTS + HIGH_ISSUE_SUBJECTS)

print(f"Total excluded subjects: {len(EXCLUDED_SUBJECTS)}")
print(f"Excluded subjects: {sorted(EXCLUDED_SUBJECTS)}")

Total excluded subjects: 27
Excluded subjects: ['S003', 'S004', 'S009', 'S010', 'S012', 'S013', 'S017', 'S018', 'S019', 'S021', 'S022', 'S023', 'S024', 'S025', 'S026', 'S027', 'S028', 'S029', 'S088', 'S089', 'S092', 'S100', 'S104', 'S106', 'S107', 'S108', 'S109']


## 4. Data Loading and Preprocessing Functions

In [7]:
def preprocess_raw(raw, config):
    cleaned_names = {name: name.rstrip('.') for name in raw.ch_names}
    raw.rename_channels(cleaned_names)
    raw.pick_types(eeg=True)
    raw.set_montage('standard_1020', on_missing='ignore', match_case=False)
    
    nyquist = raw.info['sfreq'] / 2.0
    if config['preprocessing']['notch_freq'] < nyquist:
        raw.notch_filter(freqs=config['preprocessing']['notch_freq'], verbose=False)
    
    raw.filter(
        l_freq=config['preprocessing']['l_freq'],
        h_freq=config['preprocessing']['h_freq'],
        method='fir',
        fir_design='firwin',
        verbose=False
    )
    
    if config['preprocessing']['apply_car']:
        raw.set_eeg_reference('average', projection=False, verbose=False)
    
    raw.resample(config['preprocessing']['target_sfreq'], npad='auto', verbose=False)
    return raw


def load_and_preprocess_edf(edf_path, config):
    raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')
    raw = preprocess_raw(raw, config)
    
    events, event_ids = mne.events_from_annotations(raw, verbose='ERROR')
    
    if len(events) == 0:
        return None, None, raw.ch_names
    
    epochs = mne.Epochs(
        raw,
        events,
        event_id=event_ids,
        tmin=config['data']['tmin'],
        tmax=config['data']['tmax'],
        baseline=tuple(config['data']['baseline']),
        preload=True,
        verbose='ERROR'
    )
    
    return epochs.get_data(), epochs.events[:, 2], raw.ch_names


def filter_classes(x, y, selected_classes):
    mask = np.isin(y, selected_classes)
    y, x = y[mask], x[mask]
    label_map = {old: new for new, old in enumerate(sorted(selected_classes))}
    y = np.array([label_map[int(label)] for label in y], dtype=np.int64)
    return x, y


def normalize(x):
    mu = x.mean(axis=(0, 2), keepdims=True)
    sd = x.std(axis=(0, 2), keepdims=True) + 1e-8
    return (x - mu) / sd


def load_subject_data(data_dir, subject_id, run_ids, config):
    subject_dir = data_dir / subject_id
    if not subject_dir.exists():
        return None, None, None
    
    all_x, all_y = [], []
    channel_names = None
    
    for run_id in run_ids:
        edf_path = subject_dir / f'{subject_id}{run_id}.edf'
        if not edf_path.exists():
            continue
        
        try:
            x, y, ch_names = load_and_preprocess_edf(edf_path, config)
            if x is None or len(y) == 0:
                continue
            
            x, y = filter_classes(x, y, config['data']['selected_classes'])
            if len(y) == 0:
                continue
            
            channel_names = channel_names or ch_names
            all_x.append(x)
            all_y.append(y)
        except Exception as e:
            print(f"  Warning: Failed to load {edf_path.name}: {e}")
            continue
    
    if len(all_x) == 0:
        return None, None, channel_names
    
    return np.concatenate(all_x, 0), np.concatenate(all_y, 0), channel_names


def get_available_subjects(data_dir, min_runs=8, excluded=None):
    if not data_dir.exists():
        raise ValueError(f"Data directory not found: {data_dir}")
    
    excluded = excluded or set()
    subjects = []
    
    for subject_dir in sorted(data_dir.iterdir()):
        if not subject_dir.is_dir() or not subject_dir.name.startswith('S'):
            continue
        
        if subject_dir.name in excluded:
            continue
        
        edf_files = list(subject_dir.glob('*.edf'))
        if len(edf_files) >= min_runs:
            subjects.append(subject_dir.name)
    
    return subjects


print("\nScanning for subjects...")
data_dir = CONFIG['data']['raw_data_dir']
print(f"Looking for data in: {data_dir}")

all_subjects = get_available_subjects(
    data_dir, 
    min_runs=CONFIG['min_runs_per_subject'],
    excluded=EXCLUDED_SUBJECTS
)
subjects = all_subjects[:CONFIG['max_subjects']]

print(f"Found {len(all_subjects)} clean subjects with >= {CONFIG['min_runs_per_subject']} runs")
print(f"Will process {len(subjects)} subjects: {subjects}")

MOTOR_IMAGERY_RUNS = ['R07', 'R08', 'R09', 'R10', 'R11', 'R12', 'R13', 'R14']
MOTOR_EXECUTION_RUNS = ['R03', 'R04', 'R05', 'R06']
ALL_TASK_RUNS = MOTOR_IMAGERY_RUNS + MOTOR_EXECUTION_RUNS
print(f"Using runs: {ALL_TASK_RUNS}")


Scanning for subjects...
Looking for data in: /kaggle/input/physioneteegmi
Found 82 clean subjects with >= 8 runs
Will process 10 subjects: ['S001', 'S002', 'S005', 'S006', 'S007', 'S008', 'S011', 'S014', 'S015', 'S016']
Using runs: ['R07', 'R08', 'R09', 'R10', 'R11', 'R12', 'R13', 'R14', 'R03', 'R04', 'R05', 'R06']


## 5. PyTorch Dataset

In [8]:
class EEGDataset(Dataset):
    def __init__(self, x, y):
        self.x = torch.FloatTensor(x).unsqueeze(1)
        self.y = torch.LongTensor(y)
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, i):
        return self.x[i], self.y[i]

## 6. Comprehensive Metrics Functions

In [9]:
@torch.no_grad()
def calculate_comprehensive_metrics(model, dataloader, device):
    model.eval()
    all_preds, all_labels, all_probs = [], [], []

    for X_batch, y_batch in dataloader:
        X_batch = X_batch.to(device)
        outputs = model(X_batch)
        probs = torch.softmax(outputs, dim=1)
        _, predicted = torch.max(outputs, 1)

        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(y_batch.numpy())
        all_probs.extend(probs[:, 1].cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)

    metrics = {
        'accuracy': accuracy_score(all_labels, all_preds),
        'precision': precision_score(all_labels, all_preds, average='binary', zero_division=0),
        'recall': recall_score(all_labels, all_preds, average='binary', zero_division=0),
        'f1_score': f1_score(all_labels, all_preds, average='binary', zero_division=0),
        'auc_roc': roc_auc_score(all_labels, all_probs) if len(np.unique(all_labels)) > 1 else 0.0,
    }

    cm = confusion_matrix(all_labels, all_preds)
    if cm.shape == (2, 2):
        tn, fp, fn, tp = cm.ravel()
        metrics['specificity'] = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        metrics['sensitivity'] = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    else:
        metrics['specificity'] = 0.0
        metrics['sensitivity'] = metrics['recall']

    return metrics


def calculate_sklearn_metrics(y_true, y_pred):
    metrics = {
        'accuracy': accuracy_score(y_true, y_pred),
        'precision': precision_score(y_true, y_pred, average='binary', zero_division=0),
        'recall': recall_score(y_true, y_pred, average='binary', zero_division=0),
        'f1_score': f1_score(y_true, y_pred, average='binary', zero_division=0),
        'auc_roc': 0.0,
    }

    cm = confusion_matrix(y_true, y_pred)
    if cm.shape == (2, 2):
        tn, fp, fn, tp = cm.ravel()
        metrics['specificity'] = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        metrics['sensitivity'] = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    else:
        metrics['specificity'] = 0.0
        metrics['sensitivity'] = metrics['recall']

    return metrics


print("Comprehensive metrics functions defined!")

Comprehensive metrics functions defined!


## 7. Model Architectures

### 7.1 FBCSP

In [10]:
class FBCSP:
    def __init__(self, freq_bands, n_components=4, sfreq=128.0):
        self.freq_bands = freq_bands
        self.n_components = n_components
        self.sfreq = sfreq
        self.csps = []
        self.lda = LinearDiscriminantAnalysis()
    
    def _bandpass_filter(self, data, low_freq, high_freq):
        nyquist = self.sfreq / 2.0
        low = low_freq / nyquist
        high = high_freq / nyquist
        b, a = butter(5, [low, high], btype='band')
        return filtfilt(b, a, data, axis=-1)
    
    def fit(self, X, y):
        self.csps = []
        all_features = []
        
        for low_freq, high_freq in self.freq_bands:
            X_filtered = self._bandpass_filter(X.copy(), low_freq, high_freq)
            
            csp = CSP(n_components=self.n_components, reg='ledoit_wolf', log=True, norm_trace=False)
            csp.fit(X_filtered, y)
            self.csps.append(csp)
            
            features = csp.transform(X_filtered)
            all_features.append(features)
        
        all_features = np.concatenate(all_features, axis=1)
        self.lda.fit(all_features, y)
        return self
    
    def predict(self, X):
        all_features = []
        
        for idx, (low_freq, high_freq) in enumerate(self.freq_bands):
            X_filtered = self._bandpass_filter(X.copy(), low_freq, high_freq)
            features = self.csps[idx].transform(X_filtered)
            all_features.append(features)
        
        all_features = np.concatenate(all_features, axis=1)
        return self.lda.predict(all_features)

print("FBCSP defined!")

FBCSP defined!


In [11]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    all_preds, all_labels = [], []
    
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)

        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        all_preds += torch.argmax(logits, 1).cpu().tolist()
        all_labels += y.cpu().tolist()
    
    return total_loss / max(1, len(dataloader)), accuracy_score(all_labels, all_preds)


@torch.no_grad()
def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds, all_labels = [], []
    
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = criterion(logits, y)
        
        total_loss += loss.item()
        all_preds += torch.argmax(logits, 1).cpu().tolist()
        all_labels += y.cpu().tolist()
    
    return total_loss / max(1, len(dataloader)), accuracy_score(all_labels, all_preds)


def train_model(model, train_loader, val_loader, device, epochs, lr, patience, verbose=True):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3, verbose=False
    )
    
    best_acc = 0.0
    best_state = None
    no_improve = 0
    
    epoch_iterator = tqdm(range(epochs), desc='    Epochs', leave=False) if verbose else range(epochs)
    
    for epoch in epoch_iterator:
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)
        
        scheduler.step(val_loss)
        
        if verbose:
            epoch_iterator.set_postfix({
                'train_loss': f'{train_loss:.4f}',
                'train_acc': f'{train_acc:.4f}',
                'val_loss': f'{val_loss:.4f}',
                'val_acc': f'{val_acc:.4f}',
                'best': f'{best_acc:.4f}'
            })
        
        if val_acc > best_acc:
            best_acc = val_acc
            best_state = deepcopy(model.state_dict())
            no_improve = 0
        else:
            no_improve += 1
        
        if no_improve >= patience:
            if verbose:
                print(f'      Early stopping at epoch {epoch+1}/{epochs}')
            break
    
    if best_state is None:
        best_state = deepcopy(model.state_dict())
    
    model.load_state_dict(best_state)
    return best_state, best_acc


print("Training functions defined!")

Training functions defined!


### 7.2 CNN-SAE

In [12]:
class SpatialAttention(nn.Module):
    def __init__(self, n_channels):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(n_channels, n_channels // 4),
            nn.ReLU(),
            nn.Linear(n_channels // 4, n_channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        pooled = torch.mean(x, dim=2)
        weights = self.attention(pooled)
        return x * weights.unsqueeze(2)


class CNNSAE(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=769):
        super().__init__()
        self.spatial_attention = SpatialAttention(n_channels)
        self.conv1 = nn.Conv1d(n_channels, 64, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(64)
        self.pool1 = nn.MaxPool1d(2)
        self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(128)
        self.pool2 = nn.MaxPool1d(2)
        self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(256)
        self.pool3 = nn.MaxPool1d(2)
        self.dropout = nn.Dropout(0.5)

        with torch.no_grad():
            test_input = torch.zeros(1, n_channels, n_timepoints)
            test_output = self._forward_features(test_input)
            flattened_size = test_output.view(1, -1).size(1)

        self.fc1 = nn.Linear(flattened_size, 256)
        self.fc2 = nn.Linear(256, n_classes)

    def _forward_features(self, x):
        x = self.spatial_attention(x)
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        return x

    def forward(self, x):
        if x.dim() == 4:
            x = x.squeeze(1)
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

print("CNN-SAE defined!")

CNN-SAE defined!


### 7.3 EEGNet

In [13]:
class EEGNet(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=769, F1=8, D=2, F2=16):
        super().__init__()
        self.conv1 = nn.Conv2d(1, F1, (1, 64), padding=(0, 32), bias=False)
        self.bn1 = nn.BatchNorm2d(F1)
        self.conv2 = nn.Conv2d(F1, F1*D, (n_channels, 1), groups=F1, bias=False)
        self.bn2 = nn.BatchNorm2d(F1*D)
        self.pool1 = nn.AvgPool2d((1, 4))
        self.dropout1 = nn.Dropout(0.5)
        self.conv3 = nn.Conv2d(F1*D, F2, (1, 16), padding=(0, 8), bias=False)
        self.bn3 = nn.BatchNorm2d(F2)
        self.pool2 = nn.AvgPool2d((1, 8))
        self.dropout2 = nn.Dropout(0.5)

        with torch.no_grad():
            test_input = torch.zeros(1, 1, n_channels, n_timepoints)
            test_output = self._forward_features(test_input)
            flattened_size = test_output.view(1, -1).size(1)

        self.fc = nn.Linear(flattened_size, n_classes)

    def _forward_features(self, x):
        x = self.bn1(self.conv1(x))
        x = self.dropout1(self.pool1(F.elu(self.bn2(self.conv2(x)))))
        x = self.dropout2(self.pool2(F.elu(self.bn3(self.conv3(x)))))
        return x

    def forward(self, x):
        if x.dim() == 3:
            x = x.unsqueeze(1)
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

print("EEGNet defined!")

EEGNet defined!


### 7.4 ACS-SE-CNN

In [14]:
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=4):
        super().__init__()
        self.fc1 = nn.Linear(channels, max(1, channels // reduction))
        self.fc2 = nn.Linear(max(1, channels // reduction), channels)

    def forward(self, x):
        squeeze = torch.mean(x, dim=2)
        excitation = F.relu(self.fc1(squeeze))
        excitation = torch.sigmoid(self.fc2(excitation))
        return x * excitation.unsqueeze(2)


class ACSECNN(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=769):
        super().__init__()
        self.channel_attention = nn.Sequential(
            nn.Linear(n_timepoints, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        self.se1 = SEBlock(n_channels)
        self.se2 = SEBlock(128)
        self.se3 = SEBlock(256)
        self.conv1 = nn.Conv1d(n_channels, 128, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(128)
        self.pool1 = nn.MaxPool1d(2)
        self.conv2 = nn.Conv1d(128, 256, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(256)
        self.pool2 = nn.MaxPool1d(2)
        self.conv3 = nn.Conv1d(256, 512, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(512)
        self.pool3 = nn.MaxPool1d(2)
        self.dropout = nn.Dropout(0.5)

        with torch.no_grad():
            test_input = torch.zeros(1, n_channels, n_timepoints)
            test_output = self._forward_features(test_input)
            flattened_size = test_output.view(1, -1).size(1)

        self.fc1 = nn.Linear(flattened_size, 256)
        self.fc2 = nn.Linear(256, n_classes)

    def _forward_features(self, x):
        channel_weights = []
        for i in range(x.size(1)):
            w = self.channel_attention(x[:, i, :])
            channel_weights.append(w)
        channel_weights = torch.cat(channel_weights, dim=1)
        x = x * channel_weights.unsqueeze(2)
        x = self.se1(x)
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.se2(x)
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.se3(x)
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        return x

    def forward(self, x):
        if x.dim() == 4:
            x = x.squeeze(1)
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

print("ACS-SE-CNN defined!")

ACS-SE-CNN defined!


### 7.5 G-CARM

In [15]:
class CARMBlock(nn.Module):
    def __init__(self, n_channels):
        super().__init__()
        self.A = nn.Parameter(torch.randn(n_channels, n_channels) * 0.01)

    def forward(self, x):
        A_norm = torch.softmax(self.A, dim=1)
        x_reshaped = x.permute(0, 2, 1)
        x_graph = torch.matmul(x_reshaped, A_norm.t())
        return x_graph.permute(0, 2, 1)
    
    def get_adjacency(self):
        with torch.no_grad():
            return torch.sigmoid(self.A).cpu().numpy()


class GCARM(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=769):
        super().__init__()
        self.carm1 = CARMBlock(n_channels)
        self.carm2 = CARMBlock(n_channels)
        self.conv1 = nn.Conv1d(n_channels, 128, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(128)
        self.pool1 = nn.MaxPool1d(2)
        self.conv2 = nn.Conv1d(128, 256, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(256)
        self.pool2 = nn.MaxPool1d(2)
        self.conv3 = nn.Conv1d(256, 512, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(512)
        self.pool3 = nn.MaxPool1d(2)
        self.dropout = nn.Dropout(0.5)

        with torch.no_grad():
            test_input = torch.zeros(1, n_channels, n_timepoints)
            test_output = self._forward_features(test_input)
            flattened_size = test_output.view(1, -1).size(1)

        self.fc1 = nn.Linear(flattened_size, 256)
        self.fc2 = nn.Linear(256, n_classes)

    def _forward_features(self, x):
        x = self.carm1(x)
        x = self.carm2(x)
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        return x

    def forward(self, x):
        if x.dim() == 4:
            x = x.squeeze(1)
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)
    
    def get_final_adjacency(self):
        return self.carm2.get_adjacency()

print("G-CARM defined!")

G-CARM defined!


## 8. Training Functions

In [16]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    all_preds, all_labels = [], []
    
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)

        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        all_preds += torch.argmax(logits, 1).cpu().tolist()
        all_labels += y.cpu().tolist()
    
    return total_loss / max(1, len(dataloader)), accuracy_score(all_labels, all_preds)


@torch.no_grad()
def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds, all_labels = [], []
    
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = criterion(logits, y)
        
        total_loss += loss.item()
        all_preds += torch.argmax(logits, 1).cpu().tolist()
        all_labels += y.cpu().tolist()
    
    return total_loss / max(1, len(dataloader)), accuracy_score(all_labels, all_preds)


def train_model(model, train_loader, val_loader, device, epochs, lr, patience):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3, verbose=False
    )
    
    best_acc = 0.0
    best_state = None
    no_improve = 0
    
    for epoch in range(epochs):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)
        
        scheduler.step(val_loss)
        
        if val_acc > best_acc:
            best_acc = val_acc
            best_state = deepcopy(model.state_dict())
            no_improve = 0
        else:
            no_improve += 1
        
        if no_improve >= patience:
            break
    
    if best_state is None:
        best_state = deepcopy(model.state_dict())
    
    model.load_state_dict(best_state)
    return best_state, best_acc


print("Training functions defined!")

Training functions defined!


## 9. Main Training Loop

In [17]:
all_results = {
    'fbcsp': [],
    'cnn_sae': [],
    'eegnet': [],
    'acs_se_cnn': [],
    'g_carm': []
}

print("\nStarting training for legacy methods...\n")

for subject_id in tqdm(subjects, desc='Training subjects'):
    print(f"\nProcessing {subject_id}...")
    
    X, Y, channel_names = load_subject_data(
        data_dir,
        subject_id,
        ALL_TASK_RUNS,
        CONFIG
    )
    
    if X is None or len(Y) == 0:
        print(f"  Skipped: No data available")
        continue
    
    C, T = X.shape[1], X.shape[2]
    K = len(set(CONFIG['data']['selected_classes']))
    
    print(f"  Data shape: {X.shape}")
    print(f"  Label distribution: {np.bincount(Y)}")
    
    for model_name in ['fbcsp', 'cnn_sae', 'eegnet', 'acs_se_cnn', 'g_carm']:
        print(f"\n  Training {model_name.upper()}...")
        
        skf = StratifiedKFold(n_splits=CONFIG['model']['n_folds'], shuffle=True, random_state=42)
        fold_results = []
        fold_models = []
        fold_csp_patterns = []
        fold_adjacencies = []
        
        for fold, (train_idx, val_idx) in enumerate(skf.split(X, Y)):
            X_train, X_val = X[train_idx], X[val_idx]
            Y_train, Y_val = Y[train_idx], Y[val_idx]
            
            if model_name == 'fbcsp':
                model = FBCSP(
                    freq_bands=CONFIG['fbcsp']['freq_bands'],
                    n_components=CONFIG['fbcsp']['n_components'],
                    sfreq=CONFIG['preprocessing']['target_sfreq']
                )
                model.fit(X_train, Y_train)
                y_pred = model.predict(X_val)
                metrics = calculate_sklearn_metrics(Y_val, y_pred)
                fold_results.append(metrics)
                
                csp_patterns = []
                for csp in model.csps:
                    patterns = csp.patterns_
                    csp_patterns.append(np.abs(patterns).mean(axis=1))
                avg_csp_importance = np.mean(np.stack(csp_patterns, 0), 0)
                fold_csp_patterns.append(avg_csp_importance)
                fold_models.append(model)
                
            else:
                X_train_norm = normalize(X_train)
                X_val_norm = normalize(X_val)
                
                train_loader = DataLoader(
                    EEGDataset(X_train_norm, Y_train),
                    batch_size=CONFIG['model']['batch_size'],
                    shuffle=True,
                    num_workers=0
                )
                val_loader = DataLoader(
                    EEGDataset(X_val_norm, Y_val),
                    batch_size=CONFIG['model']['batch_size'],
                    shuffle=False,
                    num_workers=0
                )
                
                if model_name == 'cnn_sae':
                    model = CNNSAE(n_channels=C, n_classes=K, n_timepoints=T).to(device)
                elif model_name == 'eegnet':
                    model = EEGNet(n_channels=C, n_classes=K, n_timepoints=T).to(device)
                elif model_name == 'acs_se_cnn':
                    model = ACSECNN(n_channels=C, n_classes=K, n_timepoints=T).to(device)
                elif model_name == 'g_carm':
                    model = GCARM(n_channels=C, n_classes=K, n_timepoints=T).to(device)
                
                best_state, best_acc = train_model(
                    model, train_loader, val_loader, device,
                    CONFIG['model']['epochs'],
                    CONFIG['model']['learning_rate'],
                    CONFIG['model']['patience']
                )
                model.load_state_dict(best_state)
                
                metrics = calculate_comprehensive_metrics(model, val_loader, device)
                fold_results.append(metrics)
                
                if model_name == 'g_carm':
                    adjacency = model.get_final_adjacency()
                    fold_adjacencies.append(adjacency)
                
                fold_models.append(deepcopy(model).cpu())
                
                del model
                torch.cuda.empty_cache()
                gc.collect()
        
        avg_metrics = {}
        for key in ['accuracy', 'precision', 'recall', 'f1_score', 'auc_roc', 'specificity']:
            values = [f[key] for f in fold_results]
            avg_metrics[f'avg_{key}'] = float(np.mean(values))
            avg_metrics[f'std_{key}'] = float(np.std(values))
        
        result = {
            'subject': subject_id,
            'num_trials': X.shape[0],
            'num_channels': C,
            **avg_metrics,
            'channel_names': channel_names,
            'fold_models': fold_models
        }
        
        if model_name == 'fbcsp':
            result['csp_importance'] = np.mean(np.stack(fold_csp_patterns, 0), 0)
        elif model_name == 'g_carm':
            result['adjacency_matrix'] = np.mean(np.stack(fold_adjacencies, 0), 0)
        
        all_results[model_name].append(result)
        
        print(f"    Accuracy: {avg_metrics['avg_accuracy']:.4f} ± {avg_metrics['std_accuracy']:.4f}")
        print(f"    F1-Score: {avg_metrics['avg_f1_score']:.4f} ± {avg_metrics['std_f1_score']:.4f}")

print("\n" + "="*80)
print("Training Complete!")
print("="*80)


Starting training for legacy methods...



Training subjects:   0%|          | 0/10 [00:00<?, ?it/s]


Processing S001...
  Data shape: (252, 64, 769)
  Label distribution: [168  84]

  Training FBCSP...
    Accuracy: 0.6151 ± 0.0148
    F1-Score: 0.3294 ± 0.0486

  Training CNN_SAE...
    Accuracy: 0.9286 ± 0.0097
    F1-Score: 0.8886 ± 0.0168

  Training EEGNET...
    Accuracy: 0.9365 ± 0.0056
    F1-Score: 0.9025 ± 0.0070

  Training ACS_SE_CNN...
    Accuracy: 0.8651 ± 0.0312
    F1-Score: 0.7859 ± 0.0652

  Training G_CARM...
    Accuracy: 0.8095 ± 0.0168
    F1-Score: 0.6883 ± 0.0473

Processing S002...
  Data shape: (252, 64, 769)
  Label distribution: [168  84]

  Training FBCSP...
    Accuracy: 0.6786 ± 0.0097
    F1-Score: 0.4209 ± 0.0454

  Training CNN_SAE...
    Accuracy: 0.7103 ± 0.0297
    F1-Score: 0.3236 ± 0.1868

  Training EEGNET...
    Accuracy: 0.7897 ± 0.0245
    F1-Score: 0.6519 ± 0.0489

  Training ACS_SE_CNN...
    Accuracy: 0.6984 ± 0.0112
    F1-Score: 0.4566 ± 0.1556

  Training G_CARM...
    Accuracy: 0.6746 ± 0.0245
    F1-Score: 0.2197 ± 0.1150

Processin

## 10. Channel Selection Functions

In [20]:
def compute_gradient_attribution(model, X, Y, device, use_true_target=True, max_samples=None):
    """
    Compute per-channel gradient-based attribution scores averaged over samples.
    Assumes X shape (n_samples, n_channels, n_time).
    Tries both input shapes (batch, C, T) and (batch, 1, C, T) automatically.
    """
    model.eval()
    model.to(device)

    X = np.asarray(X)
    Y = np.asarray(Y)
    n_total = X.shape[0]
    n_used = n_total if max_samples is None else min(n_total, int(max_samples))

    grad_scores = []

    for i in range(n_used):
        x_np = X[i:i+1]         # shape (1, C, T)
        target = int(Y[i]) if (use_true_target and i < len(Y)) else int(Y[0])

        # Prepare tensor candidate 1: (1, C, T)
        x1 = torch.from_numpy(x_np).float().to(device)    # (1, C, T)
        # Candidate 2: (1, 1, C, T)
        x2 = x1.unsqueeze(1)

        # We'll try forward with candidate shapes to detect which one works
        chosen = None
        out = None
        for x_tensor in (x1, x2):
            x_tensor = x_tensor.clone().detach().requires_grad_(True)
            try:
                with torch.enable_grad():
                    model.zero_grad()
                    out_try = model(x_tensor)
                # if forward succeeded, choose this tensor
                chosen = x_tensor
                out = out_try
                break
            except Exception:
                # not the right input shape / model choked; try next
                try:
                    # cleanup before trying next
                    if x_tensor.grad is not None:
                        x_tensor.grad.zero_()
                    del x_tensor, out_try
                    torch.cuda.empty_cache()
                except Exception:
                    pass
                continue

        if chosen is None or out is None:
            # If both shapes fail, skip this sample but warn minimally
            # (avoid printing in batch runs; user can debug by running a single forward)
            continue

        # Now compute gradient w.r.t chosen input and target
        with torch.enable_grad():
            try:
                # pick logit for target class
                if out.dim() == 0:
                    score = out
                elif out.dim() == 1:
                    # shape (n_classes,) or (batch,)
                    score = out.squeeze()
                else:
                    score = out[0, target]
                # zero grads and backward
                model.zero_grad()
                if chosen.grad is not None:
                    chosen.grad.zero_()
                score.backward(retain_graph=False)
            except Exception:
                # If backward failed, skip this sample
                try:
                    del chosen, out, score
                    torch.cuda.empty_cache()
                except Exception:
                    pass
                continue

            # extract gradient and reduce to per-channel importance
            g = chosen.grad.detach().cpu()   # e.g., (1,C,T) or (1,1,C,T)
            g = g.squeeze(0)
            # If shape is (1,C,T) -> squeeze extra dim
            if g.dim() == 3 and g.shape[0] == 1:
                g = g.squeeze(0)   # (C,T)
            # If shape now is (C,T) or (C,T,...) pick last dim as time
            if g.dim() >= 2:
                channel_imp = g.abs().mean(dim=-1).cpu().numpy().flatten()
            else:
                # fallback: mean over remaining dims
                channel_imp = g.abs().mean().cpu().numpy().flatten()

            grad_scores.append(channel_imp)

        # cleanup per-sample tensors
        try:
            del chosen, out, score, g, x1, x2
            torch.cuda.empty_cache()
        except Exception:
            pass

    if len(grad_scores) == 0:
        return np.zeros(X.shape[1], dtype=float)

    return np.mean(np.stack(grad_scores, axis=0), axis=0)


class ChannelSelectorLegacy:
    def __init__(self, channel_names, model_name, **kwargs):
        self.names = np.array(channel_names)
        self.C = len(channel_names)
        self.model_name = model_name
        
        if model_name == 'fbcsp':
            self.csp_importance = kwargs.get('csp_importance')
        elif model_name == 'g_carm':
            self.adjacency = kwargs.get('adjacency')
        else:
            self.gradient_scores = kwargs.get('gradient_scores')
    
    def select_channels(self, k, method='default'):
        if self.model_name == 'fbcsp':
            importance = self.csp_importance
        elif self.model_name == 'g_carm':
            if method == 'ES':
                importance = self._edge_selection_scores()
            else:
                importance = self._aggregation_selection_scores()
        else:
            importance = self.gradient_scores

        if importance is None:
            # fallback: select first-k channels if no importance provided
            indices = np.arange(min(int(k), self.C))
            return self.names[indices].tolist(), indices
        
        indices = np.sort(np.argsort(importance)[-int(k):])
        return self.names[indices].tolist(), indices
    
    def _edge_selection_scores(self):
        # compute delta_ij = |f_ij| + |f_ji| for i<j, then accumulate to nodes (no double counting)
        edge_importance = np.zeros(self.C)
        if self.adjacency is None:
            return edge_importance
        for i in range(self.C):
            for j in range(i+1, self.C):
                val = abs(self.adjacency[i, j]) + abs(self.adjacency[j, i])
                edge_importance[i] += val
                edge_importance[j] += val
        return edge_importance
    
    def _aggregation_selection_scores(self):
        if self.adjacency is None:
            return np.zeros(self.C)
        return np.sum(np.abs(self.adjacency), axis=1)


def retrain_legacy_model(X, Y, selected_indices, model_name, config, device):
    X_selected = X[:, selected_indices, :]
    C, T = X_selected.shape[1], X_selected.shape[2]
    K = len(set(config['data']['selected_classes']))
    
    skf = StratifiedKFold(n_splits=config['model']['n_folds'], shuffle=True, random_state=42)
    fold_results = []
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(X_selected, Y)):
        X_train, X_val = X_selected[train_idx], X_selected[val_idx]
        Y_train, Y_val = Y[train_idx], Y[val_idx]
        
        if model_name == 'fbcsp':
            model = FBCSP(
                freq_bands=config['fbcsp']['freq_bands'],
                n_components=config['fbcsp']['n_components'],
                sfreq=config['preprocessing']['target_sfreq']
            )
            model.fit(X_train, Y_train)
            y_pred = model.predict(X_val)
            metrics = calculate_sklearn_metrics(Y_val, y_pred)
            fold_results.append(metrics)
        else:
            X_train_norm = normalize(X_train)
            X_val_norm = normalize(X_val)
            
            train_loader = DataLoader(
                EEGDataset(X_train_norm, Y_train),
                batch_size=config['model']['batch_size'],
                shuffle=True,
                num_workers=0
            )
            val_loader = DataLoader(
                EEGDataset(X_val_norm, Y_val),
                batch_size=config['model']['batch_size'],
                shuffle=False,
                num_workers=0
            )
            
            if model_name == 'cnn_sae':
                model = CNNSAE(n_channels=C, n_classes=K, n_timepoints=T).to(device)
            elif model_name == 'eegnet':
                model = EEGNet(n_channels=C, n_classes=K, n_timepoints=T).to(device)
            elif model_name == 'acs_se_cnn':
                model = ACSECNN(n_channels=C, n_classes=K, n_timepoints=T).to(device)
            elif model_name == 'g_carm':
                model = GCARM(n_channels=C, n_classes=K, n_timepoints=T).to(device)
            else:
                # fallback generic model if none matched (attempt to use any available class name)
                try:
                    model = globals().get(model_name)
                    if isinstance(model, type):
                        model = model(n_channels=C, n_classes=K, n_timepoints=T).to(device)
                    else:
                        # create a trivial conv if model_name isn't found
                        model = CNNSAE(n_channels=C, n_classes=K, n_timepoints=T).to(device)
                except Exception:
                    model = CNNSAE(n_channels=C, n_classes=K, n_timepoints=T).to(device)
            
            # Robust call to train_model: try with verbose kw, else without
            try:
                result = train_model(
                    model, train_loader, val_loader, device,
                    config['model']['epochs'],
                    config['model']['learning_rate'],
                    config['model']['patience'],
                    verbose=False
                )
            except TypeError:
                result = train_model(
                    model, train_loader, val_loader, device,
                    config['model']['epochs'],
                    config['model']['learning_rate'],
                    config['model']['patience']
                )

            # Normalize possible return formats to best_state and best_acc
            best_state = None
            best_acc = None
            if isinstance(result, (tuple, list)):
                if len(result) >= 2:
                    best_state, best_acc = result[0], result[1]
                elif len(result) == 1:
                    candidate = result[0]
                    if hasattr(candidate, 'state_dict'):
                        best_state = candidate.state_dict()
            elif isinstance(result, dict):
                best_state = result.get('best_state') or result.get('state_dict') or result.get('model_state') or result.get('state')
                best_acc = result.get('best_acc') or result.get('best_val_acc') or result.get('val_acc') or result.get('accuracy')
            elif hasattr(result, 'state_dict'):
                best_state = result.state_dict()

            if best_state is None:
                # try to see if result[0] is an nn.Module in tuple/list
                if isinstance(result, (tuple, list)) and len(result) > 0 and hasattr(result[0], 'state_dict'):
                    best_state = result[0].state_dict()

            if best_state is None:
                raise RuntimeError("Could not extract model state (best_state) from train_model result. Ensure train_model returns (state_dict, best_acc) or a dict containing 'best_state'/'state_dict'.")

            # load the best state into the model
            try:
                model.load_state_dict(best_state)
            except Exception:
                # if best_state is a model instance, try to use it directly
                if hasattr(best_state, 'state_dict'):
                    model.load_state_dict(best_state.state_dict())
                else:
                    raise

            # If best_acc not provided, compute it from validation loader
            if best_acc is None:
                try:
                    metrics_after = calculate_comprehensive_metrics(model, val_loader, device)
                    # try common metric keys
                    best_acc = metrics_after.get('accuracy') or metrics_after.get('avg_accuracy') or metrics_after.get('acc') or None
                    if isinstance(best_acc, (np.ndarray, list)):
                        best_acc = float(np.mean(best_acc))
                    if hasattr(best_acc, 'item'):
                        best_acc = float(best_acc.item())
                except Exception:
                    best_acc = None

            metrics = calculate_comprehensive_metrics(model, val_loader, device)
            fold_results.append(metrics)
            
            # cleanup
            try:
                del model
                torch.cuda.empty_cache()
                gc.collect()
            except Exception:
                pass
    
    avg_metrics = {}
    for key in ['accuracy', 'precision', 'recall', 'f1_score', 'auc_roc', 'specificity']:
        values = [f[key] for f in fold_results]
        avg_metrics[f'avg_{key}'] = float(np.mean(values))
        avg_metrics[f'std_{key}'] = float(np.std(values))
    
    return avg_metrics


print("Channel selection functions defined!")


Channel selection functions defined!


In [21]:
retrain_results = {
    'fbcsp': [],
    'cnn_sae': [],
    'eegnet': [],
    'acs_se_cnn': [],
    'g_carm': []
}

print("\n" + "="*80)
print("CHANNEL SELECTION AND RETRAINING")
print("="*80 + "\n")

for subject_id in tqdm(subjects, desc='Retraining with channel selection'):
    print(f"\nProcessing {subject_id}...")
    
    X, Y, channel_names = load_subject_data(
        data_dir,
        subject_id,
        ALL_TASK_RUNS,
        CONFIG
    )
    
    if X is None:
        continue
    
    for model_name in ['fbcsp', 'cnn_sae', 'eegnet', 'acs_se_cnn', 'g_carm']:
        subj_result = None
        for res in all_results[model_name]:
            if res['subject'] == subject_id:
                subj_result = res
                break
        
        if subj_result is None:
            continue
        
        print(f"\n  {model_name.upper()} channel selection...")
        
        if model_name == 'fbcsp':
            selector = ChannelSelectorLegacy(
                channel_names, model_name,
                csp_importance=subj_result['csp_importance']
            )
            methods = ['default']
            
        elif model_name == 'g_carm':
            selector = ChannelSelectorLegacy(
                channel_names, model_name,
                adjacency=subj_result['adjacency_matrix']
            )
            methods = ['ES', 'AS']
            
        else:
            X_norm = normalize(X)
            fold_models = subj_result['fold_models']
            
            gradient_scores_list = []
            for model in fold_models:
                model_gpu = model.to(device)
                grad_scores = compute_gradient_attribution(model_gpu, X_norm[:10], Y[:10], device)
                gradient_scores_list.append(grad_scores)
                del model_gpu
                torch.cuda.empty_cache()
            
            avg_gradient_scores = np.mean(np.stack(gradient_scores_list, 0), 0)
            
            selector = ChannelSelectorLegacy(
                channel_names, model_name,
                gradient_scores=avg_gradient_scores
            )
            methods = ['default']
        
        for method in methods:
            for k in CONFIG['channel_selection']['k_values']:
                method_label = f'{model_name.upper()}-{method}' if method != 'default' else model_name.upper()
                
                selected_channels, selected_indices = selector.select_channels(k, method)
                
                retrain_metrics = retrain_legacy_model(
                    X, Y, selected_indices, model_name, CONFIG, device
                )
                
                acc_drop = subj_result['avg_accuracy'] - retrain_metrics['avg_accuracy']
                
                retrain_results[model_name].append({
                    'subject': subject_id,
                    'method': method if method != 'default' else model_name.upper(),
                    'k': k,
                    'num_channels_selected': len(selected_channels),
                    **retrain_metrics,
                    'full_channels_acc': subj_result['avg_accuracy'],
                    'accuracy_drop': acc_drop,
                    'accuracy_drop_pct': (acc_drop / subj_result['avg_accuracy'] * 100) if subj_result['avg_accuracy'] > 0 else 0.0
                })
                
                print(f"    {method_label}, k={k}: {retrain_metrics['avg_accuracy']:.4f} (drop: {acc_drop:.4f})")

print("\n" + "="*80)
print("Channel Selection Complete!")
print("="*80)


CHANNEL SELECTION AND RETRAINING



Retraining with channel selection:   0%|          | 0/10 [00:00<?, ?it/s]


Processing S001...

  FBCSP channel selection...
    FBCSP, k=10: 0.6667 (drop: -0.0516)
    FBCSP, k=15: 0.6667 (drop: -0.0516)
    FBCSP, k=20: 0.6746 (drop: -0.0595)
    FBCSP, k=25: 0.6389 (drop: -0.0238)
    FBCSP, k=30: 0.6746 (drop: -0.0595)

  CNN_SAE channel selection...
    CNN_SAE, k=10: 0.9405 (drop: -0.0119)
    CNN_SAE, k=15: 0.9286 (drop: 0.0000)
    CNN_SAE, k=20: 0.9206 (drop: 0.0079)
    CNN_SAE, k=25: 0.9365 (drop: -0.0079)
    CNN_SAE, k=30: 0.9008 (drop: 0.0278)

  EEGNET channel selection...
    EEGNET, k=10: 0.9484 (drop: -0.0119)
    EEGNET, k=15: 0.9603 (drop: -0.0238)
    EEGNET, k=20: 0.9365 (drop: 0.0000)
    EEGNET, k=25: 0.9444 (drop: -0.0079)
    EEGNET, k=30: 0.9524 (drop: -0.0159)

  ACS_SE_CNN channel selection...
    ACS_SE_CNN, k=10: 0.9167 (drop: -0.0516)
    ACS_SE_CNN, k=15: 0.9008 (drop: -0.0357)
    ACS_SE_CNN, k=20: 0.9008 (drop: -0.0357)
    ACS_SE_CNN, k=25: 0.9087 (drop: -0.0437)
    ACS_SE_CNN, k=30: 0.8929 (drop: -0.0278)

  G_CARM channe

## 10. Save Results

In [22]:
results_dir = CONFIG['output']['results_dir']

for model_name in ['fbcsp', 'cnn_sae', 'eegnet', 'acs_se_cnn', 'g_carm']:
    if len(all_results[model_name]) > 0:
        df = pd.DataFrame(all_results[model_name])
        df.to_csv(results_dir / f'legacy_{model_name}_results.csv', index=False)
        print(f"Saved: legacy_{model_name}_results.csv")

print(f"\nAll results saved to {results_dir}")

Saved: legacy_fbcsp_results.csv
Saved: legacy_cnn_sae_results.csv
Saved: legacy_eegnet_results.csv
Saved: legacy_acs_se_cnn_results.csv
Saved: legacy_g_carm_results.csv

All results saved to results


## 11. Results Summary

In [23]:
print("\n" + "="*80)
print("RESULTS SUMMARY")
print("="*80 + "\n")

for model_name in ['fbcsp', 'cnn_sae', 'eegnet', 'acs_se_cnn', 'g_carm']:
    if len(all_results[model_name]) > 0:
        accs = [r['avg_accuracy'] for r in all_results[model_name]]
        f1s = [r['avg_f1_score'] for r in all_results[model_name]]
        aucs = [r['avg_auc_roc'] for r in all_results[model_name]]
        
        print(f"{model_name.upper()} Results:")
        print(f"  Subjects: {len(all_results[model_name])}")
        print(f"  Mean accuracy: {np.mean(accs):.4f} ± {np.std(accs):.4f}")
        print(f"  Mean F1-Score: {np.mean(f1s):.4f} ± {np.std(f1s):.4f}")
        print(f"  Mean AUC-ROC: {np.mean(aucs):.4f} ± {np.std(aucs):.4f}")
        print()

print("\n" + "="*80)
print("DONE!")
print("="*80)


RESULTS SUMMARY

FBCSP Results:
  Subjects: 10
  Mean accuracy: 0.6317 ± 0.0747
  Mean F1-Score: 0.3568 ± 0.1249
  Mean AUC-ROC: 0.0000 ± 0.0000

CNN_SAE Results:
  Subjects: 10
  Mean accuracy: 0.8425 ± 0.0667
  Mean F1-Score: 0.7174 ± 0.1606
  Mean AUC-ROC: 0.8604 ± 0.0939

EEGNET Results:
  Subjects: 10
  Mean accuracy: 0.9024 ± 0.0561
  Mean F1-Score: 0.8449 ± 0.0949
  Mean AUC-ROC: 0.9238 ± 0.0755

ACS_SE_CNN Results:
  Subjects: 10
  Mean accuracy: 0.8143 ± 0.0644
  Mean F1-Score: 0.6744 ± 0.1470
  Mean AUC-ROC: 0.8201 ± 0.0923

G_CARM Results:
  Subjects: 10
  Mean accuracy: 0.7798 ± 0.0661
  Mean F1-Score: 0.5893 ± 0.1854
  Mean AUC-ROC: 0.7769 ± 0.0929


DONE!
