# PhysioNet MI – Comprehensive Benchmark Notebook

This notebook performs end-to-end EDA, preprocessing, model training, channel selection, and visualization for the PhysioNet Motor Imagery EEG dataset. It targets TPU  on Kaggle and compares classical and neural methods including FBCSP, CNN-SAE, EEGNet, ACS-SE-CNN, G-CARM, EEG-ARNN, and Gated EEG-ARNN.

## 0. Environment Setup

In [None]:
%%capture!pip install --upgrade pip!pip install mne einops plotly kaleido scienceplots

## 1. Imports

In [None]:
import osimport jsonimport mathimport timeimport randomfrom pathlib import Pathfrom collections import defaultdictimport numpy as npimport pandas as pdimport mnefrom mne.preprocessing import ICAimport matplotlib.pyplot as pltimport seaborn as snsimport plotly.express as pximport plotly.graph_objects as goimport torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimfrom torch.utils.data import Dataset, DataLoaderfrom sklearn.metrics import accuracy_score, precision_score, recall_score, f1_scorefrom sklearn.model_selection import StratifiedKFoldsns.set_context('talk')plt.style.use('seaborn-v0_8')

## 2. Configuration

In [None]:
CONFIG = {    'data': {        'root_dir': '/kaggle/input/eeg-motor-movementimagery-dataset',        'subjects': [            'S{:03d}'.format(i) for i in [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]        ],        'runs': {'imagery': list(range(7,15)), 'execution': list(range(3,7))},        'target_sampling': 128,        'tmin': -1.0,        'tmax': 5.0,        'baseline': (-0.5, 0),        'classes': {'T1': 0, 'T2': 1}    },    'preprocessing': {        'notch_freq': 60.0,        'l_freq': 0.5,        'h_freq': 40.0,        'ica_n_components': 20,    },    'model': {        'batch_size': 64,        'epochs': 20,        'learning_rate': 1e-3,        'weight_decay': 1e-4,        'n_folds': 3,        'hidden_dim': 64,        'gate_init': 0.9    },    'channel_selection': {        'k_values': [10, 15, 20, 25, 30],        'retention_metric': 'accuracy'    },    'random_seed': 42}random.seed(CONFIG['random_seed'])np.random.seed(CONFIG['random_seed'])torch.manual_seed(CONFIG['random_seed'])

## 3. TPU / Accelerator Setup

In [None]:
USE_TPU = Falseif 'TPU_NAME' in os.environ:    import torch_xla    import torch_xla.core.xla_model as xm    USE_TPU = True    DEVICE = xm.xla_device()    print('TPU detected. Using device:', DEVICE)else:    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    print('Using device:', DEVICE)

## 4. Utility Functions

In [None]:
def describe_raw(raw):    info = {        'subject': raw.info['subject_info'],        'sfreq': raw.info['sfreq'],        'n_channels': len(raw.ch_names),        'duration_min': raw.n_times / raw.info['sfreq'] / 60    }    return infodef plot_frequency_spectrum(raw, title):    fft = np.abs(np.fft.rfft(raw.get_data(), axis=1)).mean(axis=0)    freqs = np.fft.rfftfreq(raw.n_times, d=1./raw.info['sfreq'])    plt.figure(figsize=(10,4))    plt.plot(freqs, fft)    plt.title(title)    plt.xlabel('Frequency (Hz)')    plt.ylabel('Amplitude')    plt.xlim(0, 60)    plt.show()def plot_voltage_distribution(raw, title):    data = raw.get_data() * 1e6  # convert to microvolts    plt.figure(figsize=(10,4))    sns.histplot(data.flatten(), bins=200, kde=True)    plt.title(title)    plt.xlabel('Voltage (µV)')    plt.show()

## 5. Exploratory Data Analysis

In [None]:
eda_reports = []for subject_id in CONFIG['data']['subjects'][:3]:    subject_dir = Path(CONFIG['data']['root_dir']) / f'subject_{subject_id}'    if not subject_dir.exists():        continue    edf_files = list(subject_dir.glob('*.edf'))    if not edf_files:        continue    raw = mne.io.read_raw_edf(edf_files[0], preload=True, verbose=False)    info = describe_raw(raw)    info['subject'] = subject_id    eda_reports.append(info)    plot_frequency_spectrum(raw, f'{subject_id} Frequency Spectrum')    plot_voltage_distribution(raw, f'{subject_id} Voltage Distribution')eda_df = pd.DataFrame(eda_reports)display(eda_df)

## 6. Preprocessing Pipeline

In [None]:
def preprocess_subject(subject_id, config):    root = Path(config['data']['root_dir'])    subject_dir = root / f'subject_{subject_id}'    runs = config['data']['runs']['imagery'] + config['data']['runs']['execution']    epochs_list = []    labels_list = []    for run in runs:        edf_path = subject_dir / f'{subject_id}R{run:02d}.edf'        if not edf_path.exists():            continue        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)        annotations = raw.annotations        events, event_ids = mne.events_from_annotations(raw, verbose=False)        raw.notch_filter(config['preprocessing']['notch_freq'], verbose=False)        raw.filter(config['preprocessing']['l_freq'], config['preprocessing']['h_freq'], verbose=False)        raw.resample(config['data']['target_sampling'], verbose=False)        raw.set_montage('standard_1020', on_missing='ignore', match_case=False)        picks = mne.pick_types(raw.info, eeg=True, exclude='bads')        epochs = mne.Epochs(raw, events, event_id=config['data']['classes'], tmin=config['data']['tmin'],                            tmax=config['data']['tmax'], baseline=config['data']['baseline'], picks=picks,                            preload=True, verbose=False)        epochs_list.append(epochs.get_data())        labels_list.append(epochs.events[:, -1])    if not epochs_list:        return None, None, None    X = np.concatenate(epochs_list, axis=0)    y = np.concatenate(labels_list, axis=0)    channel_names = epochs.ch_names    return X, y, channel_names

## 7. Dataset Wrapper

In [None]:
class EEGTensorDataset(Dataset):    def __init__(self, X, y):        self.X = torch.tensor(X, dtype=torch.float32)        self.y = torch.tensor(y, dtype=torch.long)    def __len__(self):        return len(self.y)    def __getitem__(self, idx):        return self.X[idx], self.y[idx]

## 8. Model Zoo

In [None]:
class FBCSPNet(nn.Module):    def __init__(self, C, T, n_classes=2):        super().__init__()        self.conv = nn.Sequential(            nn.Conv2d(1, 16, (C, 1), padding=0),            nn.ELU(),            nn.Conv2d(16, 32, (1, 15), padding=(0,7)),            nn.ELU(),            nn.AvgPool2d((1, 4))        )        self.fc = nn.Sequential(            nn.Flatten(),            nn.Linear(32 * (T//4), 128),            nn.ReLU(),            nn.Dropout(0.4),            nn.Linear(128, n_classes)        )    def forward(self, x):        x = x.unsqueeze(1)        x = self.conv(x)        return self.fc(x)class CNNSAENet(nn.Module):    def __init__(self, C, T, n_classes=2):        super().__init__()        self.encoder = nn.Sequential(            nn.Conv1d(C, 64, kernel_size=5, padding=2),            nn.ReLU(),            nn.MaxPool1d(2),            nn.Conv1d(64, 128, kernel_size=5, padding=2),            nn.ReLU(),            nn.MaxPool1d(2)        )        self.decoder = nn.Sequential(            nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1),            nn.ReLU(),            nn.ConvTranspose1d(64, C, kernel_size=4, stride=2, padding=1)        )        self.classifier = nn.Sequential(            nn.Flatten(),            nn.Linear(128 * (T//4), 256),            nn.ReLU(),            nn.Dropout(0.5),            nn.Linear(256, n_classes)        )    def forward(self, x):        z = self.encoder(x)        recon = self.decoder(z)        logits = self.classifier(z)        return logits, reconclass EEGNet(nn.Module):    def __init__(self, C, T, n_classes=2):        super().__init__()        self.first = nn.Sequential(            nn.Conv2d(1, 16, (1, 64), padding=(0,32), bias=False),            nn.BatchNorm2d(16)        )        self.depthwise = nn.Sequential(            nn.Conv2d(16, 32, (C, 1), groups=16, bias=False),            nn.BatchNorm2d(32),            nn.ELU(),            nn.AvgPool2d((1, 4)),            nn.Dropout(0.25)        )        self.sep = nn.Sequential(            nn.Conv2d(32, 32, (1, 16), padding=(0,8), bias=False),            nn.BatchNorm2d(32),            nn.ELU(),            nn.AvgPool2d((1, 8)),            nn.Dropout(0.25)        )        self.classifier = nn.Sequential(            nn.Flatten(),            nn.Linear(32 * (T//32), n_classes)        )    def forward(self, x):        x = x.unsqueeze(1)        x = self.first(x)        x = self.depthwise(x)        x = self.sep(x)        return self.classifier(x)class ChannelSE(nn.Module):    def __init__(self, C, reduction=8):        super().__init__()        self.net = nn.Sequential(            nn.Linear(C, C//reduction),            nn.ReLU(),            nn.Linear(C//reduction, C),            nn.Sigmoid()        )    def forward(self, x):        z = x.mean(-1)        w = self.net(z)        return x * w.unsqueeze(-1)class ACSSECNN(nn.Module):    def __init__(self, C, T, n_classes=2):        super().__init__()        self.conv = nn.Sequential(            nn.Conv1d(C, 64, kernel_size=5, padding=2),            nn.ReLU(),            nn.Conv1d(64, 64, kernel_size=5, padding=2),            nn.ReLU()        )        self.se = ChannelSE(64)        self.pool = nn.AdaptiveAvgPool1d(1)        self.classifier = nn.Sequential(            nn.Flatten(),            nn.Linear(64, 128),            nn.ReLU(),            nn.Dropout(0.5),            nn.Linear(128, n_classes)        )    def forward(self, x):        h = self.conv(x)        h = self.se(h)        h = self.pool(h)        return self.classifier(h)class GraphConvLayer(nn.Module):    def __init__(self, C, hidden_dim):        super().__init__()        self.A = nn.Parameter(torch.randn(C, C))        self.theta = nn.Linear(hidden_dim, hidden_dim)        self.bn = nn.BatchNorm1d(C)    def forward(self, x):        adj = torch.softmax(self.A, dim=-1)        x = torch.matmul(adj, x)        x = self.theta(x)        x = self.bn(x)        return F.elu(x)class GCARM(nn.Module):    def __init__(self, C, T, hidden_dim=64, n_classes=2):        super().__init__()        self.input = nn.Linear(T, hidden_dim)        self.gconv1 = GraphConvLayer(C, hidden_dim)        self.gconv2 = GraphConvLayer(C, hidden_dim)        self.classifier = nn.Sequential(            nn.Flatten(),            nn.Linear(C * hidden_dim, 256),            nn.ReLU(),            nn.Dropout(0.5),            nn.Linear(256, n_classes)        )    def forward(self, x):        h = self.input(x)        h = self.gconv1(h)        h = self.gconv2(h)        return self.classifier(h)class EEGARNN(nn.Module):    def __init__(self, C, T, hidden_dim=64, n_classes=2):        super().__init__()        self.temporal = nn.GRU(T, hidden_dim, batch_first=True, bidirectional=True)        self.graph = GraphConvLayer(C, hidden_dim*2)        self.classifier = nn.Sequential(            nn.Flatten(),            nn.Linear(C * hidden_dim*2, 256),            nn.ReLU(),            nn.Dropout(0.5),            nn.Linear(256, n_classes)        )    def forward(self, x):        h, _ = self.temporal(x)        h = self.graph(h)        return self.classifier(h)class GatedEEGARNN(EEGARNN):    def __init__(self, C, T, hidden_dim=64, n_classes=2, gate_init=0.9):        super().__init__(C, T, hidden_dim, n_classes)        self.gate_net = nn.Sequential(            nn.Linear(C*2, C),            nn.ReLU(),            nn.Linear(C, C)        )        self.gate_bias = math.log(gate_init/(1-gate_init))    def forward(self, x):        mean = x.mean(-1)        std = x.std(-1)        g_in = torch.cat([mean, std], dim=-1)        gates = torch.sigmoid(self.gate_net(g_in) + self.gate_bias)        x = x * gates.unsqueeze(-1)        return super().forward(x)

## 9. Training Utilities

In [None]:
MODEL_REGISTRY = {}class TorchModelWrapper:    def __init__(self, name, builder, is_autoencoder=False):        self.name = name        self.builder = builder        self.is_autoencoder = is_autoencoder    def train(self, X, y, channel_names, config, device):        C, T = X.shape[1], X.shape[2]        dataset = EEGTensorDataset(X, y)        skf = StratifiedKFold(n_splits=config['model']['n_folds'], shuffle=True,                              random_state=config['random_seed'])        fold_metrics = []        for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):            train_loader = DataLoader(torch.utils.data.Subset(dataset, train_idx),                                      batch_size=config['model']['batch_size'], shuffle=True)            val_loader = DataLoader(torch.utils.data.Subset(dataset, val_idx),                                    batch_size=config['model']['batch_size'], shuffle=False)            model = self.builder(C, T).to(device)            optimizer = optim.Adam(model.parameters(), lr=config['model']['learning_rate'],                                   weight_decay=config['model']['weight_decay'])            criterion = nn.CrossEntropyLoss()            for epoch in range(config['model']['epochs']):                model.train()                for xb, yb in train_loader:                    xb, yb = xb.to(device), yb.to(device)                    optimizer.zero_grad()                    if self.is_autoencoder:                        logits, recon = model(xb)                        loss = criterion(logits, yb) + F.mse_loss(recon, xb)                    else:                        logits = model(xb)                        loss = criterion(logits, yb)                    if USE_TPU:                        import torch_xla.core.xla_model as xm                        xm.optimizer_step(optimizer, barrier=True)                    else:                        loss.backward()                        optimizer.step()                                model.eval()                y_true, y_pred = [], []                with torch.no_grad():                    for xb, yb in val_loader:                        xb = xb.to(device)                        logits = model(xb)                        preds = torch.argmax(logits, dim=-1).cpu().numpy()                        y_true.append(yb.numpy())                        y_pred.append(preds)                y_true = np.concatenate(y_true)                y_pred = np.concatenate(y_pred)                fold_metrics.append({                    'accuracy': accuracy_score(y_true, y_pred),                    'precision': precision_score(y_true, y_pred, average='weighted', zero_division=0),                    'recall': recall_score(y_true, y_pred, average='weighted', zero_division=0),                    'f1': f1_score(y_true, y_pred, average='weighted', zero_division=0)                })        return fold_metrics, modeldef register_models():    MODEL_REGISTRY['FBCSP'] = TorchModelWrapper('FBCSP', lambda C,T: FBCSPNet(C,T))    MODEL_REGISTRY['CNN-SAE'] = TorchModelWrapper('CNN-SAE', lambda C,T: CNNSAENet(C,T), is_autoencoder=True)    MODEL_REGISTRY['EEGNet'] = TorchModelWrapper('EEGNet', lambda C,T: EEGNet(C,T))    MODEL_REGISTRY['ACS-SE-CNN'] = TorchModelWrapper('ACS-SE-CNN', lambda C,T: ACSSECNN(C,T))    MODEL_REGISTRY['G-CARM'] = TorchModelWrapper('G-CARM', lambda C,T: GCARM(C,T))    MODEL_REGISTRY['EEG-ARNN'] = TorchModelWrapper('EEG-ARNN', lambda C,T: EEGARNN(C,T))    MODEL_REGISTRY['Gated EEG-ARNN'] = TorchModelWrapper('Gated EEG-ARNN', lambda C,T: GatedEEGARNN(C,T, gate_init=CONFIG['model']['gate_init']))register_models()

## 10. Channel Selection

In [None]:
class ChannelSelector:    def __init__(self, adjacency, channel_names, gates=None):        self.adj = adjacency        self.names = channel_names        self.gates = gates    def edge_selection(self, k):        scores = np.abs(self.adj).sum(axis=1)        return self.names[np.argsort(scores)[::-1][:k]]    def aggregation_selection(self, activations, k):        scores = np.abs(activations).mean(axis=-1)        return self.names[np.argsort(scores)[::-1][:k]]    def gate_selection(self, k):        if self.gates is None:            raise ValueError('Gate values not provided')        return self.names[np.argsort(self.gates)[::-1][:k]]

## 11. Retention Analysis

In [None]:
def evaluate_retention(model_wrapper, X, y, channel_names, selector, k_values, config, device):    retention_records = []    for k in k_values:        top_channels = selector.gate_selection(k) if selector.gates is not None else selector.edge_selection(k)        idx = [channel_names.index(ch) for ch in top_channels]        X_reduced = X[:, idx, :]        metrics, _ = model_wrapper.train(X_reduced, y, top_channels, config, device)        avg = np.mean([m['accuracy'] for m in metrics])        retention_records.append({'k': k, 'accuracy': avg})    return pd.DataFrame(retention_records)

## 12. Master Training Loop

In [None]:
all_subject_results = []retention_curves = {}for subject_id in CONFIG['data']['subjects']:    print(f"Processing {subject_id}")    X, y, channel_names = preprocess_subject(subject_id, CONFIG)    if X is None:        continue    subject_metrics = {'subject': subject_id}    for model_name, wrapper in MODEL_REGISTRY.items():        metrics, trained_model = wrapper.train(X, y, channel_names, CONFIG, DEVICE)        avg_metrics = {            f'{model_name}_acc': np.mean([m['accuracy'] for m in metrics]),            f'{model_name}_f1': np.mean([m['f1'] for m in metrics])        }        subject_metrics.update(avg_metrics)        if model_name == 'Gated EEG-ARNN':            adjacency = trained_model.graph.A.detach().cpu().numpy()            gates = trained_model.gate_net(                torch.cat([torch.tensor(X).mean(-1), torch.tensor(X).std(-1)], dim=-1)            ).mean(0).detach().numpy()            selector = ChannelSelector(adjacency, channel_names, gates)            retention_df = evaluate_retention(wrapper, X, y, channel_names, selector,                                              CONFIG['channel_selection']['k_values'], CONFIG, DEVICE)            retention_df['subject'] = subject_id            retention_curves[subject_id] = retention_df    all_subject_results.append(subject_metrics)

## 13. Results Summary

In [None]:
results_df = pd.DataFrame(all_subject_results)display(results_df.head())summary = results_df.mean().to_frame('mean_accuracy')display(summary)

## 14. Retention Visualization

In [None]:
retention_df = pd.concat(retention_curves.values(), ignore_index=True)fig = px.line(retention_df, x='k', y='accuracy', color='subject', markers=True,              title='Retention vs Channels (Gated EEG-ARNN)')fig.show()

## 15. Save Artifacts

In [None]:
results_df.to_csv('subject_metrics.csv', index=False)retention_df.to_csv('retention_curves.csv', index=False)with open('config_used.json', 'w') as f:    json.dump(CONFIG, f, indent=2)print("Artifacts saved: subject_metrics.csv, retention_curves.csv, config_used.json")