# BCI 2a EEG-ARNN Training Pipeline

This notebook trains subject-specific EEG-ARNN models on the BCI Competition IV 2a dataset with:
- Subject-specific 3-fold cross-validation (20 epochs per fold)
- Channel selection experiments reused from the PhysioNet pipeline


## Setup and Imports

In [7]:
import sys
from pathlib import Path
import warnings
import json
from datetime import datetime

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import mne

# IMPORTANT: Force reload train_utils to get latest fixes
import importlib
import train_utils
importlib.reload(train_utils)

from models import EEGARNN, ChannelSelector
from train_utils import (
    load_preprocessed_data, filter_classes, normalize_data,
    cross_validate_subject, EEGDataset
)

warnings.filterwarnings('ignore')
mne.set_log_level('ERROR')
sns.set_context('notebook', font_scale=1.1)
plt.style.use('seaborn-v0_8')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Verify the fix is loaded
print("\n" + "="*80)
print("VERIFICATION: Checking if train_utils.py fix is loaded...")
print("="*80)
import inspect
source = inspect.getsource(load_preprocessed_data)
if 'events_from_annotations' in source:
    print("✓ GOOD: train_utils.py has the annotations fix!")
else:
    print("✗ ERROR: train_utils.py is still using old code!")
    print("  → Please restart Jupyter kernel: Kernel → Restart Kernel")
    print("  → Then re-run all cells from the top")
print("="*80)

Using device: cpu

VERIFICATION: Checking if train_utils.py fix is loaded...
✓ GOOD: train_utils.py has the annotations fix!


## Configuration

In [8]:
EXPERIMENT_CONFIG = {
    'data': {
        'raw_dir': Path('data/BCI_2a'),
        'subjects': [f'A0{i}' for i in range(1, 10)],
        'selected_classes': [769, 770, 771, 772],
        'tmin': 0.0,
        'tmax': 4.0,
        'baseline': None
    },
    'model': {
        'hidden_dim': 40,
        'epochs': 20,
        'learning_rate': 0.001,
        'batch_size': 32,
        'n_folds': 3
    },
    'channel_selection': {
        'k_values': [10, 15, 20, 25, 'all'],
        'methods': ['ES', 'AS']
    },
    'output': {
        'results_dir': Path('results/bci_2a'),
        'models_dir': Path('saved_models/bci_2a'),
        'subject_results_file': 'bci2a_baseline_subject_results.csv',
        'channel_selection_results_file': 'bci2a_channel_selection_results.csv',
        'retrain_results_file': 'bci2a_baseline_retrain_results.csv',
        'config_file': 'bci2a_baseline_experiment_config.json',
        'results_summary_figure': 'bci2a_baseline_results_summary.png'
    },
    'max_subjects': None
}

EXPERIMENT_CONFIG['output']['results_dir'].mkdir(parents=True, exist_ok=True)
EXPERIMENT_CONFIG['output']['models_dir'].mkdir(parents=True, exist_ok=True)

print('Experiment Configuration:')
print(json.dumps(EXPERIMENT_CONFIG, indent=2, default=str))


Experiment Configuration:
{
  "data": {
    "raw_dir": "data\\BCI_2a",
    "subjects": [
      "A01",
      "A02",
      "A03",
      "A04",
      "A05",
      "A06",
      "A07",
      "A08",
      "A09"
    ],
    "selected_classes": [
      769,
      770,
      771,
      772
    ],
    "tmin": 0.0,
    "tmax": 4.0,
    "baseline": null
  },
  "model": {
    "hidden_dim": 40,
    "epochs": 20,
    "learning_rate": 0.001,
    "batch_size": 32,
    "n_folds": 3
  },
  "channel_selection": {
    "k_values": [
      10,
      15,
      20,
      25,
      "all"
    ],
    "methods": [
      "ES",
      "AS"
    ]
  },
  "output": {
    "results_dir": "results\\bci_2a",
    "models_dir": "saved_models\\bci_2a",
    "subject_results_file": "bci2a_baseline_subject_results.csv",
    "channel_selection_results_file": "bci2a_channel_selection_results.csv",
    "retrain_results_file": "bci2a_baseline_retrain_results.csv",
    "config_file": "bci2a_baseline_experiment_config.json",
    "result

## Build BCI 2a Session Index

In [9]:
raw_dir = EXPERIMENT_CONFIG['data']['raw_dir']
selected_classes = EXPERIMENT_CONFIG['data']['selected_classes']

records = []
missing_subjects = []

for subject_id in EXPERIMENT_CONFIG['data']['subjects']:
    gdf_path = raw_dir / f"{subject_id}T.gdf"
    if not gdf_path.exists():
        missing_subjects.append(subject_id)
        continue

    try:
        raw = mne.io.read_raw_gdf(gdf_path, preload=False, verbose='ERROR')
        events, event_ids = mne.events_from_annotations(raw, verbose='ERROR')
        selected_event_ids = [event_ids[str(cls)] for cls in selected_classes if str(cls) in event_ids]
        trial_mask = np.isin(events[:, 2], selected_event_ids) if selected_event_ids else np.array([])
        num_trials = int(trial_mask.sum()) if trial_mask.size else 0
    except Exception as exc:
        print(f"[warn] Could not parse {gdf_path.name}: {exc}")
        num_trials = 0

    records.append({
        'subject': subject_id,
        'session': 'T',
        'path': gdf_path,
        'num_trials': num_trials
    })

bci_sessions = pd.DataFrame(records)
motor_runs = bci_sessions[bci_sessions['num_trials'] > 0].copy()

print(f"Total subjects configured: {len(EXPERIMENT_CONFIG['data']['subjects'])}")
print(f"Subjects with labelled training data: {motor_runs['subject'].nunique()}")
print(f"Total labelled trials: {int(motor_runs['num_trials'].sum())}")

if missing_subjects:
    print('Missing training files for subjects:', missing_subjects)


Total subjects configured: 9
Subjects with labelled training data: 9
Total labelled trials: 2592


## Subject Selection

Identify BCI 2a subjects with labelled training (T) sessions and aggregate their available trials.

In [10]:
subject_counts = (motor_runs.groupby('subject')['num_trials']
                  .sum()
                  .reset_index()
                  .sort_values('subject'))

selected_subjects = subject_counts['subject'].tolist()

if not selected_subjects:
    raise RuntimeError('No BCI 2a subjects with labelled trials were found.')

max_subjects = EXPERIMENT_CONFIG.get('max_subjects')
if max_subjects:
    selected_subjects = selected_subjects[:max_subjects]
    subject_counts = subject_counts[subject_counts['subject'].isin(selected_subjects)]

print('Subject trial counts:')
print(subject_counts.to_string(index=False))
print(f"Will train on {len(selected_subjects)} subjects")
print(f"Selected subjects: {selected_subjects}")


Subject trial counts:
subject  num_trials
    A01         288
    A02         288
    A03         288
    A04         288
    A05         288
    A06         288
    A07         288
    A08         288
    A09         288
Will train on 9 subjects
Selected subjects: ['A01', 'A02', 'A03', 'A04', 'A05', 'A06', 'A07', 'A08', 'A09']


## Helper Functions

In [11]:
def load_subject_data(subject_id, subject_sessions_df, config):
    '''
    Load all labelled motor imagery trials for a BCI 2a subject.

    Returns
    -------
    data : np.ndarray or None
        (n_trials, n_channels, n_timepoints)
    labels : np.ndarray or None
        (n_trials,)
    channel_names : list[str] or None
        Channel labels preserved from the recording
    '''
    subject_rows = subject_sessions_df[subject_sessions_df['subject'] == subject_id]

    if subject_rows.empty:
        return None, None, None

    selected_classes = config['data']['selected_classes']

    all_data = []
    all_labels = []
    channel_names = None

    for _, row in subject_rows.iterrows():
        gdf_path = Path(row['path'])
        if not gdf_path.exists():
            print(f"[warn] Missing file: {gdf_path}")
            continue

        try:
            raw = mne.io.read_raw_gdf(gdf_path, preload=True, verbose='ERROR')
            events, event_ids = mne.events_from_annotations(raw, verbose='ERROR')

            selected_event_ids = {str(cls): event_ids[str(cls)] for cls in selected_classes if str(cls) in event_ids}
            if not selected_event_ids:
                print(f"[warn] No target events found in {gdf_path.name}")
                continue

            epochs = mne.Epochs(
                raw,
                events,
                event_id=selected_event_ids,
                tmin=config['data']['tmin'],
                tmax=config['data']['tmax'],
                baseline=config['data']['baseline'],
                preload=True,
                event_repeated='merge',
                picks='eeg',
                verbose='ERROR'
            )

            data = epochs.get_data()
            label_lookup = {event_ids[key]: int(key) for key in selected_event_ids}
            labels = np.array([label_lookup[event_code] for event_code in epochs.events[:, 2]])

            data, labels = filter_classes(data, labels, selected_classes)

            if data.size == 0:
                continue

            all_data.append(data)
            all_labels.append(labels)

            if channel_names is None:
                channel_names = epochs.ch_names

        except Exception as exc:
            print(f"[warn] Failed to load {gdf_path.name}: {exc}")
            continue

    if not all_data:
        return None, None, None

    data = np.concatenate(all_data, axis=0)
    labels = np.concatenate(all_labels, axis=0)

    return data, labels, channel_names


## Main Training Loop

Train subject-specific models with 3-fold cross-validation

In [12]:
all_results = []

for subject_id in tqdm(selected_subjects, desc='Training subjects'):
    print(f"{'='*80}")
    print(f"Training subject: {subject_id}")
    print(f"{'='*80}")

    data, labels, channel_names = load_subject_data(
        subject_id,
        motor_runs,
        EXPERIMENT_CONFIG
    )

    if data is None or len(data) < 30:
        print(f"Skipping {subject_id}: insufficient data")
        continue

    print(f"Data shape: {data.shape}")
    print(f"Labels: {np.unique(labels, return_counts=True)}")
    print(f"Channels: {len(channel_names)}")

    num_channels = data.shape[1]
    num_timepoints = data.shape[2]
    num_classes = len(np.unique(labels))

    cv_results = cross_validate_subject(
        data, labels,
        num_channels=num_channels,
        num_timepoints=num_timepoints,
        num_classes=num_classes,
        device=device,
        n_splits=EXPERIMENT_CONFIG['model']['n_folds'],
        epochs=EXPERIMENT_CONFIG['model']['epochs'],
        lr=EXPERIMENT_CONFIG['model']['learning_rate']
    )

    print(f"Average accuracy (all channels): {cv_results['avg_accuracy']:.4f} +/- {cv_results['std_accuracy']:.4f}")

    result = {
        'subject': subject_id,
        'num_trials': len(data),
        'num_channels': num_channels,
        'num_timepoints': num_timepoints,
        'num_classes': num_classes,
        'all_channels_acc': cv_results['avg_accuracy'],
        'all_channels_std': cv_results['std_accuracy'],
        'adjacency_matrix': cv_results['adjacency_matrix'],
        'channel_names': channel_names
    }

    all_results.append(result)

print(f"{'='*80}")
print(f"Training complete for {len(all_results)} subjects")
print(f"{'='*80}")


Training subjects:   0%|          | 0/9 [00:00<?, ?it/s]

Training subject: A01
Data shape: (288, 25, 1001)
Labels: (array([0, 1, 2, 3]), array([72, 72, 72, 72], dtype=int64))
Channels: 25
  Fold 1/3 

Training subjects:   0%|          | 0/9 [01:06<?, ?it/s]


KeyboardInterrupt: 

## Channel Selection Experiments

Test different k values with Edge Selection and Aggregation Selection

In [None]:
channel_selection_results = []

if len(all_results) > 0:
    for result in tqdm(all_results, desc="Channel selection experiments"):
        subject_id = result['subject']
        adj_matrix = result['adjacency_matrix']
        channel_names = result['channel_names']
        
        print(f"\nProcessing channel selection for {subject_id}")
        
        selector = ChannelSelector(adj_matrix, channel_names)
        
        for method in EXPERIMENT_CONFIG['channel_selection']['methods']:
            print(f"  Method: {method}")
            
            for k in EXPERIMENT_CONFIG['channel_selection']['k_values']:
                if k == 'all':
                    k_val = result['num_channels']
                    selected_channels = channel_names
                else:
                    k_val = min(k, result['num_channels'])  # Don't exceed available channels
                    
                    if method == 'ES':
                        selected_channels, _ = selector.edge_selection(k_val)
                    else:  # AS
                        selected_channels, _ = selector.aggregation_selection(k_val)
                
                print(f"    k={k_val}: {len(selected_channels)} channels selected")
                
                channel_selection_results.append({
                    'subject': subject_id,
                    'method': method,
                    'k': k_val,
                    'num_selected': len(selected_channels),
                    'selected_channels': selected_channels,
                    'accuracy_full': result['all_channels_acc']
                })

    channel_selection_df = pd.DataFrame(channel_selection_results)
    print(f"\nChannel selection results: {len(channel_selection_df)} experiments")
    display(channel_selection_df.head(10))
else:
    channel_selection_df = pd.DataFrame()
    print("\nNo results available for channel selection experiments.")

In [None]:
from train_utils import retrain_with_selected_channels

# Store all retraining results
retrain_results = []

if len(all_results) > 0:
    # We need the original data for each subject
    subject_data_cache = {}

    for result in all_results:
        subject_id = result['subject']
        print(f"Loading data for {subject_id}")

        data, labels, channel_names = load_subject_data(
            subject_id,
            motor_runs,
            EXPERIMENT_CONFIG
        )

        if data is None:
            continue

        subject_data_cache[subject_id] = {
            'data': data,
            'labels': labels,
            'channel_names': channel_names
        }

    print(f"{'='*80}")
    print("RETRAINING WITH SELECTED CHANNELS")
    print(f"{'='*80}")

    for result in tqdm(all_results, desc="Retraining subjects"):
        subject_id = result['subject']

        if subject_id not in subject_data_cache:
            continue

        cache = subject_data_cache[subject_id]
        data = cache['data']
        labels = cache['labels']
        channel_names = cache['channel_names']

        print(f"Retraining {subject_id}")

        selector = ChannelSelector(result['adjacency_matrix'], channel_names)

        for method in EXPERIMENT_CONFIG['channel_selection']['methods']:
            for k in EXPERIMENT_CONFIG['channel_selection']['k_values']:
                if k == 'all':
                    continue

                k_val = min(k, result['num_channels'])

                if method == 'ES':
                    selected_channels, selected_indices = selector.edge_selection(k_val)
                else:
                    selected_channels, selected_indices = selector.aggregation_selection(k_val)

                print(f"  {method} k={k_val}: Retraining with {len(selected_channels)} channels...")

                retrain_res = retrain_with_selected_channels(
                    data, labels,
                    selected_channel_indices=selected_indices,
                    num_timepoints=result['num_timepoints'],
                    num_classes=result['num_classes'],
                    device=device,
                    n_splits=EXPERIMENT_CONFIG['model']['n_folds'],
                    epochs=EXPERIMENT_CONFIG['model']['epochs'],
                    lr=EXPERIMENT_CONFIG['model']['learning_rate']
                )

                acc_drop = result['all_channels_acc'] - retrain_res['avg_accuracy']

                print(f"    Accuracy: {retrain_res['avg_accuracy']:.4f} +/- {retrain_res['std_accuracy']:.4f}")
                print(f"    Drop from full: {acc_drop:.4f} ({acc_drop/result['all_channels_acc']*100:.1f}%)")

                retrain_results.append({
                    'subject': subject_id,
                    'method': method,
                    'k': k_val,
                    'num_channels_selected': len(selected_channels),
                    'selected_channels': selected_channels,
                    'accuracy': retrain_res['avg_accuracy'],
                    'std': retrain_res['std_accuracy'],
                    'full_channels_acc': result['all_channels_acc'],
                    'accuracy_drop': acc_drop,
                    'accuracy_drop_pct': acc_drop / result['all_channels_acc'] * 100
                })

    retrain_df = pd.DataFrame(retrain_results)
    print(f"{'='*80}")
    print(f"Retraining complete: {len(retrain_df)} experiments")
    print(f"{'='*80}")

    retrain_path = EXPERIMENT_CONFIG['output']['results_dir'] / EXPERIMENT_CONFIG['output']['retrain_results_file']
    retrain_df.to_csv(retrain_path, index=False)
    print(f"Retrain results saved to: {retrain_path}")
else:
    retrain_df = pd.DataFrame()
    print("No results to retrain. Please run training first.")


## Retrain with Selected Channels

Now retrain the model using ONLY the selected channels and compare accuracy

## Results Summary

In [None]:
results_df = pd.DataFrame(all_results)

print("=" * 80)
print("OVERALL RESULTS SUMMARY")
print("=" * 80)
print(f"\nSubjects trained: {len(results_df)}")

if len(results_df) > 0:
    print(f"Mean accuracy (all channels): {results_df['all_channels_acc'].mean():.4f} ± {results_df['all_channels_acc'].std():.4f}")
    print(f"Best subject: {results_df.loc[results_df['all_channels_acc'].idxmax(), 'subject']} ({results_df['all_channels_acc'].max():.4f})")
    print(f"Worst subject: {results_df.loc[results_df['all_channels_acc'].idxmin(), 'subject']} ({results_df['all_channels_acc'].min():.4f})")

    # Save results
    results_path = EXPERIMENT_CONFIG['output']['results_dir'] / 'subject_results.csv'
    results_df[['subject', 'num_trials', 'num_channels', 'all_channels_acc', 'all_channels_std']].to_csv(results_path, index=False)
    print(f"\nResults saved to: {results_path}")

    display(results_df[['subject', 'num_trials', 'num_channels', 'all_channels_acc', 'all_channels_std']].head(10))
else:
    print("\nNo subjects were successfully trained. Check the data loading and preprocessing steps.")
    results_df

## Visualizations

In [None]:
if len(results_df) > 0:
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))

    # Accuracy distribution
    axes[0, 0].hist(results_df['all_channels_acc'], bins=20, color='steelblue', edgecolor='black', alpha=0.7)
    axes[0, 0].axvline(results_df['all_channels_acc'].mean(), color='red', linestyle='--', linewidth=2, label='Mean')
    axes[0, 0].set_title('Accuracy Distribution (All Channels)', fontsize=14, fontweight='bold')
    axes[0, 0].set_xlabel('Accuracy')
    axes[0, 0].set_ylabel('Frequency')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # Accuracy vs num trials
    axes[0, 1].scatter(results_df['num_trials'], results_df['all_channels_acc'], alpha=0.6, s=100)
    axes[0, 1].set_title('Accuracy vs Number of Trials', fontsize=14, fontweight='bold')
    axes[0, 1].set_xlabel('Number of Trials')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].grid(True, alpha=0.3)

    # Top 10 subjects
    top_10 = results_df.nlargest(min(10, len(results_df)), 'all_channels_acc')
    axes[1, 0].barh(range(len(top_10)), top_10['all_channels_acc'], color='green', alpha=0.7)
    axes[1, 0].set_yticks(range(len(top_10)))
    axes[1, 0].set_yticklabels(top_10['subject'])
    axes[1, 0].set_title(f'Top {len(top_10)} Subjects by Accuracy', fontsize=14, fontweight='bold')
    axes[1, 0].set_xlabel('Accuracy')
    axes[1, 0].invert_yaxis()
    axes[1, 0].grid(True, alpha=0.3, axis='x')

    # Subject ranking
    sorted_results = results_df.sort_values('all_channels_acc')
    axes[1, 1].plot(range(len(sorted_results)), sorted_results['all_channels_acc'], marker='o', markersize=4, alpha=0.6)
    axes[1, 1].set_title('Subject Ranking', fontsize=14, fontweight='bold')
    axes[1, 1].set_xlabel('Rank')
    axes[1, 1].set_ylabel('Accuracy')
    axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()
    summary_path = EXPERIMENT_CONFIG['output']['results_dir'] / EXPERIMENT_CONFIG['output']['results_summary_figure']
    plt.savefig(summary_path, dpi=300, bbox_inches='tight')
    plt.show()

    print(f"Visualizations saved to: {summary_path}")
else:
    print('No results to visualize. Please ensure subjects were successfully trained.')


## Visualize Learned Adjacency Matrix (Example Subject)

In [None]:
if len(all_results) > 0:
    # Pick best subject
    best_idx = results_df['all_channels_acc'].idxmax()
    best_result = all_results[best_idx]
    
    print(f"Visualizing adjacency matrix for best subject: {best_result['subject']}")
    print(f"Accuracy: {best_result['all_channels_acc']:.4f}")
    
    selector = ChannelSelector(best_result['adjacency_matrix'], best_result['channel_names'])
    
    fig = selector.visualize_adjacency(
        save_path=EXPERIMENT_CONFIG['output']['results_dir'] / f"adjacency_{best_result['subject']}.png"
    )
    plt.show()
    
    # Show top edges
    print("\nTop 10 Edges (Edge Selection):")
    selected_channels_es, _ = selector.edge_selection(10)
    print(f"Selected channels: {selected_channels_es}")
    
    print("\nTop 10 Channels (Aggregation Selection):")
    selected_channels_as, _ = selector.aggregation_selection(10)
    print(f"Selected channels: {selected_channels_as}")

## Export Results

In [None]:
if len(results_df) > 0:
    results_dir = EXPERIMENT_CONFIG['output']['results_dir']
    subject_results_path = results_dir / EXPERIMENT_CONFIG['output']['subject_results_file']
    results_df[['subject', 'num_trials', 'num_channels', 'all_channels_acc', 'all_channels_std']].to_csv(subject_results_path, index=False)

    if len(channel_selection_df) > 0:
        channel_selection_path = results_dir / EXPERIMENT_CONFIG['output']['channel_selection_results_file']
        channel_selection_df.to_csv(channel_selection_path, index=False)
    else:
        channel_selection_path = None

    config_path = results_dir / EXPERIMENT_CONFIG['output']['config_file']
    with open(config_path, 'w') as f:
        json.dump(EXPERIMENT_CONFIG, f, indent=2, default=str)

    print('All results exported successfully!')
    print(f'  - Subject results: {subject_results_path}')
    if channel_selection_path:
        print(f'  - Channel selection: {channel_selection_path}')
    print(f'  - Config: {config_path}')
else:
    print('No results to export.')
