In [1]:
# Imports
import os, glob, numpy as np
import mne
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

# Patch for NumPy–MNE compatibility (NumPy ≥2.0)
if hasattr(np, "fromstring"):
    _orig = np.fromstring
    def _safe_fromstring(string, dtype=float, sep=""):
        try:
            return _orig(string, dtype=dtype, sep=sep)
        except ValueError:
            if isinstance(string, str):
                string = string.encode()
            return np.frombuffer(string, dtype=dtype)
    np.fromstring = _safe_fromstring

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

# Config
DATA_ROOT = "./../BCICIV_2a_gdf"
TARGET_SFREQ = 128.0
TMIN, TMAX = 0.0, 4.0
NT = int((TMAX - TMIN) * TARGET_SFREQ)
BATCH = 32
EPOCHS = 80
LR = 1e-3
PATIENCE = 10


Device: cpu


In [4]:
import os
DATA_ROOT = "./../BCICIV_2a_gdf"   # adjust if needed
files = sorted(os.listdir(DATA_ROOT))
print("Files in folder:", len(files))
for f in files:
    print(f)
# also check recursively
for root,dirs,fs in os.walk(DATA_ROOT):
    for f in fs:
        if any(tok in f.lower() for tok in ("label","truth","answer","y","targets",".mat",".csv",".txt")):
            print("POSSIBLE LABEL FILE:", os.path.join(root,f))


Files in folder: 18
A01E.gdf
A01T.gdf
A02E.gdf
A02T.gdf
A03E.gdf
A03T.gdf
A04E.gdf
A04T.gdf
A05E.gdf
A05T.gdf
A06E.gdf
A06T.gdf
A07E.gdf
A07T.gdf
A08E.gdf
A08T.gdf
A09E.gdf
A09T.gdf


In [2]:
# Robust BCICIV-2a loader (fixed fallback + clearer prints)
import os, re, numpy as np
import mne
from collections import Counter

def load_bci_gdf_robust(path,
                        l_freq=8.0, h_freq=30.0,
                        target_sfreq=128.0, tmin=0.0, tmax=4.0,
                        verbose=True):
    """
    Robust loader for BCICIV_2a .gdf files.
    Returns: X (n_trials, n_ch, nt), y (n_trials,), ch_names
    """
    raw = mne.io.read_raw_gdf(path, preload=True, verbose=False)

    # 1) ensure unique channel names (avoid MNE duplicate-name surprises)
    if len(set(raw.ch_names)) != len(raw.ch_names):
        mapping = {old: f"{old}_{i}" for i, old in enumerate(raw.ch_names)}
        raw.rename_channels(mapping)
        if verbose:
            print("Renamed duplicate channel names to unique names.")

    # 2) filter and resample
    raw.filter(l_freq, h_freq, verbose=False)
    if abs(raw.info['sfreq'] - target_sfreq) > 1e-3:
        raw.resample(target_sfreq, npad='auto', verbose=False)

    # 3) picks (EEG only)
    picks = mne.pick_types(raw.info, eeg=True, eog=False, exclude='bads')
    ch_names = [raw.ch_names[i] for i in picks]
    nt = int((tmax - tmin) * target_sfreq)

    # 4) annotation descriptions -> counts
    ann = raw.annotations
    descs = [str(d) for d in ann.description]
    desc_counts = Counter(descs)

    if verbose:
        print("Annotation description counts (sample):", list(desc_counts.items())[:20])

    # 5) choose MI event descriptions robustly
    # Known boundary/meta codes to exclude
    exclude_set = set(['32766', '768', '1023', '1072', '276', '277', '783'])

    # Prefer canonical BCICIV-2a cues
    preferred = ['769','770','771','772']
    found_pref = [p for p in preferred if p in desc_counts]

    mi_descs = []
    if len(found_pref) == 4:
        mi_descs = found_pref
        if verbose:
            print("Found canonical MI cues:", mi_descs)
    else:
        # find codes with counts around 72
        candidates = [d for d, c in desc_counts.items() if (60 <= c <= 80) and (d not in exclude_set)]
        if len(candidates) >= 4:
            mi_descs = sorted(candidates, key=lambda s: int(re.search(r'\d+', s).group()) if re.search(r'\d+', s) else s)[:4]
            if verbose:
                print("Found MI candidates by count (~72):", mi_descs, [desc_counts[d] for d in mi_descs])
        else:
            # final fallback: top-4 frequent non-excluded descriptions
            sorted_by_count = sorted(desc_counts.items(), key=lambda x: -x[1])
            filtered = [(d,c) for d,c in sorted_by_count if d not in exclude_set]
            if len(filtered) >= 4:
                mi_descs = [d for d,c in filtered[:4]]
                if verbose:
                    print("Fallback selected top-4 non-boundary descriptions:", mi_descs, [desc_counts[d] for d in mi_descs])
            else:
                # Nothing sensible found
                if verbose:
                    print("ERROR: Could not find 4 MI-related annotation descriptions in file:", path)
                    print("Available descriptions (all):", sorted_by_count[:20])
                return np.zeros((0,len(picks),nt), dtype=np.float32), np.zeros((0,),dtype=int), ch_names

    # 6) deterministic mapping: sort by integer if possible
    try:
        mi_descs_sorted = sorted(mi_descs, key=lambda s: int(re.search(r'\d+', s).group()))
    except:
        mi_descs_sorted = sorted(mi_descs)
    desc_to_label = {d: i for i, d in enumerate(mi_descs_sorted)}

    if verbose:
        print("Mapping desc -> label:", desc_to_label)
        print("Counts for chosen MI descs:", {d: desc_counts[d] for d in mi_descs_sorted})

    # 7) extract epochs aligned to MI cue onsets
    Xs, ys = [], []
    sfreq = raw.info['sfreq']
    for d, onset in zip(descs, ann.onset):
        if d in desc_to_label:
            sample = int(round(onset * sfreq))
            start = sample + int(round(tmin * sfreq))
            stop  = sample + int(round(tmax * sfreq))
            if start < 0 or stop > raw.n_times:
                continue
            data = raw.get_data(picks=picks, start=start, stop=stop)
            # fix time length
            if data.shape[1] != nt:
                if data.shape[1] > nt:
                    data = data[:, :nt]
                else:
                    pad = np.zeros((len(picks), nt - data.shape[1]), dtype=np.float32)
                    data = np.hstack([data, pad])
            Xs.append(data.astype(np.float32))
            ys.append(desc_to_label[d])

    if len(Xs) == 0:
        if verbose:
            print("No epochs extracted for MI codes. Returning empty arrays.")
        return np.zeros((0, len(picks), nt), dtype=np.float32), np.zeros((0,), dtype=int), ch_names

    X = np.stack(Xs, axis=0)
    y = np.array(ys, dtype=int)

    if verbose:
        print("Extracted epochs shape:", X.shape, "label distribution:", np.bincount(y))

    return X, y, ch_names


In [3]:
DATA_ROOT = "./../BCICIV_2a_gdf"  # adjust if needed
subjects = [f"A0{i}" for i in range(1,10)]

for subj in subjects:
    tpath = os.path.join(DATA_ROOT, subj + "T.gdf")
    epath = os.path.join(DATA_ROOT, subj + "E.gdf")
    print("\n----", subj, "----")
    Xtr, ytr, chs = load_bci_gdf_robust(tpath, verbose=True)
    Xte, yte, _   = load_bci_gdf_robust(epath, verbose=True)
    print(f"{subj} TRAIN shape: {Xtr.shape}  TEST shape: {Xte.shape}  channels: {len(chs)}")
    if ytr.size>0:
        print(" Train class counts:", np.bincount(ytr))
    if yte.size>0:
        print(" Test  class counts:", np.bincount(yte))



---- A01 ----


  next(self.gen)


Annotation description counts (sample): [('32766', 9), ('276', 1), ('277', 1), ('1072', 1), ('768', 288), ('772', 72), ('771', 72), ('770', 72), ('769', 72), ('1023', 15)]
Found canonical MI cues: ['769', '770', '771', '772']
Mapping desc -> label: {'769': 0, '770': 1, '771': 2, '772': 3}
Counts for chosen MI descs: {'769': 72, '770': 72, '771': 72, '772': 72}
Extracted epochs shape: (288, 25, 512) label distribution: [72 72 72 72]


  next(self.gen)


Annotation description counts (sample): [('32766', 9), ('276', 1), ('277', 1), ('1072', 1), ('768', 288), ('783', 288), ('1023', 7)]
ERROR: Could not find 4 MI-related annotation descriptions in file: ./../BCICIV_2a_gdf\A01E.gdf
Available descriptions (all): [('768', 288), ('783', 288), ('32766', 9), ('1023', 7), ('276', 1), ('277', 1), ('1072', 1)]
A01 TRAIN shape: (288, 25, 512)  TEST shape: (0, 25, 512)  channels: 25
 Train class counts: [72 72 72 72]

---- A02 ----


  next(self.gen)


Annotation description counts (sample): [('32766', 9), ('276', 1), ('277', 1), ('1072', 1), ('768', 288), ('769', 72), ('770', 72), ('771', 72), ('772', 72), ('1023', 18)]
Found canonical MI cues: ['769', '770', '771', '772']
Mapping desc -> label: {'769': 0, '770': 1, '771': 2, '772': 3}
Counts for chosen MI descs: {'769': 72, '770': 72, '771': 72, '772': 72}
Extracted epochs shape: (288, 25, 512) label distribution: [72 72 72 72]


  next(self.gen)


Annotation description counts (sample): [('32766', 9), ('276', 1), ('277', 1), ('1072', 1), ('768', 288), ('783', 288), ('1023', 5)]
ERROR: Could not find 4 MI-related annotation descriptions in file: ./../BCICIV_2a_gdf\A02E.gdf
Available descriptions (all): [('768', 288), ('783', 288), ('32766', 9), ('1023', 5), ('276', 1), ('277', 1), ('1072', 1)]
A02 TRAIN shape: (288, 25, 512)  TEST shape: (0, 25, 512)  channels: 25
 Train class counts: [72 72 72 72]

---- A03 ----


  next(self.gen)


Annotation description counts (sample): [('32766', 9), ('276', 1), ('277', 1), ('1072', 1), ('768', 288), ('769', 72), ('770', 72), ('771', 72), ('772', 72), ('1023', 18)]
Found canonical MI cues: ['769', '770', '771', '772']
Mapping desc -> label: {'769': 0, '770': 1, '771': 2, '772': 3}
Counts for chosen MI descs: {'769': 72, '770': 72, '771': 72, '772': 72}
Extracted epochs shape: (288, 25, 512) label distribution: [72 72 72 72]


  next(self.gen)


Annotation description counts (sample): [('32766', 9), ('276', 1), ('277', 1), ('1072', 1), ('768', 288), ('783', 288), ('1023', 15)]
ERROR: Could not find 4 MI-related annotation descriptions in file: ./../BCICIV_2a_gdf\A03E.gdf
Available descriptions (all): [('768', 288), ('783', 288), ('1023', 15), ('32766', 9), ('276', 1), ('277', 1), ('1072', 1)]
A03 TRAIN shape: (288, 25, 512)  TEST shape: (0, 25, 512)  channels: 25
 Train class counts: [72 72 72 72]

---- A04 ----


  next(self.gen)


Annotation description counts (sample): [('32766', 7), ('1072', 1), ('768', 288), ('772', 72), ('769', 72), ('770', 72), ('771', 72), ('1023', 26)]
Found canonical MI cues: ['769', '770', '771', '772']
Mapping desc -> label: {'769': 0, '770': 1, '771': 2, '772': 3}
Counts for chosen MI descs: {'769': 72, '770': 72, '771': 72, '772': 72}
Extracted epochs shape: (288, 25, 512) label distribution: [72 72 72 72]


  next(self.gen)


Annotation description counts (sample): [('32766', 9), ('276', 1), ('277', 1), ('1072', 1), ('768', 288), ('783', 288), ('1023', 60)]
ERROR: Could not find 4 MI-related annotation descriptions in file: ./../BCICIV_2a_gdf\A04E.gdf
Available descriptions (all): [('768', 288), ('783', 288), ('1023', 60), ('32766', 9), ('276', 1), ('277', 1), ('1072', 1)]
A04 TRAIN shape: (288, 25, 512)  TEST shape: (0, 25, 512)  channels: 25
 Train class counts: [72 72 72 72]

---- A05 ----


  next(self.gen)


Annotation description counts (sample): [('32766', 9), ('276', 1), ('277', 1), ('1072', 1), ('768', 288), ('1023', 26), ('769', 72), ('770', 72), ('771', 72), ('772', 72)]
Found canonical MI cues: ['769', '770', '771', '772']
Mapping desc -> label: {'769': 0, '770': 1, '771': 2, '772': 3}
Counts for chosen MI descs: {'769': 72, '770': 72, '771': 72, '772': 72}
Extracted epochs shape: (288, 25, 512) label distribution: [72 72 72 72]


  next(self.gen)


Annotation description counts (sample): [('32766', 9), ('276', 1), ('277', 1), ('1072', 1), ('768', 288), ('783', 288), ('1023', 12)]
ERROR: Could not find 4 MI-related annotation descriptions in file: ./../BCICIV_2a_gdf\A05E.gdf
Available descriptions (all): [('768', 288), ('783', 288), ('1023', 12), ('32766', 9), ('276', 1), ('277', 1), ('1072', 1)]
A05 TRAIN shape: (288, 25, 512)  TEST shape: (0, 25, 512)  channels: 25
 Train class counts: [72 72 72 72]

---- A06 ----


  next(self.gen)


Annotation description counts (sample): [('32766', 9), ('276', 1), ('277', 1), ('1072', 1), ('768', 288), ('772', 72), ('770', 72), ('769', 72), ('771', 72), ('1023', 69)]
Found canonical MI cues: ['769', '770', '771', '772']
Mapping desc -> label: {'769': 0, '770': 1, '771': 2, '772': 3}
Counts for chosen MI descs: {'769': 72, '770': 72, '771': 72, '772': 72}
Extracted epochs shape: (288, 25, 512) label distribution: [72 72 72 72]


  next(self.gen)


Annotation description counts (sample): [('32766', 9), ('276', 1), ('277', 1), ('1072', 1), ('768', 288), ('783', 288), ('1023', 73)]
ERROR: Could not find 4 MI-related annotation descriptions in file: ./../BCICIV_2a_gdf\A06E.gdf
Available descriptions (all): [('768', 288), ('783', 288), ('1023', 73), ('32766', 9), ('276', 1), ('277', 1), ('1072', 1)]
A06 TRAIN shape: (288, 25, 512)  TEST shape: (0, 25, 512)  channels: 25
 Train class counts: [72 72 72 72]

---- A07 ----


  next(self.gen)


Annotation description counts (sample): [('32766', 9), ('276', 1), ('277', 1), ('1072', 1), ('768', 288), ('769', 72), ('770', 72), ('771', 72), ('772', 72), ('1023', 17)]
Found canonical MI cues: ['769', '770', '771', '772']
Mapping desc -> label: {'769': 0, '770': 1, '771': 2, '772': 3}
Counts for chosen MI descs: {'769': 72, '770': 72, '771': 72, '772': 72}
Extracted epochs shape: (288, 25, 512) label distribution: [72 72 72 72]


  next(self.gen)


Annotation description counts (sample): [('32766', 9), ('276', 1), ('277', 1), ('1072', 1), ('768', 288), ('783', 288), ('1023', 11)]
ERROR: Could not find 4 MI-related annotation descriptions in file: ./../BCICIV_2a_gdf\A07E.gdf
Available descriptions (all): [('768', 288), ('783', 288), ('1023', 11), ('32766', 9), ('276', 1), ('277', 1), ('1072', 1)]
A07 TRAIN shape: (288, 25, 512)  TEST shape: (0, 25, 512)  channels: 25
 Train class counts: [72 72 72 72]

---- A08 ----


  next(self.gen)


Annotation description counts (sample): [('32766', 9), ('276', 1), ('277', 1), ('1072', 1), ('768', 288), ('769', 72), ('770', 72), ('1023', 24), ('771', 72), ('772', 72)]
Found canonical MI cues: ['769', '770', '771', '772']
Mapping desc -> label: {'769': 0, '770': 1, '771': 2, '772': 3}
Counts for chosen MI descs: {'769': 72, '770': 72, '771': 72, '772': 72}
Extracted epochs shape: (288, 25, 512) label distribution: [72 72 72 72]


  next(self.gen)


Annotation description counts (sample): [('32766', 9), ('276', 1), ('277', 1), ('1072', 1), ('768', 288), ('783', 288), ('1023', 17)]
ERROR: Could not find 4 MI-related annotation descriptions in file: ./../BCICIV_2a_gdf\A08E.gdf
Available descriptions (all): [('768', 288), ('783', 288), ('1023', 17), ('32766', 9), ('276', 1), ('277', 1), ('1072', 1)]
A08 TRAIN shape: (288, 25, 512)  TEST shape: (0, 25, 512)  channels: 25
 Train class counts: [72 72 72 72]

---- A09 ----


  next(self.gen)


Annotation description counts (sample): [('32766', 9), ('276', 1), ('277', 1), ('1072', 1), ('768', 288), ('1023', 51), ('772', 72), ('770', 72), ('769', 72), ('771', 72)]
Found canonical MI cues: ['769', '770', '771', '772']
Mapping desc -> label: {'769': 0, '770': 1, '771': 2, '772': 3}
Counts for chosen MI descs: {'769': 72, '770': 72, '771': 72, '772': 72}
Extracted epochs shape: (288, 25, 512) label distribution: [72 72 72 72]


  next(self.gen)


Annotation description counts (sample): [('32766', 9), ('276', 1), ('277', 1), ('1072', 1), ('768', 288), ('783', 288), ('1023', 24)]
ERROR: Could not find 4 MI-related annotation descriptions in file: ./../BCICIV_2a_gdf\A09E.gdf
Available descriptions (all): [('768', 288), ('783', 288), ('1023', 24), ('32766', 9), ('276', 1), ('277', 1), ('1072', 1)]
A09 TRAIN shape: (288, 25, 512)  TEST shape: (0, 25, 512)  channels: 25
 Train class counts: [72 72 72 72]


In [3]:
class TFEMBlock(nn.Module):
    def __init__(self, nch, F=16, k_t=15, pool=False, pool_k=4, drop=0.25):
        super().__init__()
        pad_t = (k_t - 1)//2
        self.conv = nn.Conv2d(1, F, kernel_size=(1,k_t), padding=(0,pad_t))
        self.bn = nn.BatchNorm2d(F)
        self.pw = nn.Conv2d(F, 1, kernel_size=1)
        self.pool = nn.AvgPool2d((1,pool_k)) if pool else None
        self.elu = nn.ELU()
        self.drop = nn.Dropout(drop)
    def forward(self, x):
        b,nch,t = x.shape
        x = x.unsqueeze(1)
        x = self.conv(x); x = self.bn(x); x = self.elu(x); x = self.pw(x)
        if self.pool: x = self.pool(x)
        x = self.drop(x)
        return x.squeeze(1)

class CARM(nn.Module):
    def __init__(self, Wref, tdim, drop=0.25):
        super().__init__()
        self.Wref = Wref
        self.Theta = nn.Parameter(torch.randn(tdim, tdim)*0.01)
        self.elu = nn.ELU(); self.drop = nn.Dropout(drop)
    def forward(self,x):
        h = torch.einsum('ij,bjf->bif', self.Wref, x)
        out = torch.einsum('bif,fg->big', h, self.Theta)
        out = self.elu(out); out = self.drop(out)
        return out

class EEG_ARNN(nn.Module):
    def __init__(self, nch, T0, ncls=4, F=16, pool_k=4, rho=0.001):
        super().__init__()
        self.nch, self.T0, self.rho = nch, T0, rho
        W0 = torch.ones(nch,nch) - torch.eye(nch)
        Wt = W0 + torch.eye(nch)
        D = Wt.sum(dim=1)
        Dinv = torch.diag(1.0/torch.sqrt(D + 1e-12))
        self.W = nn.Parameter(Dinv @ Wt @ Dinv)
        self.tf1 = TFEMBlock(nch, F=F, k_t=15, pool=False)
        self.c1  = CARM(self.W, tdim=T0)
        self.tf2 = TFEMBlock(nch, F=F, pool=True, pool_k=pool_k)
        T2 = T0//pool_k
        self.c2  = CARM(self.W, tdim=T2)
        self.tf3 = TFEMBlock(nch, F=F, pool=True, pool_k=pool_k)
        T3 = T2//pool_k
        self.c3  = CARM(self.W, tdim=T3)
        self.fuse = nn.Conv2d(1,16,kernel_size=(nch,1))
        self.bn = nn.BatchNorm2d(16)
        self.elu = nn.ELU(); self.drop = nn.Dropout(0.25)
        self.fc = nn.Linear(16*T3, ncls)
    def forward(self,x):
        x = self.tf1(x); x = self.c1(x)
        x = self.tf2(x); x = self.c2(x)
        x = self.tf3(x); x = self.c3(x)
        x = x.unsqueeze(1)
        x = self.fuse(x); x = self.bn(x); x = self.elu(x); x = self.drop(x)
        x = x.squeeze(2)
        b,oc,t = x.shape
        x = x.view(b, oc*t)
        return self.fc(x)


In [4]:
def train_one_epoch(model, loader, opt, crit):
    model.train(); total=0; n=0
    for xb,yb in loader:
        xb,yb = xb.to(device), yb.to(device)
        opt.zero_grad()
        out = model(xb); loss = crit(out,yb)
        loss.backward(); opt.step()
        with torch.no_grad():
            if model.W.grad is not None:
                model.W.data = (1.0 - model.rho)*model.W.data - model.rho*model.W.grad.data
                model.W.grad.zero_()
        total += loss.item()*xb.size(0); n += xb.size(0)
    return total/max(1,n)

@torch.no_grad()
def evaluate(model, loader):
    model.eval(); preds=[]; ys=[]
    for xb,yb in loader:
        xb=xb.to(device); out=model(xb)
        preds.append(out.argmax(1).cpu().numpy())
        ys.append(yb.numpy())
    preds=np.concatenate(preds); ys=np.concatenate(ys)
    return accuracy_score(ys,preds)


In [5]:
subjects = [f"A0{i}" for i in range(1,10)]
accs = []

for subj in subjects:
    train_path = os.path.join(DATA_ROOT, f"{subj}T.gdf")
    test_path  = os.path.join(DATA_ROOT, f"{subj}E.gdf")
    print(f"\n===== {subj} =====")
    Xtr,ytr,chs = load_bci_gdf(train_path)
    Xte,yte,_   = load_bci_gdf(test_path)
    print(f"Train {Xtr.shape}, Test {Xte.shape}")
    tr_loader = DataLoader(TensorDataset(torch.tensor(Xtr),torch.tensor(ytr)), batch_size=BATCH, shuffle=True)
    te_loader = DataLoader(TensorDataset(torch.tensor(Xte),torch.tensor(yte)), batch_size=BATCH, shuffle=False)

    model = EEG_ARNN(nch=Xtr.shape[1], T0=Xtr.shape[2], ncls=4).to(device)
    opt = torch.optim.Adam([p for n,p in model.named_parameters() if n!='W'], lr=LR)
    crit = nn.CrossEntropyLoss()

    best_acc, best_ep, patience = 0.0, 0, 0
    for ep in range(1, EPOCHS+1):
        loss = train_one_epoch(model,tr_loader,opt,crit)
        acc = evaluate(model,te_loader)
        if acc > best_acc:
            best_acc, best_ep, patience = acc, ep, 0
        else:
            patience += 1
        if patience >= PATIENCE:
            print(f"Early stop @ {ep}, best {best_acc:.4f}")
            break
        if ep % 10 == 0 or ep==1:
            print(f"Ep {ep:03d} loss={loss:.4f} test_acc={acc:.4f}")
    accs.append(best_acc)
    print(f"Subject {subj} best acc={best_acc:.4f} (epoch {best_ep})")

print("\n===== SUMMARY =====")
for s,a in zip(subjects,accs):
    print(f"{s}: {a*100:.2f}%")
print(f"Mean accuracy: {np.mean(accs)*100:.2f}%  ± {np.std(accs)*100:.2f}")



===== A01 =====


  next(self.gen)
  next(self.gen)


Train (18, 25, 512), Test (10, 25, 512)
Ep 001 loss=1.3242 test_acc=0.7000
Ep 010 loss=0.3561 test_acc=0.7000
Early stop @ 11, best 0.7000
Subject A01 best acc=0.7000 (epoch 1)

===== A02 =====


  next(self.gen)
  next(self.gen)


Train (21, 25, 512), Test (8, 25, 512)
Ep 001 loss=1.1737 test_acc=0.1250
Ep 010 loss=0.6437 test_acc=0.1250
Early stop @ 11, best 0.1250
Subject A02 best acc=0.1250 (epoch 1)

===== A03 =====


  next(self.gen)
  next(self.gen)


Train (21, 25, 512), Test (18, 25, 512)
Ep 001 loss=1.6647 test_acc=0.0556
Ep 010 loss=0.5912 test_acc=0.0556
Early stop @ 11, best 0.0556
Subject A03 best acc=0.0556 (epoch 1)

===== A04 =====


  next(self.gen)
  next(self.gen)


Train (322, 25, 512), Test (63, 25, 512)
Ep 001 loss=0.5994 test_acc=0.0159
Ep 010 loss=0.3090 test_acc=0.0159
Early stop @ 11, best 0.0159
Subject A04 best acc=0.0159 (epoch 1)

===== A05 =====


  next(self.gen)
  next(self.gen)


Train (29, 25, 512), Test (15, 25, 512)
Ep 001 loss=1.4386 test_acc=0.0667
Ep 010 loss=0.3985 test_acc=0.8000
Early stop @ 13, best 0.8000
Subject A05 best acc=0.8000 (epoch 3)

===== A06 =====


  next(self.gen)
  next(self.gen)


Train (72, 25, 512), Test (76, 25, 512)
Ep 001 loss=1.4198 test_acc=0.0132
Ep 010 loss=0.2371 test_acc=0.9605
Early stop @ 15, best 0.9605
Subject A06 best acc=0.9605 (epoch 5)

===== A07 =====


  next(self.gen)
  next(self.gen)


Train (20, 25, 512), Test (14, 25, 512)
Ep 001 loss=1.3986 test_acc=0.0714
Ep 010 loss=0.5947 test_acc=0.7857
Early stop @ 13, best 0.7857
Subject A07 best acc=0.7857 (epoch 3)

===== A08 =====


  next(self.gen)
  next(self.gen)


Train (27, 25, 512), Test (20, 25, 512)
Ep 001 loss=1.3231 test_acc=0.8500
Ep 010 loss=0.5067 test_acc=0.8500
Early stop @ 11, best 0.8500
Subject A08 best acc=0.8500 (epoch 1)

===== A09 =====


  next(self.gen)
  next(self.gen)


Train (54, 25, 512), Test (27, 25, 512)
Ep 001 loss=1.2606 test_acc=0.0370
Ep 010 loss=0.3133 test_acc=0.8889
Early stop @ 15, best 0.8889
Subject A09 best acc=0.8889 (epoch 5)

===== SUMMARY =====
A01: 70.00%
A02: 12.50%
A03: 5.56%
A04: 1.59%
A05: 80.00%
A06: 96.05%
A07: 78.57%
A08: 85.00%
A09: 88.89%
Mean accuracy: 57.57%  ± 36.79


In [11]:
# DIAGNOSTIC: inspect .gdf file
import numpy as np, re, pprint
import mne

def inspect_gdf(path):
    print("FILE:", path)
    raw = mne.io.read_raw_gdf(path, preload=False, verbose=False)
    print("Raw info channels:", len(raw.ch_names))
    print("First 30 channel names:", raw.ch_names[:30])
    # show if duplicate channel names exist
    if len(set(raw.ch_names)) != len(raw.ch_names):
        print("WARNING: duplicate channel names detected.")
    print("\nFirst 20 annotations (description, onset [s], duration):")
    for a in raw.annotations[:20]:
        print(a['description'], round(a['onset'],3), round(a['duration'],3))
    # events_from_annotations
    events, event_id = mne.events_from_annotations(raw, verbose=False)
    print("\nMNE event_id dict:")
    pprint.pprint(event_id)
    unique, counts = np.unique(events[:,2], return_counts=True)
    print("\nUnique numeric event codes and counts (from events array):")
    for u,c in zip(unique, counts):
        print(u, c)
    return raw, events, event_id

# Example usage (change to a subject file)
raw, events, event_id = inspect_gdf("./../BCICIV_2a_gdf/A02T.gdf")


FILE: ./../BCICIV_2a_gdf/A02T.gdf
Raw info channels: 25
First 30 channel names: ['EEG-Fz', 'EEG-0', 'EEG-1', 'EEG-2', 'EEG-3', 'EEG-4', 'EEG-5', 'EEG-C3', 'EEG-6', 'EEG-Cz', 'EEG-7', 'EEG-C4', 'EEG-8', 'EEG-9', 'EEG-10', 'EEG-11', 'EEG-12', 'EEG-13', 'EEG-14', 'EEG-Pz', 'EEG-15', 'EEG-16', 'EOG-left', 'EOG-central', 'EOG-right']

First 20 annotations (description, onset [s], duration):
32766 0.0 0.004
276 0.0 126.048
32766 126.052 0.004
277 126.052 104.352
32766 230.408 0.004
1072 230.408 154.224
32766 384.636 0.004
768 386.036 7.5
769 388.036 1.252
768 394.048 7.5
770 396.048 1.252
768 401.72 7.5
770 403.72 1.252
768 409.528 7.5
769 411.528 1.252
768 417.56 7.5
770 419.56 1.252
768 426.004 7.5
769 428.004 1.252
768 433.672 7.5

MNE event_id dict:
{np.str_('1023'): 1,
 np.str_('1072'): 2,
 np.str_('276'): 3,
 np.str_('277'): 4,
 np.str_('32766'): 5,
 np.str_('768'): 6,
 np.str_('769'): 7,
 np.str_('770'): 8,
 np.str_('771'): 9,
 np.str_('772'): 10}

Unique numeric event codes and count

  next(self.gen)
