# PhysioNet - CARMv2 (Standalone)
Standalone training with upgraded CARMv2. No imports from project modules.

In [None]:
import json, random, warnings
from pathlib import Path
import numpy as np, pandas as pd
import torch, torch.nn as nn, torch.nn.functional as F, torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
from tqdm.auto import tqdm
import mne
import matplotlib.pyplot as plt, seaborn as sns
warnings.filterwarnings('ignore'); sns.set_context('notebook', font_scale=1.0)
def seed(s=42):
    random.seed(s); np.random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s)
    torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False
seed(42); device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'); print('Device:', device)
EXPERIMENT_CONFIG={
 'data':{'preprocessed_dir':Path('data/physionet/derived/preprocessed'),'index_file':Path('data/physionet/derived/physionet_preprocessed_index.csv'),'selected_classes':[1,2],'tmin':-1.0,'tmax':5.0,'baseline':(-0.5,0)},
 'model':{'hidden_dim':40,'epochs':10,'learning_rate':1e-3,'batch_size':32,'n_folds':2,'patience':8},
 'carmv2':{'topk_k':8,'lambda_feat':0.3,'hop_alpha':0.5,'edge_dropout':0.1,'use_pairnorm':True,'use_residual':True,'low_rank_r':0},
 'output':{'results_dir':Path('results'),'results_file':'trial_carmv2_subject_results.csv','channel_selection_file':'trial_carmv2_channel_selection_results.csv','comparison_file':'trial_carmv2_vs_baseline.csv','results_summary_figure':'trial_carmv2_results_summary.png','adjacency_prefix':'trial_carmv2_adjacency'},
 'max_subjects':5,'min_runs_per_subject':10}
EXPERIMENT_CONFIG['output']['results_dir'].mkdir(exist_ok=True,parents=True)
print(json.dumps(EXPERIMENT_CONFIG,indent=2,default=str))


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class EEGDataset(Dataset):
    def __init__(self,x,y):
        self.x=torch.FloatTensor(x).unsqueeze(1)
        self.y=torch.LongTensor(y)
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self,i):
        return self.x[i],self.y[i]

def _row_path_value(row):
    for c in ['preprocessed_path','preprocessed_file','file','filepath','path','fif_path','raw_path']:
        if c in row and pd.notnull(row[c]):
            v=str(row[c])
            if v.lower().endswith('.fif'):
                return v
    return None

def load_preprocessed_data(fif,tmin,tmax,baseline):
    raw=mne.io.read_raw_fif(fif,preload=True,verbose='ERROR')
    try:
        ev=mne.find_events(raw,verbose='ERROR')
        ids={f'T{i}':i for i in np.unique(ev[:,2])}
        assert len(ev)>0
    except Exception:
        ev,ids=mne.events_from_annotations(raw,verbose='ERROR')
    if len(ev)==0:
        return None,None,raw.ch_names
    ep=mne.Epochs(raw,ev,event_id=ids,tmin=tmin,tmax=tmax,baseline=baseline,preload=True,verbose='ERROR')
    return ep.get_data(), ep.events[:,2], raw.ch_names

def filter_classes(x,y,cls):
    m=np.isin(y,cls)
    y=y[m]
    x=x[m]
    mapd={o:n for n,o in enumerate(sorted(cls))}
    y=np.array([mapd[int(a)] for a in y],dtype=np.int64)
    return x,y

def normalize(x):
    mu=x.mean(axis=(0,2),keepdims=True)
    sd=x.std(axis=(0,2),keepdims=True)+1e-8
    return (x-mu)/sd

def load_subject(index_df,subject,cfg):
    df=index_df[(index_df['subject']==subject)&(index_df['status']=='success')].copy()
    xs,ys=[],[]
    chn=None
    tmin=cfg['data']['tmin']
    tmax=cfg['data']['tmax']
    base=tuple(cfg['data']['baseline'])
    cls=cfg['data']['selected_classes']
    for _,r in df.iterrows():
        p=_row_path_value(r)
        if not p:
            continue
        x,y,ch=load_preprocessed_data(p,tmin,tmax,base)
        if x is None:
            continue
        x,y=filter_classes(x,y,cls)
        if len(y)==0:
            continue
        chn=chn or ch
        xs.append(x)
        ys.append(y)
    if len(xs)==0:
        return None,None,chn
    return np.concatenate(xs,0), np.concatenate(ys,0), chn

def subjects_with_min_runs(index_df,min_runs):
    df=index_df[index_df['status']=='success'].copy()
    counts=df.groupby('subject')['run'].count() if 'run' in df.columns else df.groupby('subject').size()
    return list(counts[counts>=int(min_runs)].index)

index_df=pd.read_csv(EXPERIMENT_CONFIG['data']['index_file'])
subjects=subjects_with_min_runs(index_df,EXPERIMENT_CONFIG['min_runs_per_subject'])[:EXPERIMENT_CONFIG['max_subjects']]
print('Subjects:',subjects)


In [None]:
def pairnorm(x,node_dim=2,eps=1e-6):
    m=x.mean(dim=node_dim,keepdim=True)
    xc=x-m
    v=(xc*xc).mean(dim=node_dim,keepdim=True)
    return xc/torch.sqrt(v+eps)

def build_feat_topk_adj(x,k):
    B,H,C,T=x.shape
    E=x.permute(2,1,0,3).contiguous().view(C,H,B*T).mean(2)
    En=F.normalize(E,p=2,dim=1)
    S=(En@En.t()).clamp_min(0.0)
    k=max(1,min(int(k),C))
    vals,idx=torch.topk(S,k,dim=1)
    M=torch.zeros_like(S)
    M.scatter_(1,idx,1.0)
    A=S*M
    A=torch.softmax(A,1)
    A=0.5*(A+A.t())
    return A

class CARMv2(nn.Module):
    def __init__(self,C,H,cfg):
        super().__init__()
        self.C=C
        self.H=H
        self.k=int(cfg['topk_k'])
        self.lf=float(cfg['lambda_feat'])
        self.ha=float(cfg['hop_alpha'])
        self.ed=float(cfg['edge_dropout'])
        self.pn=bool(cfg['use_pairnorm'])
        self.res=bool(cfg['use_residual'])
        r=int(cfg['low_rank_r'])
        if r>0:
            self.B=nn.Parameter(torch.empty(C,r))
            nn.init.xavier_uniform_(self.B)
            self.W=None
        else:
            self.W=nn.Parameter(torch.empty(C,C))
            nn.init.xavier_uniform_(self.W)
            self.B=None
        self.th=nn.Linear(H,H,bias=False)
        self.bn=nn.BatchNorm2d(H)
        self.act=nn.ELU()
        self.last=None
    
    def _learned(self,dev):
        W=self.W if self.B is None else (self.B@self.B.t())
        A=torch.sigmoid(W)
        A=0.5*(A+A.t())
        I=torch.eye(self.C,device=dev,dtype=A.dtype)
        At=A+I
        d=torch.pow(At.sum(1).clamp_min(1e-6),-0.5)
        D=torch.diag(d)
        return D@At@D
    
    def forward(self,x):
        B,H,C,T=x.shape
        Al=self._learned(x.device)
        A2=Al@Al
        Ah=(1-self.ha)*Al+self.ha*A2
        Af=build_feat_topk_adj(x,self.k)
        A=(1-self.lf)*Ah+self.lf*Af
        if self.training and self.ed>0:
            M=(torch.rand_like(A)>self.ed).float()
            A=0.5*((A*M)+(A*M).t())
            A=A+torch.eye(C,device=A.device,dtype=A.dtype)
        d=torch.pow(A.sum(1).clamp_min(1e-6),-0.5)
        D=torch.diag(d)
        A=D@A@D
        xb=x.permute(0,3,2,1).contiguous().view(B*T,C,H)
        xg=A@xb
        xg=self.th(xg)
        xg=xg.view(B,T,C,H).permute(0,3,2,1)
        out=xg+x if self.res else xg
        out=pairnorm(out,2) if self.pn else out
        out=self.bn(out)
        out=self.act(out)
        self.last={'learned':Al.detach().cpu().numpy(),'effective':A.detach().cpu().numpy()}
        return out
    
    def get_adjs(self):
        return self.last or {}

class TFEM(nn.Module):
    def __init__(self,i,o,k=16,pool=True):
        super().__init__()
        self.pool=pool
        self.cv=nn.Conv2d(i,o,kernel_size=(1,k),padding=(0,k//2),bias=False)
        self.bn=nn.BatchNorm2d(o)
        self.act=nn.ELU()
        self.pl=nn.AvgPool2d(kernel_size=(1,2)) if pool else None
    
    def forward(self,x):
        x=self.act(self.bn(self.cv(x)))
        return self.pl(x) if self.pool else x

class EEGARNN_CARMv2(nn.Module):
    def __init__(self,C,T,K,H,cfg):
        super().__init__()
        self.t1=TFEM(1,H,16,False)
        self.g1=CARMv2(C,H,cfg)
        self.t2=TFEM(H,H,16,True)
        self.g2=CARMv2(C,H,cfg)
        self.t3=TFEM(H,H,16,True)
        self.g3=CARMv2(C,H,cfg)
        with torch.no_grad():
            ft=self._f(torch.zeros(1,1,C,T))
            fs=ft.view(1,-1).size(1)
        self.fc1=nn.Linear(fs,256)
        self.do=nn.Dropout(0.5)
        self.fc2=nn.Linear(256,K)
    
    def _f(self,x):
        x=self.g1(self.t1(x))
        x=self.g2(self.t2(x))
        x=self.g3(self.t3(x))
        return x
    
    def forward(self,x):
        x=self._f(x)
        x=x.view(x.size(0),-1)
        x=F.relu(self.fc1(x))
        x=self.do(x)
        return self.fc2(x)
    
    def get_final_adjs(self):
        return self.g3.get_adjs()


In [None]:
def train_epoch(m,ld,crit,opt,dev):
    m.train()
    tl=0.0
    ap,al=[],[]
    for x,y in ld:
        x,y=x.to(dev),y.to(dev)
        opt.zero_grad()
        lg=m(x)
        ls=crit(lg,y)
        ls.backward()
        opt.step()
        tl+=ls.item()
        ap+=torch.argmax(lg,1).cpu().tolist()
        al+=y.cpu().tolist()
    return tl/max(1,len(ld)), accuracy_score(al,ap)

@torch.no_grad()
def evaluate(m,ld,crit,dev):
    m.eval()
    tl=0.0
    ap,al=[],[]
    for x,y in ld:
        x,y=x.to(dev),y.to(dev)
        lg=m(x)
        ls=crit(lg,y)
        tl+=ls.item()
        ap+=torch.argmax(lg,1).cpu().tolist()
        al+=y.cpu().tolist()
    return tl/max(1,len(ld)), accuracy_score(al,ap), ap, al

def train_model(m,tr,va,dev,ep,lr,pt):
    crit=nn.CrossEntropyLoss()
    opt=optim.Adam(m.parameters(),lr=lr,weight_decay=1e-4)
    try:
        sch=optim.lr_scheduler.ReduceLROnPlateau(opt,mode='min',factor=0.5,patience=3,verbose=False)
    except TypeError:
        sch=optim.lr_scheduler.ReduceLROnPlateau(opt,mode='min',factor=0.5,patience=3)
    best_acc=0.0
    best=None
    noimp=0
    hist={'train_loss':[],'train_acc':[],'val_loss':[],'val_acc':[]}
    for _ in range(ep):
        tl,ta=train_epoch(m,tr,crit,opt,dev)
        vl,va_acc,_,_=evaluate(m,va,crit,dev)
        hist['train_loss']+=[tl]
        hist['train_acc']+=[ta]
        hist['val_loss']+=[vl]
        hist['val_acc']+=[va_acc]
        try:
            sch.step(vl)
        except Exception:
            pass
        if va_acc>best_acc:
            best_acc=va_acc
            best={k:v.detach().cpu() if hasattr(v,'detach') else v for k,v in m.state_dict().items()}
            noimp=0
        else:
            noimp+=1
        if noimp>=pt:
            break
    if best is None:
        best={k:v.detach().cpu() if hasattr(v,'detach') else v for k,v in m.state_dict().items()}
    m.load_state_dict(best)
    return hist,best

def cross_validate_subject(x,y,chn,T,K,dev,cfg):
    C=x.shape[1]
    skf=StratifiedKFold(n_splits=int(cfg['model']['n_folds']),shuffle=True,random_state=42)
    bs=int(cfg['model']['batch_size'])
    ep=int(cfg['model']['epochs'])
    lr=float(cfg['model']['learning_rate'])
    pt=int(cfg['model']['patience'])
    folds=[]
    adjs=[]
    for f,(tr,va) in enumerate(skf.split(x,y)):
        Xtr,Xva=normalize(x[tr]),normalize(x[va])
        Ytr,Yva=y[tr],y[va]
        trl=DataLoader(EEGDataset(Xtr,Ytr),batch_size=bs,shuffle=True,num_workers=0)
        val=DataLoader(EEGDataset(Xva,Yva),batch_size=bs,shuffle=False,num_workers=0)
        m=EEGARNN_CARMv2(C,T,K,cfg['model']['hidden_dim'],cfg['carmv2']).to(dev)
        h,b=train_model(m,trl,val,dev,ep,lr,pt)
        m.load_state_dict(b)
        _,acc,_,_=evaluate(m,val,nn.CrossEntropyLoss(),dev)
        ad=m.get_final_adjs().get('learned',None)
        adjs.append(ad)
        folds.append({'fold':f,'val_acc':acc,'history':h})
    av=float(np.mean([z['val_acc'] for z in folds]))
    sd=float(np.std([z['val_acc'] for z in folds]))
    A=np.mean(np.stack([a for a in adjs if a is not None],0),0) if any(a is not None for a in adjs) else None
    return {'fold_results':folds,'avg_accuracy':av,'std_accuracy':sd,'adjacency_matrix':A,'channel_names':chn}


In [None]:
class ChannelSelector:
    def __init__(self,A,names):
        self.A=A
        self.names=np.array(names)
        self.C=A.shape[0]
    
    def edge_selection(self,k):
        E=[]
        for i in range(self.C):
            for j in range(i+1,self.C):
                E.append((i,j,abs(self.A[i,j])+abs(self.A[j,i])))
        E.sort(key=lambda t:t[2],reverse=True)
        top=E[:int(k)]
        idx=sorted(set([i for i,_,_ in top]+[j for _,j,_ in top]))
        return self.names[idx].tolist(), np.array(idx)
    
    def aggregation_selection(self,k):
        s=np.sum(np.abs(self.A),1)
        idx=np.sort(np.argsort(s)[-int(k):])
        return self.names[idx].tolist(), idx

def viz_adj(A,names,path=None):
    import seaborn as sns, matplotlib.pyplot as plt
    plt.figure(figsize=(10,8))
    sns.heatmap(A,xticklabels=names,yticklabels=names,cmap='RdYlGn',center=0,square=True,linewidths=0.4)
    plt.title('Adjacency (CARMv2)')
    plt.tight_layout()
    if path:
        plt.savefig(path,dpi=200,bbox_inches='tight')
        plt.close()
        return path

results_dir=EXPERIMENT_CONFIG['output']['results_dir']
all_results=[]
recs=[]

for subj in tqdm(subjects,desc='Training (CARMv2)'):
    X,Y,names=load_subject(index_df,subj,EXPERIMENT_CONFIG)
    if X is None or len(Y)==0:
        print('Skip',subj)
        continue
    C,T=X.shape[1],X.shape[2]
    K=len(set(EXPERIMENT_CONFIG['data']['selected_classes']))
    res=cross_validate_subject(X,Y,names,T,K,device,EXPERIMENT_CONFIG)
    all_results.append({
        'subject':subj,
        'num_trials':X.shape[0],
        'num_channels':C,
        'carmv2_acc':res['avg_accuracy'],
        'carmv2_std':res['std_accuracy'],
        'adjacency_matrix':res['adjacency_matrix'],
        'channel_names':res['channel_names'],
        'fold_results':res['fold_results']
    })
    recs.append({
        'subject':subj,
        'num_trials':X.shape[0],
        'num_channels':C,
        'carmv2_acc':res['avg_accuracy'],
        'carmv2_std':res['std_accuracy']
    })

rdf=pd.DataFrame.from_records(recs)
print('Subjects trained:',len(rdf))

if len(rdf)>0:
    print(f"Mean acc: {rdf['carmv2_acc'].mean():.4f} ± {rdf['carmv2_acc'].std():.4f}")
    print('Best:',rdf.loc[rdf['carmv2_acc'].idxmax(),'subject'],'Worst:',rdf.loc[rdf['carmv2_acc'].idxmin(),'subject'])
    p=results_dir/EXPERIMENT_CONFIG['output']['results_file']
    rdf.to_csv(p,index=False)
    print('Saved results to',p)
    cfgp=results_dir/'experiment_config_carmv2.json'
    open(cfgp,'w').write(json.dumps(EXPERIMENT_CONFIG,indent=2,default=str))
    print('Saved config to',cfgp)
    bi=int(np.argmax([r['carmv2_acc'] for r in all_results]))
    br=all_results[bi]
    A=br['adjacency_matrix']
    names=br['channel_names']
    if A is not None and names is not None:
        ap=results_dir/f"{EXPERIMENT_CONFIG['output']['adjacency_prefix']}_{br['subject']}.png"
        viz_adj(A,names,ap)
        print('Adjacency saved to',ap)
        sel=[]
        cs=ChannelSelector(A,names)
        for m in ['ES','AS']:
            for k in [10,15,20]:
                n,i=(cs.edge_selection(k) if m=='ES' else cs.aggregation_selection(k))
                sel.append({'subject':br['subject'],'method':m,'k':k,'channels':n})
                print(f'Selected ({m},k={k}):',n)
        chp=results_dir/EXPERIMENT_CONFIG['output']['channel_selection_file']
        pd.DataFrame(sel).to_csv(chp,index=False)
        print('Channel selection saved to',chp)
    fig,ax=plt.subplots(2,2,figsize=(10,8))
    ax[0,0].hist(rdf['carmv2_acc'],bins=15,color='steelblue',alpha=0.8)
    ax[0,0].set_title('Acc Dist')
    ax[0,1].scatter(rdf['num_trials'],rdf['carmv2_acc'])
    ax[0,1].set_title('#Trials vs Acc')
    top=rdf.nlargest(min(10,len(rdf)),'carmv2_acc')
    ax[1,0].barh(range(len(top)),top['carmv2_acc'])
    ax[1,0].set_yticks(range(len(top)))
    ax[1,0].set_yticklabels(top['subject'])
    ax[1,0].invert_yaxis()
    ax[1,0].set_title('Top Subjects')
    srt=rdf.sort_values('carmv2_acc')
    ax[1,1].plot(range(len(srt)),srt['carmv2_acc'],marker='o')
    ax[1,1].set_title('Ranking')
    plt.tight_layout()
    fp=results_dir/EXPERIMENT_CONFIG['output']['results_summary_figure']
    plt.savefig(fp,dpi=200,bbox_inches='tight')
    plt.show()
    print('Saved summary to',fp)

bp=results_dir/'subject_results.csv'
cp=results_dir/EXPERIMENT_CONFIG['output']['results_file']
if Path(bp).exists() and Path(cp).exists():
    base=pd.read_csv(bp)
    carm=pd.read_csv(cp)
    m=base[['subject','all_channels_acc','all_channels_std']].merge(carm[['subject','carmv2_acc','carmv2_std']],on='subject',how='inner')
    m['accuracy_delta']=m['carmv2_acc']-m['all_channels_acc']
    mp=results_dir/EXPERIMENT_CONFIG['output']['comparison_file']
    m.rename(columns={'all_channels_acc':'baseline_acc','all_channels_std':'baseline_std'},inplace=True)
    m.to_csv(mp,index=False)
    print('Saved comparison to',mp)
else:
    print('No baseline/carmv2 results yet for comparison')
