In [None]:
#!/usr/bin/env python3
"""
BLOCK 1 ‚Äî FULL DATA BUILD (Standalone)
======================================
End-to-end feature build from databases. Produces everything later blocks need
WITHOUT depending on prior outputs.

What this builds
----------------
‚Ä¢ Members (patient_id) with demographics
‚Ä¢ ADT events (ed_visit/admission) with Member‚Üípatient_id mapping
‚Ä¢ Interventions from EncounterNote *that actually occurred* (encounterOccurred='YES'),
  restricted to *engaged* patients by Patient.status ‚àà ENGAGED_STATUSES
‚Ä¢ Diagnoses free-text ‚Üí clinical categories ‚Üí per-patient dx_cat_* features
‚Ä¢ Leak‚Äëfree features and binary outcomes for 7/30/90/180 days from a cutoff
‚Ä¢ Helper artifacts for downstream blocks

Outputs (written to ./publication_outputs/ unless noted)
--------------------------------------------------------
‚Ä¢ feature_matrix.parquet  (and mirror features_df.parquet)
‚Ä¢ feature_meta.json  (+ artifacts/feature_cols.json)
‚Ä¢ interventions_df.csv   (occurred YES, engaged only)
‚Ä¢ adt_events_clean.csv   (debug/QA optional)
‚Ä¢ signal_risks.csv       (if available)
‚Ä¢ run_meta.json          (provenance)

Notes
-----
‚Ä¢ Preserves your complex ID mapping: Core Member.id ‚Üí Member.external_identifier (patient_id)
‚Ä¢ Defaults cutoff_date to 2024-06-01 (change below if needed)
"""
from __future__ import annotations
import os, sys, json, re
from pathlib import Path
from datetime import datetime
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd

# ---------------------------------------------------------------------------
# Waymark DB setup
# ---------------------------------------------------------------------------
sagemaker_lib = os.path.expanduser("~/sagemaker-lib")
if sagemaker_lib not in sys.path:
    sys.path.insert(0, sagemaker_lib)
import waymark  # provided by your environment

CORE_ENGINE = waymark.get_waymark_core_db_engine()
LH_ENGINE   = waymark.get_lighthouse_db_engine()

OUTPUT_DIR = Path("publication_outputs"); OUTPUT_DIR.mkdir(exist_ok=True)
ART_DIR    = Path("artifacts"); ART_DIR.mkdir(exist_ok=True)

# Engaged statuses for Patient.status
ENGAGED_STATUSES = [
    "ACTIVATED", "IN_CONTACT", "GRADUATED", "ONBOARDED",
    "MODERATE", "PREGRADUATION", "MAINTENANCE", "HIGH"
]

# Cutoff for temporal split (no leakage)
CUTOFF_DATE = pd.Timestamp("2024-06-01")

# ---------------------------------------------------------------------------
# Utilities
# ---------------------------------------------------------------------------

def _read_sql(sql: str, engine) -> pd.DataFrame:
    with engine.connect() as conn:
        return pd.read_sql(sql, conn)


def _member_map() -> pd.DataFrame:
    sql = '''
        SELECT "id" AS member_id, "external_identifier" AS patient_id
        FROM "Member"
        WHERE "external_identifier" IS NOT NULL AND "external_identifier" <> ''
    '''
    df = _read_sql(sql, CORE_ENGINE)
    df["member_id"] = df["member_id"].astype(str)
    df["patient_id"] = df["patient_id"].astype(str)
    return df


# ---------------------------------------------------------------------------
# 1) Members (demographics)
# ---------------------------------------------------------------------------

def extract_members() -> pd.DataFrame:
    print("\nüìä Extracting members‚Ä¶")
    sql = '''
        SELECT 
            "external_identifier" AS patient_id,
            "firstName", "lastName", "gender", "birthDate",
            "ethnicity", "createdAt", "updatedAt"
        FROM "Member"
        WHERE "birthDate" IS NOT NULL
          AND "gender" IS NOT NULL
          AND "external_identifier" IS NOT NULL
    '''
    df = _read_sql(sql, CORE_ENGINE)
    df = df.drop_duplicates(subset=["patient_id"]).copy()
    df["patient_id"] = df["patient_id"].astype(str)
    print(f"  ‚úÖ Members: {len(df):,}")
    return df


# ---------------------------------------------------------------------------
# 2) ADT events ‚Üí patient_id
# ---------------------------------------------------------------------------

def extract_adt_events(members_df: pd.DataFrame) -> pd.DataFrame:
    print("\nüè• Extracting ADT events‚Ä¶")
    sql = '''
        SELECT 
            evn."recordedDateTime" AS event_datetime,
            evn."eventTypeCode"    AS event_type_code,
            mast."memberId"        AS member_id
        FROM "AdtEVN" evn
        JOIN "AdmissionDischargeTransferMaster" mast ON evn."adtMasterId" = mast.id
        WHERE evn."recordedDateTime" BETWEEN '2022-01-01' AND '2024-12-31'
          AND evn."eventTypeCode" IN ('A01','A04','A08')
        ORDER BY evn."recordedDateTime"'''
    ev = _read_sql(sql, CORE_ENGINE)
    if ev.empty:
        print("  ‚ö†Ô∏è No ADT rows.")
        return pd.DataFrame(columns=["patient_id","event_datetime","eventType"])

    # Map to patient_id
    mm = _member_map()
    ev["member_id"] = ev["member_id"].astype(str)
    ev = ev.merge(mm, on="member_id", how="inner")

    # Normalize event category
    def _cat(code):
        if code in ("A01","A04"): return "admission"
        if code == "A08": return "ed_visit"
        return "other"
    ev["eventType"] = ev["event_type_code"].apply(_cat)
    ev = ev[(ev["eventType"].isin(["admission","ed_visit"]))].copy()

    # Coerce datetime and keep only members present in members_df
    ev["event_datetime"] = pd.to_datetime(ev["event_datetime"], errors="coerce")
    ev = ev[ev["patient_id"].isin(members_df["patient_id"])].copy()
    ev = ev.sort_values("event_datetime").reset_index(drop=True)
    print(f"  ‚úÖ ADT events: {len(ev):,} across {ev['patient_id'].nunique():,} patients")
    return ev[["patient_id","event_datetime","eventType"]]


# ---------------------------------------------------------------------------
# 3) Engaged patients & occurred encounters (interventions)
# ---------------------------------------------------------------------------

def engaged_patient_ids() -> set[str]:
    sql = 'SELECT "id" AS lighthouse_patient_id, "status" FROM "Patient" WHERE "status" IN ({})'.format(
        ",".join([f"'{s}'" for s in ENGAGED_STATUSES])
    )
    pats = _read_sql(sql, LH_ENGINE)
    if pats.empty:
        print("  ‚ö†Ô∏è No engaged patients by status.")
        return set()
    mm = _member_map()
    pats["lighthouse_patient_id"] = pats["lighthouse_patient_id"].astype(str)
    mm["member_id"] = mm["member_id"].astype(str)
    joined = pats.merge(mm, left_on="lighthouse_patient_id", right_on="member_id", how="left")
    engaged = set(joined["patient_id"].dropna().astype(str))
    print(f"  ‚úÖ Engaged patients: {len(engaged):,}")
    return engaged


def extract_interventions_occurred_engaged(engaged_ids: set[str]) -> pd.DataFrame:
    print("\nüíä Extracting occurred encounters for engaged patients‚Ä¶")
    sql = '''
        SELECT en."id" AS encounter_id,
               en."patientId" AS lighthouse_patient_id,
               en."dateOfEncounter" AS encounter_datetime,
               en."note" AS encounter_note,
               en."encounterOccurred"
        FROM "EncounterNote" en
        WHERE en."published" = true
          AND en."deleted"   = false
          AND en."encounterOccurred" = 'YES'
        ORDER BY en."dateOfEncounter"'''
    en = _read_sql(sql, LH_ENGINE)
    if en.empty:
        print("  ‚ö†Ô∏è No occurred encounters.")
        return pd.DataFrame(columns=["patient_id","encounter_datetime","encounter_note"])

    # Map lighthouse ‚Üí patient_id
    mm = _member_map()
    en["lighthouse_patient_id"] = en["lighthouse_patient_id"].astype(str)
    mm["member_id"] = mm["member_id"].astype(str)
    en = en.merge(mm, left_on="lighthouse_patient_id", right_on="member_id", how="left")
    en = en.dropna(subset=["patient_id"]).copy()

    # Keep only engaged
    if engaged_ids:
        en = en[en["patient_id"].astype(str).isin(engaged_ids)].copy()

    en["encounter_datetime"] = pd.to_datetime(en["encounter_datetime"], errors="coerce")
    out = en[["patient_id","encounter_datetime","encounter_note"]].sort_values(["patient_id","encounter_datetime"]).reset_index(drop=True)
    print(f"  ‚úÖ Interventions kept: {len(out):,} across {out['patient_id'].nunique():,} patients")
    return out


# ---------------------------------------------------------------------------
# 4) Diagnoses free-text ‚Üí categories ‚Üí per-patient features
# ---------------------------------------------------------------------------
_DX_CATS = {
    "renal":        [r"\brenal\b", r"kidney", r"neph"],
    "infection":    [r"infection", r"sepsis", r"uti\b", r"pneumonia", r"cellulitis", r"abscess"],
    "injury":       [r"injur", r"fracture", r"sprain", r"lacerat", r"contusion", r"burn"],
    "respiratory":  [r"asthma", r"copd", r"respir", r"bronch", r"sob", r"dyspnea"],
    "cardiac":      [r"card", r"mi\b", r"myocard", r"afib", r"arrhythm", r"chf", r"heart"],
    "neuro":        [r"neuro", r"stroke", r"cva\b", r"seizure", r"tia\b", r"migraine", r"headache"],
    "psych":        [r"psych", r"depress", r"anx", r"bipolar", r"schizo", r"suicid"],
    "pain":         [r"pain", r"ache", r"cramp", r"spasm"],
    "pregnancy":    [r"pregnan", r"ob\b", r"obstet", r"gyn\b", r"miscar", r"prenatal"],
    "gi":           [r"nausea", r"vomit|emesis", r"diarrhea", r"abdom|abd\b|belly", r"gi\b", r"gastr|ulcer"],
    "gu":           [r"urinar|dysuria|hematur|pyel|prostat", r"gu\b"],
    "endocrine":    [r"diabet|hypergly|hypogly|thyroid|adrenal|endocr"],
    "oncology":     [r"cancer|onc|tumor|malignan|neoplasm"],
    "trauma":       [r"mvc|mva|gunshot|stab|assault|fall\b|trauma"],
    "skin_wound":   [r"rash|dermat|wound|ulcer|decub|psoria|eczema"],
    "substance":    [r"etoh|alcohol|opioid|overdose|cocaine|meth|fentanyl|substance"],
    "social":       [r"homeless|housing|transport|food|utility|childcare|violence|safety"],
    "dental":       [r"dental|tooth|teeth|oral|abscess"],
    "other":        [r"ekg|tbd|other|unspecified|unknown"],
}
_DX_RX = {k: re.compile("|".join(v), re.I) for k, v in _DX_CATS.items()}

def build_dx_features() -> pd.DataFrame:
    print("\nüß¨ Loading Diagnoses ‚Üí categories‚Ä¶")
    dx = _read_sql('SELECT "patientId","description","createdAt" FROM "Diagnoses"', LH_ENGINE)
    if dx.empty:
        print("  ‚ö†Ô∏è No Diagnoses rows.")
        return pd.DataFrame(columns=["patient_id"])  # empty join

    mm = _member_map()
    dx["patientId"] = dx["patientId"].astype(str)
    mm["member_id"]  = mm["member_id"].astype(str)
    dx = dx.merge(mm, left_on="patientId", right_on="member_id", how="left")
    dx = dx.dropna(subset=["patient_id"]).copy()

    def cats_for(text: str) -> list[str]:
        s = str(text) if isinstance(text, str) else ""
        s = re.sub(r"[\\/;|]+", ",", s)
        hits = [k for k, rx in _DX_RX.items() if rx.search(s)]
        return sorted(set(hits))

    dx["_cats"] = dx["description"].apply(cats_for)
    rows = []
    for pid, cats in zip(dx["patient_id"].astype(str), dx["_cats"]):
        for c in cats:
            rows.append((pid, c))
    if not rows:
        print("  ‚ö†Ô∏è No recognizable dx categories.")
        return pd.DataFrame(columns=["patient_id"])  # empty join

    df = pd.DataFrame(rows, columns=["patient_id","category"])\
           .groupby(["patient_id","category"]).size().reset_index(name="count")
    piv = df.pivot(index="patient_id", columns="category", values="count").fillna(0).astype(int)
    piv.columns = [f"dx_cat_{c}_count" for c in piv.columns]
    flags = (piv > 0).astype(int)
    flags.columns = [c.replace("_count","_any") for c in piv.columns]
    out = pd.concat([piv, flags], axis=1).reset_index()
    out["dx_any_count"] = out[[c for c in out.columns if c.endswith("_count")]].sum(axis=1)
    print(f"  ‚úÖ Dx features built for {len(out):,} patients")
    return out


# ---------------------------------------------------------------------------
# 5) Leak‚Äëfree features + outcomes
# ---------------------------------------------------------------------------

def build_features(members: pd.DataFrame, adt: pd.DataFrame, interventions: pd.DataFrame, engaged_ids: set[str], dx_feat: pd.DataFrame) -> pd.DataFrame:
    print("\nüß± Building leak‚Äëfree features + outcomes‚Ä¶")
    feat = members.copy()

    # Demographics
    feat["birthDate"] = pd.to_datetime(feat["birthDate"], errors="coerce")
    feat = feat.dropna(subset=["birthDate"]).copy()
    feat["age_years"] = (CUTOFF_DATE - feat["birthDate"]).dt.days / 365.25
    feat = feat[(feat["age_years"] >= 0) & (feat["age_years"] <= 120)].copy()
    # Gender encoding tolerant to values like 'M'/'F' or 'Male'/'Female'
    g = feat["gender"].astype(str).str.upper().str[0]
    feat["gender_male"] = (g == "M").astype(int)
    feat["gender_female"] = (g == "F").astype(int)

    # Utilization from historical ADT
    adt_h = adt[adt["event_datetime"] < CUTOFF_DATE].copy()
    adt_h = adt_h[adt_h["patient_id"].isin(feat["patient_id"])].copy()
    util = adt_h.groupby("patient_id").agg(
        total_events=("event_datetime","count"),
        unique_event_days=("event_datetime", lambda x: x.dt.date.nunique()),
        admission_count=("eventType", lambda s: (s == "admission").sum()),
        ed_visit_count=("eventType", lambda s: (s == "ed_visit").sum()),
    ).reset_index()
    feat = feat.merge(util, on="patient_id", how="left")
    for c in ["total_events","unique_event_days","admission_count","ed_visit_count"]:
        feat[c] = feat[c].fillna(0).astype(int)

    # Historical interventions count (occurred encounters)
    if not interventions.empty:
        inter_h = interventions[interventions["encounter_datetime"] < CUTOFF_DATE].copy()
        inter_cnt = inter_h.groupby("patient_id").size().reset_index(name="intervention_history_count")
        feat = feat.merge(inter_cnt, on="patient_id", how="left")
        feat["intervention_history_count"] = feat["intervention_history_count"].fillna(0).astype(int)
    else:
        feat["intervention_history_count"] = 0

    # is_engaged flag
    feat["is_engaged"] = feat["patient_id"].astype(str).isin(engaged_ids)

    # Risk score toy feature
    feat["risk_score_basic"] = (
        0.3 * (feat["age_years"] / 100.0) +
        0.4 * np.log1p(feat["total_events"]) +
        0.3 * np.log1p(feat["intervention_history_count"])
    )

    # Outcomes from future ADT windows
    adt_f = adt[(adt["event_datetime"] >= CUTOFF_DATE)].copy()
    for w in (7, 30, 90, 180):
        end = CUTOFF_DATE + pd.Timedelta(days=w)
        win = adt_f[(adt_f["event_datetime"] <= end)]
        had = set(win["patient_id"].unique())
        feat[f"outcome_{w}d"] = feat["patient_id"].isin(had).astype(int)
        first = win.groupby("patient_id")["event_datetime"].min()
        feat[f"time_to_event_{w}d"] = w
        # assign observed times
        idx = feat["patient_id"].map(first).dropna().index
        feat.loc[idx, f"time_to_event_{w}d"] = (
            feat.loc[idx, "patient_id"].map(first) - CUTOFF_DATE
        ).dt.days.values
        print(f"  ‚Ä¢ {w}d event rate: {feat[f'outcome_{w}d'].mean():.3f}")

    # Merge dx features
    if not dx_feat.empty:
        feat = feat.merge(dx_feat, on="patient_id", how="left")
        for c in [c for c in feat.columns if c.startswith("dx_cat_") or c == "dx_any_count"]:
            feat[c] = feat[c].fillna(0)

    print(f"  ‚úÖ Feature matrix shape: {feat.shape}")
    return feat


# ---------------------------------------------------------------------------
# 6) Optional signal_risks snapshot (Core)
# ---------------------------------------------------------------------------

def extract_signal_risks() -> pd.DataFrame:
    try:
        df = _read_sql('SELECT * FROM "signal_risks"', CORE_ENGINE)
        if not df.empty:
            out = OUTPUT_DIR/"signal_risks.csv"
            df.to_csv(out, index=False)
            print(f"üíæ Saved signal_risks ‚Üí {out}")
        return df
    except Exception as e:
        print(f"  ‚ö†Ô∏è signal_risks load skipped: {e}")
        return pd.DataFrame()


# ---------------------------------------------------------------------------
# 7) Runner
# ---------------------------------------------------------------------------

def main():
    print("üöÄ BLOCK 1 ‚Äî FULL DATA BUILD")
    members = extract_members()
    adt     = extract_adt_events(members)
    engaged = engaged_patient_ids()
    inter   = extract_interventions_occurred_engaged(engaged)
    dx_feat = build_dx_features()

    # Build features
    features = build_features(members, adt, inter, engaged, dx_feat)

    # Save outputs
    out_fp = OUTPUT_DIR/"feature_matrix.parquet"
    features.to_parquet(out_fp, index=False)
    features.to_parquet(OUTPUT_DIR/"features_df.parquet", index=False)  # helper mirror
    print(f"üíæ Saved feature matrix ‚Üí {out_fp}")

    # feature_cols for modeling
    drop = set(["patient_id","firstName","lastName","birthDate","gender","ethnicity","createdAt","updatedAt","is_engaged"]) \
         | {f"outcome_{w}d" for w in (7,30,90,180)} \
         | {f"time_to_event_{w}d" for w in (7,30,90,180)}
    feature_cols = [c for c in features.columns if c not in drop]
    (OUTPUT_DIR/"feature_meta.json").write_text(json.dumps({"feature_cols": feature_cols}, indent=2))
    (ART_DIR/"feature_cols.json").write_text(json.dumps({"feature_cols": feature_cols}, indent=2))
    print(f"üß© feature_cols count: {len(feature_cols)} (saved to feature_meta.json + artifacts/feature_cols.json)")

    # Save interventions (occurred, engaged) for downstream blocks
    inter.to_csv(OUTPUT_DIR/"interventions_df.csv", index=False)
    print("üíæ Saved interventions ‚Üí publication_outputs/interventions_df.csv")

    # Optional debug exports
    adt.to_csv(OUTPUT_DIR/"adt_events_clean.csv", index=False)

    # signal_risks snapshot
    extract_signal_risks()

    # Run meta
    meta = {
        "built_at": datetime.utcnow().isoformat() + "Z",
        "cutoff_date": str(CUTOFF_DATE.date()),
        "rows_features": int(len(features)),
        "rows_adt": int(len(adt)),
        "rows_interventions": int(len(inter)),
        "engaged_count": int(len(engaged)),
    }
    (OUTPUT_DIR/"run_meta.json").write_text(json.dumps(meta, indent=2))
    print("\n‚úÖ Block 1 complete ‚Äî ready for Blocks 2‚Äì5.")


if __name__ == "__main__":
    main()


In [None]:
#!/usr/bin/env python3
"""
Mini‚ÄëBlock ‚Äî Table 1 (Baseline Cohorts)
=======================================
Reads the feature matrix and optional interventions list, then writes
baseline characteristics for (a) Overall TTE cohort, (b) Engaged subset,
(c) Engaged & Treated subset.

Inputs (from ./publication_outputs):
  ‚Ä¢ feature_matrix.parquet (required)
  ‚Ä¢ interventions_df.csv   (optional)

Outputs (to ./publication_outputs):
  ‚Ä¢ table1_overall_cohort.csv
  ‚Ä¢ table1_engaged_subset.csv
  ‚Ä¢ table1_engaged_treated_subset.csv
  ‚Ä¢ table1_baseline_characteristics.xlsx (if xlsxwriter available)
"""
from __future__ import annotations
from pathlib import Path
import json
import numpy as np
import pandas as pd

OUT = Path("publication_outputs"); OUT.mkdir(exist_ok=True)
FEATURES_FP = OUT / "feature_matrix.parquet"
INTERVENTIONS_FP = OUT / "interventions_df.csv"

RISK_WINDOWS = (7, 30, 90, 180)


def _must_exist(p: Path):
    if not p.exists():
        raise FileNotFoundError(f"Missing {p}. Run Block 1 first.")


def _pct(x):
    x = float(x) if pd.notnull(x) else np.nan
    return 100.0 * x


def _cont_stats(s: pd.Series, name: str):
    s = pd.to_numeric(s, errors="coerce").dropna()
    if s.empty:
        return {f"{name}_mean": np.nan, f"{name}_sd": np.nan, f"{name}_median": np.nan,
                f"{name}_p25": np.nan, f"{name}_p75": np.nan}
    return {
        f"{name}_mean": float(s.mean()),
        f"{name}_sd": float(s.std(ddof=1)),
        f"{name}_median": float(s.median()),
        f"{name}_p25": float(s.quantile(0.25)),
        f"{name}_p75": float(s.quantile(0.75)),
    }


def summarize(df: pd.DataFrame, label: str) -> pd.DataFrame:
    rows = []
    N = len(df)
    rows.append({"Group": label, "Metric": "N", "Value": N})

    if "age" in df.columns:
        rows.extend({"Group": label, "Metric": k, "Value": v} for k, v in _cont_stats(df["age"], "Age").items())

    if "gender" in df.columns:
        g = df["gender"].astype(str).str.upper()
        rows.append({"Group": label, "Metric": "Female_%", "Value": _pct((g == "F").mean())})
        rows.append({"Group": label, "Metric": "Male_%",   "Value": _pct((g == "M").mean())})

    if "is_engaged" in df.columns:
        rows.append({"Group": label, "Metric": "Engaged_%", "Value": _pct(df["is_engaged"].mean())})
    if "treated_any" in df.columns:
        rows.append({"Group": label, "Metric": "Any_Treatment_%", "Value": _pct(df["treated_any"].mean())})

    for w in RISK_WINDOWS:
        col = f"outcome_{w}d"
        if col in df.columns:
            rows.append({"Group": label, "Metric": f"Outcome_{w}d_%", "Value": _pct(df[col].mean())})

    # Optional utilization features if present
    for cand in ["hist_ed_visits", "hist_admissions", "hist_util_total"]:
        if cand in df.columns:
            rows.extend({"Group": label, "Metric": k, "Value": v}
                        for k, v in _cont_stats(df[cand], cand).items())

    return pd.DataFrame(rows)


def run():
    _must_exist(FEATURES_FP)
    features = pd.read_parquet(FEATURES_FP)

    # Build treated_any from file if count col missing
    treated_ids = set()
    if INTERVENTIONS_FP.exists():
        try:
            tdf = pd.read_csv(INTERVENTIONS_FP, usecols=["patient_id"]) 
            treated_ids = set(tdf["patient_id"].astype(str).unique())
        except Exception:
            treated_ids = set()

    overall = features.copy()
    if "is_engaged" not in overall.columns:
        overall["is_engaged"] = False

    if "intervention_history_count" in overall.columns:
        overall["treated_any"] = (overall["intervention_history_count"].fillna(0) > 0).astype(int)
    else:
        overall["treated_any"] = overall["patient_id"].astype(str).isin(treated_ids).astype(int)

    engaged = overall[overall["is_engaged"] == True].copy()
    engaged_treated = engaged[engaged["treated_any"] == 1].copy()

    t1_overall = summarize(overall, "Overall")
    t1_engaged = summarize(engaged, "Engaged")
    t1_engaged_treated = summarize(engaged_treated, "Engaged_Treated")

    t1_overall.to_csv(OUT/"table1_overall_cohort.csv", index=False)
    t1_engaged.to_csv(OUT/"table1_engaged_subset.csv", index=False)
    t1_engaged_treated.to_csv(OUT/"table1_engaged_treated_subset.csv", index=False)

    # Optional Excel workbook
    try:
        with pd.ExcelWriter(OUT/"table1_baseline_characteristics.xlsx", engine="xlsxwriter") as w:
            t1_overall.to_excel(w, sheet_name="Overall", index=False)
            t1_engaged.to_excel(w, sheet_name="Engaged", index=False)
            t1_engaged_treated.to_excel(w, sheet_name="Engaged_Treated", index=False)
    except Exception as e:
        print(f"[WARN] Could not write Excel Table 1 (xlsxwriter missing?): {e}")

    print("‚úì Table 1 written to publication_outputs/:\n  - table1_overall_cohort.csv\n  - table1_engaged_subset.csv\n  - table1_engaged_treated_subset.csv\n  - table1_baseline_characteristics.xlsx (if available)")


if __name__ == "__main__":
    run()


In [None]:
!pip install --upgrade pip wheel setuptools
!pip install torchtuples==0.2.2 pycox==0.3.0
# torch is already present per your logs


In [None]:
#!/usr/bin/env python3
"""
BLOCK 2 ‚Äî TIME‚ÄëTO‚ÄëEVENT MODELING (robust, fast, and extensible)
===============================================================
- Original classifier zoo retained (LogReg family, RF/ET/GB, optional XGBoost)
- CoxPH (lifelines) with ridge penalization and guards
- Survival add‚Äëons: Random Survival Forest (sksurv‚Üílifelines fallback),
  DeepSurv (pycox), DeepHit (pycox), and **XGB‚ÄëAFT** (ultra‚Äëfast survival)
- Computes ROC AUC + Youden's J* at each horizon; writes same outputs as before

**Fixes in this version**
- Restores missing `_stratified_cap` (caused NameError)
- NaN/Inf‚Äësafe survival training via `_sanitize_survival_data`
- Stable per‚Äësample interpolation `_interp1d_safe` (fixes "fp and xp" error)
- DeepSurv/DeepHit use `model.fit(x, (t, e))` (no TupleDataset dependency)
- Optional **XGB‚ÄëAFT** engine to keep runs fast on large datasets

Inputs (looked for in ./publication_outputs, with artifact fallbacks):
  ‚Ä¢ publication_outputs/feature_matrix.parquet
  ‚Ä¢ publication_outputs/feature_meta.json (or artifacts/feature_cols.json)

Outputs (in ./publication_outputs):
  ‚Ä¢ best_tte_model_predictions.csv (last window fit)
  ‚Ä¢ best_tte_model_predictions_30d.csv
  ‚Ä¢ table2_time_to_event_performance.csv
  ‚Ä¢ tte_results.json
  ‚Ä¢ complete_results_<timestamp>.json
"""

from __future__ import annotations
import json, shutil, sys
from pathlib import Path
from datetime import datetime
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split, GroupShuffleSplit
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import (
    RandomForestClassifier, ExtraTreesClassifier, GradientBoostingClassifier,
)
from sklearn.calibration import CalibratedClassifierCV

# Optional XGBoost (for classifiers and AFT survival)
try:
    import xgboost as xgb
    XGBOOST_AVAILABLE = True
except Exception:
    XGBOOST_AVAILABLE = False

# lifelines (CoxPH + RSF fallback)
try:
    from lifelines import CoxPHFitter
    LIFELINES_AVAILABLE = True
except Exception:
    LIFELINES_AVAILABLE = False

try:
    from lifelines import RandomSurvivalForest as LL_RSF
    LL_RSF_AVAILABLE = True
except Exception:
    LL_RSF_AVAILABLE = False

# scikit‚Äësurvival RSF (preferred)
try:
    from sksurv.ensemble import RandomSurvivalForest as SKS_RSF
    from sksurv.util import Surv as SKS_Surv
    SKS_AVAILABLE = True
except Exception:
    SKS_AVAILABLE = False

# pycox (DeepSurv / DeepHit)
try:
    import torch
    from torch import nn
    import torchtuples as tt
    from pycox.models import CoxPH as PYCoxPH
    from pycox.models import DeepHitSingle
    from pycox.preprocessing.label_transforms import LabTransDiscreteTime
    PYCOX_AVAILABLE = True
except Exception:
    PYCOX_AVAILABLE = False

if not PYCOX_AVAILABLE:
    print("[Init] DeepSurv/DeepHit disabled: couldn't import pycox/torchtuples/torch.")

OUTPUT_DIR = Path("publication_outputs"); OUTPUT_DIR.mkdir(exist_ok=True)
ART_DIR = Path("artifacts")
RANDOM_STATE = 42

# =============================================================
# Split configuration
# =============================================================
USE_TIME_SPLIT = True
TIME_TEST_FRACTION = 0.20
TIME_SPLIT_COLUMN = None
TIME_SPLIT_DATE = None
GROUP_BY_PATIENT = True
PATIENT_ID_COL = "patient_id"

# Metrics
COMPUTE_JSTAR = True

# =============================================================
# FAST MODE toggles (edit for speed)
# =============================================================
ENABLE_SURVIVAL = True                 # master switch for survival add‚Äëons
DEVICE_OVERRIDE = None                 # set to "cpu" to avoid GPU warmup
SURV_FAST_MODE = True
MAX_TRAIN_SAMPLES_PER_WINDOW = 5000   # or smaller


# Choose which survival engines to run
# Options: 'rsf', 'deepsurv', 'deephit', 'xgb_aft'
SURVIVAL_ENGINES = ["rsf", "deepsurv", "deephit", "xgb_aft"]
#SURVIVAL_ENGINES = ["xgb_aft","rsf"]    # skip deep nets entirely
# Limit horizons survival runs are attempted on (keep all by default)
SURVIVAL_WINDOWS = [7, 30, 90, 180]

if SURV_FAST_MODE:
    DEEPSURV_HIDDEN = [32, 32]
    DEEPSURV_EPOCHS = 12
    DEEPHIT_HIDDEN = [64, 64]
    DEEPHIT_EPOCHS = 12
    DEEPHIT_BINS = 40
    RSF_TREES = 120
else:
    DEEPSURV_HIDDEN = [64, 64]
    DEEPSURV_EPOCHS = 50
    DEEPHIT_HIDDEN = [128, 128]
    DEEPHIT_EPOCHS = 50
    DEEPHIT_BINS = 100
    RSF_TREES = 300

# XGB‚ÄëAFT parameters (fast survival)
XGB_AFT_PARAMS = dict(
    objective="survival:aft",
    eval_metric="aft-nloglik",
    aft_loss_distribution="normal",     # also: 'logistic', 'extreme'
    aft_loss_distribution_scale=1.20,
    tree_method="hist",
    max_depth=6,
    learning_rate=0.05,
    reg_lambda=0.01,
    reg_alpha=0.02,
    verbosity=0,
)
XGB_AFT_NUM_ROUNDS = 400
XGB_AFT_EARLY_STOP = 30

# =============================================================
# IO helpers
# =============================================================

def _mirror(src: Path, dst: Path):
    dst.parent.mkdir(parents=True, exist_ok=True)
    shutil.copy2(src, dst)
    print(f"üîÅ Fallback: mirrored {src} ‚Üí {dst}")

def _load_features() -> pd.DataFrame:
    for p in [OUTPUT_DIR/"feature_matrix.parquet", OUTPUT_DIR/"features_df.parquet", OUTPUT_DIR/"features.parquet"]:
        if p.exists():
            print(f"Found features: {p}")
            return pd.read_parquet(p)
    cands = sorted(ART_DIR.glob("features_master_*.parquet")) if ART_DIR.exists() else []
    if cands:
        latest = cands[-1]
        mirror = OUTPUT_DIR/"feature_matrix.parquet"
        _mirror(latest, mirror)
        return pd.read_parquet(mirror)
    raise FileNotFoundError("Feature matrix not found in publication_outputs or artifacts.")

def _load_feature_cols(features: pd.DataFrame):
    for p in [OUTPUT_DIR/"feature_meta.json", OUTPUT_DIR/"feature_cols.json", ART_DIR/"feature_meta.json", ART_DIR/"feature_cols.json"]:
        p = Path(p)
        if p.exists():
            try:
                meta = json.loads(p.read_text())
                cols = meta.get("feature_cols") or meta.get("features")
                if cols:
                    print(f"Loaded feature cols from {p}")
                    return cols
            except Exception:
                pass
    # Infer if missing
    drop = {"patient_id","firstName","lastName","birthDate","gender","ethnicity","createdAt","updatedAt","has_notes","engaged_only_flag","DX_Summary"}
    for w in (7,30,90,180):
        drop.add(f"outcome_{w}d"); drop.add(f"time_to_event_{w}d")
    cols = [c for c in features.columns if c not in drop and pd.api.types.is_numeric_dtype(features[c])]
    (OUTPUT_DIR/"feature_meta.json").write_text(json.dumps({"feature_cols": cols}, indent=2))
    print(f"Inferred {len(cols)} feature columns")
    return cols

# =============================================================
# Splitters
# =============================================================

def _pick_time_col(df: pd.DataFrame) -> str:
    dt_cols = [c for c in df.columns if np.issubdtype(df[c].dtype, np.datetime64)]
    candidates = dt_cols or [c for c in ["as_of_date","index_date","feature_date","cohort_date","createdAt","updatedAt"] if c in df.columns]
    if not candidates:
        raise ValueError("No suitable time column found. Add one of: as_of_date/index_date/feature_date/cohort_date/createdAt/updatedAt, or set TIME_SPLIT_COLUMN explicitly.")
    return candidates[0]

def _ensure_datetime(df: pd.DataFrame, col: str):
    if not np.issubdtype(df[col].dtype, np.datetime64):
        df[col] = pd.to_datetime(df[col], errors="coerce")
    return df

def time_based_split(df: pd.DataFrame, date_col: str, test_fraction: float, split_date: str|None=None):
    df = _ensure_datetime(df.copy(), date_col)
    if df[date_col].isna().all():
        raise ValueError(f"Column {date_col} has no parseable datetimes. Provide a valid TIME_SPLIT_COLUMN or disable USE_TIME_SPLIT.")
    dates = df[date_col]
    thr = pd.to_datetime(split_date) if split_date else dates.quantile(1 - test_fraction)
    train_idx = dates <= thr
    test_idx  = dates > thr
    if train_idx.sum()==0 or test_idx.sum()==0:
        thr = dates.sort_values().iloc[int((1 - test_fraction)*len(dates))]
        train_idx = dates <= thr
        test_idx  = dates > thr
    print(f"‚è±Ô∏è  Time split on {date_col}: train ‚â§ {thr.date()} | train={train_idx.sum()} test={test_idx.sum()}")
    return train_idx.values, test_idx.values, thr

# =============================================================
# Models & metrics
# =============================================================

def build_models():
    models = {
        "LogisticRegression": LogisticRegression(max_iter=2000, solver="lbfgs", class_weight="balanced", random_state=RANDOM_STATE),
        "ElasticNetLogistic": LogisticRegression(max_iter=2000, solver="saga", penalty="elasticnet", l1_ratio=0.5, class_weight="balanced", random_state=RANDOM_STATE),
        "RidgeLogistic":      LogisticRegression(max_iter=2000, solver="lbfgs", penalty="l2", class_weight="balanced", random_state=RANDOM_STATE),
        "LassoLogistic":      LogisticRegression(max_iter=2000, solver="saga", penalty="l1", class_weight="balanced", random_state=RANDOM_STATE),
        "RandomForest":       RandomForestClassifier(n_estimators=400, max_depth=18, min_samples_leaf=10, class_weight="balanced_subsample", random_state=RANDOM_STATE),
        "ExtraTrees":         ExtraTreesClassifier(n_estimators=500, max_depth=18, min_samples_leaf=10, class_weight="balanced_subsample", random_state=RANDOM_STATE),
        "GradientBoosting":   GradientBoostingClassifier(random_state=RANDOM_STATE),
    }
    if XGBOOST_AVAILABLE:
        models["XGBoost"] = xgb.XGBClassifier(
            n_estimators=600, max_depth=6, learning_rate=0.05,
            subsample=0.8, colsample_bytree=0.8,
            eval_metric="logloss", random_state=RANDOM_STATE
        )
    return models


def _auc_jstar(y_true, scores):
    auc = float(roc_auc_score(y_true, scores)) if len(np.unique(y_true))>1 else np.nan
    if not COMPUTE_JSTAR or np.isnan(auc):
        return auc, np.nan
    fpr, tpr, _ = roc_curve(y_true, scores)
    J = tpr - fpr
    jstar = float(J.max()) if len(J) else np.nan
    return auc, jstar


def _prep_design_matrix(df: pd.DataFrame, feature_cols):
    X = df[feature_cols].copy()
    nunique = X.nunique(dropna=False)
    keep = nunique[nunique>1].index.tolist()
    X = X[keep]
    scaler = StandardScaler()
    Xs = scaler.fit_transform(X.fillna(0.0).values)
    return Xs, scaler, keep

# =============================================================
# Survival utilities (robust & fast)
# =============================================================

def _stratified_cap(X_trs, t_tr, e_tr, y_tr, cap):
    """Return a stratified (by event) subset if len(X_trs) > cap."""
    n = len(X_trs)
    if cap is None or n <= cap:
        return X_trs, t_tr, e_tr, y_tr
    rng = np.random.RandomState(42)
    pos_idx = np.flatnonzero(y_tr == 1)
    neg_idx = np.flatnonzero(y_tr == 0)
    pos_keep = min(len(pos_idx), cap // 2)
    neg_keep = cap - pos_keep
    pos_sel = rng.choice(pos_idx, size=pos_keep, replace=False) if pos_keep > 0 else np.array([], dtype=int)
    neg_sel = rng.choice(neg_idx, size=neg_keep, replace=False)
    sel = np.concatenate([pos_sel, neg_sel])
    sel.sort()
    return X_trs[sel], t_tr[sel], e_tr[sel], y_tr[sel]


def _sanitize_survival_data(X, t, e, y, eps=1e-6):
    """Drop rows with NaN/Inf or negative t; clamp t to >= eps."""
    X = np.asarray(X, dtype=np.float32)
    t = np.asarray(t, dtype=np.float32)
    e = np.asarray(e, dtype=np.int32)
    y = np.asarray(y, dtype=np.int32)
    mask = np.isfinite(t) & (t >= 0) & np.isfinite(e) & np.isfinite(X).all(axis=1)
    dropped = int(len(t) - mask.sum())
    if dropped:
        print(f"    [sanitize] Dropped {dropped} rows with NaN/Inf/neg duration for survival training")
    t = np.clip(t[mask], eps, np.inf)
    return X[mask], t, e[mask], y[mask]


def _interp1d_safe(x, xp, fp, left, right):
    """np.interp guard: trims to common length, enforces monotone xp, returns float."""
    xp = np.asarray(xp, dtype=float)
    fp = np.asarray(fp, dtype=float)
    m = min(len(xp), len(fp))
    if m == 0:
        return float('nan')
    xp = xp[:m]; fp = fp[:m]
    xp = np.maximum.accumulate(xp)  # ensure non‚Äëdecreasing
    return float(np.interp(float(x), xp, fp, left=left, right=right))

# =============================================================
# CoxPH helper (lifelines)
# =============================================================

def _coxph_fit_predict(train_df, test_df, feature_cols, duration_col, event_col):
    if not LIFELINES_AVAILABLE:
        raise RuntimeError("lifelines not available for CoxPH")
    X_tr = train_df[feature_cols].copy().fillna(0.0)
    X_te = test_df[feature_cols].copy().fillna(0.0)
    keep = X_tr.columns[X_tr.nunique(dropna=False)>1].tolist()
    X_tr = X_tr[keep]; X_te = X_te[keep]
    scaler = StandardScaler(); X_trs = scaler.fit_transform(X_tr.values); X_tes = scaler.transform(X_te.values)
    tr = pd.DataFrame(X_trs, columns=keep, index=train_df.index)
    te = pd.DataFrame(X_tes, columns=keep, index=test_df.index)
    tr[duration_col] = train_df[duration_col].values
    tr[event_col]    = train_df[event_col].values
    cph = CoxPHFitter(penalizer=1.0)
    try:
        cph.fit(tr, duration_col=duration_col, event_col=event_col, robust=True)
    except Exception:
        cph = CoxPHFitter(penalizer=5.0)
        cph.fit(tr, duration_col=duration_col, event_col=event_col, robust=True)
    scores = cph.predict_partial_hazard(te).values.ravel()
    return scores

# =============================================================
# DeepSurv / DeepHit / RSF / XGB‚ÄëAFT
# =============================================================

def _deepsurv_fit_predict(X_trs, t_tr, e_tr, X_tes, horizon_days):
    if not PYCOX_AVAILABLE:
        raise RuntimeError("pycox not available for DeepSurv")
    input_dim = X_trs.shape[1]
    layers, in_dim = [], input_dim
    for h in DEEPSURV_HIDDEN:
        layers += [nn.Linear(in_dim, h), nn.ReLU(), nn.BatchNorm1d(h), nn.Dropout(0.1)]
        in_dim = h
    layers += [nn.Linear(in_dim, 1)]
    net = nn.Sequential(*layers)
    device = DEVICE_OVERRIDE or ("cuda" if torch.cuda.is_available() else "cpu")
    model = PYCoxPH(net, tt.optim.Adam(lr=1e-3, weight_decay=1e-4), device=device)

    x_tr = X_trs.astype(np.float32)
    d_tr = np.clip(t_tr.astype(np.float32), 1e-6, np.inf)
    e_trb = e_tr.astype(bool)

    model.fit(x_tr, (d_tr, e_trb), batch_size=min(256, len(x_tr)), epochs=DEEPSURV_EPOCHS, verbose=False)
    model.compute_baseline_hazards()

    x_te = X_tes.astype(np.float32)
    surv_df = model.predict_surv_df(x_te)  # pandas DataFrame (time index)
    t_grid = surv_df.index.to_numpy(dtype=float)
    S_all = surv_df.T.to_numpy()  # shape (n, len(t_grid))

    S_w = np.empty(S_all.shape[0], dtype=float)
    for i in range(S_all.shape[0]):
        S_i = np.clip(S_all[i], 1e-12, 1.0)
        S_w[i] = np.exp(_interp1d_safe(horizon_days, t_grid, np.log(S_i), left=np.log(0.999999), right=np.log(S_i[-1])))
    return 1.0 - S_w


def _deephit_fit_predict(X_trs, t_tr, e_tr, X_tes, horizon_days, num_durations=None):
    if not PYCOX_AVAILABLE:
        raise RuntimeError("pycox not available for DeepHit")
    num_durations = int(num_durations or DEEPHIT_BINS)
    lab = LabTransDiscreteTime(num_durations)
    yi, ye = lab.fit_transform(t_tr, e_tr)

    input_dim = X_trs.shape[1]
    layers, in_dim = [], input_dim
    for h in DEEPHIT_HIDDEN:
        layers += [nn.Linear(in_dim, h), nn.ReLU(), nn.BatchNorm1d(h), nn.Dropout(0.1)]
        in_dim = h
    layers += [nn.Linear(in_dim, num_durations)]
    net = nn.Sequential(*layers)

    device = DEVICE_OVERRIDE or ("cuda" if torch.cuda.is_available() else "cpu")
    model = DeepHitSingle(net, tt.optim.Adam(lr=1e-3), alpha=0.2, sigma=0.1, device=device)

    x_tr = X_trs.astype(np.float32)
    model.fit(x_tr, (yi, ye), batch_size=min(256, len(x_tr)), epochs=DEEPHIT_EPOCHS, verbose=False)

    x_te = X_tes.astype(np.float32)
    pmf = np.asarray(model.predict_pmf(x_te))  # shape (n, m)
    surv_disc = np.cumprod(1 - pmf, axis=1)
    cuts = lab.cuts  # length m+1
    time_grid = cuts[1:].astype(float)  # length m

    S_w = np.empty(surv_disc.shape[0], dtype=float)
    for i in range(surv_disc.shape[0]):
        S_i = np.clip(surv_disc[i], 1e-12, 1.0)
        S_w[i] = _interp1d_safe(horizon_days, time_grid, S_i, left=1.0, right=S_i[-1])
    return 1.0 - S_w


def _rsf_fit_predict(X_trs, t_tr, e_tr, X_tes, horizon_days):
    if SKS_AVAILABLE:
        y_struct = SKS_Surv.from_arrays(event=e_tr.astype(bool), time=t_tr.astype(float))
        rsf = SKS_RSF(n_estimators=RSF_TREES, min_samples_split=10, min_samples_leaf=5, random_state=17, n_jobs=-1, oob_score=True)
        rsf.fit(X_trs, y_struct)
        surv_funcs = rsf.predict_survival_function(X_tes, return_array=False)
        out = np.zeros(len(X_tes), dtype=float)
        for i, f in enumerate(surv_funcs):
            out[i] = 1.0 - _interp1d_safe(horizon_days, f.x, f.y, left=0.0, right=1.0)
        return out
    if LL_RSF_AVAILABLE:
        df_tr = pd.DataFrame(X_trs, columns=[f"x{i}" for i in range(X_trs.shape[1])])
        df_tr["duration"] = t_tr; df_tr["event"] = e_tr
        rsf = LL_RSF(n_estimators=RSF_TREES, min_samples_split=10, min_samples_leaf=5, random_state=17, n_jobs=-1)
        rsf.fit(df_tr, duration_col="duration", event_col="event")
        surv = rsf.predict_survival_function(X_tes)
        out = np.zeros(len(X_tes), dtype=float)
        for i, s in enumerate(surv):
            t_grid = s.index.to_numpy(dtype=float); y = s.to_numpy(dtype=float)
            out[i] = 1.0 - _interp1d_safe(horizon_days, t_grid, y, left=0.0, right=1.0)
        return out
    raise RuntimeError("No RSF backend available. Install 'scikit-survival' or 'lifelines'.")


def _xgb_aft_fit_predict_time(X_trs, t_tr, e_tr, X_tes,
                              params=None, num_boost_round=None, early_stopping_rounds=None):
    if not XGBOOST_AVAILABLE:
        raise RuntimeError("XGBoost not available for AFT.")
    params = dict(XGB_AFT_PARAMS if params is None else params)
    rnds = XGB_AFT_NUM_ROUNDS if num_boost_round is None else num_boost_round
    es = XGB_AFT_EARLY_STOP if early_stopping_rounds is None else early_stopping_rounds

    # Build ranged labels for AFT
    y_lb = t_tr.astype(float)
    y_ub = np.where(e_tr.astype(bool), t_tr.astype(float), np.inf)

    dfull = xgb.DMatrix(X_trs)
    dfull.set_float_info('label_lower_bound', y_lb)
    dfull.set_float_info('label_upper_bound', y_ub)

    # Simple holdout for early stopping
    n = len(X_trs)
    k = max(1000, min(n // 5, 5000)) if n > 2000 else max(100, n // 5)
    idx_valid = np.arange(n - k, n)
    idx_train = np.arange(0, n - k)

    dtr = xgb.DMatrix(X_trs[idx_train])
    dtr.set_float_info('label_lower_bound', y_lb[idx_train])
    dtr.set_float_info('label_upper_bound', y_ub[idx_train])

    dval = xgb.DMatrix(X_trs[idx_valid])
    dval.set_float_info('label_lower_bound', y_lb[idx_valid])
    dval.set_float_info('label_upper_bound', y_ub[idx_valid])

    bst = xgb.train(params, dtr, num_boost_round=rnds, evals=[(dtr, 'train'), (dval, 'valid')],
                    early_stopping_rounds=es, verbose_eval=False)

    dtest = xgb.DMatrix(X_tes)
    t_pred = bst.predict(dtest)   # monotone with time; use -t_pred as risk
    return t_pred

# =============================================================
# Runner
# =============================================================

def main():
    print("[Block 2] Loading features‚Ä¶")
    features = _load_features()

    date_col = TIME_SPLIT_COLUMN or (_pick_time_col(features) if USE_TIME_SPLIT else None)
    if USE_TIME_SPLIT:
        features = _ensure_datetime(features, date_col)

    feature_cols = _load_feature_cols(features)
    print("Loaded feature cols from publication_outputs/feature_meta.json" if (OUTPUT_DIR/"feature_meta.json").exists() else "Inferred feature cols.")

    windows = [7, 30, 90, 180]
    results = []

    for w in windows:
        print(f"\nTraining models for {w}-day horizon‚Ä¶")
        y_col = f"outcome_{w}d"; d_col = f"time_to_event_{w}d"
        if y_col not in features.columns or d_col not in features.columns:
            print(f"  [WARN] Missing columns for {w}d ‚Äî skipping")
            continue

        df = features.dropna(subset=[y_col, d_col]).copy()
        y = df[y_col].astype(int).values.ravel()

        # Build time/random split indices
        if USE_TIME_SPLIT:
            train_mask, test_mask, thr = time_based_split(df, date_col, TIME_TEST_FRACTION, TIME_SPLIT_DATE)
        else:
            if GROUP_BY_PATIENT and PATIENT_ID_COL in df.columns:
                gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=RANDOM_STATE)
                tr_idx, te_idx = next(gss.split(df, groups=df[PATIENT_ID_COL]))
                train_mask = np.zeros(len(df), dtype=bool); train_mask[tr_idx] = True
                test_mask  = np.zeros(len(df), dtype=bool);  test_mask[te_idx]  = True
                print(f"üë• Grouped random split by {PATIENT_ID_COL}: train={train_mask.sum()} test={test_mask.sum()}")
            else:
                strat = y if np.unique(y).size>1 else None
                tr, te = train_test_split(np.arange(len(df)), test_size=0.2, random_state=RANDOM_STATE, stratify=strat)
                train_mask = np.zeros(len(df), dtype=bool); train_mask[tr] = True
                test_mask  = np.zeros(len(df), dtype=bool);  test_mask[te]  = True
                print(f"üé≤ Random split: train={train_mask.sum()} test={test_mask.sum()}")

        df_tr = df.loc[train_mask].copy(); df_te = df.loc[test_mask].copy()
        y_tr = df_tr[y_col].astype(int).values.ravel(); y_te = df_te[y_col].astype(int).values.ravel()

        # Design matrix (scale on TRAIN)
        X_trs, scaler, keep_cols = _prep_design_matrix(df_tr, feature_cols)
        X_tes = scaler.transform(df_te[keep_cols].fillna(0.0).values)

        # ------- Existing classifier zoo -------
        models = build_models()
        if LIFELINES_AVAILABLE:
            models = {**models, "CoxPH": None}  # placeholder; handled separately

        best_name, best_auc = None, -np.inf
        preds_for_best = None

        for name, model in models.items():
            try:
                if name == "CoxPH":
                    try:
                        scores = _coxph_fit_predict(df_tr[[*keep_cols, d_col, y_col]], df_te[[*keep_cols, d_col, y_col]], keep_cols, duration_col=d_col, event_col=y_col)
                        auc, jstar = _auc_jstar(y_te, scores)
                        print(f"  [{w}d] CoxPH             AUC={auc:.3f}  J*={(jstar if not np.isnan(jstar) else 0):.3f}")
                        y_scores = scores
                    except Exception as ex:
                        print(f"  [{w}d] CoxPH             FAILED ‚Äî {ex}")
                        continue
                else:
                    if isinstance(model, (RandomForestClassifier, ExtraTreesClassifier, GradientBoostingClassifier)):
                        model = CalibratedClassifierCV(model, method="sigmoid", cv=3)
                    model.fit(X_trs, y_tr)
                    y_scores = model.predict_proba(X_tes)[:,1]
                    auc, jstar = _auc_jstar(y_te, y_scores)
                    print(f"  [{w}d] {name:<18} AUC={auc:.3f}  J*={(jstar if not np.isnan(jstar) else 0):.3f}")
                if auc > best_auc:
                    best_auc = auc; best_name = name; preds_for_best = y_scores
            except Exception as e:
                print(f"  [{w}d] {name:<18} FAILED ‚Äî {e}")
                continue

        # ------- Survival models (risk at horizon = 1 - S_w) -------
        t_tr = df_tr[d_col].astype(float).values.ravel()
        e_tr = df_tr[y_col].astype(int).values.ravel()
        horizon = float(w)

        if not ENABLE_SURVIVAL:
            print(f"  [{w}d] Survival models SKIPPED ‚Äî ENABLE_SURVIVAL=False")
        else:
            if w not in SURVIVAL_WINDOWS:
                print(f"  [{w}d] Survival models SKIPPED ‚Äî not in SURVIVAL_WINDOWS={SURVIVAL_WINDOWS}")
            else:
                X_cap, t_cap, e_cap, y_cap = _stratified_cap(X_trs, t_tr, e_tr, y_tr, MAX_TRAIN_SAMPLES_PER_WINDOW)
                X_cap, t_cap, e_cap, y_cap = _sanitize_survival_data(X_cap, t_cap, e_cap, y_cap)

                # DeepSurv
                if "deepsurv" in SURVIVAL_ENGINES and PYCOX_AVAILABLE:
                    try:
                        scores = _deepsurv_fit_predict(X_cap, t_cap, e_cap, X_tes, horizon)
                        auc, jstar = _auc_jstar(y_te, scores)
                        print(f"  [{w}d] DeepSurv          AUC={auc:.3f}  J*={(jstar if not np.isnan(jstar) else 0):.3f}")
                        if auc > best_auc:
                            best_auc = auc; best_name = "DeepSurv"; preds_for_best = scores
                    except Exception as e:
                        print(f"  [{w}d] DeepSurv          FAILED ‚Äî {e}")
                elif "deepsurv" in SURVIVAL_ENGINES and not PYCOX_AVAILABLE:
                    print(f"  [{w}d] DeepSurv          SKIPPED ‚Äî pycox/torchtuples not available")

                # DeepHit
                if "deephit" in SURVIVAL_ENGINES and PYCOX_AVAILABLE:
                    try:
                        scores = _deephit_fit_predict(X_cap, t_cap, e_cap, X_tes, horizon, num_durations=DEEPHIT_BINS)
                        auc, jstar = _auc_jstar(y_te, scores)
                        print(f"  [{w}d] DeepHit           AUC={auc:.3f}  J*={(jstar if not np.isnan(jstar) else 0):.3f}")
                        if auc > best_auc:
                            best_auc = auc; best_name = "DeepHit"; preds_for_best = scores
                    except Exception as e:
                        print(f"  [{w}d] DeepHit           FAILED ‚Äî {e}")
                elif "deephit" in SURVIVAL_ENGINES and not PYCOX_AVAILABLE:
                    print(f"  [{w}d] DeepHit           SKIPPED ‚Äî pycox/torchtuples not available")

                # RSF
                if "rsf" in SURVIVAL_ENGINES and (SKS_AVAILABLE or LL_RSF_AVAILABLE):
                    try:
                        scores = _rsf_fit_predict(X_cap, t_cap, e_cap, X_tes, horizon)
                        auc, jstar = _auc_jstar(y_te, scores)
                        print(f"  [{w}d] RSF               AUC={auc:.3f}  J*={(jstar if not np.isnan(jstar) else 0):.3f}")
                        if auc > best_auc:
                            best_auc = auc; best_name = "RSF"; preds_for_best = scores
                    except Exception as e:
                        print(f"  [{w}d] RSF               FAILED ‚Äî {e}")
                elif "rsf" in SURVIVAL_ENGINES:
                    print(f"  [{w}d] RSF               SKIPPED ‚Äî install scikit-survival or lifelines >=0.27")

                # XGB‚ÄëAFT (ultra‚Äëfast)
                if "xgb_aft" in SURVIVAL_ENGINES and XGBOOST_AVAILABLE:
                    try:
                        t_pred = _xgb_aft_fit_predict_time(X_cap, t_cap, e_cap, X_tes)
                        scores = -t_pred  # smaller predicted time => higher risk
                        auc, jstar = _auc_jstar(y_te, scores)
                        print(f"  [{w}d] XGB‚ÄëAFT           AUC={auc:.3f}  J*={(jstar if not np.isnan(jstar) else 0):.3f}")
                        if auc > best_auc:
                            best_auc = auc; best_name = "XGB‚ÄëAFT"; preds_for_best = scores
                    except Exception as e:
                        print(f"  [{w}d] XGB‚ÄëAFT           FAILED ‚Äî {e}")
                elif "xgb_aft" in SURVIVAL_ENGINES and not XGBOOST_AVAILABLE:
                    print(f"  [{w}d] XGB‚ÄëAFT           SKIPPED ‚Äî XGBoost not available")

        print(f"‚Üí Best @ {w}d: {best_name} (AUC {best_auc:.3f})")
        results.append({"window": w, "best_model": best_name, "auc": float(best_auc)})

        # Persist per‚Äëpatient predictions for downstream blocks
        out = pd.DataFrame({
            PATIENT_ID_COL: df_te[PATIENT_ID_COL].values,
            "risk_score": preds_for_best,
            f"outcome_{w}d": y_te.astype(int),
        })
        out.to_csv(OUTPUT_DIR/"best_tte_model_predictions.csv", index=False)
        if w == 30:
            out.to_csv(OUTPUT_DIR/"best_tte_model_predictions_30d.csv", index=False)

    # Table 2 + JSONs
    pd.DataFrame(results).to_csv(OUTPUT_DIR/"table2_time_to_event_performance.csv", index=False)
    Path(OUTPUT_DIR/"tte_results.json").write_text(json.dumps(results, indent=2))
    stamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
    Path(OUTPUT_DIR/f"complete_results_{stamp}.json").write_text(json.dumps({
        "config": {
            "USE_TIME_SPLIT": USE_TIME_SPLIT,
            "TIME_SPLIT_COLUMN": TIME_SPLIT_COLUMN,
            "TIME_SPLIT_DATE": TIME_SPLIT_DATE,
            "TIME_TEST_FRACTION": TIME_TEST_FRACTION,
            "GROUP_BY_PATIENT": GROUP_BY_PATIENT,
            "PATIENT_ID_COL": PATIENT_ID_COL,
            "ENABLE_SURVIVAL": ENABLE_SURVIVAL,
            "SURV_FAST_MODE": SURV_FAST_MODE,
            "MAX_TRAIN_SAMPLES_PER_WINDOW": MAX_TRAIN_SAMPLES_PER_WINDOW,
            "SURVIVAL_ENGINES": SURVIVAL_ENGINES,
            "SURVIVAL_WINDOWS": SURVIVAL_WINDOWS,
        },
        "windows": results
    }, indent=2))

    print("\n‚úì Block 2 complete ‚Äî outputs:")
    print("  ‚Ä¢ best_tte_model_predictions.csv")
    print("  ‚Ä¢ best_tte_model_predictions_30d.csv (if 30d available)")
    print("  ‚Ä¢ table2_time_to_event_performance.csv")
    print("  ‚Ä¢ tte_results.json")
    print("  ‚Ä¢ complete_results_*.json")


if __name__ == "__main__":
    main()


In [None]:
#!/usr/bin/env python3
"""
BLOCK 3 ‚Äî INTERVENTION SELECTION (OPE publish‚Äëv10)
==================================================
- Engaged‚Äëonly policy modeling & off‚Äëpolicy evaluation
- **MODIFIED: Figure 2A now plots both ARR vs. Watchful Waiting and ARR vs. Status Quo.**
- **MODIFIED: Figure 2A styling updated for clarity (no CI, no f0 bar).**
- Explicitly selects 'CausalForest' for Figure 2 and best-model reporting.
- Calculates ARR and NNT vs. Status Quo (observed rate).
- Robust propensity (calibrated LR) with single‚Äëclass fallback.
- Proper S/T/X/DR/R learners (+ DRForest/ExtraTrees/DRXGBoost opt).
- Adaptive overlap trim with minimum sample requirement.
- Weight clipping for IPS/DR (p99 by default) + ESS diagnostics.
- AUUC (DR uplift) + balance diagnostics (SMD).

Inputs (mirrored from ./artifacts if missing):
  ‚Ä¢ publication_outputs/feature_matrix.parquet
  ‚Ä¢ publication_outputs/interventions_df.csv (optional)
  ‚Ä¢ publication_outputs/feature_meta.json (or artifacts/feature_cols.json)
Outputs (in ./publication_outputs):
  ‚Ä¢ table3_intervention_ope.csv
  ‚Ä¢ figure2_capacity_curve.png (Now shows dual ARR plot)
  ‚Ä¢ figure2_qini_uplift.png
  ‚Ä¢ capacity_curve_best.csv
  ‚Ä¢ capacity_sensitivity.csv
  ‚Ä¢ balance_table.csv
  ‚Ä¢ best_intervention_model_predictions.csv
  ‚Ä¢ best_intervention_model_predictions_fulltest.csv
  ‚Ä¢ intervention_ope_diagnostics.json
"""

from __future__ import annotations
import json, shutil
from pathlib import Path
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, f1_score, cohen_kappa_score
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import (
    RandomForestClassifier, ExtraTreesClassifier,
    RandomForestRegressor, ExtraTreesRegressor,
)
from sklearn.calibration import CalibratedClassifierCV
from sklearn.base import BaseEstimator, ClassifierMixin, clone

# Optional XGBoost
try:
    import xgboost as xgb
    XGBOOST_AVAILABLE = True
except Exception:
    XGBOOST_AVAILABLE = False

OUTPUT_DIR = Path("publication_outputs"); OUTPUT_DIR.mkdir(exist_ok=True)
ART_DIR = Path("artifacts")
RANDOM_STATE = 42
HORIZON_DAYS = 30
EPS = 1e-6

# =============================================================
# Resilient loaders
# =============================================================

def _mirror(src: Path, dst: Path):
    dst.parent.mkdir(parents=True, exist_ok=True)
    shutil.copy2(src, dst)
    print(f"üîÅ Fallback: mirrored {src} ‚Üí {dst}")


def _load_features() -> pd.DataFrame:
    for p in [OUTPUT_DIR/"feature_matrix.parquet", OUTPUT_DIR/"features_df.parquet", OUTPUT_DIR/"features.parquet"]:
        if p.exists():
            print(f"Found features: {p}")
            return pd.read_parquet(p)
    cands = sorted(ART_DIR.glob("features_master_*.parquet")) if ART_DIR.exists() else []
    if cands:
        latest = cands[-1]
        mirror = OUTPUT_DIR/"feature_matrix.parquet"
        _mirror(latest, mirror)
        return pd.read_parquet(mirror)
    raise FileNotFoundError("Feature matrix not found in publication_outputs or artifacts.")


def _load_interventions() -> pd.DataFrame:
    for p in [OUTPUT_DIR/"interventions_df.csv", OUTPUT_DIR/"interventions_processed.csv"]:
        if p.exists():
            print(f"Found interventions: {p}")
            return pd.read_csv(p)
    cands = sorted(ART_DIR.glob("interventions_processed_*.csv")) if ART_DIR.exists() else []
    if cands:
        latest = cands[-1]
        mirror = OUTPUT_DIR/"interventions_df.csv"
        _mirror(latest, mirror)
        return pd.read_csv(mirror)
    print("‚ö†Ô∏è No interventions file found. Proceeding without notes‚Äëbased treatment inference.")
    return pd.DataFrame(columns=["patient_id"])


def _load_feature_cols(features: pd.DataFrame):
    for p in [OUTPUT_DIR/"feature_meta.json", OUTPUT_DIR/"feature_cols.json", ART_DIR/"feature_meta.json", ART_DIR/"feature_cols.json"]:
        p = Path(p)
        if p.exists():
            try:
                meta = json.loads(p.read_text())
                cols = meta.get("feature_cols") or meta.get("features")
                if cols:
                    print(f"Loaded feature cols from {p}")
                    return cols
            except Exception:
                pass
    drop = {"patient_id","firstName","lastName","birthDate","gender","ethnicity","createdAt","updatedAt","has_notes","engaged_only_flag","DX_Summary"}
    for w in (7,30,90,180):
        drop.add(f"outcome_{w}d"); drop.add(f"time_to_event_{w}d")
    cols = [c for c in features.columns if c not in drop and pd.api.types.is_numeric_dtype(features[c])]
    (OUTPUT_DIR/"feature_meta.json").write_text(json.dumps({"feature_cols": cols}, indent=2))
    print(f"Inferred {len(cols)} feature columns")
    return cols

# =============================================================
# Data prep (engaged‚Äëonly)
# =============================================================

def prepare_intervention_frame(features: pd.DataFrame, interventions: pd.DataFrame):
    """Return standardized train/test arrays and metadata (engaged‚Äëonly if available)."""
    target_col = f"outcome_{HORIZON_DAYS}d"
    if target_col not in features.columns:
        raise ValueError(f"Missing target column {target_col} in features.")

    if "is_engaged" in features.columns:
        features = features[features["is_engaged"] == True].copy()
        print(f"  Engaged‚Äëonly subset: n={len(features)}")

    if "intervention_history_count" in features.columns:
        treatment = (features["intervention_history_count"].fillna(0) > 0).astype(int)
    else:
        treated_ids = set(interventions.get("patient_id", pd.Series(dtype=str)).dropna().astype(str))
        treatment = features["patient_id"].astype(str).isin(treated_ids).astype(int)

    feat_cols = _load_feature_cols(features)
    X_df = features[feat_cols].copy().fillna(0.0)
    y = features[target_col].astype(int).values.ravel()
    t = treatment.astype(int).values.ravel()

    stratify_vec = t if np.unique(t).size > 1 else None
    X_tr_df, X_te_df, y_tr, y_te, t_tr, t_te = train_test_split(
        X_df, y, t, test_size=0.2, random_state=RANDOM_STATE, stratify=stratify_vec
    )

    scaler = StandardScaler(); X_tr = scaler.fit_transform(X_tr_df.values); X_te = scaler.transform(X_te_df.values)
    te_ids = features.loc[X_te_df.index, "patient_id"].values

    print(f"  Treatment rate ‚Äî overall: {t.mean():.3f} | train: {t_tr.mean():.3f} | test: {t_te.mean():.3f}")
    meta = {"feature_cols": feat_cols}
    return X_tr, X_te, np.asarray(y_tr).ravel(), np.asarray(y_te).ravel(), np.asarray(t_tr).ravel(), np.asarray(t_te).ravel(), te_ids, meta

# =============================================================
# Uplift model base + learners
# =============================================================

class _UpliftBase:
    def mu0(self, X): raise NotImplementedError
    def mu1(self, X): raise NotImplementedError
    def uplift(self, X): return self.mu1(X) - self.mu0(X)
    def benefit(self, X): return self.mu0(X) - self.mu1(X)
    def policy(self, X): return (self.mu1(X) < self.mu0(X)).astype(int)

class SLearner(_UpliftBase, BaseEstimator, ClassifierMixin):
    def __init__(self, base=None):
        base = base if base is not None else RandomForestClassifier(n_estimators=300, max_depth=12, random_state=RANDOM_STATE, class_weight="balanced")
        self.model = CalibratedClassifierCV(estimator=clone(base), method="sigmoid", cv=3)
    def fit(self, X, y, treatment):
        X = np.asarray(X); t = np.asarray(treatment).ravel(); y = np.asarray(y).ravel()
        self.model.fit(np.column_stack([X, t]), y)
        return self
    def _proba_aug(self, X, a):
        return self.model.predict_proba(np.column_stack([np.asarray(X), np.asarray(a).ravel()]))[:,1]
    def mu0(self, X): return self._proba_aug(X, np.zeros(len(X)))
    def mu1(self, X): return self._proba_aug(X, np.ones(len(X)))

class TLearner(_UpliftBase, BaseEstimator, ClassifierMixin):
    def __init__(self, base=None):
        self.base = base if base is not None else RandomForestClassifier(n_estimators=300, max_depth=12, random_state=RANDOM_STATE, class_weight="balanced")
    def fit(self, X, y, treatment):
        X = np.asarray(X); y = np.asarray(y).ravel(); t = np.asarray(treatment).ravel()
        self.m0 = clone(self.base); self.m1 = clone(self.base)
        m1 = t==1; m0 = ~m1
        self.m0.fit(X[m0], y[m0]) if m0.sum()>0 else self.m0.fit(X, y)
        self.m1.fit(X[m1], y[m1]) if m1.sum()>0 else self.m1.fit(X, y)
        return self
    def mu0(self, X): return self.m0.predict_proba(X)[:,1]
    def mu1(self, X): return self.m1.predict_proba(X)[:,1]

class XLearner(_UpliftBase, BaseEstimator):
    def __init__(self, base_cls=None, base_reg=None, prop_model=None):
        self.base_cls = base_cls if base_cls is not None else RandomForestClassifier(n_estimators=300, max_depth=12, random_state=RANDOM_STATE, class_weight="balanced")
        self.base_reg = base_reg if base_reg is not None else RandomForestRegressor(n_estimators=300, max_depth=12, random_state=RANDOM_STATE)
        self.prop_model = prop_model if prop_model is not None else LogisticRegression(max_iter=1000, random_state=RANDOM_STATE)
    def fit(self, X, y, treatment):
        X = np.asarray(X); y = np.asarray(y).ravel(); t = np.asarray(treatment).ravel()
        self.m0 = clone(self.base_cls); self.m1 = clone(self.base_cls)
        m1 = t==1; m0 = ~m1
        self.m0.fit(X[m0], y[m0]) if m0.sum()>0 else self.m0.fit(X, y)
        self.m1.fit(X[m1], y[m1]) if m1.sum()>0 else self.m1.fit(X, y)
        mu0_t = self.m0.predict_proba(X[m1])[:,1] if m1.sum()>0 else np.full(m1.sum(), y.mean())
        mu1_c = self.m1.predict_proba(X[m0])[:,1] if m0.sum()>0 else np.full(m0.sum(), y.mean())
        D1 = y[m1] - mu0_t
        D0 = mu1_c - y[m0]
        self.tau1 = clone(self.base_reg).fit(X[m1], D1) if m1.sum()>0 else clone(self.base_reg).fit(X, np.zeros(len(X)))
        self.tau0 = clone(self.base_reg).fit(X[m0], D0) if m0.sum()>0 else clone(self.base_reg).fit(X, np.zeros(len(X)))
        self.g = clone(self.prop_model).fit(X, t)
        return self
    def mu0(self, X): return self.m0.predict_proba(X)[:,1]
    def mu1(self, X): return self.m1.predict_proba(X)[:,1]
    def uplift(self, X):
        e = self.g.predict_proba(X)[:,1]
        return (1 - e) * self.tau1.predict(X) + e * self.tau0.predict(X)

class DRLearner(_UpliftBase, BaseEstimator):
    def __init__(self, base_cls=None, base_reg=None, prop_model=None):
        self.base_cls = base_cls if base_cls is not None else RandomForestClassifier(n_estimators=300, max_depth=12, random_state=RANDOM_STATE, class_weight="balanced")
        self.base_reg = base_reg if base_reg is not None else RandomForestRegressor(n_estimators=300, max_depth=12, random_state=RANDOM_STATE)
        self.prop_model = prop_model if prop_model is not None else LogisticRegression(max_iter=1000, random_state=RANDOM_STATE)
    def fit(self, X, y, treatment):
        X = np.asarray(X); y = np.asarray(y).ravel(); t = np.asarray(treatment).ravel()
        self.m0 = clone(self.base_cls); self.m1 = clone(self.base_cls)
        m1 = t==1; m0 = ~m1
        self.m0.fit(X[m0], y[m0]) if m0.sum()>0 else self.m0.fit(X, y)
        self.m1.fit(X[m1], y[m1]) if m1.sum()>0 else self.m1.fit(X, y)
        self.g = LogisticRegression(max_iter=1000, random_state=RANDOM_STATE).fit(X, t)
        e = np.clip(self.g.predict_proba(X)[:,1], 1e-3, 1-1e-3)
        mu0 = self.m0.predict_proba(X)[:,1]; mu1 = self.m1.predict_proba(X)[:,1]
        mu_t = np.where(m1, mu1, mu0)
        phi = ((t - e) / (e * (1 - e))) * (y - mu_t) + (mu1 - mu0)
        self.tau = clone(self.base_reg).fit(X, phi)
        return self
    def mu0(self, X): return self.m0.predict_proba(X)[:,1]
    def mu1(self, X): return self.m1.predict_proba(X)[:,1]
    def uplift(self, X): return self.tau.predict(X)

class RLearner(_UpliftBase, BaseEstimator):
    def __init__(self, base_cls=None, base_reg=None):
        self.base_cls = base_cls if base_cls is not None else RandomForestClassifier(n_estimators=300, max_depth=12, random_state=RANDOM_STATE, class_weight="balanced")
        self.base_reg = base_reg if base_reg is not None else RandomForestRegressor(n_estimators=300, max_depth=12, random_state=RANDOM_STATE)
        self.mu_model = CalibratedClassifierCV(estimator=clone(self.base_cls), method="sigmoid", cv=3)
    def fit(self, X, y, treatment):
        X = np.asarray(X); y = np.asarray(y).ravel(); t = np.asarray(treatment).ravel()
        self.mu = self.mu_model.fit(X, y)
        self.g = CalibratedClassifierCV(estimator=LogisticRegression(max_iter=1000, random_state=RANDOM_STATE), method="sigmoid", cv=3).fit(X, t)
        mu_hat = self.mu.predict_proba(X)[:,1]
        e = np.clip(self.g.predict_proba(X)[:,1], 1e-3, 1-1e-3)
        y_res, t_res = y - mu_hat, t - e
        target = y_res / (t_res + np.sign(t_res)*EPS)
        weights = np.abs(t_res)
        self.tau = clone(self.base_reg).fit(X, target, sample_weight=weights)
        return self
    def mu0(self, X):
        mu = self.mu.predict_proba(X)[:,1]; e = np.clip(self.g.predict_proba(X)[:,1], 1e-3, 1-1e-3)
        tau = self.tau.predict(X)
        return np.clip(mu - e * tau, 0.0, 1.0)
    def mu1(self, X):
        mu = self.mu.predict_proba(X)[:,1]; e = np.clip(self.g.predict_proba(X)[:,1], 1e-3, 1-1e-3)
        tau = self.tau.predict(X)
        return np.clip(mu + (1 - e) * tau, 0.0, 1.0)

class DRForest(_UpliftBase, BaseEstimator):
    def __init__(self, base_cls=None):
        self.base_cls = base_cls if base_cls is not None else RandomForestClassifier(n_estimators=300, max_depth=12, random_state=RANDOM_STATE, class_weight="balanced")
        self.m0 = None; self.m1 = None; self.g = None; self.reg = RandomForestRegressor(n_estimators=400, max_depth=14, random_state=RANDOM_STATE)
    def fit(self, X, y, treatment):
        X = np.asarray(X); y = np.asarray(y).ravel(); t = np.asarray(treatment).ravel()
        self.m0 = clone(self.base_cls); self.m1 = clone(self.base_cls)
        m1 = t==1; m0 = ~m1
        self.m0.fit(X[m0], y[m0]) if m0.sum()>0 else self.m0.fit(X, y)
        self.m1.fit(X[m1], y[m1]) if m1.sum()>0 else self.m1.fit(X, y)
        self.g = CalibratedClassifierCV(estimator=LogisticRegression(max_iter=1000, random_state=RANDOM_STATE), method="sigmoid", cv=3).fit(X, t)
        e = np.clip(self.g.predict_proba(X)[:,1], 1e-3, 1-1e-3)
        mu0 = self.m0.predict_proba(X)[:,1]; mu1 = self.m1.predict_proba(X)[:,1]
        mu_t = np.where(m1, mu1, mu0)
        phi = ((t - e) / (e * (1 - e))) * (y - mu_t) + (mu1 - mu0)
        self.tau = clone(self.reg).fit(X, phi)
        return self
    def mu0(self, X): return self.m0.predict_proba(X)[:,1]
    def mu1(self, X): return self.m1.predict_proba(X)[:,1]
    def uplift(self, X): return self.tau.predict(X)

class DRExtremeTrees(DRForest):
    def __init__(self, base_cls=None):
        super().__init__(base_cls=base_cls)
        self.reg = ExtraTreesRegressor(n_estimators=600, max_depth=16, random_state=RANDOM_STATE)

class DRXGBoost(DRForest):
    def __init__(self, base_cls=None):
        if not XGBOOST_AVAILABLE:
            raise ImportError("xgboost not available")
        super().__init__(base_cls=base_cls)
        self.reg = xgb.XGBRegressor(n_estimators=600, max_depth=6, learning_rate=0.05, subsample=0.8, colsample_bytree=0.8, random_state=RANDOM_STATE, objective="reg:squarederror")

# =============================================================
# Model zoo
# =============================================================

def build_models():
    base_cls = RandomForestClassifier(n_estimators=300, max_depth=12, random_state=RANDOM_STATE, class_weight="balanced")
    base_reg = RandomForestRegressor(n_estimators=400, max_depth=14, random_state=RANDOM_STATE)
    models = {
        "S-Learner": SLearner(base=base_cls),
        "T-Learner": TLearner(base=base_cls),
        "X-Learner": XLearner(base_cls=base_cls, base_reg=base_reg),
        "DR-Learner": DRLearner(base_cls=base_cls, base_reg=base_reg),
        "R-Learner": RLearner(base_cls=base_cls, base_reg=base_reg),
        "CausalForest": DRForest(base_cls=base_cls),
        "CausalExtraTrees": DRExtremeTrees(base_cls=base_cls),
    }
    if XGBOOST_AVAILABLE:
        try:
            models["CausalXGBoost"] = DRXGBoost(base_cls=base_cls)
        except Exception:
            pass
    return models

# =============================================================
# OPE helpers & diagnostics
# =============================================================

def estimate_propensity(X, t):
    t = np.asarray(t).ravel()
    p_hat = float(np.mean(t))
    if np.unique(t).size < 2:
        print(f"  [propensity] Single-class in TRAIN (mean={p_hat:.3f}). Using constant propensity.")
        class _ConstP:
            def predict_proba(self, X_):
                n = len(X_)
                e = np.full((n, 1), p_hat, dtype=float)
                return np.hstack([1 - e, e])
        return _ConstP()
    lr = LogisticRegression(max_iter=2000, random_state=RANDOM_STATE)
    cal = CalibratedClassifierCV(estimator=lr, method="sigmoid", cv=3)
    cal.fit(X, t)
    return cal


def _clip_weights(w: np.ndarray):
    thr = np.percentile(w, 99)
    return np.minimum(w, thr)


def dr_policy_risk(y, t, mu0, mu1, e, policy_t):
    mu_pi = np.where(policy_t==1, mu1, mu0)
    mu_t_obs = np.where(t==1, mu1, mu0)
    w = np.where(t==1, 1/np.clip(e, 1e-6, 1-1e-6), 1/np.clip(1-e, 1e-6, 1-1e-6))
    w = _clip_weights(w)
    adj = (t==policy_t) * (y - mu_t_obs) * w
    return float(np.mean(mu_pi + adj))


def ips_policy_risk(y, t, e, policy_t):
    w = ((policy_t==1) * (t==1) / np.clip(e, 1e-6, 1-1e-6)) + ((policy_t==0) * (t==0) / np.clip(1-e, 1e-6, 1-1e-6))
    w = _clip_weights(w)
    risk = float(np.sum(w * y) / np.sum(w))
    ess = float((np.sum(w)**2) / np.sum(w**2))
    wstats = {"w_mean": float(np.mean(w)), "w_max": float(np.max(w)), "w_min": float(np.min(w)), "w_p99": float(np.percentile(w, 99))}
    return risk, ess, wstats


def arraysafe(a):
    return np.array(list(a), dtype=float)


def capacity_curve(y, t, mu0, mu1, e, fractions):
    """MODIFIED to return ARR vs. Watchful Waiting and ARR vs. Status Quo."""
    policy0 = np.zeros_like(y)
    risk_none_dr = dr_policy_risk(y, t, mu0, mu1, e, policy0)
    status_quo_risk = np.mean(y)

    benefit = mu0 - mu1
    order = np.argsort(-benefit)

    arr_dr, arr_sq = [], []
    for f in fractions:
        k = int(round(f * len(y)))
        policy = np.zeros_like(y)
        if k > 0:
            policy[order[:k]] = 1
        risk_at_f = dr_policy_risk(y, t, mu0, mu1, e, policy)
        arr_dr.append(risk_none_dr - risk_at_f)
        arr_sq.append(status_quo_risk - risk_at_f)
    return arraysafe(fractions), np.array(arr_dr), np.array(arr_sq)


def qini_auuc_dr(y, t, mu0, mu1, e):
    mu_t = np.where(t==1, mu1, mu0)
    tau_dr = (mu0 - mu1) + (t/np.clip(e,1e-6,1-1e-6) - (1-t)/np.clip(1-e,1e-6,1-1e-6)) * (y - mu_t)
    benefit_pred = mu0 - mu1
    order = np.argsort(-benefit_pred)
    gain_cum = np.cumsum(tau_dr[order])
    frac = np.arange(1, len(y)+1) / len(y)
    uplift_curve = gain_cum / np.arange(1, len(y)+1)
    auuc = float(np.trapz(uplift_curve, x=frac))
    return frac, uplift_curve, auuc


def smd_by_feature(X, t, keep_mask=None, top_k=20):
    df = pd.DataFrame(X)
    cols = list(range(df.shape[1]))
    out = []
    for stage, mask in [("pre", np.ones(len(t), dtype=bool)), ("post", keep_mask if keep_mask is not None else np.ones(len(t), dtype=bool))]:
        mt = t[mask]==1; mc = t[mask]==0
        for j,c in enumerate(cols[:top_k]):
            x_t = df.loc[mask, c].values[mt]
            x_c = df.loc[mask, c].values[mc]
            if len(x_t)==0 or len(x_c)==0: continue
            d = (np.mean(x_t) - np.mean(x_c)) / (np.sqrt((np.var(x_t, ddof=1)+np.var(x_c, ddof=1))/2)+1e-9)
            out.append({"feature": f"f{c}", "stage": stage, "SMD": float(d)})
    return pd.DataFrame(out)

# =============================================================
# Publication outputs
# =============================================================

def save_table3(metrics_by_model: dict):
    rows = []
    for name, m in metrics_by_model.items():
        nnt_vs_ww = (1.0/m['arr_dr_full']) if m['arr_dr_full'] > 0 else "N/A"
        if isinstance(nnt_vs_ww, float): nnt_vs_ww = f"{nnt_vs_ww:.1f}"
        nnt_vs_sq = (1.0/m['arr_vs_status_quo']) if m['arr_vs_status_quo'] > 0 else "N/A"
        if isinstance(nnt_vs_sq, float): nnt_vs_sq = f"{nnt_vs_sq:.1f}"
        row = {
            "Model": name,
            "ARR_vs_WatchfulWaiting": f"{m['arr_dr_full']:.4f}",
            "NNT_vs_WatchfulWaiting": nnt_vs_ww,
            "ARR_vs_StatusQuo": f"{m['arr_vs_status_quo']:.4f}",
            "NNT_vs_StatusQuo": nnt_vs_sq,
            "AUUC_DR": f"{m['auuc_dr']:.4f}",
            "TreatRate@full": f"{m['treat_rate']:.3f}",
            "ESS@full": f"{m['ess']:.1f}",
        }
        rows.append(row)
    pd.DataFrame(rows).to_csv(OUTPUT_DIR/"table3_intervention_ope.csv", index=False)


def save_figures(best_name: str, cap_frac, cap_arr_dr, cap_arr_sq, q_frac, q_curve):
    import matplotlib.pyplot as plt
    plt.figure(figsize=(7,5))
    plt.plot(cap_frac, cap_arr_dr, label="ARR vs. Watchful Waiting", linewidth=2)
    plt.plot(cap_frac, cap_arr_sq, label="ARR vs. Status Quo", linewidth=2, linestyle='--')
    plt.axhline(0, color='black', linewidth=0.5, linestyle='-')
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.xlabel("Fraction Treated (capacity)")
    plt.ylabel("Absolute Risk Reduction (ARR)")
    plt.title(f"Figure 2A: Capacity Curve ‚Äî {best_name}")
    plt.legend(); plt.tight_layout()
    plt.savefig(OUTPUT_DIR/"figure2_capacity_curve.png", dpi=300)
    plt.close()

    plt.figure(figsize=(7,5))
    plt.plot(q_frac, q_curve, linewidth=2)
    plt.xlabel("Fraction of population (ranked by predicted benefit)")
    plt.ylabel("Cumulative mean benefit (DR)")
    plt.title(f"Figure 2B: Qini / Uplift Curve ‚Äî {best_name}")
    plt.tight_layout()
    plt.savefig(OUTPUT_DIR/"figure2_qini_uplift.png", dpi=300)
    plt.close()

# =============================================================
# Bootstrap capacity CIs
# =============================================================

def bootstrap_capacity(y, t, mu0, mu1, e, fractions, B=200, seed=RANDOM_STATE):
    rng = np.random.default_rng(seed)
    n = len(y)
    arr_mat = np.zeros((B, len(fractions)))
    arr_full = np.zeros(B)
    auuc_b = np.zeros(B)
    for b in range(B):
        idx = rng.integers(0, n, n)
        fb, arr_dr_b, _ = capacity_curve(y[idx], t[idx], mu0[idx], mu1[idx], e[idx], fractions)
        arr_mat[b,:] = arr_dr_b
        policy_full = (mu1[idx] < mu0[idx]).astype(int)
        r0 = dr_policy_risk(y[idx], t[idx], mu0[idx], mu1[idx], e[idx], np.zeros_like(y[idx]))
        r1 = dr_policy_risk(y[idx], t[idx], mu0[idx], mu1[idx], e[idx], policy_full)
        arr_full[b] = r0 - r1
        _, _, au = qini_auuc_dr(y[idx], t[idx], mu0[idx], mu1[idx], e[idx])
        auuc_b[b] = au
    lo = np.percentile(arr_mat, 2.5, axis=0); hi = np.percentile(arr_mat, 97.5, axis=0)
    arr_full_ci = (float(np.percentile(arr_full, 2.5)), float(np.percentile(arr_full, 97.5)))
    auuc_ci = (float(np.percentile(auuc_b, 2.5)), float(np.percentile(auuc_b, 97.5)))
    return {
        "arr_curve_lo": lo, "arr_curve_hi": hi,
        "arr_full_ci": arr_full_ci, "auuc_ci": auuc_ci,
    }

# =============================================================
# Runner
# =============================================================

def run():
    print("[Block 3] Loading inputs‚Ä¶")
    features = _load_features()
    interventions = _load_interventions()

    print("Preparing engaged‚Äëonly train/test and treatments‚Ä¶")
    X_tr, X_te, y_tr, y_te, t_tr, t_te, te_ids, meta = prepare_intervention_frame(features, interventions)

    print("Building models‚Ä¶")
    models = build_models()

    print("Fitting behavior policy (propensity) on TRAIN‚Ä¶")
    prop = estimate_propensity(X_tr, t_tr)
    e_te = np.clip(prop.predict_proba(X_te)[:,1], 1e-6, 1-1e-6)

    def do_trim(e, lo, hi):
        return (e >= lo) & (e <= hi)

    lo_hi_candidates = [(0.05,0.95), (0.02,0.98), (0.01,0.99)]
    keep = None
    for lo,hi in lo_hi_candidates:
        keep = do_trim(e_te, lo, hi)
        if keep.sum() >= max(50, int(0.05*len(e_te))):
            print(f"  ‚úÇÔ∏è  Overlap trim: kept {keep.sum()} / {len(e_te)} (e in [{lo:.2f}, {hi:.2f}])")
            break
    if keep is None or keep.sum()==0:
        print("  ‚úÇÔ∏è  Overlap trim skipped (no samples). Using untrimmed test set.")
        keep = np.ones_like(e_te, dtype=bool)
    X_te_k, y_te_k, t_te_k, e_te_k = X_te[keep], y_te[keep], t_te[keep], e_te[keep]
    te_ids_k = te_ids[keep]

    smd_df = smd_by_feature(X_te, t_te, keep_mask=keep, top_k=20)
    smd_df.to_csv(OUTPUT_DIR/"balance_table.csv", index=False)

    print("Fitting & evaluating models‚Ä¶")
    metrics_by_model, mu_cache = {}, {}

    fractions = np.unique(np.concatenate([np.linspace(0, 1, 21), np.linspace(0, 0.5, 11)]))

    for name, model in models.items():
        try:
            model.fit(X_tr, y_tr, treatment=t_tr)
            mu0 = model.mu0(X_te_k); mu1 = model.mu1(X_te_k)
            mu_cache[name] = {"mu0": mu0.tolist(), "mu1": mu1.tolist()}

            policy_full = (mu1 < mu0).astype(int)
            risk_none_dr = dr_policy_risk(y_te_k, t_te_k, mu0, mu1, e_te_k, np.zeros_like(y_te_k))
            risk_full_dr = dr_policy_risk(y_te_k, t_te_k, mu0, mu1, e_te_k, policy_full)
            _, ess, wstats = ips_policy_risk(y_te_k, t_te_k, e_te_k, policy_full)

            arr_dr_full = risk_none_dr - risk_full_dr
            status_quo_risk = np.mean(y_te_k)
            arr_vs_status_quo = status_quo_risk - risk_full_dr

            cap_frac, cap_arr_dr, cap_arr_sq = capacity_curve(y_te_k, t_te_k, mu0, mu1, e_te_k, fractions)
            q_frac, q_curve, auuc = qini_auuc_dr(y_te_k, t_te_k, mu0, mu1, e_te_k)
            
            metrics_by_model[name] = {
                "treat_rate": float(policy_full.mean()), "arr_dr_full": float(arr_dr_full),
                "arr_vs_status_quo": float(arr_vs_status_quo), "ess": ess, "wstats": wstats, "auuc_dr": auuc,
                "cap_frac": cap_frac.tolist(), "cap_arr_dr": cap_arr_dr.tolist(), "cap_arr_sq": cap_arr_sq.tolist(),
                "q_frac": q_frac.tolist(), "q_curve": q_curve.tolist(),
            }
            print(f"  {name}: ARR_vs_WW={arr_dr_full:.4f} | ARR_vs_SQ={arr_vs_status_quo:.4f} | AUUC={auuc:.4f}")
        except Exception as e:
            print(f"  {name}: FAILED ‚Äî {e}")

    if not metrics_by_model:
        raise RuntimeError("All intervention models failed. Check inputs.")

    best = 'CausalForest'
    if best not in metrics_by_model:
        print(f"[Warning] '{best}' not found in results. Falling back to best AUUC model.")
        best = max(metrics_by_model, key=lambda k: metrics_by_model[k]['auuc_dr'])
    
    print(f"\nSelected '{best}' as the best model for reporting and plotting.")
    m = metrics_by_model[best]

    save_table3(metrics_by_model)
    save_figures(best, np.array(m['cap_frac']), np.array(m['cap_arr_dr']), np.array(m['cap_arr_sq']), np.array(m['q_frac']), np.array(m['q_curve']))

    cap_df = pd.DataFrame({
        "fraction": m['cap_frac'], "ARR_DR": m['cap_arr_dr'], "ARR_vs_SQ": m['cap_arr_sq'],
    })
    cap_df.to_csv(OUTPUT_DIR/"capacity_curve_best.csv", index=False)

    # ... [rest of the saving logic remains the same] ...
    sens_fracs = [0.10, 0.15, 0.20, 0.25]
    sens = []
    for f in sens_fracs:
        idx = int(np.argmin(np.abs(np.array(m['cap_frac']) - f)))
        arr_val = float(m['cap_arr_dr'][idx])
        sens.append({
            "fraction": float(m['cap_frac'][idx]), "ARR_DR": arr_val,
            "NNT": (1.0/arr_val) if arr_val > 0 else np.nan,
        })
    pd.DataFrame(sens).to_csv(OUTPUT_DIR/"capacity_sensitivity.csv", index=False)

    mu0_b = np.array(mu_cache[best]['mu0']); mu1_b = np.array(mu_cache[best]['mu1'])
    policy_b = (mu1_b < mu0_b).astype(int)
    rec = np.where(policy_b==1, "any_intervention", "watchful_waiting")
    out_pred = pd.DataFrame({
        "patient_id": te_ids_k, "recommended_intervention": rec,
        "mu0": mu0_b, "mu1": mu1_b, "benefit": (mu0_b - mu1_b),
        "policy_treat": policy_b, "model_name": best, "horizon_days": HORIZON_DAYS,
    })
    out_pred.to_csv(OUTPUT_DIR/"best_intervention_model_predictions.csv", index=False)

    best_model = build_models()[best]
    best_model.fit(X_tr, y_tr, treatment=t_tr)
    mu0_full = best_model.mu0(X_te); mu1_full = best_model.mu1(X_te)
    policy_full = (mu1_full < mu0_full).astype(int)
    rec_full = np.where(policy_full==1, "any_intervention", "watchful_waiting")
    out_pred_full = pd.DataFrame({
        "patient_id": te_ids, "recommended_intervention": rec_full,
        "mu0": mu0_full, "mu1": mu1_full, "benefit": (mu0_full - mu1_full),
        "policy_treat": policy_full, "model_name": best, "horizon_days": HORIZON_DAYS,
    })
    out_pred_full.to_csv(OUTPUT_DIR/"best_intervention_model_predictions_fulltest.csv", index=False)

    Path(OUTPUT_DIR/"intervention_ope_diagnostics.json").write_text(json.dumps(metrics_by_model, indent=2))

    print("\n‚úì Block 3 complete. Wrote:")
    print("  ‚Ä¢ table3_intervention_ope.csv")
    print("  ‚Ä¢ figure2_capacity_curve.png")
    print("  ‚Ä¢ figure2_qini_uplift.png")
    print("  ‚Ä¢ capacity_curve_best.csv")
    print("  ‚Ä¢ capacity_sensitivity.csv")
    print("  ‚Ä¢ balance_table.csv")
    print("  ‚Ä¢ best_intervention_model_predictions.csv")
    print("  ‚Ä¢ best_intervention_model_predictions_fulltest.csv")
    print("  ‚Ä¢ intervention_ope_diagnostics.json")


if __name__ == "__main__":
    run()


In [None]:
!pip install xlsxwriter
# or (fallback engine)
!pip install openpyxl

In [None]:
#!/usr/bin/env python3
"""
BLOCK 4 ‚Äî CLINICAL VALIDATION PACK (engaged-only v2)
===================================================
Builds a 200-case review pack **from engaged patients only**, collapses
encounter notes per patient (only encounters where `encounterOccurred == 'YES'`),
restores signal-risk summaries, adds diagnosis summary (DX_Summary) from feature
columns, and writes a tidy Excel workbook with **two columns per reviewer**
(Risk & Action) plus free-text notes. Also writes a CSV mirror and a
Case_ID ‚Üî patient_id mapping.

Inputs (from ./publication_outputs):
  ‚Ä¢ feature_matrix.parquet
  ‚Ä¢ interventions_df.csv           (free-text notes, optional columns)
  ‚Ä¢ signal_risks.csv               (optional; per-patient signals)
  ‚Ä¢ feature_meta.json              (for feature list)
  ‚Ä¢ best_tte_model_predictions_30d.csv (optional; for risk quartiles)

Outputs:
  ‚Ä¢ clinical_validation_200_cases.xlsx
  ‚Ä¢ clinical_validation_200_cases.csv
  ‚Ä¢ clinical_validation_case_index_mapping.csv (Case_ID ‚Üî patient_id)

Reviewer sheets include data validation drop-downs for: Risk level, Next Action.
Notes columns are wrapped and sized. Text is contained within cells (no bleed).
"""

from __future__ import annotations
from pathlib import Path
import json
import re
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd

try:
    from openpyxl import load_workbook
    from openpyxl.utils import get_column_letter
    from openpyxl.worksheet.datavalidation import DataValidation
except Exception:
    load_workbook = None

OUTPUT_DIR = Path("publication_outputs"); OUTPUT_DIR.mkdir(exist_ok=True)
RANDOM_STATE = 42
N_CASES = 200
HIGH_RISK_RATIO = 0.75  # target share of cases from the highest risk quartile
HORIZON_DAYS = 30

# ---------------------------------------------------------------------
# Resilient loaders
# ---------------------------------------------------------------------

def _load_features() -> pd.DataFrame:
    for p in [OUTPUT_DIR/"feature_matrix.parquet", OUTPUT_DIR/"features_df.parquet"]:
        if p.exists():
            print(f"Found features: {p}")
            return pd.read_parquet(p)
    raise FileNotFoundError("feature_matrix.parquet not found in publication_outputs.")


def _load_interventions() -> pd.DataFrame:
    for p in [OUTPUT_DIR/"interventions_df.csv", OUTPUT_DIR/"interventions_processed.csv"]:
        if p.exists():
            print(f"Found interventions: {p}")
            return pd.read_csv(p)
    print("‚ö†Ô∏è No interventions file found ‚Äî proceeding without note joins.")
    return pd.DataFrame()


def _load_signal_risks() -> pd.DataFrame | None:
    p = OUTPUT_DIR/"signal_risks.csv"
    if p.exists():
        print(f"Found signal_risks: {p}")
        return pd.read_csv(p)
    return None


def _load_pred_30d() -> pd.DataFrame | None:
    for p in [OUTPUT_DIR/"best_tte_model_predictions_30d.csv", OUTPUT_DIR/"best_tte_model_predictions.csv"]:
        if Path(p).exists():
            return pd.read_csv(p)
    return None

# ---------------------------------------------------------------------
# Helpers for interventions & notes
# ---------------------------------------------------------------------

def _normalize_interventions_columns(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    if "patient_id" not in df.columns and "patientId" in df.columns:
        df = df.rename(columns={"patientId": "patient_id"})
    return df


def _pick_note_column(df: pd.DataFrame) -> str | None:
    if "Encounter_Notes" in df.columns:
        return "Encounter_Notes"
    def norm(c): return str(c).lower().replace("_", "").replace(" ", "")
    name_map = {c: norm(c) for c in df.columns}
    ranked = [c for c,nc in name_map.items() if ("encounter" in nc and "note" in nc)]
    for c,nc in name_map.items():
        if nc in {"notes","note","text","content","body"}:
            ranked.append(c)
    return ranked[0] if ranked else None


def _clean_note_text(x: str) -> str:
    if not isinstance(x, str):
        x = str(x)
    # normalize line breaks, remove stray control chars except tab/newline
    t = x.replace("\r\n", "\n").replace("\r", "\n")
    t = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f]", " ", t)
    # collapse massive whitespace runs
    t = re.sub(r"\s{3,}", "  ", t)
    return t.strip()


# ---------------------------------------------------------------------
# Diagnosis summary from feature columns (dx_cat_* counts/flags)
# ---------------------------------------------------------------------

def build_dx_summary(features: pd.DataFrame) -> pd.Series:
    # Look for columns like dx_cat_* (numeric)
    mask = [c for c in features.columns if c.lower().startswith("dx_cat_")]
    if not mask:
        return pd.Series(index=features.index, dtype=object)
    dx_df = features[mask].copy()
    # keep only numeric
    dx_df = dx_df.loc[:, [c for c in dx_df.columns if pd.api.types.is_numeric_dtype(dx_df[c])]]
    cats = [re.sub(r"^dx_cat_", "", c, flags=re.I) for c in dx_df.columns]
    dx_df.columns = cats

    def summarize(row):
        pos = [(c, float(row[c])) for c in dx_df.columns if pd.notnull(row[c]) and float(row[c]) > 0]
        if not pos:
            return ""
        # sort by count/flag descending, then alphabetically for ties
        pos.sort(key=lambda x: (-x[1], x[0]))
        top = [c for c,_ in pos[:3]]
        return "; ".join(top)

    return dx_df.apply(summarize, axis=1)


# ---------------------------------------------------------------------
# Risk quartiles from predictions (preferred) or fallback to outcome_30d
# ---------------------------------------------------------------------

def label_risk_quartiles(features: pd.DataFrame) -> pd.Series:
    pred = _load_pred_30d()
    if pred is not None and "patient_id" in pred.columns:
        score_col = None
        for c in ["yhat","risk","prob","pred","prediction","p_event"]:
            if c in pred.columns:
                score_col = c; break
        if score_col is None and {"mu0","mu1"}.issubset(set(pred.columns)):
            # if only mu0/mu1 are present, use mu0 as risk under no treatment
            score_col = "mu0"
        if score_col is not None:
            s = pred[["patient_id", score_col]].dropna()
            s["patient_id"] = s["patient_id"].astype(str)
            m = features[["patient_id"]].copy()
            m["patient_id"] = m["patient_id"].astype(str)
            m = m.merge(s, on="patient_id", how="left")
            vals = m[score_col].fillna(m[score_col].median())
            q = pd.qcut(vals, 4, labels=["Q1-Low","Q2-Moderate","Q3-High","Q4-VeryHigh"]) 
            return pd.Series(q.values, index=features.index, name="Risk_Quartile")
    # Fallback: use outcome_30d as a crude proxy (rare, but consistent)
    col = f"outcome_{HORIZON_DAYS}d"
    if col in features.columns:
        vals = features[col].astype(float)
        # push positives to top quartile when extremely sparse
        if vals.max() <= 1:
            q = pd.qcut(vals.rank(method="first"), 4, labels=["Q1-Low","Q2-Moderate","Q3-High","Q4-VeryHigh"]) 
            return pd.Series(q.values, index=features.index, name="Risk_Quartile")
    return pd.Series(["Q2-Moderate"] * len(features), index=features.index, name="Risk_Quartile")


# ---------------------------------------------------------------------
# Build public dataset (engaged-only + collapsed notes)
# ---------------------------------------------------------------------

def build_public_df(n_cases=N_CASES, high_risk_ratio=HIGH_RISK_RATIO, seed=RANDOM_STATE) -> tuple[pd.DataFrame, pd.DataFrame | None]:
    features = _load_features().copy()
    interventions = _load_interventions().copy()
    signal_risks = _load_signal_risks()

    # Engaged-only subset (Block 1 should have created this flag)
    if "engaged_only_flag" in features.columns:
        features = features[features["engaged_only_flag"] == 1].copy()
        print(f"Engaged-only features: n={len(features)}")

    # Attach DX_Summary
    features["DX_Summary"] = build_dx_summary(features)

    # Build collapsed notes from interventions (only encounterOccurred == 'YES')
    notes = pd.Series(dtype=str)
    if not interventions.empty:
        interventions = _normalize_interventions_columns(interventions)
        if "patient_id" in interventions.columns:
            df_notes = interventions.copy()
            if "encounterOccurred" in df_notes.columns:
                df_notes = df_notes[df_notes["encounterOccurred"].astype(str).str.upper().eq("YES")]
            col = _pick_note_column(df_notes)
            if col is not None:
                tmp = (
                    df_notes[["patient_id", col]]
                    .dropna()
                    .assign(**{col: lambda d: d[col].astype(str).map(_clean_note_text)})
                    .groupby("patient_id")[col]
                    .apply(lambda s: "\n‚Äî\n".join([x for x in s if x.strip()]))
                )
                notes = tmp[tmp.astype(str).str.strip().ne("")]
            else:
                print("‚ö†Ô∏è No note-like column in interventions; Encounter_Notes will be empty.")
        else:
            print("‚ö†Ô∏è interventions has no 'patient_id'; cannot join notes.")

    features = features.merge(notes.rename("Encounter_Notes"), left_on="patient_id", right_index=True, how="left")

    # Signal risks summary per patient (optional)
    signals_col = pd.Series(["" ] * len(features), index=features.index, name="Signals_Summary")
    if signal_risks is not None and {"patient_id","signal"}.issubset(signal_risks.columns):
        tmp = (
            signal_risks[["patient_id","signal"]]
            .dropna()
            .assign(patient_id=lambda d: d["patient_id"].astype(str))
            .groupby("patient_id")["signal"].apply(lambda s: "; ".join(sorted(set(map(str, s)))))
        )
        m = features["patient_id"].astype(str)
        signals_col = m.map(tmp).fillna("")
    features["Signals_Summary"] = signals_col

    # Risk quartiles for sampling & display
    rq = label_risk_quartiles(features)
    features["Risk_Quartile"] = rq

    # Keep only patients with at least some notes
    base = features.copy()
    base["has_notes"] = base["Encounter_Notes"].astype(str).str.strip().ne("")
    base = base[base["has_notes"]]
    if base.empty:
        raise RuntimeError("No engaged patients with Encounter_Notes found after filtering.")

    # Sample: aim for HIGH_RISK_RATIO from Q4 and the remainder from Q1‚ÄìQ3
    rng = np.random.default_rng(seed)
    q4 = base[base["Risk_Quartile"].eq("Q4-VeryHigh")]
    q123 = base[~base.index.isin(q4.index)]

    n_q4 = min(int(round(n_cases * high_risk_ratio)), len(q4))
    n_q123 = min(n_cases - n_q4, len(q123))

    sample_idx = list(rng.choice(q4.index, size=n_q4, replace=False))
    needed = n_cases - len(sample_idx)
    if needed > 0:
        sample_idx += list(rng.choice(q123.index, size=needed, replace=False))

    public = base.loc[sample_idx].copy().reset_index(drop=True)

    # Case IDs
    public.insert(0, "Case_ID", np.arange(1, len(public) + 1, dtype=int))

    # Convenience demographics columns if available
    for a,b in [("age","Age"),("gender","Sex"),("Race","Race")]:
        if a in public.columns and b not in public.columns:
            public.rename(columns={a:b}, inplace=True)

    # Mapping file
    mapping = public[["Case_ID","patient_id"]].copy()
    mapping.to_csv(OUTPUT_DIR/"clinical_validation_case_index_mapping.csv", index=False)

    return public, signal_risks


# ---------------------------------------------------------------------
# Excel writer with dual reviewer columns & data validation
# ---------------------------------------------------------------------

def write_review_excel(public: pd.DataFrame, out_xlsx: Path, signal_risks_df: pd.DataFrame | None):
    # Columns to keep/show first
    base_cols = [c for c in [
        "Case_ID","patient_id","Age","Sex","Risk_Quartile",
        f"outcome_{HORIZON_DAYS}d" if f"outcome_{HORIZON_DAYS}d" in public.columns else None,
        "DX_Summary","Signals_Summary","Encounter_Notes"
    ] if c is not None and c in public.columns]

    # Build reviewer columns (3 reviewers)
    reviewers = [1,2,3]
    risk_colnames = [f"R{r}_Risk" for r in reviewers]
    act_colnames  = [f"R{r}_Action" for r in reviewers]
    note_colnames = [f"R{r}_Notes" for r in reviewers]

    review_df = public[base_cols].copy()
    for c in risk_colnames + act_colnames + note_colnames:
        review_df[c] = ""

    # Write workbook (openpyxl engine)
    with pd.ExcelWriter(out_xlsx, engine="openpyxl") as writer:
        review_df.to_excel(writer, index=False, sheet_name="review")

        # Helper sheet: codes (risks & interventions)
        risks_list = ["Q1-Low","Q2-Moderate","Q3-High","Q4-VeryHigh"]
        interventions_list = [
            "watchful_waiting","mental_health_support","social_support",
            "care_coordination","pcp_followup","specialist_referral",
            "transportation_support","medication_support","housing_support",
            "substance_use_support","nutrition_support","other"
        ]
        # to avoid different lengths, pad to same length
        L = max(len(risks_list), len(interventions_list))
        def pad(lst, L): return lst + [""] * (L - len(lst))
        codes_df = pd.DataFrame({
            "Risk": pad(risks_list, L),
            "Intervention": pad(interventions_list, L)
        })
        codes_df.to_excel(writer, index=False, sheet_name="codes")

        ws = writer.sheets["review"]
        ws_codes = writer.sheets["codes"]

        # Freeze panes & basic widths
        ws.freeze_panes = "B2"
        widths = {
            "A": 7,  # Case_ID
            "B": 18, # patient_id
            "C": 6,  # Age
            "D": 9,  # Sex
            "E": 13, # Risk_Quartile
        }
        # Find Encounter_Notes col index to set width + wrap
        headers = [c.value for c in next(ws.iter_rows(min_row=1, max_row=1))]
        col_idx = {h: i+1 for i,h in enumerate(headers)}
        if "Encounter_Notes" in col_idx:
            widths[get_column_letter(col_idx["Encounter_Notes"])] = 65
        if "DX_Summary" in col_idx:
            widths[get_column_letter(col_idx["DX_Summary"])] = 24
        if "Signals_Summary" in col_idx:
            widths[get_column_letter(col_idx["Signals_Summary"])] = 28
        for letter, w in widths.items():
            ws.column_dimensions[letter].width = w

        # Wrap text for notes columns
        wrap_targets = ["Encounter_Notes"] + note_colnames
        for name in wrap_targets:
            if name in col_idx:
                j = col_idx[name]
                for cell in ws.iter_cols(min_col=j, max_col=j, min_row=2, max_row=ws.max_row):
                    for c in cell:
                        c.alignment = c.alignment.copy(wrap_text=True, vertical="top")
        # Align base text to top as well
        for name in ["DX_Summary","Signals_Summary"]:
            if name in col_idx:
                j = col_idx[name]
                for cell in ws.iter_cols(min_col=j, max_col=j, min_row=2, max_row=ws.max_row):
                    for c in cell:
                        c.alignment = c.alignment.copy(wrap_text=True, vertical="top")

        # Data validation (drop-downs)
        last_row = ws.max_row
        # Named ranges on codes sheet
        risk_range = f"codes!$A$2:$A${1+len(risks_list)}"
        int_range  = f"codes!$B$2:$B${1+len(interventions_list)}"

        dv_risk = DataValidation(type="list", formula1=risk_range, allow_blank=True)
        dv_act  = DataValidation(type="list", formula1=int_range,  allow_blank=True)
        ws.add_data_validation(dv_risk)
        ws.add_data_validation(dv_act)
        # Apply to each reviewer column
        for cname in risk_colnames:
            if cname in col_idx:
                col_letter = get_column_letter(col_idx[cname])
                dv_risk.add(f"{col_letter}2:{col_letter}{last_row}")
        for cname in act_colnames:
            if cname in col_idx:
                col_letter = get_column_letter(col_idx[cname])
                dv_act.add(f"{col_letter}2:{col_letter}{last_row}")

        # README sheet
        readme = pd.DataFrame({"Instructions": [
            "Each case is one engaged patient with collapsed encounter notes.",
            "Fill BOTH columns per reviewer: Risk level and Next Best Action.",
            "Use the drop-down lists for standardized coding.",
            "Add any free-text rationale in your reviewer Notes column.",
            "Risk_Quartile shown is model-based (30d) when available.",
        ]})
        readme.to_excel(writer, index=False, sheet_name="README")

    print(f"‚úì Wrote Excel pack: {out_xlsx}")


# ---------------------------------------------------------------------
# CSV mirror (handy for quick diffing or copy/paste workflows)
# ---------------------------------------------------------------------

def write_review_csv(public: pd.DataFrame, out_csv: Path):
    # Keep the same order as Excel base columns
    keep = [c for c in [
        "Case_ID","patient_id","Age","Sex","Risk_Quartile",
        f"outcome_{HORIZON_DAYS}d" if f"outcome_{HORIZON_DAYS}d" in public.columns else None,
        "DX_Summary","Signals_Summary","Encounter_Notes"
    ] if c is not None and c in public.columns]
    out = public[keep].copy()
    # Ensure newlines are kept inside a single CSV cell by quoting the field
    out.to_csv(out_csv, index=False)
    print(f"‚úì Wrote CSV mirror: {out_csv}")


# ---------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------

def main():
    public, signal_risks_df = build_public_df(n_cases=N_CASES, high_risk_ratio=HIGH_RISK_RATIO, seed=RANDOM_STATE)

    # Write Excel + CSV
    xlsx_path = OUTPUT_DIR/"clinical_validation_200_cases.xlsx"
    csv_path  = OUTPUT_DIR/"clinical_validation_200_cases.csv"
    write_review_excel(public, xlsx_path, signal_risks_df)
    write_review_csv(public, csv_path)

    # Print quick head
    print(public.head(3).to_string(index=False))


if __name__ == "__main__":
    main()


In [None]:
#!/usr/bin/env python3
"""
BLOCK 5 ‚Äî CLINICAL VALIDATION ANALYSIS & MODEL COMPARISON (v2)
==============================================================
- Links reviewer CSV (Case_ID) back to patient_id via mapping from Block 4.
- Loads best-model predictions for time-to-event (7/30/90/180d) and intervention selection, if available.
- Compares:
  ‚Ä¢ Inter‚Äërater agreement among clinicians (risk + intervention)
    - Pairwise Cohen's Œ∫ (supplement)
    - Fleiss' Œ∫ with bootstrap 95% CI (primary)
  ‚Ä¢ Model vs each reviewer (risk + intervention)
  ‚Ä¢ Reviewer/majority-vote vs outcome (sensitivity, specificity, PPV, NPV, F1)
  ‚Ä¢ Model ROC/AUC vs 30d outcome with reviewer operating points overlaid (High‚Äëonly and Med+High)
- Saves publication-ready tables/figures to /publication_outputs.

Expected inputs (created by prior blocks):
- /publication_outputs/clinical_validation_200_cases.csv
- /publication_outputs/clinical_validation_case_index_mapping.csv  (Case_ID ‚Üî patient_id)
- /publication_outputs/best_tte_model_predictions.csv              (optional; per-patient probs)
- /publication_outputs/best_intervention_model_predictions.csv     (optional; per-patient recs)
- /publication_outputs/complete_results_*.json                     (optional; for best model names)

Outputs:
- table_clinval_risk_interrater.csv (pairwise Cohen)
- table_clinval_risk_interrater_fleiss.csv (Fleiss' Œ∫ + 95% CI)
- table_clinval_intervention_interrater.csv (pairwise Cohen)
- table_clinval_intervention_interrater_fleiss.csv (Fleiss' Œ∫ + 95% CI)
- table_clinval_risk_model_vs_reviewer.csv
- table_clinval_risk_vs_outcome_thresholds.csv
- table_clinval_intervention_agreement.csv
- table_clinval_intervention_confusion_matrix.csv
- figure_clinval_model_roc_30d.png (with reviewer operating points)
- figure_clinval_intervention_confusion.png
- clinical_validation_with_patient_and_model.csv (analysis-ready merged dataset)
"""

from __future__ import annotations
import json
import warnings
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    roc_auc_score, roc_curve, confusion_matrix, cohen_kappa_score,
    accuracy_score, precision_score, recall_score, f1_score
)

warnings.filterwarnings("ignore")

PUBDIR = Path("publication_outputs")
PUBDIR.mkdir(exist_ok=True)
RANDOM_STATE = 42

###############################
# I/O HELPERS
###############################

def _find_latest(glob_pat: str) -> Path | None:
    files = sorted(PUBDIR.glob(glob_pat))
    return files[-1] if files else None


def load_inputs():
    reviews_fp = PUBDIR / "clinical_validation_200_cases.csv"
    mapping_fp = PUBDIR / "clinical_validation_case_index_mapping.csv"
    if not reviews_fp.exists():
        raise FileNotFoundError(f"Missing {reviews_fp}. Run Block 4 to generate it.")
    if not mapping_fp.exists():
        raise FileNotFoundError(f"Missing {mapping_fp}. Run Block 4 to generate it.")
    reviews = pd.read_csv(reviews_fp)
    mapping = pd.read_csv(mapping_fp)

    # Try to load model prediction files (optional but recommended)
    tte_pred_fp = _find_latest("best_tte_model_predictions*.csv") or _find_latest("tte_predictions_*.csv")
    int_pred_fp = _find_latest("best_intervention_model_predictions*.csv") or _find_latest("intervention_predictions_*.csv")

    tte_preds = pd.read_csv(tte_pred_fp) if tte_pred_fp else pd.DataFrame()
    int_preds = pd.read_csv(int_pred_fp) if int_pred_fp else pd.DataFrame()

    # Try to get best model names for annotation
    results_fp = _find_latest("complete_results_*.json")
    best_tte_name = None
    best_int_name = None
    if results_fp and results_fp.exists():
        try:
            with open(results_fp) as f:
                res = json.load(f)
            # Pick best TTE by 30d AUC
            best_tte_name = None
            best_auc = -np.inf
            for mname, mdict in res.get("time_to_event", {}).items():
                auc_30 = mdict.get("30d", {}).get("auc")
                if auc_30 is not None and auc_30 > best_auc:
                    best_auc = auc_30
                    best_tte_name = mname
            # Pick best intervention by kappa
            best_int_name = None
            best_k = -np.inf
            for mname, md in res.get("intervention_selection", {}).items():
                k = md.get("kappa")
                if k is not None and k > best_k:
                    best_k = k
                    best_int_name = mname
        except Exception:
            pass

    return reviews, mapping, tte_preds, int_preds, best_tte_name, best_int_name

###############################
# NORMALIZATION
###############################

RISK_MAP_ORD = {
    # maps to ordinal 0/1/2 (Low/Med/High)
    "low": 0, "l": 0, "lo": 0,
    "medium": 1, "med": 1, "m": 1,
    "high": 2, "hi": 2, "h": 2,
}

INTERVENTION_NORMALIZE = {
    # unify common strings to snake_case used by Block 1-4
    "substance use": "substance_use_support",
    "substance_use": "substance_use_support",
    "mental health": "mental_health_support",
    "mental_health": "mental_health_support",
    "chronic": "chronic_condition_management",
    "housing": "housing_assistance",
    "transport": "transportation_assistance",
    "transportation": "transportation_assistance",
    "food": "food_assistance",
    "utility": "utility_assistance",
    "childcare": "childcare_assistance",
    "watchful waiting": "watchful_waiting",
    "watchful_waiting": "watchful_waiting",
}


def norm_risk_label(x: str | float | int) -> str:
    if pd.isna(x):
        return ""
    s = str(x).strip().lower()
    for key in ("high", "medium", "low", "hi", "med", "lo", "h", "m", "l"):
        if key in s.split():
            return {"high": "High", "hi": "High", "h": "High",
                    "medium": "Medium", "med": "Medium", "m": "Medium",
                    "low": "Low", "lo": "Low", "l": "Low"}[key]
    return s.title()


def risk_to_ord(s: str) -> int | None:
    if not s:
        return None
    return {"Low": 0, "Medium": 1, "High": 2}.get(s, None)


def risk_to_bin(s: str, threshold: str = "High") -> int | None:
    ordv = risk_to_ord(s)
    if ordv is None:
        return None
    if threshold == "High":
        return 1 if ordv == 2 else 0
    if threshold == "MedPlus":  # Medium or High considered positive
        return 1 if ordv >= 1 else 0
    return None


def norm_intervention(x: str) -> str:
    if pd.isna(x) or str(x).strip() == "":
        return ""
    s = str(x).strip().lower().replace("-", " ").replace("/", " ")
    s = " ".join(s.split())
    if s in INTERVENTION_NORMALIZE:
        return INTERVENTION_NORMALIZE[s]
    for k, v in INTERVENTION_NORMALIZE.items():
        if k in s:
            return v
    return s.replace(" ", "_")

###############################
# METRICS, AGREEMENT & BOOTSTRAP
###############################

def bin_class_metrics(y_true, y_pred, y_score=None) -> dict:
    y_true = np.asarray(y_true).astype(int)
    y_pred = np.asarray(y_pred).astype(int)
    out = {
        "n": int(len(y_true)),
        "accuracy": float(accuracy_score(y_true, y_pred)),
        "precision": float(precision_score(y_true, y_pred, zero_division=0)),
        "recall": float(recall_score(y_true, y_pred, zero_division=0)),
        "f1": float(f1_score(y_true, y_pred, zero_division=0)),
    }
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0,1]).ravel()
    out.update({
        "tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp),
        "specificity": float(tn / (tn + fp) if (tn + fp) else 0.0),
        "ppv": float(tp / (tp + fp) if (tp + fp) else 0.0),
        "npv": float(tn / (tn + fn) if (tn + fn) else 0.0),
    })
    if y_score is not None and len(np.unique(y_true)) == 2:
        try:
            out["auc"] = float(roc_auc_score(y_true, y_score))
        except Exception:
            out["auc"] = np.nan
    else:
        out["auc"] = np.nan
    return out


def _fleiss_from_counts(counts: np.ndarray) -> tuple[float, dict]:
    """counts: (N_items x N_categories) integer ratings per item.
    Returns (kappa, extras).
    """
    counts = np.asarray(counts, dtype=float)
    N, k = counts.shape
    n_i = counts.sum(axis=1)  # raters per item (can vary)
    # Remove items with <2 ratings
    mask = n_i >= 2
    counts = counts[mask]
    n_i = n_i[mask]
    if counts.shape[0] == 0:
        return np.nan, {"n_items": 0}
    # Item agreement
    P_i = (np.sum(counts * (counts - 1), axis=1)) / (n_i * (n_i - 1))
    P_bar = float(np.mean(P_i))
    # Category proportions
    p_j = np.sum(counts, axis=0) / np.sum(n_i)
    P_e = float(np.sum(p_j ** 2))
    kappa = (P_bar - P_e) / (1 - P_e) if (1 - P_e) > 0 else np.nan
    extras = {
        "n_items": int(counts.shape[0]),
        "mean_raters": float(np.mean(n_i)),
        "P_bar": P_bar, "P_e": P_e
    }
    return float(kappa), extras


def _build_counts(df: pd.DataFrame, cols: list[str], categories: list[str]) -> np.ndarray:
    cat_to_idx = {c: i for i, c in enumerate(categories)}
    mat = []
    for _, row in df[cols].iterrows():
        votes = [str(v) for v in row.values if isinstance(v, str) and v.strip() != ""]
        if len(votes) < 2:
            continue
        cnt = np.zeros(len(categories), dtype=int)
        for v in votes:
            if v in cat_to_idx:
                cnt[cat_to_idx[v]] += 1
        if cnt.sum() >= 2:
            mat.append(cnt)
    return np.array(mat, dtype=int)


def fleiss_bootstrap(counts: np.ndarray, B: int = 1000, seed: int = RANDOM_STATE) -> tuple[float, float, float]:
    """Bootstrap 95% CI for Fleiss' kappa by resampling items with replacement."""
    counts = np.asarray(counts)
    if counts.size == 0:
        return np.nan, np.nan, np.nan
    rng = np.random.default_rng(seed)
    N = counts.shape[0]
    ks = []
    for _ in range(B):
        idx = rng.integers(0, N, size=N)
        k, _ = _fleiss_from_counts(counts[idx])
        ks.append(k)
    ks = np.array(ks)
    return float(np.nanmean(ks)), float(np.nanpercentile(ks, 2.5)), float(np.nanpercentile(ks, 97.5))

###############################
# MAIN ANALYSIS
###############################

def main():
    reviews, mapping, tte_preds, int_preds, best_tte_name, best_int_name = load_inputs()

    # Merge Case_ID ‚Üî patient_id
    if "Case_ID" not in reviews.columns:
        raise ValueError("Reviewer CSV missing Case_ID column.")
    if not {"Case_ID", "patient_id"}.issubset(mapping.columns):
        raise ValueError("Mapping file must have Case_ID and patient_id.")

    df = reviews.merge(mapping[["Case_ID", "patient_id"]], on="Case_ID", how="left")

    # Normalize reviewer risk and intervention columns
    risk_cols = [c for c in df.columns if c.startswith("Reviewer_") and c.endswith("Risk_Assessment")]
    int_cols  = [c for c in df.columns if c.startswith("Reviewer_") and c.endswith("Intervention")]

    for c in risk_cols:
        df[c] = df[c].apply(norm_risk_label)
    for c in int_cols:
        df[c] = df[c].apply(norm_intervention)

    # Outcome (expects 'Actual_30d_Outcome' as 'Event'/'No Event' from Block 4)
    if "Actual_30d_Outcome" in df.columns:
        df["y30"] = df["Actual_30d_Outcome"].map({"Event":1, "No Event":0})
    else:
        df["y30"] = np.nan

    # Attach model predictions if available
    if not tte_preds.empty:
        tte_cols = [c for c in tte_preds.columns if c.startswith("pred_")]
        keep_cols = ["patient_id"] + tte_cols
        if "model_name" in tte_preds.columns:
            keep_cols += ["model_name"]
        df = df.merge(tte_preds[keep_cols].drop_duplicates("patient_id"),
                      on="patient_id", how="left")
        if "model_name" in df.columns and best_tte_name is None:
            best_tte_name = df["model_name"].dropna().unique().tolist()
    else:
        print("[WARN] No best_tte_model_predictions*.csv found ‚Äî AUC/ROC for model will be skipped.")

    if not int_preds.empty:
        int_preds["recommended_intervention"] = int_preds["recommended_intervention"].apply(norm_intervention)
        df = df.merge(int_preds.drop_duplicates("patient_id"), on="patient_id", how="left")
        if "model_name" in int_preds.columns and best_int_name is None:
            best_int_name = int_preds["model_name"].dropna().unique().tolist()
    else:
        print("[WARN] No best_intervention_model_predictions*.csv found ‚Äî intervention agreement vs model will be limited.")

    # =========================
    # 1) INTER‚ÄëRATER AGREEMENT
    # =========================
    # Pairwise Cohen's Œ∫ (supplement)
    pair_rows = []
    for i in range(len(risk_cols)):
        for j in range(i+1, len(risk_cols)):
            a, b = risk_cols[i], risk_cols[j]
            sub = df[[a, b]].dropna()
            if len(sub):
                a_ord = sub[a].map({"Low":0, "Medium":1, "High":2})
                b_ord = sub[b].map({"Low":0, "Medium":1, "High":2})
                pair_rows.append({"rater_a": a, "rater_b": b, "kappa": cohen_kappa_score(a_ord, b_ord)})
    kappa_risk_df = pd.DataFrame(pair_rows) if pair_rows else pd.DataFrame(columns=["rater_a","rater_b","kappa"])
    kappa_risk_df.to_csv(PUBDIR / "table_clinval_risk_interrater.csv", index=False)

    pair_rows = []
    for i in range(len(int_cols)):
        for j in range(i+1, len(int_cols)):
            a, b = int_cols[i], int_cols[j]
            sub = df[[a, b]].replace("", np.nan).dropna()
            if len(sub):
                pair_rows.append({"rater_a": a, "rater_b": b, "kappa": cohen_kappa_score(sub[a], sub[b])})
    kappa_int_df = pd.DataFrame(pair_rows) if pair_rows else pd.DataFrame(columns=["rater_a","rater_b","kappa"])
    kappa_int_df.to_csv(PUBDIR / "table_clinval_intervention_interrater.csv", index=False)

    # Fleiss' Œ∫ (primary) + bootstrap CI
    # Risk (Low/Med/High)
    risk_categories = ["Low", "Medium", "High"]
    risk_counts = _build_counts(df, risk_cols, risk_categories)
    k_fleiss_risk, extra_risk = _fleiss_from_counts(risk_counts)
    k_boot_mean, k_lo, k_hi = fleiss_bootstrap(risk_counts, B=1000, seed=RANDOM_STATE)
    pd.DataFrame([{"fleiss_kappa": k_fleiss_risk, "ci_mean": k_boot_mean, "ci_lo": k_lo, "ci_hi": k_hi, **extra_risk}]).to_csv(
        PUBDIR/"table_clinval_risk_interrater_fleiss.csv", index=False
    )

    # Intervention (multi‚Äëclass)
    # Build set of categories present (exclude blanks)
    all_int_vals = set()
    for c in int_cols:
        all_int_vals.update([v for v in df[c].dropna().unique().tolist() if isinstance(v, str) and v.strip() != ""])
    int_categories = sorted(all_int_vals)
    if len(int_categories) >= 2:
        int_counts = _build_counts(df, int_cols, int_categories)
        k_fleiss_int, extra_int = _fleiss_from_counts(int_counts)
        k_boot_mean_i, k_lo_i, k_hi_i = fleiss_bootstrap(int_counts, B=1000, seed=RANDOM_STATE)
        pd.DataFrame([{"fleiss_kappa": k_fleiss_int, "ci_mean": k_boot_mean_i, "ci_lo": k_lo_i, "ci_hi": k_hi_i, **extra_int}]).to_csv(
            PUBDIR/"table_clinval_intervention_interrater_fleiss.csv", index=False
        )

    # =========================
    # 2) MODEL VS REVIEWER ‚Äî RISK
    # =========================
    rows = []
    if "pred_30d" in df.columns and df["y30"].notna().any():
        y_true = df["y30"].dropna().astype(int)
        y_score = df.loc[y_true.index, "pred_30d"].astype(float)
        y_pred_m = (y_score >= 0.5).astype(int)
        m_metrics = bin_class_metrics(y_true, y_pred_m, y_score)
        m_metrics.update({"who": f"Model ({best_tte_name or 'best_30d'})", "threshold": 0.5})
        rows.append(m_metrics)

        # Reviewer operating points (High and Med+High)
        op_points = []  # (fpr, tpr, label)
        for thr_name in ["High", "MedPlus"]:
            for r in risk_cols:
                r_bin = df[r].apply(lambda s: risk_to_bin(s, thr_name))
                mask = r_bin.notna() & df["y30"].notna()
                if mask.sum() == 0:
                    continue
                mets = bin_class_metrics(df.loc[mask, "y30"], r_bin[mask])
                mets.update({"who": f"{r} ({'High' if thr_name=='High' else 'Med+High'}=1)", "threshold": thr_name})
                rows.append(mets)
                # For ROC overlay
                tn, fp, fn, tp = mets["tn"], mets["fp"], mets["fn"], mets["tp"]
                fpr = fp / (fp + tn) if (fp + tn) else 0.0
                tpr = tp / (tp + fn) if (tp + fn) else 0.0
                op_points.append((fpr, tpr, mets["who"]))

        # Majority vote thresholds
        def maj_vote(series: pd.Series, thr: str) -> int | None:
            bins = series.apply(lambda s: risk_to_bin(s, thr))
            if bins.isna().all():
                return None
            ones = (bins == 1).sum()
            zeros = (bins == 0).sum()
            if ones == zeros:
                return None
            return 1 if ones > zeros else 0

        for thr in ["High", "MedPlus"]:
            mv = df[risk_cols].apply(lambda row: maj_vote(row, thr), axis=1)
            mask = mv.notna() & df["y30"].notna()
            if mask.sum() > 0:
                mets = bin_class_metrics(df.loc[mask, "y30"], mv[mask])
                label = f"Majority Vote ({'High' if thr=='High' else 'Med+High'}=1)"
                mets.update({"who": label, "threshold": thr})
                rows.append(mets)
                # ROC overlay point
                tn, fp, fn, tp = mets["tn"], mets["fp"], mets["fn"], mets["tp"]
                fpr = fp / (fp + tn) if (fp + tn) else 0.0
                tpr = tp / (tp + fn) if (tp + fn) else 0.0
                op_points.append((fpr, tpr, label))

        # Plot ROC with reviewer operating points
        fpr, tpr, _ = roc_curve(y_true, y_score)
        plt.figure(figsize=(7, 6))
        plt.plot(fpr, tpr, lw=2, label=f"Model ROC (AUC={m_metrics.get('auc', np.nan):.3f})")
        plt.plot([0,1],[0,1], linestyle="--", alpha=0.4, label="Chance")
        # Overlay points
        for fpr_p, tpr_p, lab in op_points:
            plt.scatter([fpr_p], [tpr_p], s=60)
            plt.annotate(lab, (fpr_p, tpr_p), textcoords="offset points", xytext=(6,4), fontsize=8)
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        plt.title("Model ROC ‚Äî 30d with Reviewer Operating Points")
        plt.legend(loc="lower right")
        plt.tight_layout()
        plt.savefig(PUBDIR / "figure_clinval_model_roc_30d.png", dpi=300)
        plt.close()
    else:
        print("[WARN] Missing model 30d predictions and/or outcomes; skipping model ROC and metrics vs outcome.")

    risk_perf_df = pd.DataFrame(rows)
    risk_perf_df.to_csv(PUBDIR / "table_clinval_risk_vs_outcome_thresholds.csv", index=False)

    # Kappa between model binary decision and each reviewer
    rows = []
    if "pred_30d" in df.columns:
        m_dec = (df["pred_30d"].astype(float) >= 0.5).astype(int)
        for thr in ["High", "MedPlus"]:
            for r in risk_cols:
                r_bin = df[r].apply(lambda s: risk_to_bin(s, thr))
                mask = r_bin.notna()
                if mask.sum() == 0:
                    continue
                k = cohen_kappa_score(m_dec[mask], r_bin[mask])
                rows.append({"threshold": thr, "kappa": k, "reviewer": r, "model": best_tte_name or "best_30d"})
    risk_model_vs_reviewer_df = pd.DataFrame(rows)
    risk_model_vs_reviewer_df.to_csv(PUBDIR / "table_clinval_risk_model_vs_reviewer.csv", index=False)

    # ========================================
    # 3) INTERVENTION ‚Äî AGREEMENT & CONFUSION
    # ========================================
    inter_rows = []
    if "recommended_intervention" in df.columns:
        for r in int_cols:
            sub = df[[r, "recommended_intervention"]].replace("", np.nan).dropna()
            if len(sub):
                acc = accuracy_score(sub[r], sub["recommended_intervention"])
                f1m = f1_score(sub[r], sub["recommended_intervention"], average="macro")
                kap = cohen_kappa_score(sub[r], sub["recommended_intervention"])
                inter_rows.append({
                    "reviewer": r,
                    "model": best_int_name or "best_intervention",
                    "accuracy": acc,
                    "macro_f1": f1m,
                    "kappa": kap,
                    "n": len(sub)
                })
        inter_df = pd.DataFrame(inter_rows)
        inter_df.to_csv(PUBDIR / "table_clinval_intervention_agreement.csv", index=False)

        # Majority reviewer vs model confusion
        def maj_interv(row):
            vals = [v for v in row if isinstance(v, str) and v]
            if not vals:
                return None
            vc = pd.Series(vals).value_counts()
            if len(vc) >= 2 and vc.iloc[0] == vc.iloc[1]:
                return None
            return vc.index[0]
        df["reviewer_intervention_majority"] = df[int_cols].apply(maj_interv, axis=1)
        cm_df = df[["reviewer_intervention_majority", "recommended_intervention"]].dropna()
        if len(cm_df):
            labels = sorted(pd.unique(pd.concat([cm_df.iloc[:,0], cm_df.iloc[:,1]])))
            cm = confusion_matrix(cm_df.iloc[:,0], cm_df.iloc[:,1], labels=labels)
            cm_tbl = pd.DataFrame(cm, index=[f"R:{l}" for l in labels], columns=[f"M:{l}" for l in labels])
            cm_tbl.to_csv(PUBDIR / "table_clinval_intervention_confusion_matrix.csv")

            plt.figure(figsize=(max(6, 0.5*len(labels)+2), max(5, 0.5*len(labels)+2)))
            sns.heatmap(cm_tbl, annot=True, fmt="d", cbar=False)
            plt.title("Intervention ‚Äî Majority Reviewer vs Model (counts)")
            plt.tight_layout()
            plt.savefig(PUBDIR / "figure_clinval_intervention_confusion.png", dpi=300)
            plt.close()
    else:
        print("[WARN] No model recommended_intervention predictions loaded; skipping intervention agreement.")

    # =====================================
    # 4) SAVE MERGED ANALYSIS DATASET
    # =====================================
    out_cols = ["Case_ID", "patient_id", "y30", "pred_7d", "pred_30d", "pred_90d", "pred_180d",
                "recommended_intervention"] + risk_cols + int_cols
    keep = [c for c in out_cols if c in df.columns]
    df[keep].to_csv(PUBDIR / "clinical_validation_with_patient_and_model.csv", index=False)

    print("‚úì Clinical validation analysis complete. Outputs written to /publication_outputs.")


if __name__ == "__main__":
    main()


In [None]:
import matplotlib.pyplot as plt
import os

def plot_flow_diagram():
    """
    Generates and saves the study flow diagram (Figure 1) using matplotlib.
    This diagram shows the patient cohort selection and splitting process.
    The output is saved to the 'publication_outputs' folder.
    """
    # --- Define Output Path ---
    output_dir = 'publication_outputs'
    # Create the directory if it does not exist to prevent errors
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    output_path = os.path.join(output_dir, 'figure1_study_flow_diagram.png')

    # --- Create a figure and a set of subplots ---
    fig, ax = plt.subplots(figsize=(10, 8))

    # --- Define Box Styles ---
    # Style for the main boxes (e.g., total assessed, included)
    box_style_main = dict(boxstyle='round,pad=0.5', fc='lightblue', alpha=0.7)
    # Style for the exclusion box
    box_style_excluded = dict(boxstyle='round,pad=0.5', fc='lightcoral', alpha=0.7)
    # Style for the final split boxes (train/test)
    box_style_split = dict(boxstyle='round,pad=0.5', fc='lightyellow', alpha=0.7)

    # --- Create Text Boxes with Sample Sizes ---
    # Top box: Assessed for eligibility
    ax.text(0.5, 0.9, 'Assessed for eligibility\n(N = 157,411)', ha='center', va='center', bbox=box_style_main, fontsize=12)

    # Exclusion box
    ax.text(0.5, 0.7, 'Excluded (n = 1,780)\n- Missing birthDate or gender', ha='center', va='center', bbox=box_style_excluded, fontsize=12)

    # Included in analysis box
    ax.text(0.5, 0.5, 'Included in analysis\n(N = 155,631)', ha='center', va='center', bbox=box_style_main, fontsize=12)

    # Final split boxes: Training and Test sets
    ax.text(0.25, 0.3, 'Training Set\n(n = 124,506)', ha='center', va='center', bbox=box_style_split, fontsize=12)
    ax.text(0.75, 0.3, 'Test Set\n(n = 31,125)', ha='center', va='center', bbox=box_style_split, fontsize=12)

    # --- Draw Arrows ---
    # Arrow from 'Assessed' to 'Excluded'
    ax.arrow(0.5, 0.85, 0, -0.1, head_width=0.02, head_length=0.02, fc='black', ec='black')
    
    # Arrow from 'Excluded' to 'Included'
    ax.arrow(0.5, 0.65, 0, -0.1, head_width=0.02, head_length=0.02, fc='black', ec='black')

    # Arrows from 'Included' to the Train/Test split
    ax.arrow(0.5, 0.45, -0.2, -0.1, head_width=0.02, head_length=0.02, fc='black', ec='black')
    ax.arrow(0.5, 0.45, 0.2, -0.1, head_width=0.02, head_length=0.02, fc='black', ec='black')

    # --- Final Plot Adjustments ---
    # Remove the axis ticks and frame for a cleaner look
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_frame_on(False)
    
    # Set the limits of the plot to ensure everything fits
    ax.set_ylim(0.2, 1)
    ax.set_xlim(0, 1)

    # Add a title to the plot
    plt.title('Figure 1: Study Flow Diagram', fontsize=16)
    
    # --- Save the plot instead of displaying it ---
    # Using bbox_inches='tight' removes excess whitespace around the figure
    # Using dpi=300 provides a high resolution suitable for publication
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    
    # Close the plot to free up memory
    plt.close(fig)
    
    print(f"Figure 1 successfully saved to: {output_path}")

# --- Run the function to generate and save the plot ---
plot_flow_diagram()
