In [None]:
import os, numpy as np, pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt, seaborn as sns
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score, StratifiedKFold
import statsmodels.formula.api as smf



In [None]:
lookup = pd.read_csv("config/scanners_lookup_fmri.csv")
# load pre / post FC matrices as example
pre_dir = "derivatives/preproc/fmri/connectivity_matrices"
post_dir = "derivatives/harmonized/fmri/fc_covbat"
# load first 50 subjects for speed
lookup_sub = lookup.head(50)
def load_fc(row, d):
    return np.load(os.path.join(d, row['fc_file']))

pre_mats = np.stack([load_fc(r, pre_dir) for _, r in lookup_sub.iterrows()])
post_mats = np.stack([load_fc(r, post_dir) for _, r in lookup_sub.iterrows()])


In [None]:
def vec_upper(mat):
    iu = np.triu_indices(mat.shape[0],1)
    return mat[iu]

X_pre = np.stack([vec_upper(m) for m in pre_mats])
X_post = np.stack([vec_upper(m) for m in post_mats])
y = lookup_sub['scanner_id'].values

clf = RandomForestClassifier(n_estimators=200, random_state=0)
cv = StratifiedKFold(5, shuffle=True, random_state=0)
score_pre = cross_val_score(clf, X_pre, y, cv=cv, scoring='balanced_accuracy').mean()
score_post = cross_val_score(clf, X_post, y, cv=cv, scoring='balanced_accuracy').mean()
print("Scanner balanced accuracy â€” pre:", score_pre, " post:", score_post)


In [None]:
lookup_sub['global_fc_pre'] = [m.mean() for m in pre_mats]
lookup_sub['global_fc_post'] = [m.mean() for m in post_mats]
res_pre = smf.ols("global_fc_pre ~ age + C(scanner_id)", data=lookup_sub).fit()
res_post = smf.ols("global_fc_post ~ age + C(scanner_id)", data=lookup_sub).fit()
print("pre age coef:", res_pre.params['age'], "post age coef:", res_post.params['age'])


In [None]:
# load sample pre/post NIfTI files for one subject
pre_img = nib.load("derivatives/preproc/fmri/voxel_maps/sub-01_map.nii.gz").get_fdata()
post_img = nib.load("derivatives/harmonized/fmri/voxel_combat/sub-01_map_harm.nii.gz").get_fdata()
slice_idx = pre_img.shape[2]//2
fig, ax = plt.subplots(1,2, figsize=(10,5))
ax[0].imshow(pre_img[:,:,slice_idx].T, origin='lower', cmap='RdBu_r')
ax[0].set_title("pre")
ax[1].imshow(post_img[:,:,slice_idx].T, origin='lower', cmap='RdBu_r')
ax[1].set_title("post")
plt.show()
