# Validation notebook
Checks:
 - scanner variance before/after harmonization
 - scanner classifier accuracy before/after
 - preservation of biological effects (age slope & effect sizes)


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score, StratifiedKFold
import statsmodels.formula.api as smf
import statsmodels.api as sm
import statsmodels.formula.api as smf
import warnings
warnings.filterwarnings("ignore")
sns.set(style="whitegrid")


In [None]:
# Paths
pre_csv = "../derivatives/preproc/anat/roi_features.csv"
combat_csv = "../derivatives/harmonized/structural/roi_features_combat.csv"
longcombat_csv = "../derivatives/harmonized/structural/roi_longcombt.csv"  

pre = pd.read_csv(pre_csv)
post = pd.read_csv(combat_csv)

# ensure same ordering / sanity check
assert set(pre.subject_id)==set(post.subject_id)


In [None]:
# choose example ROI
roi = [c for c in pre.columns if c.startswith("roi_")][0]
print("Example ROI:", roi)

# Fit mixed model: ROI ~ age + sex + (1 | scanner)
# Using statsmodels: fit OLS and estimate variance across scanners via group means as proxy
pre['roi'] = pre[roi]
post['roi'] = post[roi]

# simple scanner variance: group variance of scanner means / total variance
def scanner_variance(df):
    grand_var = df['roi'].var()
    scanner_means = df.groupby('scanner_id')['roi'].mean()
    between_var = scanner_means.var()
    return dict(grand_var=grand_var, between_var=between_var, prop_between=between_var/grand_var)

print("Pre-harmonization variance:", scanner_variance(pre))
print("Post-harmonization variance:", scanner_variance(post))


In [None]:
# fit linear model with scanner as categorical fixed effect to extract scanner R^2 proxy
def scanner_effect_R2(df):
    formula = "roi ~ age + sex + C(scanner_id)"
    model = smf.ols(formula, data=df).fit()
    # compute partial R2 for scanner: compare model with/without scanner
    base = smf.ols("roi ~ age + sex", data=df).fit()
    r2_scanner = model.rsquared - base.rsquared
    return model.rsquared, base.rsquared, r2_scanner

print("Pre R2s:", scanner_effect_R2(pre))
print("Post R2s:", scanner_effect_R2(post))


In [None]:
# small pipeline: use a subset of ROIs for speed
roi_cols = [c for c in pre.columns if c.startswith("roi_")]
X_pre = pre[roi_cols].values
y_pre = pre['scanner_id'].values

X_post = post[roi_cols].values
y_post = post['scanner_id'].values

clf = RandomForestClassifier(n_estimators=200, random_state=42)
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
score_pre = cross_val_score(clf, X_pre, y_pre, cv=cv, scoring='balanced_accuracy').mean()
score_post = cross_val_score(clf, X_post, y_post, cv=cv, scoring='balanced_accuracy').mean()
print(f"Scanner prediction balanced accuracy â€” pre: {score_pre:.3f}, post: {score_post:.3f}")


In [None]:
# For each ROI compute age slope pre/post and compare
def age_slopes(df, roi_cols):
    slopes = {}
    for roi in roi_cols:
        try:
            res = smf.ols(f"{roi} ~ age + sex", data=df).fit()
            slopes[roi] = res.params['age']
        except Exception:
            slopes[roi] = np.nan
    return pd.Series(slopes)

slopes_pre = age_slopes(pre, roi_cols)
slopes_post = age_slopes(post, roi_cols)
comp = pd.DataFrame({"pre":slopes_pre, "post":slopes_post}).dropna()
comp['abs_change_pct'] = 100*(comp['post'] - comp['pre']).abs() / comp['pre'].replace(0, np.nan).abs()
comp.describe().loc[['mean','50%','max']]


In [None]:
sample_rois = roi_cols[:6]
fig, axes = plt.subplots(2,3, figsize=(14,8))
for ax, r in zip(axes.flat, sample_rois):
    sns.kdeplot(pre[r], ax=ax, label="pre", color='C0')
    sns.kdeplot(post[r], ax=ax, label="post", color='C1')
    ax.set_title(r)
    ax.legend()
plt.tight_layout()


In [None]:
outdir = "../results/validation"
os.makedirs(outdir, exist_ok=True)

metrics = {
    "scanner_bal_acc_pre": float(score_pre),
    "scanner_bal_acc_post": float(score_post),
    "scanner_R2_pre": float(scanner_effect_R2(pre)[0]),
    "scanner_R2_post": float(scanner_effect_R2(post)[0]),
    "mean_age_slope_change_pct": float(comp['abs_change_pct'].mean())
}
import json
with open(os.path.join(outdir,"validation_summary.json"), "w") as f:
    json.dump(metrics, f, indent=2)
print("Saved validation summary:", metrics)
