In [7]:
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, f1_score


DATA_PATH = "Clinical_FNIH_merged_all_tables.csv"
df = pd.read_csv(DATA_PATH)

if "ID" not in df.columns:
    raise ValueError("Expected an 'ID' column in the dataset.")
df = df.dropna(subset=["ID"]).copy()


TP_MAP = {
    "V00": "baseline",
    "V01": "12mo",
    "V03": "24mo",
    "V05": "36mo",
    "V06": "48mo",
}


def timepoint_from_col(col: str):
    for k, v in TP_MAP.items():
        if k in col:
            return v
    return None



def maybe_numeric_coerce(frame: pd.DataFrame):
    candidates = []
    for c in frame.columns:
        if c.startswith(("V00", "V01", "V03", "V05", "V06")):
            candidates.append(c)
        elif c.startswith("Labcorp_") or c.startswith("Biomediq_"):
            candidates.append(c)

    for c in candidates:
        # if it's already numeric, keep; else coerce
        if frame[c].dtype == object:
            frame[c] = pd.to_numeric(frame[c], errors="coerce")

maybe_numeric_coerce(df)


CORE = {
    "age":   "V00AGE",
    "sex":   "P02SEX",
    "side":  "SIDE",      
    "bmi":   "P01BMI",
    "race":  "P02RACE",
    "hisp":  "P02HISP",
    "kpmed": "P01KPMEDCV",
}



XR_COLS = [
    "V00XRJSM","V00XRKL","V00XRJSL",
    "V01XRJSM","V01XRKL","V01XRJSL",
    "V03XRJSM","V03XRKL","V03XRJSL",
    "V05XRJSM","V05XRKL","V05XRJSL",
    "V06XRJSM","V06XRKL","V06XRJSL",
]
JSW_COLS = ["V00MCMJSW","V01MCMJSW","V03MCMJSW","V05MCMJSW","V06MCMJSW"]

WOMKP_COLS  = ["V00WOMKP","V01WOMKP","V03WOMKP","V05WOMKP","V06WOMKP"]
WOMADL_COLS = ["V00WOMADL","V01WOMADL","V03WOMADL","V05WOMADL","V06WOMADL"]

LABCORP_COLS = [c for c in df.columns if c.startswith("Labcorp_")]
BIOMEDIQ_COLS = [c for c in df.columns if c.startswith("Biomediq_")]


BIOMEDIQ_SIDE_COL = "Biomediq_SIDE" if "Biomediq_SIDE" in df.columns else None



def compute_weak_labels(data: pd.DataFrame) -> pd.DataFrame:
    y = pd.DataFrame(index=data.index)

    age = pd.to_numeric(data.get(CORE["age"], np.nan), errors="coerce")
    bmi = pd.to_numeric(data.get(CORE["bmi"], np.nan), errors="coerce")
    kpmed = pd.to_numeric(data.get(CORE["kpmed"], np.nan), errors="coerce")

   
    y["mention_age"] = ((age <= 50) | (age >= 70)).fillna(False).astype(int)
    y["mention_bmi"] = ((bmi < 19) | (bmi >= 30) | ((bmi >= 28) & (bmi < 30))).fillna(False).astype(int)
    y["mention_kpmed"] = (kpmed > 0).fillna(False).astype(int)

    
    return y

Y = compute_weak_labels(df)
MENTION_LABELS = Y.columns.tolist()



numeric_cols_all = df.select_dtypes(include=[np.number]).columns.tolist()

LEAKAGE_BLOCKLIST = {
    "mention_age": [CORE["age"]],
    "mention_bmi": [CORE["bmi"]],
    "mention_kpmed": [CORE["kpmed"]],
}

def get_feature_cols_for_label(all_numeric_cols, label):
    blocked = set([c for c in LEAKAGE_BLOCKLIST.get(label, []) if c in all_numeric_cols])
    return [c for c in all_numeric_cols if c not in blocked]

def make_preprocessor(n_components: int):
    return Pipeline(steps=[
        ("imputer", SimpleImputer(strategy="median")),
        ("scaler", StandardScaler()),
        ("pca", PCA(n_components=n_components, random_state=42)),
    ])

def train_classical_selector(Xtr, ytr, Xte, yte):
    model = LogisticRegression(max_iter=2000, class_weight="balanced")
    model.fit(Xtr, ytr)

    p_te = model.predict_proba(Xte)[:, 1]
    pred_te = (p_te >= 0.5).astype(int)

    if len(np.unique(yte)) < 2:
        return model, {"auc": np.nan, "f1": np.nan, "note": "test label constant"}
    return model, {"auc": roc_auc_score(yte, p_te), "f1": f1_score(yte, pred_te)}

N_COMP = 6

selectors = {}
preprocs = {}
metrics = {}

for label in MENTION_LABELS:
    feat_cols = get_feature_cols_for_label(numeric_cols_all, label)
    if len(feat_cols) < 5:
        selectors[label] = None
        metrics[label] = {"auc": np.nan, "f1": np.nan, "note": "too few features"}
        continue

    X_raw = df[feat_cols].copy()
    X_train, X_test, Y_train, Y_test = train_test_split(X_raw, Y, test_size=0.2, random_state=42)

    ytr = Y_train[label].astype(int).values
    yte = Y_test[label].astype(int).values

    if len(np.unique(ytr)) < 2:
        selectors[label] = None
        metrics[label] = {"auc": np.nan, "f1": np.nan, "note": "constant label, skipped"}
        continue

    n_comp = min(N_COMP, X_raw.shape[1])
    preproc = make_preprocessor(n_comp)
    Xtr = preproc.fit_transform(X_train)
    Xte = preproc.transform(X_test)

    print(f"\nTraining classical selector for {label} ...")
    model, m = train_classical_selector(Xtr, ytr, Xte, yte)

    selectors[label] = model
    preprocs[label] = (preproc, feat_cols)
    metrics[label] = m

print("\nSelector metrics (test):")
for k, v in metrics.items():
    print(f"{k:12s} AUC={v.get('auc', np.nan):.3f} F1={v.get('f1', np.nan):.3f} {v.get('note','')}")



def safe_int(x):
    try:
        if pd.isna(x): return None
        return int(round(float(x)))
    except Exception:
        return None

def safe_float(x, nd=3):
    try:
        if pd.isna(x): return None
        return round(float(x), nd)
    except Exception:
        return None

def decode_sex(x):
    if pd.isna(x): return None
    s = str(x).strip()
    if s in {"1", "1.0"}: return "male"
    if s in {"2", "2.0"}: return "female"
    return s

def decode_hisp(x):
    if pd.isna(x): return None
    s = str(x).strip()
    if s in {"0", "0.0"}: return "non-Hispanic"
    if s in {"1", "1.0"}: return "Hispanic"
    return s

def clean_feature_name(col: str) -> str:
    
    if col.startswith("Labcorp_"):
        col = col.replace("Labcorp_", "", 1)
    if col.startswith("Biomediq_"):
        col = col.replace("Biomediq_", "", 1)

    
    col = col.replace("_", " ")

    
    for k in TP_MAP.keys():
        col = col.replace(k, "")
    col = " ".join(col.split())
    return col


def format_kv(col: str, val) -> str:
    v = safe_float(val, nd=3)
    if v is None:
        
        if pd.isna(val): return None
        return f"{clean_feature_name(col)}={val}"
    return f"{clean_feature_name(col)}={v}"



def predict_mentions_for_rows(df_rows: pd.DataFrame) -> pd.DataFrame:
    out = pd.DataFrame(index=df_rows.index)
    for label in MENTION_LABELS:
        if selectors.get(label) is None:
            out[label] = 0
            continue
        preproc, feat_cols = preprocs[label]
        Xr = df_rows[feat_cols].copy()
        Xr_tr = preproc.transform(Xr)
        p = selectors[label].predict_proba(Xr_tr)[:, 1]
        out[label] = (p >= 0.5).astype(int)
    return out



MAX_LABCORP_PER_ROW = 25     
MAX_BIOMEDIQ_PER_ROW = 20

def build_text(row: pd.Series, mention_flags: dict) -> str:
    parts = []

   
    age = safe_int(row.get(CORE["age"], np.nan))
    sex = decode_sex(row.get(CORE["sex"], np.nan))
    bmi = safe_float(row.get(CORE["bmi"], np.nan), nd=1)
    hisp = decode_hisp(row.get(CORE["hisp"], np.nan))
    race = row.get(CORE["race"], np.nan)

    
    side = row.get(CORE["side"], np.nan) if CORE["side"] in row.index else np.nan
    if (pd.isna(side) or side is None) and BIOMEDIQ_SIDE_COL:
        side = row.get(BIOMEDIQ_SIDE_COL, np.nan)

    core_bits = []
    if age is not None: core_bits.append(f"{age}-year-old")
    if sex is not None: core_bits.append(sex)
    if core_bits:
        parts.append(f"Patient is a {' '.join(core_bits)}.")
    else:
        parts.append("Patient baseline summary.")

    if pd.notna(side):
        parts.append(f"Knee side: {side}.")

    
    if mention_flags.get("mention_bmi", 0) == 1 and bmi is not None:
        parts.append(f"BMI: {bmi}.")
    if mention_flags.get("mention_age", 0) == 1 and age is not None:
        parts.append(f"Age noted as {age}.")
    if pd.notna(hisp):
        parts.append(f"Ethnicity: {hisp}.")
    if pd.notna(race):
        parts.append(f"Race code: {race}.")
    if mention_flags.get("mention_kpmed", 0) == 1:
        kpmed = safe_int(row.get(CORE["kpmed"], np.nan))
        if kpmed is not None:
            parts.append("Pain medication use indicated." if kpmed > 0 else "No pain medication use indicated.")


    xr_present = [c for c in XR_COLS if c in row.index and pd.notna(row[c])]
    if xr_present:
        
        tp_blocks = {}
        for c in xr_present:
            tp = timepoint_from_col(c) or "unknown"
            tp_blocks.setdefault(tp, []).append(c)

        for tp in ["baseline", "12mo", "24mo", "36mo", "48mo"]:
            cols = tp_blocks.get(tp, [])
            if not cols:
                continue
            kvs = []
           
            order = ["XRJSM", "XRKL", "XRJSL"]
            for tag in order:
                for c in cols:
                    if tag in c:
                        s = format_kv(c, row[c])
                        if s: kvs.append(s)
            if kvs:
                parts.append(f"X-ray ({tp}): " + ", ".join(kvs) + ".")

   
    jsw_present = [c for c in JSW_COLS if c in row.index and pd.notna(row[c])]
    if jsw_present:
        kvs = []
        for c in jsw_present:
            tp = timepoint_from_col(c) or ""
            v = safe_float(row[c], nd=2)
            if v is not None:
                kvs.append(f"{tp} JSW={v}")
        if kvs:
            parts.append("Medial compartment JSW: " + "; ".join(kvs) + ".")

    # WOMKP / WOMADL over time
    womkp_present = [c for c in WOMKP_COLS if c in row.index and pd.notna(row[c])]
    womadl_present = [c for c in WOMADL_COLS if c in row.index and pd.notna(row[c])]

    if womkp_present:
        kvs = []
        for c in womkp_present:
            tp = timepoint_from_col(c) or ""
            v = safe_float(row[c], nd=1)
            if v is not None:
                kvs.append(f"{tp} WOMKP={v}")
        if kvs:
            parts.append("Pain score (WOMAC Knee Pain): " + "; ".join(kvs) + ".")

    if womadl_present:
        kvs = []
        for c in womadl_present:
            tp = timepoint_from_col(c) or ""
            v = safe_float(row[c], nd=1)
            if v is not None:
                kvs.append(f"{tp} WOMADL={v}")
        if kvs:
            parts.append("Function score (WOMAC ADL): " + "; ".join(kvs) + ".")

   
    lab_present = [c for c in LABCORP_COLS if c in row.index and pd.notna(row[c])]
    if lab_present:
        
        serum = [c for c in lab_present if "Serum" in c]
        urine = [c for c in lab_present if "Urine" in c]
        rest  = [c for c in lab_present if c not in serum and c not in urine]

        ordered = serum + urine + rest
        ordered = ordered[:MAX_LABCORP_PER_ROW]

        kvs = []
        for c in ordered:
            s = format_kv(c, row[c])
            if s: kvs.append(s)

        if kvs:
            parts.append("Lab biomarkers (baseline): " + ", ".join(kvs) + ".")

    
    bio_present = [c for c in BIOMEDIQ_COLS if c in row.index and pd.notna(row[c])]
    if bio_present:
        # cap
        ordered = bio_present[:MAX_BIOMEDIQ_PER_ROW]
        kvs = []
        for c in ordered:
            s = format_kv(c, row[c])
            if s: kvs.append(s)
        if kvs:
            parts.append("Biomediq measures (baseline): " + ", ".join(kvs) + ".")

   
    parts.append(f"(ID: {row.get('ID')})")
    return " ".join(parts).replace("  ", " ").strip()


demo = df.sample(5, random_state=7).copy()
demo_flags = predict_mentions_for_rows(demo)

print("\n--- DEMO GENERATED TEXTS ---")
for idx in demo.index:
    txt = build_text(demo.loc[idx], demo_flags.loc[idx].to_dict())
    print("\n", txt)


all_flags = predict_mentions_for_rows(df)

out = pd.DataFrame({"ID": df["ID"].values})
out["generated_text"] = [build_text(df.loc[i], all_flags.loc[i].to_dict()) for i in df.index]

OUT_PATH = "Clinical_FNIH_generated_text_New.csv"
out.to_csv(OUT_PATH, index=False)

print(f"\nSaved generated texts to: {OUT_PATH}")



Training classical selector for mention_age ...

Training classical selector for mention_bmi ...

Selector metrics (test):
mention_age  AUC=0.499 F1=0.388 
mention_bmi  AUC=0.594 F1=0.699 
mention_kpmed AUC=nan F1=nan constant label, skipped

--- DEMO GENERATED TEXTS ---

 Patient is a 63-year-old 2: Female. Knee side: 1: Right. BMI: 34.2. Age noted as 63. Ethnicity: 0: No. Race code: 1: White or Caucasian. X-ray (baseline): XRJSM=0.0, XRJSL=1.0. X-ray (12mo): XRJSM=0.0, XRJSL=1.0. X-ray (24mo): XRJSM=1.0, XRJSL=1.0. X-ray (36mo): XRJSM=2.0, XRJSL=1.0. X-ray (48mo): XRJSM=2.0, XRJSL=1.0. Medial compartment JSW: baseline JSW=4.67; 12mo JSW=4.71; 24mo JSW=3.9; 36mo JSW=2.8; 48mo JSW=3.12. Pain score (WOMAC Knee Pain): baseline WOMKP=1.0; 12mo WOMKP=1.0; 24mo WOMKP=2.0; 36mo WOMKP=0.0; 48mo WOMKP=1.0. Function score (WOMAC ADL): baseline WOMADL=10.0; 12mo WOMADL=8.0; 24mo WOMADL=17.0; 36mo WOMADL=8.0; 48mo WOMADL=4.0. Lab biomarkers (baseline): Serum C1 2C lc=0.49, Serum C2C lc=208.0, Se