In [None]:
# =========================
# Colab Cell 1: Mount Drive + project paths
# =========================
from google.colab import drive
drive.mount("/content/drive")

import os, pathlib, textwrap, json, re, shutil, glob, math, random, time
from datetime import datetime


BASE_DIR = pathlib.Path("/content/drive/MyDrive/EMBC_project")

DATA_DIR = BASE_DIR / "Data"
RAW_DIR  = DATA_DIR / "raw"
PROC_DIR = DATA_DIR / "processed"

OUT_DIR  = BASE_DIR / "outputs"
FIG_DIR  = OUT_DIR / "figures"
TAB_DIR  = OUT_DIR / "tables"
CACHE_DIR = OUT_DIR / "cache"

for p in [RAW_DIR, PROC_DIR, FIG_DIR, TAB_DIR, CACHE_DIR]:
    p.mkdir(parents=True, exist_ok=True)

print("BASE_DIR:", BASE_DIR)
print("RAW_DIR :", RAW_DIR)
print("PROC_DIR:", PROC_DIR)
print("OUT_DIR :", OUT_DIR)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
BASE_DIR: /content/drive/MyDrive/EMBC_project
RAW_DIR : /content/drive/MyDrive/EMBC_project/Data/raw
PROC_DIR: /content/drive/MyDrive/EMBC_project/Data/processed
OUT_DIR : /content/drive/MyDrive/EMBC_project/outputs


In [None]:
# =========================
# Colab Cell 2: Install dependencies
# =========================
!pip -q install nibabel pandas numpy scipy scikit-image scikit-learn matplotlib tqdm joblib \
  SimpleITK pyradiomics umap-learn torch torchvision torchaudio

# Optional: faster file listing / progress bars
!pip -q install rich

!pip install -q pyradiomics



  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py egg_info[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Preparing metadata (setup.py) ... [?25l[?25herror
[1;31merror[0m: [1mmetadata-generation-failed[0m

[31m×[0m Encountered error while generating package metadata.
[31m╰─>[0m See above for output.

[1;35mnote[0m: This is an issue with the package mentioned above, not pip.
[1;36mhint[0m: See above for details.
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  [1;31merror[0m: [1ms

In [None]:
!pip install SimpleITK
!pip install pyradiomics

Collecting pyradiomics
  Using cached pyradiomics-3.1.0.tar.gz (34.5 MB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Discarding [4;34mhttps://files.pythonhosted.org/packages/03/c1/20fc2c50ab1e3304da36d866042a1905a2b05a1431ece35448ab6b4578f2/pyradiomics-3.1.0.tar.gz (from https://pypi.org/simple/pyradiomics/)[0m: [33mRequested pyradiomics from https://files.pythonhosted.org/packages/03/c1/20fc2c50ab1e3304da36d866042a1905a2b05a1431ece35448ab6b4578f2/pyradiomics-3.1.0.tar.gz has inconsistent version: expected '3.1.0', but metadata has '3.0.1a1'[0m
  Using cached pyradiomics-3.0.1.tar.gz (34.5 MB)
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py egg_info[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess

In [None]:
# =========================
# Colab Cell 3: Imports + global config
# =========================
import numpy as np
import pandas as pd

from tqdm.auto import tqdm
from joblib import Memory

import nibabel as nib
import SimpleITK as sitk

from scipy.spatial.distance import cosine as cosine_distance
from scipy.stats import wilcoxon, ttest_rel

import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

import umap

import torch
import torchvision
import torchvision.transforms.functional as TF

try:
    from radiomics import featureextractor
except Exception as e:
    print("Radiomics import issue:", e)
    featureextractor = None

mem = Memory(location=str(CACHE_DIR), verbose=0)

# -------------------------
# Runtime knobs
# -------------------------
FAST_DEBUG = False
DEBUG_MAX_PAIRED_SUBJECTS = 3          # used only if FAST_DEBUG=True
LOWFIELD_SESSION_POLICY = "ses-01"     # or "first" or "all"

# Which modalities to include
USE_MODALITIES = ["T1w", "T2w", "FLAIR", "ADC"]  # treat ADC as diffusion summary

# Registration settings
REG_SHRINK_FACTORS = [4, 2, 1]         # multi-resolution pyramid
REG_SMOOTHING_SIGMAS = [2, 1, 0]       # in physical units
REG_MAX_ITERS = 80 if FAST_DEBUG else 200

# Embedding extraction
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE)

SEED = 1337
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)


Radiomics import issue: No module named 'radiomics'
DEVICE: cuda


In [None]:
ROOT_3T   = pathlib.Path("/content/drive/MyDrive/EMBC_project/Data/3T")
ROOT_64MT = pathlib.Path("/content/drive/MyDrive/EMBC_project/Data/64T")

In [None]:
# =========================
# REPLACEMENT Colab Cell 5: Build a manifest (fixed 64mT anat parsing, safer columns)
# =========================
import re
from pathlib import Path
import pandas as pd
from tqdm.auto import tqdm

def list_subjects(root: Path):
    return sorted([p.name for p in root.glob("sub-*") if p.is_dir()])

def parse_3t_anat_files(sub_dir: Path):
    anat = sub_dir / "anat"
    rows = []
    if not anat.exists():
        return rows
    for f in anat.glob("*.nii.gz"):
        name = f.name
        m = re.match(r"(sub-\d+)_acq-(highres|lowres)_(T1w|T2w|FLAIR)\.nii\.gz$", name)
        if m:
            sub, acq, mod = m.group(1), m.group(2), m.group(3)
            json_sidecar = f.with_suffix("").with_suffix(".json")
            rows.append({
                "field": "3T",
                "subject": sub,
                "session": None,
                "acq": acq,
                "run": None,
                "modality": mod,
                "nifti_path": str(f),
                "json_path": str(json_sidecar) if json_sidecar.exists() else None,
                "bval_path": None,
                "bvec_path": None,
            })
    return rows

def parse_3t_dwi_adc_files(sub_dir: Path):
    dwi = sub_dir / "dwi"
    rows = []
    if not dwi.exists():
        return rows
    for f in dwi.glob("*.nii.gz"):
        name = f.name
        m = re.match(r"(sub-\d+)_acq-(highres|lowres)_(adc)\.nii\.gz$", name, re.IGNORECASE)
        if m:
            sub, acq = m.group(1), m.group(2)
            json_sidecar = f.with_suffix("").with_suffix(".json")
            rows.append({
                "field": "3T",
                "subject": sub,
                "session": None,
                "acq": acq,
                "run": None,
                "modality": "ADC",
                "nifti_path": str(f),
                "json_path": str(json_sidecar) if json_sidecar.exists() else None,
                "bval_path": None,
                "bvec_path": None,
            })
    return rows

def parse_64mt_anat_files(sub_dir: Path):
    """
    Robust to both:
      sub-0001_ses-01_T1w.nii.gz
      sub-0001_ses-01_acq-something_T1w.nii.gz
    Excludes localizers like *_T1w_acq-localizer.nii.gz
    """
    rows = []
    for ses_dir in sorted([p for p in sub_dir.glob("ses-*") if p.is_dir()]):
        anat = ses_dir / "anat"
        if not anat.exists():
            continue

        for f in anat.glob("*.nii.gz"):
            name = f.name

            # Exclude localizer (common pattern)
            if "acq-localizer" in name:
                continue

            m = re.match(
                r"(sub-\d+)_(ses-\d+)(?:_(acq-[^_]+))?_(T1w|T2w|FLAIR)\.nii\.gz$",
                name
            )
            if m:
                sub, ses, acq_opt, mod = m.group(1), m.group(2), m.group(3), m.group(4)
                json_sidecar = f.with_suffix("").with_suffix(".json")
                rows.append({
                    "field": "64mT",
                    "subject": sub,
                    "session": ses,
                    "acq": acq_opt,          # may be None
                    "run": None,
                    "modality": mod,
                    "nifti_path": str(f),
                    "json_path": str(json_sidecar) if json_sidecar.exists() else None,
                    "bval_path": None,
                    "bvec_path": None,
                })
    return rows

def parse_64mt_dwi_adc_files(sub_dir: Path):
    rows = []
    for ses_dir in sorted([p for p in sub_dir.glob("ses-*") if p.is_dir()]):
        dwi = ses_dir / "dwi"
        if not dwi.exists():
            continue
        for f in dwi.glob("*.nii.gz"):
            name = f.name
            m = re.match(r"(sub-\d+)_(ses-\d+)_run-(\d+)_(ADC)\.nii\.gz$", name, re.IGNORECASE)
            if m:
                sub, ses, run = m.group(1), m.group(2), m.group(3)
                json_sidecar = f.with_suffix("").with_suffix(".json")
                rows.append({
                    "field": "64mT",
                    "subject": sub,
                    "session": ses,
                    "acq": None,
                    "run": run,
                    "modality": "ADC",
                    "nifti_path": str(f),
                    "json_path": str(json_sidecar) if json_sidecar.exists() else None,
                    "bval_path": None,
                    "bvec_path": None,
                })
    return rows

rows = []

for sub in tqdm(list_subjects(ROOT_3T), desc="Index 3T"):
    sub_dir = ROOT_3T / sub
    rows.extend(parse_3t_anat_files(sub_dir))
    rows.extend(parse_3t_dwi_adc_files(sub_dir))

for sub in tqdm(list_subjects(ROOT_64MT), desc="Index 64mT"):
    sub_dir = ROOT_64MT / sub
    rows.extend(parse_64mt_anat_files(sub_dir))
    rows.extend(parse_64mt_dwi_adc_files(sub_dir))

manifest = pd.DataFrame(rows)

# Make sure these columns exist and have consistent types
for col in ["session", "acq", "run", "json_path", "bval_path", "bvec_path"]:
    if col not in manifest.columns:
        manifest[col] = None

manifest_path = TAB_DIR / "manifest_all.csv"
manifest.to_csv(manifest_path, index=False)

print("Saved manifest:", manifest_path)
print("Rows:", len(manifest))
display(manifest.head(15))
display(manifest.groupby(["field", "modality"]).size().reset_index(name="n"))


Index 3T:   0%|          | 0/11 [00:00<?, ?it/s]

Index 64mT:   0%|          | 0/65 [00:00<?, ?it/s]

Saved manifest: /content/drive/MyDrive/EMBC_project/outputs/tables/manifest_all.csv
Rows: 173


Unnamed: 0,field,subject,session,acq,run,modality,nifti_path,json_path,bval_path,bvec_path
0,3T,sub-0011,,lowres,,T1w,/content/drive/MyDrive/EMBC_project/Data/3T/su...,/content/drive/MyDrive/EMBC_project/Data/3T/su...,,
1,3T,sub-0011,,lowres,,FLAIR,/content/drive/MyDrive/EMBC_project/Data/3T/su...,/content/drive/MyDrive/EMBC_project/Data/3T/su...,,
2,3T,sub-0011,,highres,,FLAIR,/content/drive/MyDrive/EMBC_project/Data/3T/su...,/content/drive/MyDrive/EMBC_project/Data/3T/su...,,
3,3T,sub-0011,,highres,,T1w,/content/drive/MyDrive/EMBC_project/Data/3T/su...,/content/drive/MyDrive/EMBC_project/Data/3T/su...,,
4,3T,sub-0011,,highres,,T2w,/content/drive/MyDrive/EMBC_project/Data/3T/su...,/content/drive/MyDrive/EMBC_project/Data/3T/su...,,
5,3T,sub-0011,,lowres,,T2w,/content/drive/MyDrive/EMBC_project/Data/3T/su...,/content/drive/MyDrive/EMBC_project/Data/3T/su...,,
6,3T,sub-0011,,lowres,,ADC,/content/drive/MyDrive/EMBC_project/Data/3T/su...,/content/drive/MyDrive/EMBC_project/Data/3T/su...,,
7,3T,sub-0011,,highres,,ADC,/content/drive/MyDrive/EMBC_project/Data/3T/su...,/content/drive/MyDrive/EMBC_project/Data/3T/su...,,
8,3T,sub-0015,,lowres,,T1w,/content/drive/MyDrive/EMBC_project/Data/3T/su...,/content/drive/MyDrive/EMBC_project/Data/3T/su...,,
9,3T,sub-0015,,lowres,,T2w,/content/drive/MyDrive/EMBC_project/Data/3T/su...,/content/drive/MyDrive/EMBC_project/Data/3T/su...,,


Unnamed: 0,field,modality,n
0,3T,ADC,22
1,3T,FLAIR,22
2,3T,T1w,22
3,3T,T2w,22
4,64mT,ADC,85


In [None]:
# =========================
# Colab Cell 6: Build paired cohort table (3T subjects ∩ 64mT subjects)
# =========================
subs_3t = set(manifest.loc[manifest["field"]=="3T", "subject"].unique())
subs_64 = set(manifest.loc[manifest["field"]=="64mT", "subject"].unique())
paired_subjects = sorted(list(subs_3t.intersection(subs_64)))

print("Paired subjects:", len(paired_subjects))
print(paired_subjects[:20])

if FAST_DEBUG:
    paired_subjects = paired_subjects[:DEBUG_MAX_PAIRED_SUBJECTS]
    print("FAST_DEBUG on -> using:", paired_subjects)

# Choose one low-field session per subject (default ses-01 if present else first)
def pick_lowfield_session(sub: str) -> str:
    ses = sorted(manifest.loc[(manifest.field=="64mT") & (manifest.subject==sub), "session"].dropna().unique().tolist())
    if not ses:
        return None
    if LOWFIELD_SESSION_POLICY == "all":
        return "all"
    if LOWFIELD_SESSION_POLICY == "first":
        return ses[0]
    # default: prefer ses-01
    if "ses-01" in ses:
        return "ses-01"
    return ses[0]

paired_rows = []
for sub in paired_subjects:
    paired_rows.append({
        "subject": sub,
        "lowfield_session": pick_lowfield_session(sub),
    })

paired_df = pd.DataFrame(paired_rows)
paired_df_path = TAB_DIR / "paired_subjects.csv"
paired_df.to_csv(paired_df_path, index=False)
print("Saved:", paired_df_path)
display(paired_df)


Paired subjects: 10
['sub-0011', 'sub-0015', 'sub-0025', 'sub-0027', 'sub-0035', 'sub-0046', 'sub-0047', 'sub-0048', 'sub-0064', 'sub-0066']
Saved: /content/drive/MyDrive/EMBC_project/outputs/tables/paired_subjects.csv


Unnamed: 0,subject,lowfield_session
0,sub-0011,ses-01
1,sub-0015,ses-01
2,sub-0025,ses-01
3,sub-0027,ses-01
4,sub-0035,ses-01
5,sub-0046,ses-01
6,sub-0047,ses-01
7,sub-0048,ses-01
8,sub-0064,ses-01
9,sub-0066,ses-01


In [None]:
# =========================
# Colab Cell 7: Preprocessing utilities (I/O, masking, normalization, rigid registration)
# =========================
def sitk_read(path: str) -> sitk.Image:
    img = sitk.ReadImage(path)
    # Force float for registration / normalization
    return sitk.Cast(img, sitk.sitkFloat32)

def sitk_write(img: sitk.Image, path: str):
    outp = pathlib.Path(path)
    outp.parent.mkdir(parents=True, exist_ok=True)
    sitk.WriteImage(img, str(outp), True)

def make_brain_mask_simple(img: sitk.Image) -> sitk.Image:
    """
    Fast, label-free mask:
      1) Otsu threshold
      2) keep largest connected component
      3) morphological closing
    Works reasonably for healthy brain MRIs; not a replacement for SynthStrip/HD-BET.
    """
    arr = sitk.GetArrayFromImage(img)
    arr = np.nan_to_num(arr, nan=0.0)

    # Robust clipping
    p1, p99 = np.percentile(arr, [1, 99])
    arrc = np.clip(arr, p1, p99)

    imgc = sitk.GetImageFromArray(arrc.astype(np.float32))
    imgc.CopyInformation(img)

    thr = sitk.OtsuThreshold(imgc, 0, 1)
    cc = sitk.ConnectedComponent(thr)
    stats = sitk.LabelShapeStatisticsImageFilter()
    stats.Execute(cc)
    if stats.GetNumberOfLabels() == 0:
        return thr

    largest = max(stats.GetLabels(), key=lambda L: stats.GetPhysicalSize(L))
    mask = sitk.BinaryThreshold(cc, lowerThreshold=largest, upperThreshold=largest, insideValue=1, outsideValue=0)

    mask = sitk.BinaryMorphologicalClosing(mask, [2,2,2])
    mask = sitk.Cast(mask, sitk.sitkUInt8)
    return mask

def zscore_normalize(img: sitk.Image, mask: sitk.Image) -> sitk.Image:
    arr = sitk.GetArrayFromImage(img).astype(np.float32)
    m = sitk.GetArrayFromImage(mask).astype(np.uint8) > 0
    vals = arr[m]
    if vals.size < 50:
        return img
    mu = float(vals.mean())
    sd = float(vals.std() + 1e-6)
    arrn = (arr - mu) / sd
    out = sitk.GetImageFromArray(arrn.astype(np.float32))
    out.CopyInformation(img)
    return out

def rigid_register(fixed: sitk.Image, moving: sitk.Image) -> tuple[sitk.Image, sitk.Transform]:
    """
    Rigid registration (moving -> fixed) using mutual information.
    Returns: (resampled_moving, transform)
    """
    initial = sitk.CenteredTransformInitializer(
        fixed, moving, sitk.Euler3DTransform(),
        sitk.CenteredTransformInitializerFilter.GEOMETRY
    )

    reg = sitk.ImageRegistrationMethod()
    reg.SetMetricAsMattesMutualInformation(numberOfHistogramBins=32)
    reg.SetMetricSamplingStrategy(reg.RANDOM)
    reg.SetMetricSamplingPercentage(0.2 if FAST_DEBUG else 0.4)
    reg.SetInterpolator(sitk.sitkLinear)

    reg.SetOptimizerAsGradientDescent(
        learningRate=1.0,
        numberOfIterations=REG_MAX_ITERS,
        convergenceMinimumValue=1e-6,
        convergenceWindowSize=10
    )
    reg.SetOptimizerScalesFromPhysicalShift()

    reg.SetShrinkFactorsPerLevel(REG_SHRINK_FACTORS)
    reg.SetSmoothingSigmasPerLevel(REG_SMOOTHING_SIGMAS)
    reg.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    reg.SetInitialTransform(initial, inPlace=False)
    final_transform = reg.Execute(fixed, moving)

    resampled = sitk.Resample(
        moving, fixed, final_transform, sitk.sitkLinear, 0.0, sitk.sitkFloat32
    )
    return resampled, final_transform

def resample_to_reference(moving: sitk.Image, reference: sitk.Image, interp=sitk.sitkLinear) -> sitk.Image:
    identity = sitk.Transform(3, sitk.sitkIdentity)
    return sitk.Resample(moving, reference, identity, interp, 0.0, sitk.sitkFloat32)

print("Utilities ready.")


Utilities ready.


In [None]:
# =========================
# Colab Cell 8: Preprocess paired cohort (register 64mT -> 3T lowres; create masks; normalize)
# Outputs per subject/modality:
#   - fixed_3T_lowres.nii.gz
#   - 3T_highres_resampled_to_lowres.nii.gz
#   - 64mT_registered_to_3T_lowres.nii.gz
#   - brain_mask_fixed.nii.gz
#   - normalized versions of the above
# =========================
# =========================
# DROP-IN PATCH: Replace get_path in Colab Cell 8 with this version
# (so it matches the new manifest schema and handles run as string safely)
# =========================
def get_path(field: str, subject: str, modality: str, session=None, acq=None, run=None):
    df = manifest
    m = (df.field == field) & (df.subject == subject) & (df.modality == modality)

    if session is not None:
        m &= (df.session == session)

    if acq is not None:
        m &= (df.acq == acq)

    if run is not None:
        # run stored as string in some rows; normalize comparison
        m &= (df.run.astype(str) == str(run))

    cand = df.loc[m, "nifti_path"].tolist()
    return cand[0] if cand else None


preproc_index = []

for _, row in tqdm(paired_df.iterrows(), total=len(paired_df), desc="Preprocess paired"):
    sub = row["subject"]
    ses = row["lowfield_session"]
    if ses is None:
        continue

    for mod in USE_MODALITIES:
        if mod not in USE_MODALITIES:
            continue

        # Fixed: 3T lowres
        p_3t_low = get_path("3T", sub, mod, acq="lowres")
        p_3t_high = get_path("3T", sub, mod, acq="highres")

        # Low-field:
        if mod == "ADC":
            # pick run-1 by default for ADC, but you can change to run-2 or average them later
            p_lf = get_path("64mT", sub, "ADC", session=ses, run="1")
            if p_lf is None:
                p_lf = get_path("64mT", sub, "ADC", session=ses, run="2")
        else:
            p_lf = get_path("64mT", sub, mod, session=ses)

        if not (p_3t_low and p_3t_high and p_lf):
            continue

        out_sub_dir = PROC_DIR / sub / mod
        out_sub_dir.mkdir(parents=True, exist_ok=True)

        fixed_path = out_sub_dir / "fixed_3T_lowres.nii.gz"
        high_resamp_path = out_sub_dir / "3T_highres_resampled_to_lowres.nii.gz"
        lf_reg_path = out_sub_dir / "64mT_registered_to_3T_lowres.nii.gz"
        mask_path = out_sub_dir / "brain_mask_fixed.nii.gz"

        fixedN_path = out_sub_dir / "fixed_3T_lowres_norm.nii.gz"
        highN_path  = out_sub_dir / "3T_highres_resampled_to_lowres_norm.nii.gz"
        lfN_path    = out_sub_dir / "64mT_registered_to_3T_lowres_norm.nii.gz"

        tfm_path = out_sub_dir / "rigid_64mT_to_3Tlowres.tfm"

        # Read
        fixed = sitk_read(p_3t_low)
        high  = sitk_read(p_3t_high)
        lf    = sitk_read(p_lf)

        # Save fixed
        if not fixed_path.exists():
            sitk_write(fixed, fixed_path)

        # Mask (fixed space)
        if not mask_path.exists():
            mask = make_brain_mask_simple(fixed)
            sitk_write(mask, mask_path)
        else:
            mask = sitk.ReadImage(str(mask_path))

        # Resample highres to lowres grid (no registration; just match grid)
        if not high_resamp_path.exists():
            high_rs = resample_to_reference(high, fixed, interp=sitk.sitkLinear)
            sitk_write(high_rs, high_resamp_path)
        else:
            high_rs = sitk_read(str(high_resamp_path))

        # Register LF to fixed
        if not lf_reg_path.exists() or not tfm_path.exists():
            lf_reg, tfm = rigid_register(fixed, lf)
            sitk_write(lf_reg, lf_reg_path)
            sitk.WriteTransform(tfm, str(tfm_path))
        else:
            lf_reg = sitk_read(str(lf_reg_path))

        # Normalize all in fixed space
        if not fixedN_path.exists():
            sitk_write(zscore_normalize(fixed, mask), fixedN_path)
        if not highN_path.exists():
            sitk_write(zscore_normalize(high_rs, mask), highN_path)
        if not lfN_path.exists():
            sitk_write(zscore_normalize(lf_reg, mask), lfN_path)

        preproc_index.append({
            "subject": sub,
            "modality": mod,
            "session_64mT": ses,
            "path_fixed_3T_lowres": str(fixed_path),
            "path_3T_highres_resampled": str(high_resamp_path),
            "path_64mT_registered": str(lf_reg_path),
            "path_mask_fixed": str(mask_path),
            "path_fixed_3T_lowres_norm": str(fixedN_path),
            "path_3T_highres_resampled_norm": str(highN_path),
            "path_64mT_registered_norm": str(lfN_path),
            "path_transform": str(tfm_path),
        })

preproc_df = pd.DataFrame(preproc_index)
preproc_df_path = TAB_DIR / "preproc_index.csv"
preproc_df.to_csv(preproc_df_path, index=False)
print("Saved:", preproc_df_path)
display(preproc_df.head(20))


Preprocess paired:   0%|          | 0/10 [00:00<?, ?it/s]

Saved: /content/drive/MyDrive/EMBC_project/outputs/tables/preproc_index.csv


Unnamed: 0,subject,modality,session_64mT,path_fixed_3T_lowres,path_3T_highres_resampled,path_64mT_registered,path_mask_fixed,path_fixed_3T_lowres_norm,path_3T_highres_resampled_norm,path_64mT_registered_norm,path_transform
0,sub-0011,ADC,ses-01,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...
1,sub-0015,ADC,ses-01,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...
2,sub-0025,ADC,ses-01,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...
3,sub-0027,ADC,ses-01,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...
4,sub-0035,ADC,ses-01,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...
5,sub-0046,ADC,ses-01,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...
6,sub-0047,ADC,ses-01,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...
7,sub-0048,ADC,ses-01,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...
8,sub-0064,ADC,ses-01,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...
9,sub-0066,ADC,ses-01,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...,/content/drive/MyDrive/EMBC_project/Data/proce...


In [None]:
# =========================
# NEW Colab Cell: Registration QC metrics + example overlays
# Run this after the preprocessing cell (Cell 8) finishes and preproc_df exists.
# =========================
from skimage.metrics import structural_similarity as ssim
from skimage.filters import sobel
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.auto import tqdm

def load_arr_mask(img_path: str, mask_path: str):
    img = sitk_read(img_path)
    arr = sitk.GetArrayFromImage(img).astype(np.float32)
    m = sitk.ReadImage(mask_path)
    mask = sitk.GetArrayFromImage(m).astype(np.uint8) > 0
    return arr, mask

def mid_slice(arr, mask):
    z = arr.shape[0] // 2
    return arr[z], mask[z]

def masked_corr(a2, b2, m2):
    aa = a2[m2].astype(np.float32)
    bb = b2[m2].astype(np.float32)
    if aa.size < 100:
        return np.nan
    aa = aa - aa.mean()
    bb = bb - bb.mean()
    denom = (np.linalg.norm(aa) * np.linalg.norm(bb)) + 1e-8
    return float((aa @ bb) / denom)

def masked_ssim(a2, b2, m2):
    # Compute SSIM on a bounding box around the mask to avoid background dominance
    if m2.sum() < 200:
        return np.nan
    ys, xs = np.where(m2)
    y0, y1 = ys.min(), ys.max()
    x0, x1 = xs.min(), xs.max()
    A = a2[y0:y1+1, x0:x1+1]
    B = b2[y0:y1+1, x0:x1+1]
    # Normalize to [0,1] for SSIM stability
    def norm01(x):
        lo, hi = np.percentile(x, [1, 99])
        x = np.clip(x, lo, hi)
        return (x - lo) / (hi - lo + 1e-6)
    A = norm01(A)
    B = norm01(B)
    return float(ssim(A, B, data_range=1.0))

def overlay_edges(base2, overlay2, m2):
    # edge maps for quick visual alignment check
    b = base2.copy()
    o = overlay2.copy()

    # normalize for display
    def norm01(x):
        lo, hi = np.percentile(x[m2], [1, 99]) if m2.sum() > 0 else np.percentile(x, [1, 99])
        x = np.clip(x, lo, hi)
        return (x - lo) / (hi - lo + 1e-6)

    b = norm01(b)
    o = norm01(o)

    eb = sobel(b)
    eo = sobel(o)

    # create simple RGB overlay: red = fixed edges, green = moving edges
    rgb = np.zeros((b.shape[0], b.shape[1], 3), dtype=np.float32)
    rgb[..., 0] = eb
    rgb[..., 1] = eo
    # dim outside mask
    if m2.sum() > 0:
        rgb[~m2] *= 0.1
    return rgb

qc_rows = []
MAX_QC_SUBJECTS_PER_MODALITY = 6 if not FAST_DEBUG else 3

for mod in sorted(preproc_df["modality"].unique()):
    dfm = preproc_df[preproc_df["modality"] == mod].copy()
    dfm = dfm.head(MAX_QC_SUBJECTS_PER_MODALITY)

    for _, r in tqdm(dfm.iterrows(), total=len(dfm), desc=f"QC {mod}"):
        sub = r["subject"]
        maskp = r["path_mask_fixed"]

        fixedp = r["path_fixed_3T_lowres_norm"]
        highp  = r["path_3T_highres_resampled_norm"]
        lfp    = r["path_64mT_registered_norm"]

        fixed_arr, mask = load_arr_mask(fixedp, maskp)
        high_arr, _ = load_arr_mask(highp, maskp)
        lf_arr, _ = load_arr_mask(lfp, maskp)

        fixed2, m2 = mid_slice(fixed_arr, mask)
        high2, _ = mid_slice(high_arr, mask)
        lf2, _ = mid_slice(lf_arr, mask)

        qc_rows.append({
            "subject": sub,
            "modality": mod,
            "ssim_64mT_vs_3Tlow": masked_ssim(lf2, fixed2, m2),
            "corr_64mT_vs_3Tlow": masked_corr(lf2, fixed2, m2),
            "ssim_3Thigh_vs_3Tlow": masked_ssim(high2, fixed2, m2),
            "corr_3Thigh_vs_3Tlow": masked_corr(high2, fixed2, m2),
        })

        # Save a QC figure per subject and modality
        rgb = overlay_edges(fixed2, lf2, m2)
        diff = (fixed2 - lf2)
        if m2.sum() > 0:
            dvals = diff[m2]
            dlo, dhi = np.percentile(dvals, [1, 99])
        else:
            dlo, dhi = np.percentile(diff, [1, 99])

        fig = plt.figure(figsize=(12, 4))
        ax1 = plt.subplot(1, 3, 1)
        ax1.imshow(fixed2, cmap="gray")
        ax1.set_title(f"{sub} {mod}\nFixed: 3T lowres")
        ax1.axis("off")

        ax2 = plt.subplot(1, 3, 2)
        ax2.imshow(lf2, cmap="gray")
        ax2.set_title("Registered: 64mT")
        ax2.axis("off")

        ax3 = plt.subplot(1, 3, 3)
        ax3.imshow(diff, cmap="gray", vmin=dlo, vmax=dhi)
        ax3.set_title("Difference (fixed - 64mT)")
        ax3.axis("off")

        qc_fig_path = FIG_DIR / "qc_registration" / mod / f"{sub}_{mod}_qc.png"
        qc_fig_path.parent.mkdir(parents=True, exist_ok=True)
        plt.tight_layout()
        plt.savefig(qc_fig_path, dpi=200)
        plt.close(fig)

        # Also save edge overlay (quick alignment check)
        fig2 = plt.figure(figsize=(5, 5))
        plt.imshow(rgb)
        plt.title(f"{sub} {mod}\nEdges: red=3Tlow, green=64mT")
        plt.axis("off")
        qc_overlay_path = FIG_DIR / "qc_registration" / mod / f"{sub}_{mod}_edge_overlay.png"
        plt.tight_layout()
        plt.savefig(qc_overlay_path, dpi=200)
        plt.close(fig2)

qc_df = pd.DataFrame(qc_rows)
qc_path = TAB_DIR / "registration_qc_metrics.csv"
qc_df.to_csv(qc_path, index=False)

print("Saved QC metrics:", qc_path)
display(qc_df)

print("QC figures saved under:", FIG_DIR / "qc_registration")


QC ADC:   0%|          | 0/6 [00:00<?, ?it/s]

Saved QC metrics: /content/drive/MyDrive/EMBC_project/outputs/tables/registration_qc_metrics.csv


Unnamed: 0,subject,modality,ssim_64mT_vs_3Tlow,corr_64mT_vs_3Tlow,ssim_3Thigh_vs_3Tlow,corr_3Thigh_vs_3Tlow
0,sub-0011,ADC,0.108655,0.035642,0.680127,0.713839
1,sub-0015,ADC,0.23308,0.12217,0.659965,0.620309
2,sub-0025,ADC,0.215972,0.088463,0.532457,0.503962
3,sub-0027,ADC,0.218815,0.022588,0.689063,0.640076
4,sub-0035,ADC,0.040739,0.052656,0.72705,0.803709
5,sub-0046,ADC,0.149886,0.173398,0.74797,0.779418


QC figures saved under: /content/drive/MyDrive/EMBC_project/outputs/figures/qc_registration


In [None]:
# =========================
# Colab Cell 9: Classical features + (optional) radiomics
# =========================
from skimage.feature import graycomatrix, graycoprops

def load_np(img_path: str):
    img = sitk_read(img_path)
    arr = sitk.GetArrayFromImage(img).astype(np.float32)  # z,y,x
    return img, arr

def basic_intensity_features(arr: np.ndarray, mask: np.ndarray) -> dict:
    vals = arr[mask]
    if vals.size < 50:
        return {k: np.nan for k in ["mean","std","p10","p50","p90","skew","kurtosis","energy"]}
    mean = float(vals.mean())
    std  = float(vals.std() + 1e-6)
    p10, p50, p90 = [float(x) for x in np.percentile(vals, [10,50,90])]
    # skew/kurtosis (manual to avoid extra deps)
    z = (vals - mean) / std
    skew = float((z**3).mean())
    kurt = float((z**4).mean() - 3.0)
    energy = float((vals**2).mean())
    return {"mean":mean, "std":std, "p10":p10, "p50":p50, "p90":p90, "skew":skew, "kurtosis":kurt, "energy":energy}

def glcm_features_mid_slice(arr: np.ndarray, mask: np.ndarray, levels=32) -> dict:
    """
    Compute simple 2D GLCM features on the mid-axial slice (fast + stable).
    """
    zmid = arr.shape[0] // 2
    img2 = arr[zmid]
    m2 = mask[zmid]

    if m2.sum() < 100:
        return {f"glcm_{k}": np.nan for k in ["contrast","dissimilarity","homogeneity","ASM","energy","correlation"]}

    vals = img2[m2]
    vmin, vmax = np.percentile(vals, [1, 99])
    imgc = np.clip(img2, vmin, vmax)

    # quantize
    q = ((imgc - vmin) / (vmax - vmin + 1e-6) * (levels-1)).astype(np.uint8)
    # set background to 0 for stability
    q[~m2] = 0

    glcm = graycomatrix(q, distances=[1,2], angles=[0, np.pi/4, np.pi/2], levels=levels, symmetric=True, normed=True)
    feats = {}
    for prop in ["contrast","dissimilarity","homogeneity","ASM","energy","correlation"]:
        feats[f"glcm_{prop}"] = float(graycoprops(glcm, prop).mean())
    return feats

# Radiomics extractor (kept minimal to avoid huge compute)
rad_extractor = None
if featureextractor is not None:
    params = {
        "binWidth": 25,
        "resampledPixelSpacing": None,  # already in fixed space; do not resample here
        "interpolator": "sitkBSpline",
        "normalize": False,
        "removeOutliers": 3,
        "force2D": True,               # faster; use 2D on slices internally
        "force2Ddimension": 0
    }
    rad_extractor = featureextractor.RadiomicsFeatureExtractor(**params)
    # enable just a small set
    rad_extractor.enableFeatureClassByName("firstorder")
    rad_extractor.enableFeatureClassByName("glcm")
    rad_extractor.enableFeatureClassByName("glrlm")

def run_features_for_one(img_path: str, mask_path: str, do_radiomics: bool) -> dict:
    img, arr = load_np(img_path)
    mimg = sitk.ReadImage(mask_path)
    mask = sitk.GetArrayFromImage(mimg).astype(np.uint8) > 0

    out = {}
    out.update(basic_intensity_features(arr, mask))
    out.update(glcm_features_mid_slice(arr, mask))

    if do_radiomics and (rad_extractor is not None):
        try:
            # Pyradiomics wants SimpleITK images
            feats = rad_extractor.execute(sitk_read(img_path), sitk.Cast(mimg, sitk.sitkUInt8))
            # Keep only actual features (skip diagnostics)
            for k, v in feats.items():
                if k.startswith("diagnostics_"):
                    continue
                # many radiomics are numpy types
                try:
                    out[f"rad_{k}"] = float(v)
                except Exception:
                    pass
        except Exception as e:
            out["radiomics_error"] = str(e)

    return out

feature_rows = []
DO_RADIOMICS = True

for _, r in tqdm(preproc_df.iterrows(), total=len(preproc_df), desc="Classical/Radiomics"):
    sub = r["subject"]
    mod = r["modality"]
    maskp = r["path_mask_fixed"]

    # Use normalized images for stability
    conds = {
        "64mT": r["path_64mT_registered_norm"],
        "3T_lowres": r["path_fixed_3T_lowres_norm"],
        "3T_highres": r["path_3T_highres_resampled_norm"],
    }

    for cond, imgp in conds.items():
        feats = run_features_for_one(imgp, maskp, do_radiomics=DO_RADIOMICS)
        feats.update({"subject": sub, "modality": mod, "condition": cond})
        feature_rows.append(feats)

features_df = pd.DataFrame(feature_rows)
features_path = TAB_DIR / "features_classical_radiomics.csv"
features_df.to_csv(features_path, index=False)
print("Saved:", features_path)
display(features_df.head(10))


Classical/Radiomics:   0%|          | 0/10 [00:00<?, ?it/s]

Saved: /content/drive/MyDrive/EMBC_project/outputs/tables/features_classical_radiomics.csv


Unnamed: 0,mean,std,p10,p50,p90,skew,kurtosis,energy,glcm_contrast,glcm_dissimilarity,glcm_homogeneity,glcm_ASM,glcm_energy,glcm_correlation,subject,modality,condition
0,-7.236945e-08,1.000001,-1.695912,0.373543,0.98159,-0.619682,-0.744771,1.0,9.023577,1.188919,0.733748,0.285817,0.534585,0.953102,sub-0011,ADC,64mT
1,-1.612915e-07,1.000001,-0.807003,-0.242787,1.166441,1.57983,5.2342,1.0,11.7742,1.446278,0.72508,0.272687,0.52215,0.848394,sub-0011,ADC,3T_lowres
2,-5.928135e-08,1.000001,-1.513282,-0.115348,1.18358,0.956344,2.537601,1.0,11.768318,1.470994,0.72334,0.290321,0.538775,0.868644,sub-0011,ADC,3T_highres
3,-3.529699e-07,1.000001,-1.221868,0.177662,0.939684,-1.114428,2.130154,1.0,8.636661,1.123967,0.744685,0.289766,0.538266,0.962025,sub-0015,ADC,64mT
4,8.824248e-08,1.000001,-0.777029,-0.258318,1.150763,1.896019,6.967762,1.0,13.178249,1.527951,0.724635,0.270835,0.520374,0.822219,sub-0015,ADC,3T_lowres
5,1.625519e-07,1.000001,-1.370167,-0.04351,1.197443,1.102234,3.177437,1.0,12.368946,1.424588,0.741767,0.309971,0.5567,0.842448,sub-0015,ADC,3T_highres
6,-3.15229e-07,1.000001,-1.214307,0.189217,1.011194,-0.724166,1.309797,1.0,10.565342,1.355277,0.707523,0.246819,0.496768,0.952656,sub-0025,ADC,64mT
7,7.099753e-09,1.000001,-0.893845,-0.234183,1.151762,1.657719,5.393216,1.0,19.186008,1.984051,0.671589,0.230472,0.480028,0.805068,sub-0025,ADC,3T_lowres
8,1.277956e-08,1.000001,-1.31327,-0.055092,1.192535,1.0109,2.396991,1.0,17.223641,1.818216,0.698994,0.278117,0.527304,0.827498,sub-0025,ADC,3T_highres
9,1.448218e-07,1.000001,-1.411177,0.253305,0.958067,-0.872476,0.919041,1.0,8.730052,1.220243,0.722411,0.261494,0.511329,0.963529,sub-0027,ADC,64mT


In [None]:
# =========================
# REPLACEMENT Colab Cell: Learned embeddings (dtype-safe ResNet50)
# =========================
import torch
import torchvision
import numpy as np
import SimpleITK as sitk

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE)

# Build model and force float32
resnet = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V2)
resnet.fc = torch.nn.Identity()
resnet = resnet.to(device=DEVICE, dtype=torch.float32)
resnet.eval()

# ImageNet normalization tensors (float32)
IMNET_MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE, dtype=torch.float32).view(3,1,1)
IMNET_STD  = torch.tensor([0.229, 0.224, 0.225], device=DEVICE, dtype=torch.float32).view(3,1,1)

def sitk_read_float32(path: str) -> sitk.Image:
    img = sitk.ReadImage(path)
    return sitk.Cast(img, sitk.sitkFloat32)

@torch.no_grad()
def volume_to_resnet_embedding(img_path: str, mask_path: str, n_slices=32, target_hw=224) -> np.ndarray:
    img = sitk_read_float32(img_path)
    arr = sitk.GetArrayFromImage(img).astype(np.float32)  # z,y,x

    mimg = sitk.ReadImage(mask_path)
    mask = (sitk.GetArrayFromImage(mimg).astype(np.uint8) > 0)

    z = arr.shape[0]
    mid = z // 2
    half = n_slices // 2
    z0 = max(0, mid - half)
    z1 = min(z, mid + half)

    slcs = []
    for zi in range(z0, z1):
        m2 = mask[zi]
        if m2.sum() < 50:
            continue

        s = arr[zi].astype(np.float32)

        # robust scale within mask
        vals = s[m2]
        vmin, vmax = np.percentile(vals, [1, 99])
        s = np.clip(s, vmin, vmax)
        s = (s - vmin) / (vmax - vmin + 1e-6)
        s = s.astype(np.float32)

        # to torch float32 on DEVICE
        t = torch.from_numpy(s).to(device=DEVICE, dtype=torch.float32)  # H,W
        t = t.unsqueeze(0).unsqueeze(0)  # 1,1,H,W

        t = torch.nn.functional.interpolate(
            t, size=(target_hw, target_hw), mode="bilinear", align_corners=False
        )  # 1,1,224,224
        t = t.squeeze(0)  # 1,224,224

        t3 = t.repeat(3, 1, 1)  # 3,224,224
        t3 = (t3 - IMNET_MEAN) / IMNET_STD
        slcs.append(t3)

    if len(slcs) == 0:
        return np.full((2048,), np.nan, dtype=np.float32)

    batch = torch.stack(slcs, dim=0).to(dtype=torch.float32)  # N,3,224,224
    feats = resnet(batch).float()  # N,2048
    emb = feats.mean(dim=0).detach().cpu().numpy().astype(np.float32)
    return emb

print("ResNet embedding function ready (dtype-safe).")


DEVICE: cuda
ResNet embedding function ready (dtype-safe).


In [None]:
# =========================
# Colab Cell 11: Representation-level comparisons (paired distances + stats)
#   - Reload embeddings if they exist on disk (recommended)
#   - Otherwise compute embeddings from preproc_df using volume_to_resnet_embedding
#   - Then compute paired distances + paired stats and save outputs
# =========================

# --- Imports this cell depends on (safe to re-run) ---
import numpy as np
import pandas as pd

from scipy.spatial.distance import cosine as cosine_distance
from scipy.stats import wilcoxon, ttest_rel
from tqdm.auto import tqdm

# --- Paths to embedding artifacts ---
embed_meta_path = TAB_DIR / "embeddings_resnet50_meta.csv"
embed_npy_path  = OUT_DIR / "embeddings_resnet50.npy"

# --- Load if available; else compute and save ---
if embed_meta_path.exists() and embed_npy_path.exists():
    print("Loading embeddings from disk...")
    embed_meta = pd.read_csv(embed_meta_path)
    embeddings = np.load(embed_npy_path)
else:
    print("Embeddings not found on disk; computing now...")

    # Preconditions: preproc_df exists and volume_to_resnet_embedding is defined
    assert "preproc_df" in globals(), "preproc_df not found. Run preprocessing cell first."
    assert "volume_to_resnet_embedding" in globals(), "volume_to_resnet_embedding not found. Run embedding model cell first."

    embed_rows = []
    embed_mat = []

    for _, r in tqdm(preproc_df.iterrows(), total=len(preproc_df), desc="Embeddings"):
        sub = r["subject"]
        mod = r["modality"]
        maskp = r["path_mask_fixed"]

        conds = {
            "64mT": r["path_64mT_registered_norm"],
            "3T_lowres": r["path_fixed_3T_lowres_norm"],
            "3T_highres": r["path_3T_highres_resampled_norm"],
        }

        for cond, imgp in conds.items():
            emb = volume_to_resnet_embedding(imgp, maskp, n_slices=16 if FAST_DEBUG else 32)
            embed_rows.append({"subject": sub, "modality": mod, "condition": cond})
            embed_mat.append(emb)

    embeddings = np.vstack(embed_mat).astype(np.float32)
    embed_meta = pd.DataFrame(embed_rows)

    embed_meta.to_csv(embed_meta_path, index=False)
    np.save(embed_npy_path, embeddings)

    print("Saved embeddings:")
    print(" -", embed_meta_path)
    print(" -", embed_npy_path)

# --- Sanity checks (catch silent issues early) ---
assert embeddings.shape[0] == len(embed_meta), "Row mismatch: embeddings rows != embed_meta rows"
conds = set(embed_meta["condition"].unique().tolist())
expected = {"64mT", "3T_lowres", "3T_highres"}
assert expected.issubset(conds), f"Missing conditions. Found: {sorted(list(conds))}"

print("embed_meta shape:", embed_meta.shape)
print("embeddings shape:", embeddings.shape)
display(embed_meta.head(5))

# --- Distance table ---
def paired_distance_table(embed_meta: pd.DataFrame, embeddings: np.ndarray) -> pd.DataFrame:
    df = embed_meta.copy()
    df["idx"] = np.arange(len(df))

    out = []
    for (sub, mod), g in df.groupby(["subject", "modality"]):
        idx = {row["condition"]: int(row["idx"]) for _, row in g.iterrows()}
        needed = ["64mT", "3T_lowres", "3T_highres"]
        if not all(k in idx for k in needed):
            continue

        v_lf = embeddings[idx["64mT"]]
        v_lo = embeddings[idx["3T_lowres"]]
        v_hi = embeddings[idx["3T_highres"]]

        def cos_sim(a, b):
            if np.any(np.isnan(a)) or np.any(np.isnan(b)):
                return np.nan
            return 1.0 - cosine_distance(a, b)

        def l2(a, b):
            if np.any(np.isnan(a)) or np.any(np.isnan(b)):
                return np.nan
            return float(np.linalg.norm(a - b))

        out.append({
            "subject": sub,
            "modality": mod,

            # field-strength effect (resolution matched): 64mT vs 3T_lowres
            "cos_lf_vs_3Tlow": cos_sim(v_lf, v_lo),
            "l2_lf_vs_3Tlow":  l2(v_lf, v_lo),

            # resolution effect at 3T: lowres vs highres
            "cos_3Tlow_vs_3Thigh": cos_sim(v_lo, v_hi),
            "l2_3Tlow_vs_3Thigh":  l2(v_lo, v_hi),

            # combined effect: 64mT vs 3T_highres
            "cos_lf_vs_3Thigh": cos_sim(v_lf, v_hi),
            "l2_lf_vs_3Thigh":  l2(v_lf, v_hi),
        })

    return pd.DataFrame(out)

dist_df = paired_distance_table(embed_meta, embeddings)
dist_path = TAB_DIR / "paired_embedding_distances.csv"
dist_df.to_csv(dist_path, index=False)

print("Saved:", dist_path)
display(dist_df.head(20))

# --- Paired stats ---
def paired_stats(df: pd.DataFrame, colA: str, colB: str, label: str):
    a = df[colA].astype(float).values
    b = df[colB].astype(float).values
    m = np.isfinite(a) & np.isfinite(b)
    a, b = a[m], b[m]

    if len(a) < 5:
        return {
            "label": label,
            "n": int(len(a)),
            "wilcoxon_p": np.nan,
            "ttest_p": np.nan,
            "mean_A": float(np.mean(a)) if len(a) else np.nan,
            "mean_B": float(np.mean(b)) if len(b) else np.nan
        }

    try:
        w = wilcoxon(a, b).pvalue
    except Exception:
        w = np.nan

    try:
        t = ttest_rel(a, b).pvalue
    except Exception:
        t = np.nan

    return {
        "label": label,
        "n": int(len(a)),
        "wilcoxon_p": float(w) if w is not None else np.nan,
        "ttest_p": float(t) if t is not None else np.nan,
        "mean_A": float(np.mean(a)),
        "mean_B": float(np.mean(b)),
    }

stats_rows = []
stats_rows.append(paired_stats(
    dist_df,
    "l2_lf_vs_3Tlow",
    "l2_3Tlow_vs_3Thigh",
    "L2: (64mT vs 3T_lowres) vs (3T_lowres vs 3T_highres)"
))
stats_rows.append(paired_stats(
    dist_df,
    "cos_lf_vs_3Tlow",
    "cos_3Tlow_vs_3Thigh",
    "Cosine: (64mT vs 3T_lowres) vs (3T_lowres vs 3T_highres)"
))

stats_df = pd.DataFrame(stats_rows)
stats_path = TAB_DIR / "paired_stats_summary.csv"
stats_df.to_csv(stats_path, index=False)

print("Saved:", stats_path)
display(stats_df)


Loading embeddings from disk...
embed_meta shape: (9, 3)
embeddings shape: (9, 2048)


Unnamed: 0,subject,modality,condition
0,sub-0011,ADC,64mT
1,sub-0011,ADC,3T_lowres
2,sub-0011,ADC,3T_highres
3,sub-0015,ADC,64mT
4,sub-0015,ADC,3T_lowres


Saved: /content/drive/MyDrive/EMBC_project/outputs/tables/paired_embedding_distances.csv


Unnamed: 0,subject,modality,cos_lf_vs_3Tlow,l2_lf_vs_3Tlow,cos_3Tlow_vs_3Thigh,l2_3Tlow_vs_3Thigh,cos_lf_vs_3Thigh,l2_lf_vs_3Thigh
0,sub-0011,ADC,0.866413,5.985075,0.982775,1.972973,0.8764,5.766547
1,sub-0015,ADC,0.885738,5.456633,0.972881,2.451736,0.888034,5.405859
2,sub-0025,ADC,0.883664,5.567214,0.961624,2.946868,0.88404,5.554993


Saved: /content/drive/MyDrive/EMBC_project/outputs/tables/paired_stats_summary.csv


Unnamed: 0,label,n,wilcoxon_p,ttest_p,mean_A,mean_B
0,L2: (64mT vs 3T_lowres) vs (3T_lowres vs 3T_hi...,3,,,5.669641,2.457192
1,Cosine: (64mT vs 3T_lowres) vs (3T_lowres vs 3...,3,,,0.878605,0.972426


In [None]:
# =========================
# Colab Cell 12: Visualizations + protocol-style summaries (robust + no-crash)
#   - Uses tick_labels (Matplotlib 3.9+)
#   - Skips empty boxplots safely
#   - Robust UMAP: auto-picks a modality with enough valid embeddings
#   - Saves figures + summary CSV
# =========================
import pathlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
import umap

def save_fig(path: pathlib.Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    plt.tight_layout()
    plt.savefig(path, dpi=200)
    plt.close()

# -------------------------
# 1) Boxplots of distances per modality
# -------------------------
for mod in sorted(dist_df["modality"].dropna().unique().tolist()):
    d = dist_df[dist_df["modality"] == mod].copy()

    a = d["l2_lf_vs_3Tlow"].dropna().values
    b = d["l2_3Tlow_vs_3Thigh"].dropna().values
    c = d["l2_lf_vs_3Thigh"].dropna().values

    # If everything is empty (possible in tiny debug runs), skip
    if len(a) == 0 and len(b) == 0 and len(c) == 0:
        print(f"[skip] No distance values available for modality={mod}")
        continue

    plt.figure(figsize=(8, 4))
    plt.boxplot(
        [a, b, c],
        tick_labels=["64mT vs 3T_low", "3T_low vs 3T_high", "64mT vs 3T_high"],
        showmeans=True
    )
    plt.title(f"Embedding L2 distances (ResNet50) — {mod}")
    plt.ylabel("L2 distance")

    outp = FIG_DIR / f"box_l2_{mod}.png"
    save_fig(outp)
    print("Saved:", outp)

# -------------------------
# 2) UMAP embedding visualization (auto-select modality w/ enough valid rows)
# -------------------------
def pick_modality_for_umap(embed_meta: pd.DataFrame, embeddings: np.ndarray, preferred=None, min_rows=6):
    mods = []
    for mod in sorted(embed_meta["modality"].dropna().unique().tolist()):
        sel = (embed_meta["modality"] == mod).values
        X = embeddings[sel]
        if X.size == 0:
            continue
        good = np.all(np.isfinite(X), axis=1)
        n_good = int(good.sum())
        mods.append((mod, n_good))

    # try preferred first
    if preferred is not None:
        for mod, n_good in mods:
            if mod == preferred and n_good >= min_rows:
                return mod

    # otherwise pick best available
    mods = sorted(mods, key=lambda x: x[1], reverse=True)
    if not mods:
        return None
    return mods[0][0]

# Prefer T1w if available; otherwise best modality
preferred = "T1w" if "T1w" in embed_meta["modality"].unique() else None
PLOT_MODALITY = pick_modality_for_umap(embed_meta, embeddings, preferred=preferred, min_rows=3)

if PLOT_MODALITY is None:
    print("[skip] No modalities found for UMAP.")
else:
    sel = (embed_meta["modality"] == PLOT_MODALITY).values
    X = embeddings[sel]
    M = embed_meta.loc[sel].reset_index(drop=True)

    good = np.all(np.isfinite(X), axis=1)
    Xg = X[good]
    Mg = M.loc[good].reset_index(drop=True)

    print("UMAP modality:", PLOT_MODALITY)
    print("Raw rows:", len(M), " | Valid rows:", len(Mg))
    print("Valid rows by condition:")
    display(Mg["condition"].value_counts())

    if len(Mg) < 3:
        print("[skip] Not enough valid embeddings to run UMAP. "
              "Try FAST_DEBUG=False, increase n_slices, or reduce mask threshold in embedding extraction.")
    else:
        Xz = StandardScaler().fit_transform(Xg)
        n_neighbors = min(10, len(Mg) - 1)
        um = umap.UMAP(n_neighbors=n_neighbors, min_dist=0.2, random_state=SEED).fit_transform(Xz)

        plt.figure(figsize=(7, 6))
        for cond in ["64mT", "3T_lowres", "3T_highres"]:
            idx = (Mg["condition"] == cond).values
            if idx.sum() > 0:
                plt.scatter(um[idx, 0], um[idx, 1], label=cond, alpha=0.8)

        plt.title(f"UMAP of embeddings — {PLOT_MODALITY}")
        plt.legend()

        outp = FIG_DIR / f"umap_embeddings_{PLOT_MODALITY}.png"
        save_fig(outp)
        print("Saved:", outp)

print("Saved figures under:", FIG_DIR)

# -------------------------
# 3) “Protocol guidance” style summary table
# -------------------------
summary = dist_df.groupby("modality").agg(
    mean_cos_lf_vs_low=("cos_lf_vs_3Tlow", "mean"),
    mean_cos_low_vs_high=("cos_3Tlow_vs_3Thigh", "mean"),
    mean_cos_lf_vs_high=("cos_lf_vs_3Thigh", "mean"),
    mean_l2_lf_vs_low=("l2_lf_vs_3Tlow", "mean"),
    mean_l2_low_vs_high=("l2_3Tlow_vs_3Thigh", "mean"),
    mean_l2_lf_vs_high=("l2_lf_vs_3Thigh", "mean"),
    n=("subject", "nunique")
).reset_index()

summary_path = TAB_DIR / "protocol_guidance_summary.csv"
summary.to_csv(summary_path, index=False)
print("Saved:", summary_path)
display(summary)


Saved: /content/drive/MyDrive/EMBC_project/outputs/figures/box_l2_ADC.png
UMAP modality: ADC
Raw rows: 9  | Valid rows: 9
Valid rows by condition:


Unnamed: 0_level_0,count
condition,Unnamed: 1_level_1
64mT,3
3T_lowres,3
3T_highres,3


  warn(


Saved: /content/drive/MyDrive/EMBC_project/outputs/figures/umap_embeddings_ADC.png
Saved figures under: /content/drive/MyDrive/EMBC_project/outputs/figures
Saved: /content/drive/MyDrive/EMBC_project/outputs/tables/protocol_guidance_summary.csv


Unnamed: 0,modality,mean_cos_lf_vs_low,mean_cos_low_vs_high,mean_cos_lf_vs_high,mean_l2_lf_vs_low,mean_l2_low_vs_high,mean_l2_lf_vs_high,n
0,ADC,0.878605,0.972426,0.882825,5.669641,2.457192,5.5758,3


In [None]:
# =========================
# Colab Cell 13: (Optional) Compare classical/radiomics differences similarly
#   This gives you a parallel story to “learned representations”
# =========================
# Build per-subject/modality vectors from the classical/radiomics table
df = features_df.copy()

# pick a manageable feature subset (avoid super high-dimensional radiomics unless you want it)
cols_basic = [c for c in df.columns if c.startswith(("mean","std","p","skew","kurtosis","energy","glcm_"))]
cols_rad = [c for c in df.columns if c.startswith("rad_")]

USE_RADIOMICS_FEATURES = False
feat_cols = cols_basic + (cols_rad if USE_RADIOMICS_FEATURES else [])

# pivot into vectors
piv = df.pivot_table(index=["subject","modality","condition"], values=feat_cols, aggfunc="mean")
piv = piv.reset_index()

def vector_from_row(row):
    v = row[feat_cols].values.astype(np.float32)
    v = np.nan_to_num(v, nan=0.0, posinf=0.0, neginf=0.0)
    return v

# scale features globally
X = piv[feat_cols].values.astype(np.float32)
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
Xz = StandardScaler().fit_transform(X)
piv_scaled = piv[["subject","modality","condition"]].copy()
piv_scaled["idx"] = np.arange(len(piv))
V = Xz

# compute paired distances
out = []
for (sub, mod), g in piv_scaled.groupby(["subject","modality"]):
    idx = {row["condition"]: int(row["idx"]) for _, row in g.iterrows()}
    if not all(k in idx for k in ["64mT","3T_lowres","3T_highres"]):
        continue
    v_lf, v_lo, v_hi = V[idx["64mT"]], V[idx["3T_lowres"]], V[idx["3T_highres"]]

    out.append({
        "subject": sub, "modality": mod,
        "cos_lf_vs_3Tlow": 1.0 - cosine_distance(v_lf, v_lo),
        "cos_3Tlow_vs_3Thigh": 1.0 - cosine_distance(v_lo, v_hi),
        "cos_lf_vs_3Thigh": 1.0 - cosine_distance(v_lf, v_hi),
        "l2_lf_vs_3Tlow": float(np.linalg.norm(v_lf-v_lo)),
        "l2_3Tlow_vs_3Thigh": float(np.linalg.norm(v_lo-v_hi)),
        "l2_lf_vs_3Thigh": float(np.linalg.norm(v_lf-v_hi)),
    })

classical_dist = pd.DataFrame(out)
classical_dist_path = TAB_DIR / "paired_classical_feature_distances.csv"
classical_dist.to_csv(classical_dist_path, index=False)
print("Saved:", classical_dist_path)
display(classical_dist.head(20))


Saved: /content/drive/MyDrive/EMBC_project/outputs/tables/paired_classical_feature_distances.csv


Unnamed: 0,subject,modality,cos_lf_vs_3Tlow,cos_3Tlow_vs_3Thigh,cos_lf_vs_3Thigh,l2_lf_vs_3Tlow,l2_3Tlow_vs_3Thigh,l2_lf_vs_3Thigh
0,sub-0011,ADC,0.038735,-0.325679,0.100848,5.791967,4.309099,4.842553
1,sub-0015,ADC,-0.245998,0.068639,-0.671535,6.365081,3.775392,5.792776
2,sub-0025,ADC,-0.270169,0.801204,-0.161072,6.299749,2.620909,4.901084
3,sub-0027,ADC,-0.622071,0.780207,-0.231582,6.45711,2.369395,4.484468
4,sub-0035,ADC,-0.446457,0.375264,0.476539,8.157922,3.912604,4.975733
5,sub-0046,ADC,-0.391467,0.849004,-0.078661,5.562768,1.525688,4.684873
6,sub-0047,ADC,0.0833,0.497573,0.240078,4.15535,1.937022,3.637206
7,sub-0048,ADC,-0.341023,0.453943,-0.278359,5.852177,3.773515,4.53108
8,sub-0064,ADC,-0.138208,0.394725,0.223038,10.189916,3.537109,9.177953
9,sub-0066,ADC,-0.55851,0.267487,0.11093,6.703815,3.798806,4.101384


In [None]:
# =========================
# ADD-ON 1 (FIXED): Localized information-loss maps (patchwise embedding drift)
#   - Resizes every patch to a fixed PATCH_HW x PATCH_HW so batching works
#   - Uses cosine distance drift between 64mT and 3T_lowres
#   - Saves PNGs + a CSV summary
# =========================
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.auto import tqdm
import torch

assert "preproc_df" in globals(), "preproc_df not found"
assert "resnet" in globals(), "resnet model not found"
assert "IMNET_MEAN" in globals() and "IMNET_STD" in globals(), "ImageNet norm tensors not found"
assert "DEVICE" in globals(), "DEVICE not found"

PATCH_HW = 96          # fixed patch size before upsampling to 224
GRID = (10, 10)        # coarse brain grid
MIN_MASK_FRAC = 0.15   # how much brain must be in a patch to include it
BATCH_SIZE = 64

def _norm01_masked(x2, m2):
    x = x2.astype(np.float32).copy()
    if m2.sum() > 0:
        vals = x[m2]
        lo, hi = np.percentile(vals, [1, 99])
    else:
        lo, hi = np.percentile(x, [1, 99])
    x = np.clip(x, lo, hi)
    x = (x - lo) / (hi - lo + 1e-6)
    return x.astype(np.float32)

def _resize_patch_to_hw(patch2d, hw=96):
    # patch2d: (h,w) float32
    t = torch.from_numpy(patch2d).to(device=DEVICE, dtype=torch.float32)[None, None, :, :]  # 1,1,h,w
    t = torch.nn.functional.interpolate(t, size=(hw, hw), mode="bilinear", align_corners=False)
    return t[0,0].detach().cpu().numpy().astype(np.float32)

@torch.no_grad()
def _patch_emb_batch(patches_01_float32_hw):
    # patches_01_float32_hw: (N, PATCH_HW, PATCH_HW) in [0,1]
    if len(patches_01_float32_hw) == 0:
        return np.zeros((0, 2048), dtype=np.float32)
    t = torch.from_numpy(patches_01_float32_hw).to(device=DEVICE, dtype=torch.float32)  # N,hw,hw
    t = t.unsqueeze(1)  # N,1,hw,hw
    t = torch.nn.functional.interpolate(t, size=(224, 224), mode="bilinear", align_corners=False)  # N,1,224,224
    t = t.repeat(1, 3, 1, 1)  # N,3,224,224
    t = (t - IMNET_MEAN.unsqueeze(0)) / IMNET_STD.unsqueeze(0)
    feats = resnet(t).float().detach().cpu().numpy().astype(np.float32)  # N,2048
    return feats

def _cos_dist(a, b):
    na = np.linalg.norm(a) + 1e-8
    nb = np.linalg.norm(b) + 1e-8
    return float(1.0 - (a @ b) / (na * nb))

def localized_drift_map_for_row(row, grid=GRID, patch_hw=PATCH_HW, min_mask_frac=MIN_MASK_FRAC, batch_size=BATCH_SIZE):
    p64 = row["path_64mT_registered_norm"]
    p3  = row["path_fixed_3T_lowres_norm"]
    pm  = row["path_mask_fixed"]

    img64 = sitk_read(p64); a64 = sitk.GetArrayFromImage(img64).astype(np.float32)
    img3  = sitk_read(p3);  a3  = sitk.GetArrayFromImage(img3).astype(np.float32)
    mimg  = sitk.ReadImage(pm); mask = (sitk.GetArrayFromImage(mimg).astype(np.uint8) > 0)

    z = a3.shape[0] // 2
    s64 = a64[z]; s3 = a3[z]; m2 = mask[z]
    if m2.sum() < 200:
        return None

    # Normalize for patch extraction
    s64n = _norm01_masked(s64, m2)
    s3n  = _norm01_masked(s3, m2)

    ys, xs = np.where(m2)
    y0, y1 = ys.min(), ys.max()
    x0, x1 = xs.min(), xs.max()

    gh, gw = grid
    H = max(1, (y1 - y0 + 1) // gh)
    W = max(1, (x1 - x0 + 1) // gw)

    keep_coords = []
    p64_list = []
    p3_list = []

    for i in range(gh):
        for j in range(gw):
            yy0 = y0 + i * H
            yy1 = y0 + (i + 1) * H if i < gh - 1 else (y1 + 1)
            xx0 = x0 + j * W
            xx1 = x0 + (j + 1) * W if j < gw - 1 else (x1 + 1)

            pmask = m2[yy0:yy1, xx0:xx1]
            if pmask.size == 0:
                continue
            if pmask.mean() < min_mask_frac:
                continue

            patch64 = s64n[yy0:yy1, xx0:xx1]
            patch3  = s3n[yy0:yy1, xx0:xx1]

            # Resize each patch to a fixed size so we can batch them
            patch64r = _resize_patch_to_hw(patch64, hw=patch_hw)
            patch3r  = _resize_patch_to_hw(patch3, hw=patch_hw)

            keep_coords.append((i, j, yy0, yy1, xx0, xx1))
            p64_list.append(patch64r)
            p3_list.append(patch3r)

    if len(keep_coords) == 0:
        return None

    p64_arr = np.stack(p64_list, axis=0).astype(np.float32)  # N,patch_hw,patch_hw
    p3_arr  = np.stack(p3_list, axis=0).astype(np.float32)

    # Embed in batches and compute drift
    drift_vals = np.zeros((len(keep_coords),), dtype=np.float32)
    for k in range(0, len(keep_coords), batch_size):
        e64 = _patch_emb_batch(p64_arr[k:k+batch_size])
        e3  = _patch_emb_batch(p3_arr[k:k+batch_size])
        for ii in range(e64.shape[0]):
            drift_vals[k + ii] = _cos_dist(e64[ii], e3[ii])

    # Fill drift grid
    drift_grid = np.full((gh, gw), np.nan, dtype=np.float32)
    for idx, (i, j, yy0, yy1, xx0, xx1) in enumerate(keep_coords):
        drift_grid[i, j] = drift_vals[idx]

    return {
        "slice_fixed": s3n,
        "mask": m2,
        "drift_grid": drift_grid,
        "bbox": (y0, y1, x0, x1)
    }

# ---- Run and save figures ----
OUT_DIR_LOC = FIG_DIR / "localized_drift"
OUT_DIR_LOC.mkdir(parents=True, exist_ok=True)

MAX_SUBJECTS_PER_MODALITY = 3 if FAST_DEBUG else 8
rows_out = []

for mod in sorted(preproc_df["modality"].unique().tolist()):
    dfm = preproc_df[preproc_df["modality"] == mod].copy().head(MAX_SUBJECTS_PER_MODALITY)

    for _, r in tqdm(dfm.iterrows(), total=len(dfm), desc=f"Localized drift {mod}"):
        res = localized_drift_map_for_row(r)
        if res is None:
            continue

        sub = r["subject"]
        drift = res["drift_grid"]
        y0, y1, x0, x1 = res["bbox"]

        plt.figure(figsize=(10,4))
        ax1 = plt.subplot(1,2,1)
        ax1.imshow(res["slice_fixed"], cmap="gray")
        ax1.set_title(f"{sub} {mod}\n3T_low (fixed) mid-slice")
        ax1.axis("off")

        ax2 = plt.subplot(1,2,2)
        ax2.imshow(res["slice_fixed"], cmap="gray")
        im = ax2.imshow(drift, cmap="magma", alpha=0.75)
        ax2.set_title("Patchwise embedding drift\ncosine distance (64mT vs 3T_low)")
        ax2.axis("off")
        plt.colorbar(im, ax=ax2, fraction=0.046, pad=0.04)

        outp = OUT_DIR_LOC / mod / f"{sub}_{mod}_localized_drift.png"
        outp.parent.mkdir(parents=True, exist_ok=True)
        plt.tight_layout()
        plt.savefig(outp, dpi=200)
        plt.close()

        rows_out.append({
            "subject": sub,
            "modality": mod,
            "mean_patch_drift": float(np.nanmean(drift)),
            "p90_patch_drift": float(np.nanpercentile(drift[np.isfinite(drift)], 90)) if np.isfinite(drift).any() else np.nan
        })

loc_df = pd.DataFrame(rows_out)
loc_path = TAB_DIR / "localized_patch_drift_summary.csv"
loc_df.to_csv(loc_path, index=False)

print("Saved localized drift figures under:", OUT_DIR_LOC)
print("Saved:", loc_path)
display(loc_df.head(10))


Localized drift ADC:   0%|          | 0/8 [00:00<?, ?it/s]

Saved localized drift figures under: /content/drive/MyDrive/EMBC_project/outputs/figures/localized_drift
Saved: /content/drive/MyDrive/EMBC_project/outputs/tables/localized_patch_drift_summary.csv


Unnamed: 0,subject,modality,mean_patch_drift,p90_patch_drift
0,sub-0011,ADC,0.1637,0.258733
1,sub-0015,ADC,0.177719,0.260561
2,sub-0025,ADC,0.183457,0.312478
3,sub-0027,ADC,0.194557,0.310124
4,sub-0035,ADC,0.272544,0.525966
5,sub-0046,ADC,0.189916,0.311471
6,sub-0047,ADC,0.194292,0.316119
7,sub-0048,ADC,0.127836,0.194608


In [None]:
# =========================
# ADD-ON 2: Field Dominance Index (FDI) + plot
#   FDI = field_effect / (field_effect + resolution_effect)
#   Uses L2 distances by default (can also do cosine by switching columns)
# =========================
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

assert "dist_df" in globals(), "dist_df not found"

fdi_df = dist_df.copy()
eps = 1e-8
fdi_df["FDI_L2"] = fdi_df["l2_lf_vs_3Tlow"] / (fdi_df["l2_lf_vs_3Tlow"] + fdi_df["l2_3Tlow_vs_3Thigh"] + eps)
fdi_df["FDI_COS"] = (1 - fdi_df["cos_lf_vs_3Tlow"]) / ((1 - fdi_df["cos_lf_vs_3Tlow"]) + (1 - fdi_df["cos_3Tlow_vs_3Thigh"]) + eps)

fdi_sum = fdi_df.groupby("modality").agg(
    n=("subject", "nunique"),
    mean_FDI_L2=("FDI_L2", "mean"),
    std_FDI_L2=("FDI_L2", "std"),
    mean_FDI_COS=("FDI_COS", "mean"),
    std_FDI_COS=("FDI_COS", "std"),
).reset_index()

fdi_path = TAB_DIR / "field_dominance_index.csv"
fdi_sum.to_csv(fdi_path, index=False)
print("Saved:", fdi_path)
display(fdi_sum)

# Plot (FDI_L2)
plt.figure(figsize=(8,4))
x = np.arange(len(fdi_sum))
plt.bar(x, fdi_sum["mean_FDI_L2"].values)
plt.xticks(x, fdi_sum["modality"].values, rotation=0)
plt.ylim(0, 1)
plt.ylabel("FDI (L2-based)")
plt.title("Field Dominance Index by modality\nhigher means degradation is mostly field-driven")
outp = FIG_DIR / "field_dominance_index_bar.png"
outp.parent.mkdir(parents=True, exist_ok=True)
plt.tight_layout()
plt.savefig(outp, dpi=200)
plt.close()
print("Saved:", outp)


Saved: /content/drive/MyDrive/EMBC_project/outputs/tables/field_dominance_index.csv


Unnamed: 0,modality,n,mean_FDI_L2,std_FDI_L2,mean_FDI_COS,std_FDI_COS
0,ADC,3,0.698648,0.049668,0.815305,0.0672


Saved: /content/drive/MyDrive/EMBC_project/outputs/figures/field_dominance_index_bar.png


In [None]:
# =========================
# ADD-ON 3: Representation Similarity (RSA + linear CKA) across feature families
#   - Aligns samples by (subject, modality, condition)
#   - Works with:
#       (a) ResNet embeddings (embed_meta + embeddings)
#       (b) Classical features if you have a dataframe named classic_df or feat_df
#   - Outputs per-modality heatmaps to FIG_DIR / "rep_similarity"
# =========================
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler

assert "embed_meta" in globals() and "embeddings" in globals(), "Embeddings not found (embed_meta/embeddings)"

# Try to find your classical features dataframe automatically
# Expected columns: subject, modality, condition + numeric feature columns
CLASSICAL_DF_NAME = None
for cand in ["classic_df", "feat_df", "features_df", "classical_df"]:
    if cand in globals() and isinstance(globals()[cand], pd.DataFrame):
        CLASSICAL_DF_NAME = cand
        break

classic = globals()[CLASSICAL_DF_NAME] if CLASSICAL_DF_NAME else None
if classic is None:
    print("No classical feature dataframe found. This will run RSA/CKA only for ResNet (single-family).")
else:
    print("Using classical feature dataframe:", CLASSICAL_DF_NAME)

def _key_df(df):
    return df[["subject","modality","condition"]].astype(str).agg("|".join, axis=1)

# Build aligned ResNet table
E_meta = embed_meta.copy()
E_meta["key"] = _key_df(E_meta)
E = pd.DataFrame(embeddings, index=E_meta["key"].values)
E.index.name = "key"

# Build aligned classical features table if available
C = None
if classic is not None:
    C0 = classic.copy()
    for col in ["subject","modality","condition"]:
        C0[col] = C0[col].astype(str)
    C0["key"] = _key_df(C0)
    num_cols = [c for c in C0.columns if c not in ["subject","modality","condition","key"] and np.issubdtype(C0[c].dtype, np.number)]
    C = C0.set_index("key")[num_cols].copy()
    C = C.replace([np.inf, -np.inf], np.nan)

def linear_cka(X, Y):
    # X: (n,d), Y:(n,k)
    X = X - X.mean(axis=0, keepdims=True)
    Y = Y - Y.mean(axis=0, keepdims=True)
    dot = np.linalg.norm(X.T @ Y, ord="fro") ** 2
    normx = np.linalg.norm(X.T @ X, ord="fro")
    normy = np.linalg.norm(Y.T @ Y, ord="fro")
    denom = (normx * normy) + 1e-12
    return float(dot / (denom ** 2))

def rsa_spearman(X):
    # distance matrix upper triangle vector
    from scipy.spatial.distance import pdist
    from scipy.stats import spearmanr
    d = pdist(X, metric="euclidean")
    return d

OUT_RS = FIG_DIR / "rep_similarity"
OUT_RS.mkdir(parents=True, exist_ok=True)

modalities = sorted(E_meta["modality"].unique().tolist())
rows_summary = []

for mod in modalities:
    # Keys present for this modality
    keys_mod = E_meta.loc[E_meta["modality"] == mod, "key"].values
    if len(keys_mod) < 6:
        continue

    # Intersect across families
    keys = set(keys_mod)
    families = {"resnet": E.loc[keys_mod].copy()}
    if C is not None:
        keys = keys.intersection(set(C.index.values))
    keys = sorted(list(keys))

    if len(keys) < 6:
        continue

    Xr = E.loc[keys].values.astype(np.float32)
    Xr = Xr[np.all(np.isfinite(Xr), axis=1)]
    if len(Xr) < 6:
        continue
    families = {"resnet": Xr}

    if C is not None:
        Xc = C.loc[keys].values.astype(np.float32)
        good = np.all(np.isfinite(Xc), axis=1) & np.all(np.isfinite(E.loc[keys].values.astype(np.float32)), axis=1)
        Xc = Xc[good]
        Xr2 = E.loc[keys].values.astype(np.float32)[good]
        if len(Xc) >= 6:
            families = {"resnet": Xr2, "classical": Xc}

    # Standardize each family
    fam_names = list(families.keys())
    fam_mats = {}
    for name in fam_names:
        X = families[name]
        Xz = StandardScaler().fit_transform(X)
        fam_mats[name] = Xz

    # Compute CKA and RSA correlations
    n = len(fam_names)
    cka_mat = np.zeros((n,n), dtype=np.float32)
    rsa_mat = np.zeros((n,n), dtype=np.float32)

    # Precompute RSA vectors
    rsa_vecs = {name: rsa_spearman(fam_mats[name]) for name in fam_names}

    from scipy.stats import spearmanr
    for i,a in enumerate(fam_names):
        for j,b in enumerate(fam_names):
            cka_mat[i,j] = linear_cka(fam_mats[a], fam_mats[b])
            rho = spearmanr(rsa_vecs[a], rsa_vecs[b]).correlation
            rsa_mat[i,j] = float(rho)

    # Save heatmaps
    plt.figure(figsize=(5,4))
    plt.imshow(cka_mat, vmin=0, vmax=1)
    plt.xticks(range(n), fam_names)
    plt.yticks(range(n), fam_names)
    plt.title(f"Linear CKA - {mod}")
    plt.colorbar(fraction=0.046, pad=0.04)
    outp1 = OUT_RS / f"cka_{mod}.png"
    plt.tight_layout(); plt.savefig(outp1, dpi=200); plt.close()

    plt.figure(figsize=(5,4))
    plt.imshow(rsa_mat, vmin=-1, vmax=1)
    plt.xticks(range(n), fam_names)
    plt.yticks(range(n), fam_names)
    plt.title(f"RSA (Spearman of distance vectors) - {mod}")
    plt.colorbar(fraction=0.046, pad=0.04)
    outp2 = OUT_RS / f"rsa_{mod}.png"
    plt.tight_layout(); plt.savefig(outp2, dpi=200); plt.close()

    rows_summary.append({
        "modality": mod,
        "families": ",".join(fam_names),
        "cka_mean_offdiag": float((cka_mat.sum() - np.trace(cka_mat)) / (n*n - n)) if n > 1 else 1.0,
        "rsa_mean_offdiag": float((rsa_mat.sum() - np.trace(rsa_mat)) / (n*n - n)) if n > 1 else 1.0,
        "n_samples": int(len(list(fam_mats.values())[0]))
    })

rep_sum = pd.DataFrame(rows_summary)
rep_path = TAB_DIR / "representation_similarity_summary.csv"
rep_sum.to_csv(rep_path, index=False)
print("Saved rep similarity figures under:", OUT_RS)
print("Saved:", rep_path)
display(rep_sum)


Using classical feature dataframe: features_df
Saved rep similarity figures under: /content/drive/MyDrive/EMBC_project/outputs/figures/rep_similarity
Saved: /content/drive/MyDrive/EMBC_project/outputs/tables/representation_similarity_summary.csv


Unnamed: 0,modality,families,cka_mean_offdiag,rsa_mean_offdiag,n_samples
0,ADC,"resnet,classical",1e-06,0.689318,9


In [None]:
# =========================
# ADD-ON 4: Frequency-domain loss (radial power spectrum + high-frequency fraction)
#   - Computes radial power spectrum on mid-slice (masked, fixed space)
#   - Summarizes high-frequency energy fraction and spectral slope
#   - Saves plots and a CSV summary
# =========================
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

assert "preproc_df" in globals(), "preproc_df not found"

def radial_profile(power2):
    H, W = power2.shape
    cy, cx = H // 2, W // 2
    y, x = np.ogrid[:H, :W]
    r = np.sqrt((y - cy)**2 + (x - cx)**2)
    r_int = r.astype(np.int32)
    maxr = r_int.max()
    prof = np.zeros(maxr + 1, dtype=np.float32)
    cnt = np.zeros(maxr + 1, dtype=np.int32)
    for rr in range(maxr + 1):
        m = (r_int == rr)
        if m.any():
            prof[rr] = float(power2[m].mean())
            cnt[rr] = int(m.sum())
    return prof, cnt

def spectrum_metrics(img2, mask2, hf_start_frac=0.35):
    # Normalize within mask for stability
    x = img2.astype(np.float32)
    if mask2.sum() > 0:
        vals = x[mask2]
        lo, hi = np.percentile(vals, [1, 99])
    else:
        lo, hi = np.percentile(x, [1, 99])
    x = np.clip(x, lo, hi)
    x = (x - lo) / (hi - lo + 1e-6)

    # Apply mask
    x = x * mask2.astype(np.float32)

    # FFT power spectrum
    F = np.fft.fftshift(np.fft.fft2(x))
    P = (np.abs(F) ** 2).astype(np.float32)

    prof, _ = radial_profile(P)
    r = np.arange(len(prof), dtype=np.float32)
    # Avoid r=0 for slope
    rr = r[1:]
    pp = prof[1:] + 1e-8

    # Spectral slope in log-log (rough)
    log_r = np.log(rr + 1e-6)
    log_p = np.log(pp)
    slope = np.polyfit(log_r, log_p, 1)[0]

    # High-frequency fraction
    hf_start = int(len(prof) * hf_start_frac)
    hf = float(prof[hf_start:].sum() / (prof.sum() + 1e-8))

    return prof, slope, hf

rows = []
OUT_SPEC = FIG_DIR / "frequency_domain"
OUT_SPEC.mkdir(parents=True, exist_ok=True)

MAX_SUBJECTS = 6 if FAST_DEBUG else 11

for mod in sorted(preproc_df["modality"].unique().tolist()):
    dfm = preproc_df[preproc_df["modality"] == mod].copy().head(MAX_SUBJECTS)

    prof_acc = {"64mT": [], "3T_lowres": [], "3T_highres": []}

    for _, r in tqdm(dfm.iterrows(), total=len(dfm), desc=f"Spectrum {mod}"):
        sub = r["subject"]
        pm  = r["path_mask_fixed"]
        mimg = sitk.ReadImage(pm)
        mask = (sitk.GetArrayFromImage(mimg).astype(np.uint8) > 0)
        z = mask.shape[0] // 2
        m2 = mask[z]
        if m2.sum() < 200:
            continue

        paths = {
            "64mT": r["path_64mT_registered_norm"],
            "3T_lowres": r["path_fixed_3T_lowres_norm"],
            "3T_highres": r["path_3T_highres_resampled_norm"],
        }

        for cond, p in paths.items():
            img = sitk_read(p)
            arr = sitk.GetArrayFromImage(img).astype(np.float32)
            img2 = arr[z]
            prof, slope, hf = spectrum_metrics(img2, m2)

            rows.append({
                "subject": sub,
                "modality": mod,
                "condition": cond,
                "spectral_slope": float(slope),
                "highfreq_fraction": float(hf),
            })
            prof_acc[cond].append(prof)

    # Plot mean radial spectra for this modality
    if all(len(prof_acc[c]) > 0 for c in prof_acc.keys()):
        # pad to same length
        L = min(min(len(p) for p in prof_acc["64mT"]),
                min(len(p) for p in prof_acc["3T_lowres"]),
                min(len(p) for p in prof_acc["3T_highres"]))
        plt.figure(figsize=(7,4))
        for cond in ["64mT","3T_lowres","3T_highres"]:
            P = np.stack([p[:L] for p in prof_acc[cond]], axis=0)
            mean = P.mean(axis=0)
            plt.plot(mean, label=cond)
        plt.title(f"Mean radial power spectrum (mid-slice) - {mod}")
        plt.xlabel("Radius (frequency bin)")
        plt.ylabel("Power (a.u.)")
        plt.legend()
        outp = OUT_SPEC / f"radial_spectrum_{mod}.png"
        plt.tight_layout(); plt.savefig(outp, dpi=200); plt.close()
        print("Saved:", outp)

spec_df = pd.DataFrame(rows)
spec_path = TAB_DIR / "frequency_domain_metrics.csv"
spec_df.to_csv(spec_path, index=False)
print("Saved:", spec_path)
display(spec_df.head(10))


Spectrum ADC:   0%|          | 0/10 [00:00<?, ?it/s]

Saved: /content/drive/MyDrive/EMBC_project/outputs/figures/frequency_domain/radial_spectrum_ADC.png
Saved: /content/drive/MyDrive/EMBC_project/outputs/tables/frequency_domain_metrics.csv


Unnamed: 0,subject,modality,condition,spectral_slope,highfreq_fraction
0,sub-0011,ADC,64mT,-2.894543,0.000125
1,sub-0011,ADC,3T_lowres,-2.370341,0.000544
2,sub-0011,ADC,3T_highres,-2.619454,0.000402
3,sub-0015,ADC,64mT,-2.914201,8.1e-05
4,sub-0015,ADC,3T_lowres,-2.196025,0.000667
5,sub-0015,ADC,3T_highres,-2.506091,0.00056
6,sub-0025,ADC,64mT,-3.139255,0.0001
7,sub-0025,ADC,3T_lowres,-2.241898,0.000657
8,sub-0025,ADC,3T_highres,-2.432579,0.000554
9,sub-0027,ADC,64mT,-2.913081,8e-05


In [None]:
# =========================
# ADD-ON 5: 64mT ADC run-1 vs run-2 reliability (if both runs exist)
#   - Finds ADC run-1 and run-2 for 64mT in manifest
#   - Rigid-registers each run to 3T_lowres ADC fixed image (if needed)
#   - Computes ResNet embeddings and within-64mT distance vs cross-field drift
#   - Saves CSV + a plot
# =========================
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

assert "manifest" in globals(), "manifest not found (need it to locate ADC runs)"
assert "preproc_df" in globals(), "preproc_df not found"
assert "volume_to_resnet_embedding" in globals(), "volume_to_resnet_embedding not found"

def _get_adc_runs_64mt(subject):
    df = manifest[(manifest["field"]=="64mT") & (manifest["subject"]==subject) & (manifest["modality"]=="ADC")].copy()
    if "run" not in df.columns:
        return {}
    out = {}
    for _, r in df.iterrows():
        rn = str(r.get("run", "")).strip()
        if rn and rn != "None" and rn != "nan":
            out[rn] = r["nifti_path"]
    return out

# Minimal rigid registration if you do not already have one
import SimpleITK as sitk

def rigid_register_to_fixed(fixed_path, moving_path):
    fixed = sitk.ReadImage(str(fixed_path), sitk.sitkFloat32)
    moving = sitk.ReadImage(str(moving_path), sitk.sitkFloat32)

    init = sitk.CenteredTransformInitializer(
        fixed, moving, sitk.Euler3DTransform(),
        sitk.CenteredTransformInitializerFilter.GEOMETRY
    )

    reg = sitk.ImageRegistrationMethod()
    reg.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
    reg.SetMetricSamplingStrategy(reg.RANDOM)
    reg.SetMetricSamplingPercentage(0.2)
    reg.SetInterpolator(sitk.sitkLinear)
    reg.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=80,
                                      convergenceMinimumValue=1e-6, convergenceWindowSize=10)
    reg.SetOptimizerScalesFromPhysicalShift()
    reg.SetInitialTransform(init, inPlace=False)

    tx = reg.Execute(fixed, moving)
    res = sitk.Resample(moving, fixed, tx, sitk.sitkLinear, 0.0, sitk.sitkFloat32)
    return res

def zscore_in_mask(img_sitk, mask_sitk):
    arr = sitk.GetArrayFromImage(img_sitk).astype(np.float32)
    m = (sitk.GetArrayFromImage(mask_sitk).astype(np.uint8) > 0)
    vals = arr[m]
    mu = float(vals.mean()) if vals.size else 0.0
    sd = float(vals.std()) if vals.size else 1.0
    if sd < 1e-6:
        sd = 1.0
    arr = (arr - mu) / sd
    out = sitk.GetImageFromArray(arr)
    out.CopyInformation(img_sitk)
    return out

# Prepare per-subject fixed ADC row (3T lowres fixed space) from preproc_df
adc_fixed_rows = preproc_df[preproc_df["modality"]=="ADC"].copy()
if len(adc_fixed_rows) == 0:
    print("No ADC rows in preproc_df. Skipping run reliability.")
else:
    OUT_REL = OUT_DIR / "adc_run_reliability_cache"
    OUT_REL.mkdir(parents=True, exist_ok=True)

    rel_rows = []
    for _, r in tqdm(adc_fixed_rows.iterrows(), total=len(adc_fixed_rows), desc="ADC run reliability"):
        sub = r["subject"]
        runs = _get_adc_runs_64mt(sub)
        if not ("1" in runs and "2" in runs):
            continue

        fixed_path = r["path_fixed_3T_lowres_norm"]
        mask_path  = r["path_mask_fixed"]

        # Register run1 and run2 to fixed, normalize in mask, cache to disk
        fixed_img = sitk.ReadImage(str(fixed_path), sitk.sitkFloat32)
        mask_img  = sitk.ReadImage(str(mask_path), sitk.sitkUInt8)

        cached = {}
        for rn in ["1","2"]:
            outp = OUT_REL / f"{sub}_ADC_run-{rn}_to_3Tlow_norm.nii.gz"
            if not outp.exists():
                reg_img = rigid_register_to_fixed(fixed_path, runs[rn])
                norm_img = zscore_in_mask(reg_img, mask_img)
                sitk.WriteImage(norm_img, str(outp))
            cached[rn] = str(outp)

        # Embeddings
        e1 = volume_to_resnet_embedding(cached["1"], mask_path, n_slices=16 if FAST_DEBUG else 32)
        e2 = volume_to_resnet_embedding(cached["2"], mask_path, n_slices=16 if FAST_DEBUG else 32)
        e3 = volume_to_resnet_embedding(fixed_path, mask_path, n_slices=16 if FAST_DEBUG else 32)  # 3T_low

        def cos_dist(a,b):
            if np.any(np.isnan(a)) or np.any(np.isnan(b)):
                return np.nan
            na = np.linalg.norm(a) + 1e-8
            nb = np.linalg.norm(b) + 1e-8
            return float(1.0 - (a @ b) / (na * nb))

        rel_rows.append({
            "subject": sub,
            "modality": "ADC",
            "cosdist_run1_vs_run2_64mT": cos_dist(e1, e2),
            "cosdist_run1_64mT_vs_3Tlow": cos_dist(e1, e3),
            "cosdist_run2_64mT_vs_3Tlow": cos_dist(e2, e3),
        })

    rel_df = pd.DataFrame(rel_rows)
    rel_path = TAB_DIR / "adc_run_reliability.csv"
    rel_df.to_csv(rel_path, index=False)
    print("Saved:", rel_path)
    display(rel_df.head(20))

    if len(rel_df) > 0:
        # Plot comparison
        a = rel_df["cosdist_run1_vs_run2_64mT"].values
        b = rel_df[["cosdist_run1_64mT_vs_3Tlow","cosdist_run2_64mT_vs_3Tlow"]].mean(axis=1).values

        plt.figure(figsize=(6,4))
        plt.boxplot([a[~np.isnan(a)], b[~np.isnan(b)]], tick_labels=["64mT run1 vs run2", "64mT vs 3T_low"], showmeans=True)
        plt.ylabel("Cosine distance (lower is more similar)")
        plt.title("ADC reliability vs cross-field drift")
        outp = FIG_DIR / "adc_run_reliability_box.png"
        outp.parent.mkdir(parents=True, exist_ok=True)
        plt.tight_layout(); plt.savefig(outp, dpi=200); plt.close()
        print("Saved:", outp)
    else:
        print("No subjects with both ADC run-1 and run-2 found. Skipping plot.")


ADC run reliability:   0%|          | 0/10 [00:00<?, ?it/s]

Saved: /content/drive/MyDrive/EMBC_project/outputs/tables/adc_run_reliability.csv


No subjects with both ADC run-1 and run-2 found. Skipping plot.


In [None]:
# =========================
# ADD-ON 6: Bootstrap confidence intervals for key metrics
#   - Bootstraps at the subject level within each modality
#   - Produces CI for mean distances and FDI
#   - Saves CSV
# =========================
import numpy as np
import pandas as pd

assert "dist_df" in globals(), "dist_df not found"

def bootstrap_ci_mean(values, n_boot=2000, seed=0):
    rng = np.random.default_rng(seed)
    v = np.array(values, dtype=np.float64)
    v = v[np.isfinite(v)]
    if len(v) < 2:
        return (np.nan, np.nan, np.nan)
    boots = []
    for _ in range(n_boot):
        samp = rng.choice(v, size=len(v), replace=True)
        boots.append(float(np.mean(samp)))
    boots = np.array(boots)
    return (float(np.mean(v)), float(np.percentile(boots, 2.5)), float(np.percentile(boots, 97.5)))

boot_rows = []
for mod in sorted(dist_df["modality"].unique().tolist()):
    d = dist_df[dist_df["modality"] == mod].copy()
    if len(d) < 2:
        continue

    # Mean L2 distances
    mean_lf_low, lo_lf_low, hi_lf_low = bootstrap_ci_mean(d["l2_lf_vs_3Tlow"].values, seed=SEED)
    mean_low_high, lo_low_high, hi_low_high = bootstrap_ci_mean(d["l2_3Tlow_vs_3Thigh"].values, seed=SEED+1)
    mean_lf_high, lo_lf_high, hi_lf_high = bootstrap_ci_mean(d["l2_lf_vs_3Thigh"].values, seed=SEED+2)

    # FDI (L2)
    eps = 1e-8
    fdi = d["l2_lf_vs_3Tlow"].values / (d["l2_lf_vs_3Tlow"].values + d["l2_3Tlow_vs_3Thigh"].values + eps)
    mean_fdi, lo_fdi, hi_fdi = bootstrap_ci_mean(fdi, seed=SEED+3)

    boot_rows.append({
        "modality": mod,
        "n_pairs": int(d["subject"].nunique()),
        "mean_l2_lf_vs_low": mean_lf_low,
        "ci95_l2_lf_vs_low_lo": lo_lf_low,
        "ci95_l2_lf_vs_low_hi": hi_lf_low,
        "mean_l2_low_vs_high": mean_low_high,
        "ci95_l2_low_vs_high_lo": lo_low_high,
        "ci95_l2_low_vs_high_hi": hi_low_high,
        "mean_l2_lf_vs_high": mean_lf_high,
        "ci95_l2_lf_vs_high_lo": lo_lf_high,
        "ci95_l2_lf_vs_high_hi": hi_lf_high,
        "mean_FDI_L2": mean_fdi,
        "ci95_FDI_L2_lo": lo_fdi,
        "ci95_FDI_L2_hi": hi_fdi,
    })

boot_df = pd.DataFrame(boot_rows)
boot_path = TAB_DIR / "bootstrap_ci_summary.csv"
boot_df.to_csv(boot_path, index=False)
print("Saved:", boot_path)
display(boot_df)


Saved: /content/drive/MyDrive/EMBC_project/outputs/tables/bootstrap_ci_summary.csv


Unnamed: 0,modality,n_pairs,mean_l2_lf_vs_low,ci95_l2_lf_vs_low_lo,ci95_l2_lf_vs_low_hi,mean_l2_low_vs_high,ci95_l2_low_vs_high_lo,ci95_l2_low_vs_high_hi,mean_l2_lf_vs_high,ci95_l2_lf_vs_high_lo,ci95_l2_lf_vs_high_hi,mean_FDI_L2,ci95_FDI_L2_lo,ci95_FDI_L2_hi
0,ADC,3,5.669641,5.456633,5.985075,2.457192,1.972973,2.946868,5.5758,5.405859,5.766547,0.698648,0.653883,0.752078


In [None]:
# =========================
# ADD-ON 7: Protocol guidance as an optimization problem (label-free)
#   - Ranks modalities by robustness and optionally by scan-time proxies if present
#   - Produces recommended "protocol variants" (top-k modalities)
#   - Saves CSV
# =========================
import numpy as np
import pandas as pd

assert "dist_df" in globals(), "dist_df not found"

# Robustness score: lower field-effect distance is better, lower combined effect is better
# Convert to a 0..1 score where higher is better
prot = dist_df.groupby("modality").agg(
    n=("subject","nunique"),
    l2_field=("l2_lf_vs_3Tlow","mean"),
    l2_res=("l2_3Tlow_vs_3Thigh","mean"),
    l2_combined=("l2_lf_vs_3Thigh","mean"),
    cos_field=("cos_lf_vs_3Tlow","mean"),
    cos_res=("cos_3Tlow_vs_3Thigh","mean"),
    cos_combined=("cos_lf_vs_3Thigh","mean"),
).reset_index()

# Optional scan-time proxy from sidecar JSON if available in manifest (best-effort)
# Many datasets do not provide scan time; if not found, we treat times as equal.
time_map = {}
if "json_path" in manifest.columns:
    import json, os
    # Try to parse RepetitionTime or TotalReadoutTime as a proxy (not true scan time)
    for mod in prot["modality"].tolist():
        # find any 3T lowres json for this mod
        dfj = manifest[(manifest["field"]=="3T") & (manifest["modality"]==mod) & (manifest["acq"]=="lowres")].copy()
        dfj = dfj[dfj["json_path"].notna()]
        if len(dfj) == 0:
            continue
        jp = dfj.iloc[0]["json_path"]
        try:
            with open(jp, "r") as f:
                meta = json.load(f)
            # try common keys
            t = None
            for k in ["ScanTime", "AcquisitionDuration", "TotalScanTime", "EstimatedScanTime"]:
                if k in meta:
                    t = float(meta[k]); break
            if t is None and "RepetitionTime" in meta:
                # not scan time, but gives a proxy for "longer sequence"
                t = float(meta["RepetitionTime"])
            if t is not None and np.isfinite(t):
                time_map[mod] = t
        except Exception:
            pass

prot["time_proxy"] = prot["modality"].map(time_map).astype(float)
if prot["time_proxy"].isna().all():
    prot["time_proxy"] = 1.0  # equal weights if unknown
    time_note = "No scan-time metadata found; using equal time_proxy=1.0 for all modalities."
else:
    prot["time_proxy"] = prot["time_proxy"].fillna(prot["time_proxy"].median())
    time_note = "Using time_proxy from available JSON fields (best-effort; may be TR-based proxy)."

# Normalize components
def norm01(x):
    x = np.array(x, dtype=np.float64)
    lo, hi = np.nanmin(x), np.nanmax(x)
    if not np.isfinite(lo) or not np.isfinite(hi) or abs(hi-lo) < 1e-12:
        return np.zeros_like(x)
    return (x - lo) / (hi - lo)

# Robustness: prefer low l2_field and low l2_combined
field_norm = norm01(prot["l2_field"].values)
comb_norm  = norm01(prot["l2_combined"].values)
time_norm  = norm01(prot["time_proxy"].values)

# Utility score weights (tune if desired)
w_field = 0.50
w_comb  = 0.35
w_time  = 0.15

prot["utility"] = (
    w_field * (1 - field_norm) +
    w_comb  * (1 - comb_norm)  -
    w_time  * (time_norm)
)

prot = prot.sort_values("utility", ascending=False).reset_index(drop=True)
prot_path = TAB_DIR / "protocol_optimization_ranking.csv"
prot.to_csv(prot_path, index=False)

print(time_note)
print("Saved:", prot_path)
display(prot)

# Recommend protocol variants: top-k sequences (k=1..min(4, num modalities))
mods_ranked = prot["modality"].tolist()
variants = []
KMAX = min(4, len(mods_ranked))
for k in range(1, KMAX+1):
    chosen = mods_ranked[:k]
    variants.append({
        "variant": f"Top-{k}",
        "modalities": ",".join(chosen),
        "mean_l2_field": float(prot.set_index("modality").loc[chosen, "l2_field"].mean()),
        "mean_l2_combined": float(prot.set_index("modality").loc[chosen, "l2_combined"].mean()),
        "utility_sum": float(prot.set_index("modality").loc[chosen, "utility"].sum()),
    })

variants_df = pd.DataFrame(variants)
variants_path = TAB_DIR / "protocol_variant_recommendations.csv"
variants_df.to_csv(variants_path, index=False)
print("Saved:", variants_path)
display(variants_df)


Using time_proxy from available JSON fields (best-effort; may be TR-based proxy).
Saved: /content/drive/MyDrive/EMBC_project/outputs/tables/protocol_optimization_ranking.csv


Unnamed: 0,modality,n,l2_field,l2_res,l2_combined,cos_field,cos_res,cos_combined,time_proxy,utility
0,ADC,3,5.669641,2.457192,5.5758,0.878605,0.972426,0.882825,33.924007,0.85


Saved: /content/drive/MyDrive/EMBC_project/outputs/tables/protocol_variant_recommendations.csv


Unnamed: 0,variant,modalities,mean_l2_field,mean_l2_combined,utility_sum
0,Top-1,ADC,5.669641,5.5758,0.85


In [None]:
# =========================
# FINAL RESULTS AGGREGATOR (single Colab cell) — WORKING/ROBUST
# =========================
# Purpose:
#   - Load all saved tables from your pipeline (QC, distances, features, add-ons)
#   - Produce publication-style summary tables + key figures
#   - Auto-generate a draft "Results" section (plain text + numbers)
#
# Assumptions:
#   - You already defined: OUT_DIR, TAB_DIR, FIG_DIR, SEED
#   - Your pipeline saved CSVs into TAB_DIR (some may be missing/empty; this cell handles that)
# =========================

import os, json, math
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from pandas.errors import EmptyDataError
from scipy.stats import wilcoxon, ttest_rel

# -------------------------
# Helpers
# -------------------------
def _p(p):
    if p is None or (isinstance(p, float) and (not np.isfinite(p))):
        return "NA"
    if p < 1e-4: return "<1e-4"
    if p < 1e-3: return "<1e-3"
    return f"{p:.4f}"

def _fmt(x, nd=3):
    if x is None or (isinstance(x, float) and (not np.isfinite(x))):
        return "NA"
    return f"{float(x):.{nd}f}"

def ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)
    return p

def save_fig(path: Path):
    ensure_dir(path.parent)
    plt.tight_layout()
    plt.savefig(path, dpi=200)
    plt.close()

def safe_read_csv(path: Path, required=False):
    """
    Robust CSV reader:
      - returns None if missing OR empty OR unreadable
      - if required=True, raises on missing/empty/unreadable
    """
    path = Path(path)
    if not path.exists():
        if required:
            raise FileNotFoundError(f"Missing required file: {path}")
        return None

    # empty file (0 bytes) -> treat as missing
    try:
        if path.stat().st_size == 0:
            if required:
                raise ValueError(f"Required file is empty (0 bytes): {path}")
            return None
    except OSError:
        pass

    try:
        df = pd.read_csv(path)
        if df.shape[1] == 0:
            if required:
                raise ValueError(f"Required file has no columns: {path}")
            return None
        return df
    except EmptyDataError:
        if required:
            raise
        return None
    except Exception as e:
        if required:
            raise
        print(f"[read error] {path.name}: {type(e).__name__}: {e} -> skipping")
        return None

def paired_test(df, colA, colB):
    a = df[colA].astype(float).values
    b = df[colB].astype(float).values
    m = np.isfinite(a) & np.isfinite(b)
    a, b = a[m], b[m]
    out = {"n": int(len(a))}
    if len(a) < 3:
        out.update({"wilcoxon_p": np.nan, "ttest_p": np.nan, "meanA": np.nan, "meanB": np.nan, "delta": np.nan})
        return out
    try:
        out["wilcoxon_p"] = float(wilcoxon(a, b).pvalue)
    except Exception:
        out["wilcoxon_p"] = np.nan
    try:
        out["ttest_p"] = float(ttest_rel(a, b).pvalue)
    except Exception:
        out["ttest_p"] = np.nan
    out["meanA"] = float(np.mean(a))
    out["meanB"] = float(np.mean(b))
    out["delta"] = float(np.mean(a) - np.mean(b))
    return out

def write_text(path: Path, text: str):
    ensure_dir(path.parent)
    path.write_text(text, encoding="utf-8")
    return path

def require_cols(df: pd.DataFrame, cols, name: str):
    missing = [c for c in cols if c not in df.columns]
    if missing:
        raise ValueError(f"{name} missing columns: {missing}\nFound columns: {list(df.columns)}")

# -------------------------
# Locate expected outputs
# -------------------------
TAB_DIR = Path(TAB_DIR)
FIG_DIR = Path(FIG_DIR)
OUT_DIR = Path(OUT_DIR)

ensure_dir(FIG_DIR / "results_agg")
ensure_dir(TAB_DIR / "results_agg")

PATHS = {
    "qc": TAB_DIR / "qc_metrics.csv",
    "dist": TAB_DIR / "paired_embedding_distances.csv",
    "paired_stats": TAB_DIR / "paired_stats_summary.csv",
    "protocol_guidance": TAB_DIR / "protocol_guidance_summary.csv",
    "classic_features": TAB_DIR / "classical_features.csv",
    "freq": TAB_DIR / "frequency_domain_metrics.csv",
    "fdi": TAB_DIR / "field_dominance_index.csv",
    "loc": TAB_DIR / "localized_patch_drift_summary.csv",
    "boot": TAB_DIR / "bootstrap_ci_summary.csv",
    "adc_rel": TAB_DIR / "adc_run_reliability.csv",
    "rep_sim": TAB_DIR / "representation_similarity_summary.csv",
    "protocol_opt": TAB_DIR / "protocol_optimization_ranking.csv",
    "protocol_vars": TAB_DIR / "protocol_variant_recommendations.csv",
}

tables = {k: safe_read_csv(p, required=False) for k, p in PATHS.items()}

print("Found tables:")
for k, df in tables.items():
    status = "OK" if (df is not None and df.shape[1] > 0) else "MISSING/EMPTY"
    shape = f"{df.shape[0]}x{df.shape[1]}" if df is not None else "-"
    print(f"  - {k:14s}: {status:12s} {shape:>8s}  ({PATHS[k].name})")

# -------------------------
# Required: embedding distances
# -------------------------
dist_df = tables["dist"]
if dist_df is None or len(dist_df) == 0:
    raise RuntimeError(f"Missing or empty: {PATHS['dist']}. Run Cell 11 (distances) first.")

require_cols(
    dist_df,
    ["subject","modality",
     "cos_lf_vs_3Tlow","l2_lf_vs_3Tlow",
     "cos_3Tlow_vs_3Thigh","l2_3Tlow_vs_3Thigh",
     "cos_lf_vs_3Thigh","l2_lf_vs_3Thigh"],
    "paired_embedding_distances.csv"
)

# -------------------------
# 1) Embedding distances summary by modality
# -------------------------
dist_summary = dist_df.groupby("modality").agg(
    n=("subject","nunique"),
    mean_cos_lf_vs_low=("cos_lf_vs_3Tlow","mean"),
    mean_cos_low_vs_high=("cos_3Tlow_vs_3Thigh","mean"),
    mean_cos_lf_vs_high=("cos_lf_vs_3Thigh","mean"),
    mean_l2_lf_vs_low=("l2_lf_vs_3Tlow","mean"),
    mean_l2_low_vs_high=("l2_3Tlow_vs_3Thigh","mean"),
    mean_l2_lf_vs_high=("l2_lf_vs_3Thigh","mean"),
    std_l2_lf_vs_low=("l2_lf_vs_3Tlow","std"),
    std_l2_low_vs_high=("l2_3Tlow_vs_3Thigh","std"),
).reset_index()

dist_summary_path = TAB_DIR / "results_agg" / "embedding_distance_summary_by_modality.csv"
dist_summary.to_csv(dist_summary_path, index=False)

# Plot: per-modality bars for field vs resolution (L2)
plt.figure(figsize=(9,4))
mods = dist_summary["modality"].values
x = np.arange(len(mods))
plt.bar(x - 0.15, dist_summary["mean_l2_lf_vs_low"].values, width=0.3, label="64mT vs 3T_low (field)")
plt.bar(x + 0.15, dist_summary["mean_l2_low_vs_high"].values, width=0.3, label="3T_low vs 3T_high (resolution)")
plt.xticks(x, mods)
plt.ylabel("Mean L2 distance (lower = more similar)")
plt.title("Embedding distances: field-strength vs resolution effect (ResNet50)")
plt.legend()
fig1 = FIG_DIR / "results_agg" / "embedding_field_vs_resolution_by_modality.png"
save_fig(fig1)

# Paired stats per modality
stat_rows = []
for mod in sorted(dist_df["modality"].unique().tolist()):
    d = dist_df[dist_df["modality"]==mod].copy()
    s = paired_test(d, "l2_lf_vs_3Tlow", "l2_3Tlow_vs_3Thigh")
    stat_rows.append({
        "modality": mod,
        "n_pairs": s["n"],
        "mean_l2_field": s["meanA"],
        "mean_l2_resolution": s["meanB"],
        "delta_field_minus_resolution": s.get("delta", np.nan),
        "wilcoxon_p": s.get("wilcoxon_p", np.nan),
        "ttest_p": s.get("ttest_p", np.nan)
    })
field_vs_res_stats = pd.DataFrame(stat_rows)
field_vs_res_path = TAB_DIR / "results_agg" / "paired_field_vs_resolution_stats_by_modality.csv"
field_vs_res_stats.to_csv(field_vs_res_path, index=False)

# Overall pooled field vs resolution stats
pooled = paired_test(dist_df, "l2_lf_vs_3Tlow", "l2_3Tlow_vs_3Thigh")

# -------------------------
# 2) QC summary (optional)
# -------------------------
qc_df = tables["qc"]
qc_summary = None
qc_fig = None
if qc_df is not None and len(qc_df) > 0:
    req = ["subject","modality","ssim_64mT_vs_3Tlow","corr_64mT_vs_3Tlow","ssim_3Thigh_vs_3Tlow","corr_3Thigh_vs_3Tlow"]
    missing = [c for c in req if c not in qc_df.columns]
    if missing:
        print(f"[QC] Missing columns {missing} -> skipping QC summary/plot.")
    else:
        qc_summary = qc_df.groupby("modality").agg(
            n=("subject","nunique"),
            mean_ssim_lf_vs_low=("ssim_64mT_vs_3Tlow","mean"),
            mean_corr_lf_vs_low=("corr_64mT_vs_3Tlow","mean"),
            mean_ssim_high_vs_low=("ssim_3Thigh_vs_3Tlow","mean"),
            mean_corr_high_vs_low=("corr_3Thigh_vs_3Tlow","mean"),
        ).reset_index()

        qc_summary_path = TAB_DIR / "results_agg" / "qc_summary_by_modality.csv"
        qc_summary.to_csv(qc_summary_path, index=False)

        plt.figure(figsize=(9,4))
        mods = qc_summary["modality"].values
        x = np.arange(len(mods))
        plt.bar(x - 0.15, qc_summary["mean_ssim_lf_vs_low"].values, width=0.3, label="SSIM: 64mT vs 3T_low")
        plt.bar(x + 0.15, qc_summary["mean_ssim_high_vs_low"].values, width=0.3, label="SSIM: 3T_high vs 3T_low")
        plt.xticks(x, mods)
        plt.ylabel("Mean SSIM (higher = more similar)")
        plt.title("QC similarity: field vs resolution (SSIM)")
        plt.legend()
        qc_fig = FIG_DIR / "results_agg" / "qc_ssim_by_modality.png"
        save_fig(qc_fig)

# -------------------------
# 3) Field Dominance Index (optional)
# -------------------------
fdi_df = tables["fdi"]
fdi_fig = None
if fdi_df is not None and len(fdi_df) > 0 and ("mean_FDI_L2" in fdi_df.columns) and ("modality" in fdi_df.columns):
    plt.figure(figsize=(8,4))
    x = np.arange(len(fdi_df))
    plt.bar(x, fdi_df["mean_FDI_L2"].values)
    plt.xticks(x, fdi_df["modality"].values)
    plt.ylim(0, 1)
    plt.ylabel("FDI (L2-based)")
    plt.title("Field Dominance Index by modality (higher = more field-driven degradation)")
    fdi_fig = FIG_DIR / "results_agg" / "field_dominance_index.png"
    save_fig(fdi_fig)
else:
    fdi_df = None  # normalize missing

# -------------------------
# 4) Classical features summary (optional)
# -------------------------
classic_df = tables["classic_features"]
classic_summary = None
classic_fig = None
if classic_df is not None and len(classic_df) > 0:
    if not all(c in classic_df.columns for c in ["subject","modality","condition"]):
        print("[Classic] Missing subject/modality/condition columns -> skipping classical features summary.")
    else:
        feat_cols = [c for c in classic_df.columns if c.startswith("glcm_")] + ["skew","kurtosis","p10","p50","p90"]
        feat_cols = [c for c in feat_cols if c in classic_df.columns]

        if len(feat_cols) > 0:
            classic_summary = classic_df.groupby(["modality","condition"]).agg(
                n=("subject","nunique"),
                **{f"mean_{c}": (c,"mean") for c in feat_cols}
            ).reset_index()

            classic_summary_path = TAB_DIR / "results_agg" / "classical_feature_summary_by_modality_condition.csv"
            classic_summary.to_csv(classic_summary_path, index=False)

            if "glcm_contrast" in classic_df.columns:
                tmp = classic_df.groupby(["modality","condition"])["glcm_contrast"].mean().reset_index()
                piv = tmp.pivot(index="modality", columns="condition", values="glcm_contrast")
                plt.figure(figsize=(9,4))
                mods = piv.index.values
                x = np.arange(len(mods))
                conds = [c for c in ["64mT","3T_lowres","3T_highres"] if c in piv.columns]
                if len(conds) > 0:
                    for i, cond in enumerate(conds):
                        plt.bar(x + (i-(len(conds)-1)/2)*0.25, piv[cond].values, width=0.25, label=cond)
                    plt.xticks(x, mods)
                    plt.ylabel("Mean GLCM contrast")
                    plt.title("Texture: GLCM contrast by condition (higher = more texture variation)")
                    plt.legend()
                    classic_fig = FIG_DIR / "results_agg" / "glcm_contrast_by_condition.png"
                    save_fig(classic_fig)
        else:
            classic_df = None
else:
    classic_df = None

# -------------------------
# 5) Frequency-domain metrics (optional)
# -------------------------
freq_df = tables["freq"]
freq_summary = None
freq_fig = None
if freq_df is not None and len(freq_df) > 0:
    if not all(c in freq_df.columns for c in ["subject","modality","condition","spectral_slope","highfreq_fraction"]):
        print("[Freq] Missing required columns -> skipping frequency summary/plot.")
    else:
        freq_summary = freq_df.groupby(["modality","condition"]).agg(
            n=("subject","nunique"),
            mean_slope=("spectral_slope","mean"),
            mean_hf=("highfreq_fraction","mean"),
        ).reset_index()

        freq_summary_path = TAB_DIR / "results_agg" / "frequency_summary_by_modality_condition.csv"
        freq_summary.to_csv(freq_summary_path, index=False)

        piv = freq_summary.pivot(index="modality", columns="condition", values="mean_hf")
        plt.figure(figsize=(9,4))
        mods = piv.index.values
        x = np.arange(len(mods))
        conds = [c for c in ["64mT","3T_lowres","3T_highres"] if c in piv.columns]
        if len(conds) > 0:
            for i, cond in enumerate(conds):
                plt.bar(x + (i-(len(conds)-1)/2)*0.25, piv[cond].values, width=0.25, label=cond)
            plt.xticks(x, mods)
            plt.ylabel("Mean high-frequency energy fraction")
            plt.title("Frequency-domain: high-frequency content by condition")
            plt.legend()
            freq_fig = FIG_DIR / "results_agg" / "highfreq_fraction_by_condition.png"
            save_fig(freq_fig)
else:
    freq_df = None

# -------------------------
# 6) Localized drift summary (optional)
# -------------------------
loc_df = tables["loc"]
loc_summary = None
if loc_df is not None and len(loc_df) > 0:
    if not all(c in loc_df.columns for c in ["subject","modality","mean_patch_drift","p90_patch_drift"]):
        print("[Loc] Missing required columns -> skipping localized drift summary.")
    else:
        loc_summary = loc_df.groupby("modality").agg(
            n=("subject","nunique"),
            mean_patch_drift=("mean_patch_drift","mean"),
            p90_patch_drift=("p90_patch_drift","mean"),
        ).reset_index()

        loc_summary_path = TAB_DIR / "results_agg" / "localized_drift_summary_by_modality.csv"
        loc_summary.to_csv(loc_summary_path, index=False)
else:
    loc_df = None

# -------------------------
# 7) ADC run reliability (optional)
# -------------------------
adc_rel_df = tables["adc_rel"]
adc_rel_summary = None
adc_rel_fig = None
if adc_rel_df is not None and len(adc_rel_df) > 0:
    needed = ["subject","cosdist_run1_vs_run2_64mT","cosdist_run1_64mT_vs_3Tlow"]
    if not all(c in adc_rel_df.columns for c in needed):
        print("[ADC] Missing required columns -> skipping ADC reliability summary/plot.")
    else:
        adc_rel_summary = adc_rel_df.agg(
            n=("subject","nunique"),
            mean_run_run=("cosdist_run1_vs_run2_64mT","mean"),
            mean_cross=("cosdist_run1_64mT_vs_3Tlow","mean"),
        ).to_frame().T

        adc_rel_summary_path = TAB_DIR / "results_agg" / "adc_run_reliability_summary.csv"
        adc_rel_summary.to_csv(adc_rel_summary_path, index=False)

        # optional: if run2 vs 3Tlow exists, use it; else compare run1 vs 3Tlow only
        cols_cross = [c for c in ["cosdist_run1_64mT_vs_3Tlow","cosdist_run2_64mT_vs_3Tlow"] if c in adc_rel_df.columns]
        a = adc_rel_df["cosdist_run1_vs_run2_64mT"].values
        b = adc_rel_df[cols_cross].mean(axis=1).values if len(cols_cross) > 0 else adc_rel_df["cosdist_run1_64mT_vs_3Tlow"].values

        plt.figure(figsize=(6,4))
        plt.boxplot([a[np.isfinite(a)], b[np.isfinite(b)]],
                    tick_labels=["64mT run1 vs run2", "64mT vs 3T_low"],
                    showmeans=True)
        plt.ylabel("Cosine distance (lower = more similar)")
        plt.title("ADC: within-64mT variability vs cross-field drift")
        adc_rel_fig = FIG_DIR / "results_agg" / "adc_run_reliability.png"
        save_fig(adc_rel_fig)
else:
    adc_rel_df = None

# -------------------------
# 8) Bootstrap CI (optional copy)
# -------------------------
boot_df = tables["boot"]
if boot_df is not None and len(boot_df) > 0:
    boot_out = TAB_DIR / "results_agg" / "bootstrap_ci_summary.csv"
    boot_df.to_csv(boot_out, index=False)
else:
    boot_df = None

# -------------------------
# 9) Protocol optimization (optional copy)
# -------------------------
prot_opt = tables["protocol_opt"]
prot_vars = tables["protocol_vars"]
if prot_opt is not None and len(prot_opt) > 0:
    prot_opt_out = TAB_DIR / "results_agg" / "protocol_optimization_ranking.csv"
    prot_opt.to_csv(prot_opt_out, index=False)
else:
    prot_opt = None

if prot_vars is not None and len(prot_vars) > 0:
    prot_vars_out = TAB_DIR / "results_agg" / "protocol_variant_recommendations.csv"
    prot_vars.to_csv(prot_vars_out, index=False)
else:
    prot_vars = None

# -------------------------
# Build draft Results text
# -------------------------
rank_df = dist_summary.sort_values("mean_l2_lf_vs_low", ascending=True).reset_index(drop=True)
best_mod = rank_df.iloc[0]["modality"]
worst_mod = rank_df.iloc[-1]["modality"]

lines = []
lines.append("# Draft Results (auto-generated from pipeline outputs)\n")

lines.append("## 1. Global representation-level differences (ResNet50 embeddings)\n")
lines.append(
    "We compared learned embeddings across 64mT, 3T_lowres (resolution-matched), and 3T_highres. "
    "This separates field-strength effects (64mT vs 3T_lowres) from resolution effects (3T_lowres vs 3T_highres).\n"
)
lines.append(
    f"Pooled across modalities, mean L2 distance (field)={_fmt(pooled.get('meanA'))} vs "
    f"(resolution)={_fmt(pooled.get('meanB'))} with n={pooled.get('n',0)} paired samples "
    f"(Wilcoxon p={_p(pooled.get('wilcoxon_p'))}, paired t-test p={_p(pooled.get('ttest_p'))}).\n"
)
lines.append(
    f"Across modalities, **{best_mod}** was most robust (lowest mean field-effect distance), while "
    f"**{worst_mod}** showed the largest field-driven drift.\n"
)
lines.append(f"Figure: {fig1.name}\n")

lines.append("## 2. Registration / preprocessing quality control (QC)\n")
if qc_summary is None:
    lines.append("QC table was missing or lacked required columns, so QC summaries are not reported here.\n")
else:
    lines.append(
        "Across modalities, similarity between 3T_highres and 3T_lowres was higher than similarity between 64mT and 3T_lowres, "
        "supporting that observed differences are primarily driven by field strength rather than resolution.\n"
    )
    try:
        ex = qc_summary.sort_values("mean_ssim_lf_vs_low").iloc[0]
        lines.append(
            f"Example: lowest mean SSIM for 64mT vs 3T_lowres was in **{ex['modality']}** "
            f"(SSIM={_fmt(ex['mean_ssim_lf_vs_low'])}) vs SSIM={_fmt(ex['mean_ssim_high_vs_low'])} for 3T_highres vs 3T_lowres.\n"
        )
    except Exception:
        pass
    lines.append(f"Figure: {qc_fig.name if qc_fig else 'NA'}\n")

lines.append("## 3. Field Dominance Index (FDI)\n")
if fdi_df is None:
    lines.append("FDI table was missing; no FDI summaries reported.\n")
else:
    try:
        top = fdi_df.sort_values("mean_FDI_L2", ascending=False).iloc[0]
        bot = fdi_df.sort_values("mean_FDI_L2", ascending=True).iloc[0]
        lines.append(
            "FDI summarizes whether degradation is mainly due to field strength (higher) versus resolution (lower).\n"
        )
        lines.append(
            f"Highest mean FDI: **{top['modality']}** (FDI={_fmt(top['mean_FDI_L2'])}); "
            f"lowest mean FDI: **{bot['modality']}** (FDI={_fmt(bot['mean_FDI_L2'])}).\n"
        )
        lines.append(f"Figure: {fdi_fig.name if fdi_fig else 'NA'}\n")
    except Exception:
        lines.append("FDI table loaded but could not compute top/bottom summary.\n")

lines.append("## 4. Classical intensity and texture features\n")
if classic_summary is None:
    lines.append("Classical feature table was missing or unusable.\n")
else:
    lines.append("Classical features provide interpretable confirmation of degradation patterns observed in embeddings.\n")
    lines.append(f"Saved: classical_feature_summary_by_modality_condition.csv\n")
    lines.append(f"Figure: {classic_fig.name if classic_fig else 'NA'}\n")

lines.append("## 5. Frequency-domain analysis\n")
if freq_summary is None:
    lines.append("Frequency-domain metrics table was missing or unusable.\n")
else:
    lines.append("Frequency metrics quantify loss of fine detail via reduced high-frequency energy at 64mT.\n")
    lines.append(f"Saved: frequency_summary_by_modality_condition.csv\n")
    lines.append(f"Figure: {freq_fig.name if freq_fig else 'NA'}\n")

lines.append("## 6. Localized drift (spatially resolved loss)\n")
if loc_summary is None:
    lines.append("Localized drift summary table was missing or unusable.\n")
else:
    try:
        ex = loc_summary.sort_values("mean_patch_drift", ascending=False).iloc[0]
        lines.append(
            f"Patchwise drift maps show non-uniform spatial degradation; the highest mean patch drift was observed for "
            f"**{ex['modality']}** (mean={_fmt(ex['mean_patch_drift'])}).\n"
        )
    except Exception:
        lines.append("Localized drift table loaded but could not compute top modality.\n")
    lines.append("Localized drift figures saved under FIG_DIR/localized_drift/.\n")

lines.append("## 7. ADC run-to-run reliability\n")
if adc_rel_summary is None:
    lines.append("ADC reliability table missing or unusable.\n")
else:
    nsub = int(adc_rel_summary["n"].iloc[0])
    lines.append(
        f"In subjects with repeated low-field ADC (n={nsub}), mean within-64mT distance was "
        f"{_fmt(adc_rel_summary['mean_run_run'].iloc[0])}, compared to cross-field distance "
        f"{_fmt(adc_rel_summary['mean_cross'].iloc[0])}.\n"
    )
    lines.append(f"Figure: {adc_rel_fig.name if adc_rel_fig else 'NA'}\n")

lines.append("## 8. Uncertainty via bootstrap\n")
if boot_df is None:
    lines.append("Bootstrap CI table missing.\n")
else:
    lines.append("Bootstrap confidence intervals were computed to quantify uncertainty given limited paired sample size.\n")

lines.append("## 9. Protocol guidance / optimization\n")
if prot_opt is None:
    lines.append("Protocol optimization outputs missing.\n")
else:
    lines.append("Sequences were ranked by robustness to field-driven degradation to inform low-field protocol design.\n")
    try:
        topk = prot_opt.head(min(3, len(prot_opt)))
        for _, r in topk.iterrows():
            # tolerate varying column names
            mod = r.get("modality", r.get("sequence", "NA"))
            util = r.get("utility", np.nan)
            l2f = r.get("l2_field", np.nan)
            l2c = r.get("l2_combined", np.nan)
            lines.append(f"- {mod}: utility={_fmt(util)}, field L2={_fmt(l2f)}, combined L2={_fmt(l2c)}\n")
    except Exception:
        pass

results_md = "\n".join(lines)
results_path = TAB_DIR / "results_agg" / "DRAFT_RESULTS.md"
write_text(results_path, results_md)

# -------------------------
# Create a single bundle CSV for quick inspection
# -------------------------
bundle = dist_summary.copy()
bundle = bundle.merge(field_vs_res_stats[["modality","wilcoxon_p","ttest_p"]], on="modality", how="left")

if fdi_df is not None and ("mean_FDI_L2" in fdi_df.columns):
    cols = [c for c in ["modality","mean_FDI_L2","std_FDI_L2"] if c in fdi_df.columns]
    bundle = bundle.merge(fdi_df[cols], on="modality", how="left")

if qc_summary is not None:
    cols = [c for c in ["modality","mean_ssim_lf_vs_low","mean_ssim_high_vs_low","mean_corr_lf_vs_low","mean_corr_high_vs_low"] if c in qc_summary.columns]
    bundle = bundle.merge(qc_summary[cols], on="modality", how="left")

if loc_summary is not None:
    cols = [c for c in ["modality","mean_patch_drift","p90_patch_drift"] if c in loc_summary.columns]
    bundle = bundle.merge(loc_summary[cols], on="modality", how="left")

bundle_path = TAB_DIR / "results_agg" / "RESULTS_BUNDLE_by_modality.csv"
bundle.to_csv(bundle_path, index=False)

print("\n================ RESULTS AGGREGATION COMPLETE ================\n")
print("Key outputs saved:")
print(" - Draft Results text:", results_path)
print(" - Bundle table:", bundle_path)
print(" - Distance summary:", dist_summary_path)
print(" - Field-vs-resolution stats:", field_vs_res_path)
print("Figures saved:")
print(" -", fig1)
if qc_fig: print(" -", qc_fig)
if fdi_fig: print(" -", fdi_fig)
if classic_fig: print(" -", classic_fig)
if freq_fig: print(" -", freq_fig)
if adc_rel_fig: print(" -", adc_rel_fig)

print("\nOpen the draft results here:")
print(results_path)

display(bundle)


Found tables:
  - qc            : MISSING/EMPTY        -  (qc_metrics.csv)
  - dist          : OK                3x8  (paired_embedding_distances.csv)
  - paired_stats  : OK                2x6  (paired_stats_summary.csv)
  - protocol_guidance: OK                1x8  (protocol_guidance_summary.csv)
  - classic_features: MISSING/EMPTY        -  (classical_features.csv)
  - freq          : OK               30x5  (frequency_domain_metrics.csv)
  - fdi           : OK                1x6  (field_dominance_index.csv)
  - loc           : OK                8x4  (localized_patch_drift_summary.csv)
  - boot          : OK               1x14  (bootstrap_ci_summary.csv)
  - adc_rel       : MISSING/EMPTY        -  (adc_run_reliability.csv)
  - rep_sim       : OK                1x5  (representation_similarity_summary.csv)
  - protocol_opt  : OK               1x10  (protocol_optimization_ranking.csv)
  - protocol_vars : OK                1x5  (protocol_variant_recommendations.csv)


Key outputs saved:
 

Unnamed: 0,modality,n,mean_cos_lf_vs_low,mean_cos_low_vs_high,mean_cos_lf_vs_high,mean_l2_lf_vs_low,mean_l2_low_vs_high,mean_l2_lf_vs_high,std_l2_lf_vs_low,std_l2_low_vs_high,wilcoxon_p,ttest_p,mean_FDI_L2,std_FDI_L2,mean_patch_drift,p90_patch_drift
0,ADC,3,0.878605,0.972426,0.882825,5.669641,2.457192,5.5758,0.278714,0.486971,0.25,0.016279,0.698648,0.049668,0.188003,0.311258
