In [None]:
# FYR-
"""
Volumetric tumor burden metrics from PET (SUV) + 3D binary mask(s).

Outputs per patient:
- MTV_mL: Molecular Tumor Volume (sum of lesion volumes) in mL
- TLA_mL: Total Lesion Activity = SUVmean * MTV (mL) [unitless SUV convention]
- TLF_% : Total Lesion Fraction = 100 * TLA / BodyVolume_BW (unitless %)
- LesionCount
- SUVmean_global, SUVmax_global (over all lesion voxels)
- Dmax_mm: max centroid-to-centroid distance between lesions (mm)
- (Optional passthrough): Age

Also writes a per-lesion CSV with: lesion label, volume (mL), centroid (mm),
SUVmean, SUVmax, and lesion TLA (mL).

Assumptions:
- PET volume is in SUV units (already normalized, typically by body weight).
- Mask and PET are in the same space/grid.
- Body volume (mL) ≈ 1000 * body_mass_kg (density ~1 g/mL).
- All patients are male; no height information (LBW not computed).

You can keep the fixed voxel size override (VOXEL_SIZE_MM), or set to None to
use NIfTI header spacing.
"""

import os
import re
import csv
import glob
import math
import numpy as np
import nibabel as nib
import pandas as pd  # for reading anthropometrics from Excel

from typing import Optional, Dict, Tuple, List
from scipy.spatial import distance
from scipy.ndimage import distance_transform_edt
from skimage import util
from skimage.measure import label, regionprops

# ------------------------
# Configuration
# ------------------------
# If you want to force a fixed voxel size, set VOXEL_SIZE_MM to a 3-tuple.
# To use header spacing instead, set VOXEL_SIZE_MM = None.
VOXEL_SIZE_MM: Optional[Tuple[float, float, float]] = (4.07283, 4.07283, 4.07283)  # or None

SEG_FOLDER = "nifti_output_mask_anonymized"   # folder with 3D binary masks
PET_FOLDER = "nifti_pet_suv"                  # folder with PET NIfTI (SUV)
ANTHRO_XLSX = "clinical_wiith_tmtv_dmax_shared.xlsx"  # Excel with PID, weight (kg), (optional) age
ANTHRO_SHEET = 0                               # sheet index or name

OUT_SUMMARY_CSV = "tumor_burden_summary.csv"
OUT_LESIONS_CSV = "tumor_burden_lesions.csv"

# Accept 1234_MASK.nii.gz or 1234_SEG.nii.gz etc.
MASK_PATTERNS = [
    re.compile(r"^(\d+)_MASK\.nii(\.gz)?$", re.IGNORECASE),
    re.compile(r"^(\d+)_SEG\.nii(\.gz)?$",  re.IGNORECASE),
    re.compile(r"^(\d+)\.nii(\.gz)?$",      re.IGNORECASE),  # fallback: bare PID
]

# PET file name search templates per PID (in PET_FOLDER)
PET_GLOBS = [
    "{pid}_PET.nii.gz", "{pid}_PET.nii", "{pid}_SUV.nii.gz", "{pid}_SUV.nii",
    "{pid}.nii.gz", "{pid}.nii", "*{pid}*PET*.nii.gz", "*{pid}*PET*.nii",
    "*{pid}*SUV*.nii.gz", "*{pid}*SUV*.nii"
]


# ------------------------
# Utilities
# ------------------------

def maybe_pid(fname: str) -> Optional[str]:
    for pat in MASK_PATTERNS:
        m = pat.match(fname)
        if m:
            return m.group(1)
    return None


def find_pet_for_pid(pid: str, folder: str) -> Optional[str]:
    for pat in PET_GLOBS:
        for path in glob.glob(os.path.join(folder, pat.format(pid=pid))):
            if os.path.isfile(path):
                return path
    return None


def get_spacing_mm(img_nii: nib.Nifti1Image,
                   override: Optional[Tuple[float, float, float]] = VOXEL_SIZE_MM
                  ) -> Tuple[float, float, float]:
    if override is not None:
        return tuple(float(x) for x in override)
    hdr_sp = img_nii.header.get_zooms()[:3]
    return tuple(float(x) for x in hdr_sp)


def to_zyx_spacing(xyz_spacing: Tuple[float, float, float]) -> Tuple[float, float, float]:
    # skimage regionprops & EDT use array axis order (z, y, x)
    x, y, z = xyz_spacing
    return (z, y, x)


def compute_mtv_mL(mask: np.ndarray, spacing_xyz_mm: Tuple[float, float, float]) -> float:
    voxel_vol_mm3 = spacing_xyz_mm[0] * spacing_xyz_mm[1] * spacing_xyz_mm[2]
    mm3 = int(mask.sum()) * voxel_vol_mm3
    return float(mm3) / 1000.0  # mL


def compute_regionprops(mask: np.ndarray) -> Tuple[np.ndarray, List]:
    cc = util.img_as_ubyte(mask) > 0
    lab = label(cc, connectivity=cc.ndim)  # 26-connectivity in 3D
    props = regionprops(lab)
    return lab, props


def compute_dmax_mm(props: List, spacing_xyz_mm: Tuple[float, float, float]) -> float:
    if len(props) == 0:
        return 0.0
    spacing_zyx = to_zyx_spacing(spacing_xyz_mm)

    if len(props) == 1:
        comp_mask = (props[0].image)  # local cropped mask
        dt = distance_transform_edt(comp_mask.astype(bool), sampling=spacing_zyx)
        return float(dt.max())

    # Multiple lesions: centroid distances in physical units
    cents_mm = []
    for p in props:
        cz, cy, cx = p.centroid  # (z, y, x)
        cents_mm.append(np.array([cz * spacing_zyx[0], cy * spacing_zyx[1], cx * spacing_zyx[2]], dtype=float))

    dmax = 0.0
    for i in range(len(cents_mm)):
        for j in range(i + 1, len(cents_mm)):
            d = float(distance.euclidean(cents_mm[i], cents_mm[j]))
            if d > dmax:
                dmax = d
    return dmax


def lesion_metrics(mask_labeled: np.ndarray,
                   props: List,
                   pet_suv: Optional[np.ndarray],
                   spacing_xyz_mm: Tuple[float, float, float]):
    """
    Returns list of per-lesion dicts: label, voxels, volume_mL, centroid_mm (x,y,z),
    SUVmean, SUVmax, TLA_mL.
    """
    spacing_zyx = to_zyx_spacing(spacing_xyz_mm)
    voxel_vol_mL = (spacing_xyz_mm[0] * spacing_xyz_mm[1] * spacing_xyz_mm[2]) / 1000.0

    lesions = []
    for p in props:
        lbl = int(p.label)
        voxels = int(p.area)
        vol_mL = voxels * voxel_vol_mL

        cz, cy, cx = p.centroid
        centroid_mm = (
            float(cx * spacing_zyx[2]),
            float(cy * spacing_zyx[1]),
            float(cz * spacing_zyx[0]),
        )

        suv_mean = None
        suv_max = None
        tla_mL = None

        if pet_suv is not None:
            lesion_mask = (mask_labeled == lbl)
            lesion_vals = pet_suv[lesion_mask]
            if lesion_vals.size > 0:
                suv_mean = float(np.nanmean(lesion_vals))
                suv_max = float(np.nanmax(lesion_vals))
                # TLA under unitless-SUV convention = SUVmean * MTV(mL)
                tla_mL = float(suv_mean * vol_mL)

        lesions.append({
            "label": lbl,
            "voxels": voxels,
            "Volume_mL": vol_mL,
            "CentroidX_mm": centroid_mm[0],
            "CentroidY_mm": centroid_mm[1],
            "CentroidZ_mm": centroid_mm[2],
            "SUVmean": suv_mean,
            "SUVmax": suv_max,
            "TLA_mL": tla_mL,
        })
    return lesions


def global_pet_stats_over_mask(pet_suv: np.ndarray, mask_bool: np.ndarray) -> Tuple[Optional[float], Optional[float]]:
    vals = pet_suv[mask_bool > 0]
    if vals.size == 0:
        return None, None
    return float(np.nanmean(vals)), float(np.nanmax(vals))


# --- BW-only body volume (all male, no height) ---

def compute_body_volume_mL_BW(weight_kg: Optional[float]) -> Optional[float]:
    """
    Compute body volume in mL assuming 1 kg ≈ 1000 mL.
    Returns None if weight is missing.
    """
    if weight_kg is None:
        return None
    return 1000.0 * float(weight_kg)


def load_anthro_table_excel(path_xlsx: Optional[str],
                            sheet=0,
                            pid_candidates=("PID", "pid", "Id", "id", "PatientID", "Patient_Id"),
                            weight_candidates=("weight_kg", "Weight_kg", "weight (kg)", "Weight (kg)", "Weight", "weight"),
                            age_candidates=("Age", "age", "PatientAge", "patient_age")) -> Dict[str, Dict]:
    """
    Returns mapping PID -> dict(weight_kg, age).
    Auto-detects PID, weight, and age column names from common variants.
    """
    out = {}
    if not path_xlsx or not os.path.isfile(path_xlsx):
        return out

    df = pd.read_excel(path_xlsx, sheet_name=sheet)

    def find_col(candidates):
        for c in candidates:
            if c in df.columns:
                return c
        # try case-insensitive match
        lower_map = {c.lower(): c for c in df.columns}
        for c in candidates:
            if c.lower() in lower_map:
                return lower_map[c.lower()]
        return None

    pid_col = find_col(pid_candidates)
    wt_col  = find_col(weight_candidates)
    age_col = find_col(age_candidates)

    if pid_col is None:
        raise ValueError("Could not find PID column in anthropometrics Excel.")
    # weight can be missing per-row; we still proceed.

    for _, row in df.iterrows():
        pid = str(row.get(pid_col)).strip()
        if not pid or pid.lower() in ("nan", "none"):
            continue
        w = row.get(wt_col) if wt_col in df.columns else None
        a = row.get(age_col) if age_col in df.columns else None
        try:
            w_float = float(w) if w is not None and str(w).strip() not in ("", "nan", "None") else None
        except Exception:
            w_float = None
        try:
            a_float = float(a) if a is not None and str(a).strip() not in ("", "nan", "None") else None
        except Exception:
            a_float = None

        out[pid] = {"weight_kg": w_float, "age": a_float}
    return out


# ------------------------
# Main batch
# ------------------------

def process_one(pid: str,
                seg_path: str,
                pet_path: Optional[str],
                anthro: Dict[str, Dict]):
    # Load mask
    seg_nii = nib.load(seg_path)
    spacing_xyz_mm = get_spacing_mm(seg_nii, VOXEL_SIZE_MM)
    mask = np.asarray(seg_nii.dataobj)
    mask = (mask > 0).astype(np.uint8)

    # Optionally load PET (SUV)
    pet = None
    if pet_path and os.path.isfile(pet_path):
        pet_nii = nib.load(pet_path)
        pet = np.asarray(pet_nii.dataobj).astype(np.float32)
        if pet.shape != mask.shape:
            raise ValueError(f"PET shape {pet.shape} does not match mask shape {mask.shape} for PID {pid}.")

    # MTV
    mtv_mL = compute_mtv_mL(mask, spacing_xyz_mm)

    # Label & lesion-wise
    lab, props = compute_regionprops(mask)
    lesions = lesion_metrics(lab, props, pet, spacing_xyz_mm)

    # Global PET stats & TLA
    suvmean_global, suvmax_global = (None, None)
    tla_mL = None
    if pet is not None:
        suvmean_global, suvmax_global = global_pet_stats_over_mask(pet, mask)
        # TLA equals sum over lesions (SUVmean_lesion * MTV_lesion)
        if len(lesions) > 0:
            tla_vals = [l["TLA_mL"] for l in lesions if l["TLA_mL"] is not None]
            tla_mL = float(np.nansum(tla_vals)) if len(tla_vals) > 0 else None

    # Dmax
    dmax_mm = compute_dmax_mm(props, spacing_xyz_mm)

    # Anthropometrics & TLF (BW only)
    a = anthro.get(pid, {})
    weight_kg = a.get("weight_kg")
    age = a.get("age")

    body_vol_bw_mL = compute_body_volume_mL_BW(weight_kg)

    tlf_bw_pct = None
    if tla_mL is not None and body_vol_bw_mL is not None and body_vol_bw_mL > 0:
        tlf_bw_pct = 100.0 * float(tla_mL) / float(body_vol_bw_mL)

    # Per-lesion rows (add PID)
    for l in lesions:
        l["PID"] = pid

    summary = {
        "PID": pid,
        "Age": age,
        "LesionCount": len(props),
        "MTV_mL": mtv_mL,
        "TLA_mL": tla_mL,
        "TLF_%": tlf_bw_pct,                 # BW-only
        "SUVmean_global": suvmean_global,
        "SUVmax_global": suvmax_global,
        "Dmax_mm": dmax_mm,
        "BodyVol_BW_mL": body_vol_bw_mL,     # for transparency/debugging
        "UsedSpacingX_mm": spacing_xyz_mm[0],
        "UsedSpacingY_mm": spacing_xyz_mm[1],
        "UsedSpacingZ_mm": spacing_xyz_mm[2],
        "PET_file": pet_path,
        "Mask_file": seg_path,
    }

    return summary, lesions


def main():
    # Load anthropometrics (age, weight) from Excel
    anthro = load_anthro_table_excel(ANTHRO_XLSX, sheet=ANTHRO_SHEET)

    summaries = []
    lesion_rows = []

    for fname in sorted(os.listdir(SEG_FOLDER)):
        pid = maybe_pid(fname)
        if not pid:
            continue
        seg_path = os.path.join(SEG_FOLDER, fname)
        pet_path = find_pet_for_pid(pid, PET_FOLDER)

        try:
            summary, lesions = process_one(pid, seg_path, pet_path, anthro)
        except Exception as e:
            print(f"[WARN] PID {pid} failed: {e}")
            # Minimal placeholder
            summary = {
                "PID": pid,
                "Age": anthro.get(pid, {}).get("age"),
                "LesionCount": None,
                "MTV_mL": None,
                "TLA_mL": None,
                "TLF_%": None,
                "SUVmean_global": None,
                "SUVmax_global": None,
                "Dmax_mm": None,
                "BodyVol_BW_mL": None,
                "UsedSpacingX_mm": None,
                "UsedSpacingY_mm": None,
                "UsedSpacingZ_mm": None,
                "PET_file": pet_path,
                "Mask_file": seg_path,
            }
            lesions = []

        summaries.append(summary)
        lesion_rows.extend(lesions)

    # Write CSVs
    sum_fields = [
        "PID", "Age", "LesionCount", "MTV_mL", "TLA_mL", "TLF_%",
        "SUVmean_global", "SUVmax_global", "Dmax_mm",
        "BodyVol_BW_mL",
        "UsedSpacingX_mm", "UsedSpacingY_mm", "UsedSpacingZ_mm",
        "PET_file", "Mask_file",
    ]
    with open(OUT_SUMMARY_CSV, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=sum_fields)
        w.writeheader()
        for r in summaries:
            w.writerow(r)

    lesion_fields = [
        "PID", "label", "voxels", "Volume_mL",
        "CentroidX_mm", "CentroidY_mm", "CentroidZ_mm",
        "SUVmean", "SUVmax", "TLA_mL",
    ]
    with open(OUT_LESIONS_CSV, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=lesion_fields)
        w.writeheader()
        for r in lesion_rows:
            w.writerow(r)

    print(f"Done. Wrote {len(summaries)} patients to {OUT_SUMMARY_CSV} and {len(lesion_rows)} lesions to {OUT_LESIONS_CSV}.")


if __name__ == "__main__":
    main()
