
# BCI Competition IV-2a — 4‑Class Motor Imagery with TemporalCNN + GCN + Channel Gating



This notebook fixes the *"all predictions collapse to class 0"* issue by **properly loading all four MI classes** (Left, Right, Foot, Tongue) and building a **4-way classifier**.  
It uses a hybrid model:
- Temporal 1D CNN per channel → time pooling (extracts band‑limited features)
- **Channel Gating** (learnable per‑channel weights; L1 sparsity to encourage selection)
- **Graph Convolution** (GCN) over channels using an adjacency built from electrode distances
- MLP classifier (4 classes)

You’ll get:
- Robust **4‑class event mapping** (supports 769–772, 1–4, `'T1'..'T4'`, or textual names)
- Clean train/val/test split (session‑aware or stratified)
- Training loop with early stopping & scheduler
- Metrics: accuracy, macro‑F1, Cohen’s kappa, per‑class report
- **Confusion matrix** and **Top‑k channels** from the learned gate

> **Dataset**: BCI Competition IV‑2a (GDF files) — put them under `DATA_ROOT/subjectXX/`.


In [1]:
# Compatibility shim for MNE GDF header reading with NumPy 2.x
# Paste this at the top of the notebook (before importing mne or reading GDFs).
import numpy as np

# keep original so we can still use text-mode path
_orig_fromstring = np.fromstring

def _fromstring_compat(data, dtype=float, sep=''):
    """
    Compatibility wrapper:
      - If sep=='' (binary-mode expected) and data is bytes/bytearray -> use frombuffer
      - If sep=='' and data is str -> encode with latin-1 (one-to-one) then frombuffer
      - Otherwise call original np.fromstring (text parsing)
    """
    if sep == '':
        # binary/path expected
        if isinstance(data, (bytes, bytearray)):
            return np.frombuffer(data, dtype=dtype)
        if isinstance(data, str):
            # map characters to original byte values (latin-1 is 1:1 for 0-255)
            try:
                b = data.encode('latin-1')
            except Exception:
                # fallback: preserve raw bytes via surrogatepass if weird unicode present
                b = data.encode('utf-8', errors='surrogatepass')
            return np.frombuffer(b, dtype=dtype)
    # text mode -> delegate to original behaviour (parses str with sep)
    return _orig_fromstring(data, dtype=dtype, sep=sep)

# apply patch
np.fromstring = _fromstring_compat

print("Applied np.fromstring compatibility shim. numpy version:", np.__version__)
# quick self-test (should not error)
print("test bytes =>", np.fromstring(b'\\x01\\x02', dtype=np.uint8))
print("test str   =>", np.fromstring('\x01\x02', dtype=np.uint8))


Applied np.fromstring compatibility shim. numpy version: 2.3.3
test bytes => [ 92 120  48  49  92 120  48  50]
test str   => [1 2]


## 1. Environment & Imports

In [2]:

# %pip install -q mne numpy scipy scikit-learn torch==2.2.0 torchaudio torchvision matplotlib
import os, re, math, random, warnings
from pathlib import Path
import numpy as np
import mne
from scipy import signal
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import classification_report, confusion_matrix, f1_score, cohen_kappa_score, accuracy_score
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
warnings.filterwarnings("ignore", category=RuntimeWarning)
print("Torch:", torch.__version__, "| MNE:", mne.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


Torch: 2.8.0+cpu | MNE: 1.10.1


device(type='cpu')

## 2. Config

In [3]:

DATA_ROOT = Path("./../BCICIV_2a_gdf")  # <-- CHANGE THIS
SAVE_DIR = Path("./outputs_bci2a_gcn"); SAVE_DIR.mkdir(parents=True, exist_ok=True)
SF_TARGET = 250; BANDPASS = (4., 38.); NOTCH = 50.0
TMIN = 0.5; TMAX = 3.5; REF = "average"
N_EPOCHS = 60; BATCH_SIZE = 64; LR = 1e-3; WEIGHT_DECAY = 1e-4
PATIENCE = 12; LR_PATIENCE = 6; GATE_L1 = 1e-4
SESSION_A_TRAIN_SESSION_B_TEST = True; VAL_SIZE = 0.15
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED); random.seed(RANDOM_SEED); torch.manual_seed(RANDOM_SEED); torch.cuda.manual_seed_all(RANDOM_SEED)
BCI2A_CHS = ["Fz","FC3","FC1","FCz","FC2","FC4","C5","C3","C1","Cz","C2","C4","C6","CP3","CP1","CPz","CP2","CP4","P1","Pz","P2","POz"]
CLASS_NAMES = ["Left", "Right", "Foot", "Tongue"]; NUM_CLASSES = 4


## 3. Helpers

In [4]:

def set_seed(seed=RANDOM_SEED):
    np.random.seed(seed); random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
def print_counts(prefix, y):
    counts = np.bincount(y, minlength=NUM_CLASSES)
    print(prefix, dict(zip(CLASS_NAMES, counts.tolist())))
set_seed()


## 4. Event Mapping (robust to different encodings)

In [5]:
from collections import Counter

def build_event_map(event_id_dict, raw=None, verbose=True):
    """
    Create a mapping: numeric_event_code (the ints you see as ev[2]) -> label 0..3
    Heuristics:
     - If annotation keys include textual tags (t1/t2/t3/t4 or left/right/foot/tongue),
       map accordingly.
     - If annotation keys are numeric-strings and match common sets (1..4 or 769..772), use canonical mapping.
     - Otherwise, if raw is provided, pick the top-4 most frequent annotation descriptions (excluding sentinel/boundary codes)
       and map them to 0..3 in frequency order (prints mapping so you can inspect).
    """
    # Normalize keys -> strings
    ev_items = {str(k): int(v) for k, v in event_id_dict.items()}
    keys = list(ev_items.keys())
    if verbose:
        print("build_event_map: annotation keys (sample):", keys[:12])

    # 1) textual mapping: try to map 't1','left','right','foot','tongue'
    textual_map = {}
    for desc, code in ev_items.items():
        s = desc.lower()
        if "t1" in s or "left" in s:
            textual_map[code] = 0
        elif "t2" in s or "right" in s:
            textual_map[code] = 1
        elif "t3" in s or "foot" in s or "leg" in s:
            textual_map[code] = 2
        elif "t4" in s or "tongue" in s:
            textual_map[code] = 3
    if len(textual_map) == 4:
        if verbose: print("Detected textual mapping:", textual_map)
        return textual_map

    # 2) numeric-string keys -> try canonical code sets
    num_desc_to_code = {}
    for desc, code in ev_items.items():
        try:
            dnum = int(desc)
            num_desc_to_code[dnum] = code
        except Exception:
            pass

    # try known sets
    if set([769,770,771,772]).issubset(set(num_desc_to_code.keys())):
        mapping = {
            num_desc_to_code[769]: 0,
            num_desc_to_code[770]: 1,
            num_desc_to_code[771]: 2,
            num_desc_to_code[772]: 3,
        }
        if verbose: print("Detected BCI2000(769..772) numeric mapping ->", mapping)
        return mapping

    if set([1,2,3,4]).issubset(set(num_desc_to_code.keys())):
        mapping = {
            num_desc_to_code[1]: 0,
            num_desc_to_code[2]: 1,
            num_desc_to_code[3]: 2,
            num_desc_to_code[4]: 3,
        }
        if verbose: print("Detected simple (1..4) numeric mapping ->", mapping)
        return mapping

    # 3) fallback: use annotation frequencies from raw.annotations (if raw provided)
    if raw is not None:
        descs = list(raw.annotations.description)
        if len(descs) == 0:
            raise RuntimeError("No annotations available in raw.annotations to auto-detect event mapping.")
        c = Counter(descs)
        # filter candidates to only annotations present in event_id_dict
        candidates = {desc: cnt for desc, cnt in c.items() if desc in ev_items}
        # filter-out likely sentinel/boundary codes:
        def is_unwanted(desc):
            s = str(desc).lower()
            if any(x in s for x in ("bad", "boundary", "artifact", "start", "stop")):
                return True
            # if desc is numeric and very large (typical sentinel like 32766), ignore
            try:
                v = int(desc)
                if v > 30000: 
                    return True
            except:
                pass
            return False

        for bad in list(candidates.keys()):
            if is_unwanted(bad):
                candidates.pop(bad, None)

        if len(candidates) >= 4:
            top4 = [d for d, _ in sorted(candidates.items(), key=lambda x:-x[1])[:4]]
            mapping = { ev_items[desc] : idx for idx, desc in enumerate(top4) }
            if verbose:
                print("Auto-mapped top-4 annotation descriptions -> labels (code -> label):")
                for code, lab in mapping.items():
                    print("   code", code, " = label", lab, " (desc='{}', count={})".format(
                        [d for d in top4 if ev_items[d]==code][0], c[[d for d in top4 if ev_items[d]==code][0]]
                    ))
                print("Top4 desc order used:", top4)
            return mapping

    # If none of the above worked — raise with helpful debug info
    raise RuntimeError(f"Could not build 4-class event map automatically. event_id keys: {list(ev_items.keys())}. "
                       "If you know the mapping, pass a manual map or inspect raw.annotations.description and event_id. "
                       "Example: event_id = " + str(event_id_dict))



## 5. Load & Preprocess BCI IV‑2a (one subject)

In [6]:

def _find_subject_files(data_root: Path, subject_id: int):
    s=f"{subject_id:02d}"
    cand=list((data_root).rglob(f"*{s}*.[gG][dD][fF]"))
    train_files=[p for p in cand if re.search(r"(T|train)", p.name, re.I)]
    eval_files=[p for p in cand if re.search(r"(E|eval)", p.name, re.I)]
    if not train_files or not eval_files: return cand, []
    return train_files, eval_files

def _get_montage_info(raw):
    try:
        montage = mne.channels.make_standard_montage("standard_1020")
        raw.set_montage(montage, on_missing="ignore")
    except Exception as e:
        print("Montage set failed (non-fatal):", e)

def _preprocess_raw(raw):
    if REF=="average": raw.set_eeg_reference("average", projection=False, verbose=False)
    if NOTCH is not None: raw.notch_filter(freqs=[NOTCH], picks="eeg", verbose=False)
    raw.filter(BANDPASS[0], BANDPASS[1], picks="eeg", verbose=False)
    if SF_TARGET is not None and abs(raw.info["sfreq"]-SF_TARGET)>1e-3:
        raw.resample(SF_TARGET, npad="auto", verbose=False)

def _map_events(raw):
    """
    Use mne.events_from_annotations and robustly map numeric event codes -> labels 0..3.
    Returns events array (n,3) and emap (code->label).
    """
    events, event_id = mne.events_from_annotations(raw, verbose=False)
    print("mne.events_from_annotations -> event_id:", event_id)
    try:
        emap = build_event_map(event_id, raw=raw, verbose=True)
    except RuntimeError as e:
        # raise with extra debug info
        print("build_event_map failed; raw.annotations.description sample:", raw.annotations.description[:30])
        raise

    # Build new events' label column (map numeric code -> 0..3); we will convert events accordingly
    mapped_events = []
    for ev in events:
        onset, _, code = int(ev[0]), int(ev[1]) if ev.shape[1] > 1 else 0, int(ev[2])
        if code not in emap:
            # skip non-MI codes
            continue
        lbl = emap[code]
        mapped_events.append([onset, 0, lbl])
    if len(mapped_events) == 0:
        raise RuntimeError("No MI events found after mapping — check annotations and event_id.")
    return np.array(mapped_events), emap

def _epochs_from_raw(raw):
    events, emap = _map_events(raw)
    picks = mne.pick_types(raw.info, eeg=True, meg=False, stim=False, eog=False)
    ch_names = [raw.ch_names[i] for i in picks]
    keep_idx = [raw.ch_names.index(ch) for ch in BCI2A_CHS if ch in raw.ch_names]
    if len(keep_idx)>=8: picks=keep_idx; ch_names=[raw.ch_names[i] for i in picks]
    epochs = mne.Epochs(raw, events, event_id=None, tmin=TMIN, tmax=TMAX,
                        picks=picks, baseline=None, preload=True, verbose=False)
    X=epochs.get_data(); y=epochs.events[:,-1]
    return X, y, ch_names, raw.info["sfreq"]

def load_subject(data_root: Path, subject_id: int):
    train_files, eval_files = _find_subject_files(data_root, subject_id)
    if not train_files: raise FileNotFoundError(f"No GDF for S{subject_id:02d}")
    rawT_list=[]
    for f in sorted(train_files):
        raw=mne.io.read_raw_gdf(f, preload=True, verbose=False)
        _get_montage_info(raw); _preprocess_raw(raw); rawT_list.append(raw)
    rawT=mne.concatenate_raws(rawT_list, verbose=False)
    Xtr,ytr,chs,sf=_epochs_from_raw(rawT)
    Xte=yte=None
    if eval_files:
        rawE_list=[]
        for f in sorted(eval_files):
            raw=mne.io.read_raw_gdf(f, preload=True, verbose=False)
            _get_montage_info(raw); _preprocess_raw(raw); rawE_list.append(raw)
        rawE=mne.concatenate_raws(rawE_list, verbose=False)
        Xte,yte,_,_= _epochs_from_raw(rawE)
    return Xtr,ytr,Xte,yte,chs,sf


## 6. Channel Graph (Adjacency from electrode distances)

In [7]:

def build_adjacency(ch_names, info, sigma=0.08, self_loop=True):
    """
    Build adjacency using 3D positions from info['chs'][i]['loc'] (or montage).
    A_ij = exp(-||p_i - p_j||^2 / (2 sigma^2)), then normalize with D^{-1/2} A D^{-1/2}.
    """
    pos=[]; name_to_idx={n:i for i,n in enumerate(info["ch_names"])}
    for ch in ch_names:
        if ch in info["ch_names"]:
            idx=name_to_idx[ch]; loc=info["chs"][idx]["loc"][:3]; pos.append(loc)
        else: pos.append(np.zeros(3))
    pos=np.array(pos)
    dists=np.linalg.norm(pos[:,None,:]-pos[None,:,:], axis=-1)
    A=np.exp(-(dists**2)/(2.0*(sigma**2)+1e-12))
    np.fill_diagonal(A, 1.0 if self_loop else 0.0)
    D=np.diag(np.sum(A,axis=1)+1e-8)
    D_inv_sqrt=np.linalg.inv(np.sqrt(D))
    An=D_inv_sqrt @ A @ D_inv_sqrt
    return A.astype(np.float32), An.astype(np.float32), dists.astype(np.float32)


## 7. Torch Dataset & DataLoaders

In [8]:

class TrialsDataset(Dataset):
    def __init__(self, X, y):
        self.X=X.astype(np.float32); self.y=y.astype(np.int64)
    def __len__(self): return len(self.X)
    def __getitem__(self, idx): return torch.from_numpy(self.X[idx]), torch.tensor(self.y[idx])

def make_loaders(Xtr, ytr, Xte, yte, batch_size=BATCH_SIZE):
    if SESSION_A_TRAIN_SESSION_B_TEST and Xte is not None:
        X_train,y_train=Xtr,ytr
        X_train,X_val,y_train,y_val = train_test_split(X_train,y_train,test_size=VAL_SIZE,random_state=RANDOM_SEED,stratify=y_train)
        X_test,y_test=Xte,yte
    else:
        X_train,X_temp,y_train,y_temp = train_test_split(Xtr,ytr,test_size=0.3,random_state=RANDOM_SEED,stratify=ytr)
        X_val,X_test,y_val,y_test = train_test_split(X_temp,y_temp,test_size=0.5,random_state=RANDOM_SEED,stratify=y_temp)
    print_counts("Train", y_train); print_counts("Val  ", y_val); print_counts("Test ", y_test)
    ds_tr,ds_va,ds_te = TrialsDataset(X_train,y_train), TrialsDataset(X_val,y_val), TrialsDataset(X_test,y_test)
    dl_tr=DataLoader(ds_tr,batch_size=batch_size,shuffle=True,drop_last=False)
    dl_va=DataLoader(ds_va,batch_size=batch_size,shuffle=False,drop_last=False)
    dl_te=DataLoader(ds_te,batch_size=batch_size,shuffle=False,drop_last=False)
    return dl_tr, dl_va, dl_te, (X_train.shape[1], X_train.shape[2])


## 8. Model — TemporalCNN + Channel Gating + GCN + MLP

In [9]:

class ChannelGate(nn.Module):
    def __init__(self, n_channels):
        super().__init__()
        self.logits=nn.Parameter(torch.zeros(n_channels))
    def forward(self, x):
        g=torch.sigmoid(self.logits); return x*g[None,:,None], g

class GCNLayer(nn.Module):
    def __init__(self, in_feats, out_feats, bias=True):
        super().__init__()
        self.lin=nn.Linear(in_feats,out_feats,bias=bias); self.act=nn.ReLU(inplace=True)
    def forward(self, H, A_norm):
        HW=self.lin(H); out=torch.einsum('ij,bcj->bci', A_norm, HW); return self.act(out)

class EEG_GCNNet(nn.Module):
    def __init__(self, n_channels, n_classes, T, A_norm, gate_l1=0.0):
        super().__init__()
        self.n_channels=n_channels; self.register_buffer("A", A_norm); self.gate_l1=gate_l1
        self.grp_conv1=nn.Conv1d(n_channels, n_channels*8, kernel_size=25, padding=12, groups=n_channels)
        self.grp_conv2=nn.Conv1d(n_channels*8, n_channels*8, kernel_size=9, padding=4, groups=n_channels)
        self.grp_pool=nn.AdaptiveAvgPool1d(1); self.feat_dim=8
        self.gate=ChannelGate(n_channels)
        self.gcn1=GCNLayer(self.feat_dim,32); self.gcn2=GCNLayer(32,64)
        self.dropout=nn.Dropout(0.3)
        self.classifier=nn.Sequential(nn.Linear(64,64), nn.ReLU(inplace=True), nn.Dropout(0.3), nn.Linear(64,n_classes))
    def forward(self, x):
        x,gates=self.gate(x)
        H=F.relu(self.grp_conv1(x)); H=F.relu(self.grp_conv2(H)); H=self.grp_pool(H).squeeze(-1)
        H=H.view(H.size(0), self.n_channels, self.feat_dim)
        H=self.gcn1(H, self.A); H=self.gcn2(H, self.A); H=H.mean(dim=1)
        out=self.classifier(self.dropout(H)); return out, gates

def gate_l1_penalty(gates, coeff):
    if coeff<=0: return torch.tensor(0., device=gates.device)
    return coeff*torch.sum(torch.abs(gates))


## 9. Training & Evaluation Loops

In [10]:

class EarlyStopper:
    def __init__(self, patience=PATIENCE, min_delta=0.0):
        self.patience=patience; self.min_delta=min_delta; self.counter=0; self.best=None; self.should_stop=False
    def step(self, metric, model, path):
        if (self.best is None) or (metric>self.best+self.min_delta):
            self.best=metric; self.counter=0; torch.save(model.state_dict(), path)
        else:
            self.counter+=1
            if self.counter>=self.patience: self.should_stop=True

def run_epoch(model, loader, optimizer=None):
    is_train=optimizer is not None; model.train(is_train)
    losses, yh, yt = [], [], []
    for xb,yb in loader:
        xb=xb.to(device); yb=yb.to(device)
        if is_train: optimizer.zero_grad()
        logits,gates=model(xb)
        loss=F.cross_entropy(logits,yb)+gate_l1_penalty(gates, model.gate_l1)
        if is_train:
            loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(),5.0); optimizer.step()
        losses.append(loss.item()); yh.append(logits.detach().cpu().numpy().argmax(axis=1)); yt.append(yb.detach().cpu().numpy())
    yh=np.concatenate(yh); yt=np.concatenate(yt)
    acc=accuracy_score(yt,yh); f1m=f1_score(yt,yh,average='macro'); kappa=cohen_kappa_score(yt,yh)
    return np.mean(losses), acc, f1m, kappa, yh, yt

def evaluate(model, loader, class_names=CLASS_NAMES, title="Confusion Matrix"):
    model.eval(); yh_all, yt_all = [], []
    with torch.no_grad():
        for xb,yb in loader:
            xb=xb.to(device); yb=yb.to(device)
            logits,_=model(xb); yh_all.append(logits.argmax(dim=1).cpu().numpy()); yt_all.append(yb.cpu().numpy())
    yh=np.concatenate(yh_all); yt=np.concatenate(yt_all)
    print("\nClassification report:\n", classification_report(yt,yh,target_names=class_names,digits=3))
    cm=confusion_matrix(yt,yh,labels=list(range(len(class_names))))
    fig=plt.figure(figsize=(6,5)); plt.imshow(cm, interpolation='nearest'); plt.title(title); plt.colorbar()
    tick=np.arange(len(class_names)); plt.xticks(tick,class_names,rotation=45); plt.yticks(tick,class_names)
    plt.xlabel("Predicted"); plt.ylabel("True"); plt.tight_layout(); plt.show()
    return cm, yt, yh


## 10. End‑to‑End: Load, Build Graph, Train, Evaluate (one subject)

In [11]:

subject_id=1
Xtr,ytr,Xte,yte,CHS,sf = load_subject(DATA_ROOT, subject_id)
print("Shapes:", Xtr.shape, ytr.shape, None if Xte is None else Xte.shape)
train_files,_=_find_subject_files(DATA_ROOT, subject_id)
raw_tmp=mne.io.read_raw_gdf(sorted(train_files)[0], preload=False, verbose=False); _get_montage_info(raw_tmp)
A,A_norm,dmat=build_adjacency(CHS, raw_tmp.info, sigma=0.08, self_loop=True); print("Adjacency:", A_norm.shape)
dl_tr,dl_va,dl_te,(C,T)=make_loaders(Xtr,ytr,Xte,yte,batch_size=BATCH_SIZE)
A_t=torch.tensor(A_norm,dtype=torch.float32,device=device)
model=EEG_GCNNet(n_channels=C,n_classes=NUM_CLASSES,T=T,A_norm=A_t,gate_l1=GATE_L1).to(device)
opt=torch.optim.AdamW(model.parameters(),lr=LR,weight_decay=WEIGHT_DECAY)
scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(opt,mode="max",patience=LR_PATIENCE,factor=0.5,verbose=True)
early=EarlyStopper(patience=PATIENCE,min_delta=1e-4); best_path=str(SAVE_DIR/f"subject{subject_id:02d}_best.pt")
hist={"train":[],"val":[]}
for epoch in range(1,N_EPOCHS+1):
    tr_loss,tr_acc,tr_f1,tr_k,_,_=run_epoch(model,dl_tr,optimizer=opt)
    va_loss,va_acc,va_f1,va_k,_,_=run_epoch(model,dl_va,optimizer=None)
    scheduler.step(va_acc); early.step(va_acc, model, best_path)
    hist["train"].append([tr_loss,tr_acc,tr_f1,tr_k]); hist["val"].append([va_loss,va_acc,va_f1,va_k])
    print(f"Epoch {epoch:03d} | Train L={tr_loss:.3f} A={tr_acc:.3f} F1={tr_f1:.3f} K={tr_k:.3f} || Val L={va_loss:.3f} A={va_acc:.3f} F1={va_f1:.3f} K={va_k:.3f} LR={opt.param_groups[0]['lr']:.2e}")
    if early.should_stop: print("Early stopping."); break
model.load_state_dict(torch.load(best_path, map_location=device))
cm,yt,yh = evaluate(model, dl_te if dl_te is not None else dl_va, class_names=CLASS_NAMES, title="Confusion Matrix (Best)")
with torch.no_grad(): gates=torch.sigmoid(model.gate.logits).detach().cpu().numpy()
topk=10; idx=np.argsort(-gates)[:topk]; print("\nTop-%d channels by learned gate:"%topk)
for rank,i in enumerate(idx,1): print(f"{rank:2d}. {CHS[i]}  (weight={gates[i]:.3f})")
plt.figure(figsize=(7,3)); plt.bar(np.arange(len(CHS)),gates); plt.xticks(np.arange(len(CHS)),CHS,rotation=90)
plt.ylabel("Gate weight"); plt.title("Per-channel importance (sigmoid gate)"); plt.tight_layout(); plt.show()


mne.events_from_annotations -> event_id: {np.str_('1023'): 1, np.str_('1072'): 2, np.str_('276'): 3, np.str_('277'): 4, np.str_('32766'): 5, np.str_('768'): 6, np.str_('769'): 7, np.str_('770'): 8, np.str_('771'): 9, np.str_('772'): 10}
build_event_map: annotation keys (sample): ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
Detected BCI2000(769..772) numeric mapping -> {7: 0, 8: 1, 9: 2, 10: 3}


IndexError: tuple index out of range

## 11. (Optional) Loop over subjects & report mean metrics

In [None]:

def train_one_subject(sid):
    Xtr,ytr,Xte,yte,CHS,sf=load_subject(DATA_ROOT,sid)
    train_files,_=_find_subject_files(DATA_ROOT,sid)
    raw_tmp=mne.io.read_raw_gdf(sorted(train_files)[0], preload=False, verbose=False); _get_montage_info(raw_tmp)
    A,A_norm,_=build_adjacency(CHS, raw_tmp.info, sigma=0.08, self_loop=True)
    dl_tr,dl_va,dl_te,(C,T)=make_loaders(Xtr,ytr,Xte,yte,batch_size=BATCH_SIZE)
    A_t=torch.tensor(A_norm,dtype=torch.float32,device=device)
    model=EEG_GCNNet(n_channels=C,n_classes=NUM_CLASSES,T=T,A_norm=A_t,gate_l1=GATE_L1).to(device)
    opt=torch.optim.AdamW(model.parameters(),lr=LR,weight_decay=WEIGHT_DECAY)
    scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(opt,mode="max",patience=LR_PATIENCE,factor=0.5,verbose=False)
    early=EarlyStopper(patience=PATIENCE,min_delta=1e-4); best_path=str(SAVE_DIR/f"subject{sid:02d}_best.pt")
    for epoch in range(1,N_EPOCHS+1):
        tr_loss,tr_acc,tr_f1,tr_k,_,_=run_epoch(model,dl_tr,optimizer=opt)
        va_loss,va_acc,va_f1,va_k,_,_=run_epoch(model,dl_va,optimizer=None)
        scheduler.step(va_acc); early.step(va_acc, model, best_path)
        if early.should_stop: break
    model.load_state_dict(torch.load(best_path, map_location=device))
    _,yt,yh=evaluate(model, dl_te if dl_te is not None else dl_va, class_names=CLASS_NAMES, title=f"Subject {sid:02d} CM")
    acc=accuracy_score(yt,yh); f1m=f1_score(yt,yh,average='macro'); k=cohen_kappa_score(yt,yh)
    return acc,f1m,k
# Example:
# results=[train_one_subject(s) for s in range(1,10)]
# print("Mean acc=%.3f, f1=%.3f, kappa=%.3f"%tuple(np.mean(results,axis=0)))


## 12. Troubleshooting Checklist


- **Label sanity:** print `np.bincount(y, minlength=4)` right after epoching. Each class should have similar counts.
- **Shapes:** Make sure inputs are `(batch, channels, time)` and final logits are `(batch, 4)`.
- **Learning rate:** If accuracy stagnates at a single class, try lowering LR to `3e-4` or increasing `N_EPOCHS`.
- **Artifact rejection:** Heavy artifacts can poison features. Consider removing bad trials with MNE’s autoreject or simple amplitude thresholds.
- **Window:** You can tweak `TMIN/TMAX` to 0.5–2.5s or 1.0–3.0s after cue onset; small changes can help.
- **Graph:** `sigma` in `build_adjacency` controls the spatial neighborhood. Try `0.06..0.12`.
- **Channel selection:** Increase `GATE_L1` to push more sparsity; inspect the top‑k list for neuro‑plausible channels (C3/Cz/C4/CPz/FCz, etc.).
