In [None]:
"""
Trustworthy-AI ReAct-Style Model Judge (HF Transformers version; CPU-safe)
- Deterministic selection: hard constraints + composite score.
- LLM Audit Note via Hugging Face `pipeline()` (TinyLlama by default).
- Works in Jupyter or CLI.

Run (Jupyter): just execute the cell.
Run (CLI):  python judge_react_hf.py --policy default --models_path artifacts/models.json

If the HF model load fails (env issues, missing pkgs, no token), you'll still get a sane
deterministic audit note (fallback).
"""

import os, sys, json, argparse, traceback
from pathlib import Path

# SECTION 0 — MODELS & POLICY
DEFAULT_MODELS = [
  {
    "model_name": "XGBoost",
    "sets": {
      "train": {"accuracy": 0.8087, "roc_auc": 0.8880, "recall_pos": 0.81, "recall_neg": 0.81, "f1_weighted": 0.81, "n": 460},
      "val":   {"accuracy": 0.6948, "roc_auc": 0.6785, "recall_pos": 0.75, "recall_neg": 0.59, "f1_weighted": 0.70, "n": 154},
      "test":  {"accuracy": 0.7208, "roc_auc": 0.7765, "recall_pos": 0.77, "recall_neg": 0.63, "f1_weighted": 0.72, "n": 154}
    }
  },
  {
    "model_name": "CatBoost",
    "sets": {
      "train": {"accuracy": 0.7543, "roc_auc": 0.8619, "recall_pos": 0.72, "recall_neg": 0.83, "f1_weighted": 0.76, "n": 460},
      "val":   {"accuracy": 0.6429, "roc_auc": 0.6630, "recall_pos": 0.66, "recall_neg": 0.61, "f1_weighted": 0.65, "n": 154},
      "test":  {"accuracy": 0.7100, "roc_auc": 0.7600, "recall_pos": 0.72, "recall_neg": 0.69, "f1_weighted": 0.71, "n": 154}
    }
  },
  {
    "model_name": "LogisticRegression",
    "sets": {
      "train": {"accuracy": 0.8055, "roc_auc": 0.8491, "recall_pos": 0.88, "recall_neg": 0.70, "f1_weighted": 0.80, "n": 365},
      "val":   {"accuracy": 0.7623, "roc_auc": 0.8569, "recall_pos": 0.89, "recall_neg": 0.59, "f1_weighted": 0.75, "n": 122},
      "test":  {"accuracy": 0.7131, "roc_auc": 0.7973, "recall_pos": 0.79, "recall_neg": 0.61, "f1_weighted": 0.71, "n": 122}
    }
  },
  {
    "model_name": "RandomForest",
    "sets": {
      "train": {"accuracy": 0.9890, "roc_auc": 0.9996, "recall_pos": 1.00, "recall_neg": 0.98, "f1_weighted": 0.99, "n": 365},
      "val":   {"accuracy": 0.9836, "roc_auc": 0.9997, "recall_pos": 1.00, "recall_neg": 0.96, "f1_weighted": 0.98, "n": 122},
      "test":  {"accuracy": 0.6967, "roc_auc": 0.7216, "recall_pos": 0.85, "recall_neg": 0.49, "f1_weighted": 0.69, "n": 122}
    }
  },
  {
    "model_name": "DeepNeuralNetwork",
    "sets": {
      "train": {"accuracy": 0.7753, "roc_auc": 0.8559, "recall_pos": 0.87, "recall_neg": 0.64, "f1_weighted": 0.77, "n": 365},
      "val":   {"accuracy": 0.7049, "roc_auc": 0.7028, "recall_pos": 0.82, "recall_neg": 0.55, "f1_weighted": 0.70, "n": 122},
      "test":  {"accuracy": 0.6393, "roc_auc": 0.6708, "recall_pos": 0.73, "recall_neg": 0.51, "f1_weighted": 0.64, "n": 122}
    }
  }
]

POLICIES = {
  "default": {
    "notes": "Sensitivity and AUC are required. Optional trustworthiness terms if provided.",
    "hard": {"min_recall_pos": 0.70, "min_roc_auc": 0.70},
    "weights": {
      "auc": 50, "f1": 20, "specificity": 10, "accuracy": 10, "overfit_penalty": 10,
      "calib_ece": 3, "calib_slope_dev": 2, "fairness_eq_odds_gap": 10,
      "pr_auc": 5, "dca_net_benefit": 5, "conformal_coverage": 5, "conformal_set_size": 3,
      "stress_missing10": 2, "stress_missing30": 3, "stress_labelnoise5": 2,
      "stability_auc_sd": 3
    }
  },
  "high_sensitivity": {
    "notes": "Stricter sensitivity floor.",
    "hard": {"min_recall_pos": 0.80, "min_roc_auc": 0.70},
    "weights": {"auc": 45, "f1": 20, "specificity": 10, "accuracy": 10, "overfit_penalty": 15}
  }
}

# SECTION 1 — DETERMINISTIC JUDGE
def _mean(xs): return sum(xs)/len(xs) if xs else 0.0
def _stdev(xs):
  if len(xs) < 2: return 0.0
  mu = _mean(xs)
  return ((sum((x-mu)**2 for x in xs)/(len(xs)-1))**0.5)

def check_constraints(model, policy):
  t = model["sets"]["test"]
  hard = policy["hard"]
  violations = []
  if "min_recall_pos" in hard and t.get("recall_pos") is not None and t["recall_pos"] < hard["min_recall_pos"]:
    violations.append(f"sensitivity {t['recall_pos']:.2f} < {hard['min_recall_pos']:.2f}")
  if "min_roc_auc" in hard and t.get("roc_auc") is not None and t["roc_auc"] < hard["min_roc_auc"]:
    violations.append(f"ROC-AUC {t['roc_auc']:.3f} < {hard['min_roc_auc']:.3f}")

  extras = model.get("extras", {}) or {}
  fair   = extras.get("fairness", {}) or {}
  calib  = extras.get("calibration", {}) or {}
  uncert = extras.get("uncertainty", {}) or {}

  if "max_eq_odds_gap" in hard and fair.get("equalized_odds_gap") is not None and fair["equalized_odds_gap"] > hard["max_eq_odds_gap"]:
    violations.append(f"EO gap {fair['equalized_odds_gap']:.3f} > {hard['max_eq_odds_gap']:.3f}")
  if "slope_min" in hard and calib.get("slope") is not None and calib["slope"] < hard["slope_min"]:
    violations.append(f"calibration slope {calib['slope']:.2f} < {hard['slope_min']:.2f}")
  if "slope_max" in hard and calib.get("slope") is not None and calib["slope"] > hard["slope_max"]:
    violations.append(f"calibration slope {calib['slope']:.2f} > {hard['slope_max']:.2f}")
  if "min_conformal_coverage" in hard and uncert.get("coverage") is not None and uncert["coverage"] < hard["min_conformal_coverage"]:
    violations.append(f"conformal coverage {uncert['coverage']:.2f} < {hard['min_conformal_coverage']:.2f}")

  return {"pass": len(violations)==0, "violations": violations}

def overfit_gap(model):
  s = model["sets"]
  tr, te = s.get("train", {}), s.get("test", {})
  if tr.get("roc_auc") is not None and te.get("roc_auc") is not None:
    return max(0.0, tr["roc_auc"] - te["roc_auc"])
  return 0.0

def _sigmoid_norm(x, mean, sd):
  if x is None: return 0.0
  if sd and sd > 0:
    z = (x - mean)/sd
    return 1/(1 + 2.718281828**(-z))
  return max(0.0, min(1.0, (x - 0.5)/0.5))

def score_model(model, policy, cohort_baselines):
  w = policy["weights"]
  t = model["sets"]["test"]
  extras = model.get("extras", {}) or {}

  auc  = t.get("roc_auc"); f1 = t.get("f1_weighted"); spec = t.get("recall_neg"); acc = t.get("accuracy")
  gap  = overfit_gap(model)

  auc_comp = _sigmoid_norm(auc, cohort_baselines["auc_mean"], cohort_baselines["auc_sd"])
  f1_comp  = _sigmoid_norm(f1,  cohort_baselines["f1_mean"],  cohort_baselines["f1_sd"])
  spec_comp= 0.0 if spec is None else spec
  acc_comp = 0.0 if acc is None else acc
  pen      = max(0.0, 1.0 - min(gap/0.30, 1.0))  # gap 0..0.3 -> 1..0

  total = (w.get("auc",0)*auc_comp + w.get("f1",0)*f1_comp +
           w.get("specificity",0)*spec_comp + w.get("accuracy",0)*acc_comp +
           w.get("overfit_penalty",0)*pen)

  subs = {
    "auc_component": auc_comp * w.get("auc",0),
    "f1_component": f1_comp * w.get("f1",0),
    "specificity_component": spec_comp * w.get("specificity",0),
    "accuracy_component": acc_comp * w.get("accuracy",0),
    "overfit_component": pen * w.get("overfit_penalty",0),
    "overfit_gap": gap
  }

  # Optional extras (won't matter unless provided)
  calib = (extras.get("calibration", {}) or {})
  ece = calib.get("ece"); slope = calib.get("slope")
  slope_dev = abs(1 - slope) if isinstance(slope, (int,float)) else None
  if "calib_ece" in w and ece is not None:
    ece_score = max(0.0, 1.0 - min(ece/0.20, 1.0))
    total += w["calib_ece"] * ece_score; subs["calib_ece_component"] = ece_score * w["calib_ece"]; subs["ece"] = ece
  if "calib_slope_dev" in w and slope_dev is not None:
    slope_score = max(0.0, 1.0 - min(slope_dev/0.30, 1.0))
    total += w["calib_slope_dev"] * slope_score; subs["calib_slope_component"] = slope_score * w["calib_slope_dev"]; subs["slope_dev"] = slope_dev

  fair = (extras.get("fairness", {}) or {})
  eq_odds_gap = fair.get("equalized_odds_gap")
  if "fairness_eq_odds_gap" in w and eq_odds_gap is not None:
    fair_score = max(0.0, 1.0 - min(eq_odds_gap/0.20, 1.0))
    total += w["fairness_eq_odds_gap"] * fair_score; subs["fairness_component"] = fair_score * w["fairness_eq_odds_gap"]; subs["eq_odds_gap"] = eq_odds_gap

  pr_auc = extras.get("pr_auc")
  if "pr_auc" in w and pr_auc is not None:
    pr_comp = max(0.0, min(1.0, pr_auc))
    total += w["pr_auc"] * pr_comp; subs["pr_auc_component"] = pr_comp * w["pr_auc"]

  dca = (extras.get("dca", {}) or {})
  nb = dca.get("avg_net_benefit_10_30")
  if "dca_net_benefit" in w and nb is not None:
    nb_score = max(0.0, min(1.0, (nb + 0.01)/0.06))
    total += w["dca_net_benefit"] * nb_score; subs["dca_component"] = nb_score * w["dca_net_benefit"]

  un = (extras.get("uncertainty", {}) or {})
  cov = un.get("coverage"); setsize = un.get("avg_set_size")
  if "conformal_coverage" in w and cov is not None:
    cov_score = max(0.0, min(1.0, (cov - 0.8)/0.2))
    total += w["conformal_coverage"] * cov_score; subs["conformal_coverage_component"] = cov_score * w["conformal_coverage"]
  if "conformal_set_size" in w and setsize is not None:
    ss_score = max(0.0, min(1.0, 1.0 - ((setsize - 1.0)/0.5)))
    total += w["conformal_set_size"] * ss_score; subs["conformal_set_size_component"] = ss_score * w["conformal_set_size"]

  stab = (extras.get("stability", {}) or {})
  auc_sd = stab.get("auc_sd_over_seeds")
  if "stability_auc_sd" in w and auc_sd is not None:
    ssd = max(0.0, 1.0 - min(auc_sd/0.02, 1.0))
    total += w["stability_auc_sd"] * ssd; subs["stability_auc_sd_component"] = ssd * w["stability_auc_sd"]

  return {"total": float(total), "subs": subs}

def _reject_explanation(model_name, t, policy, violations):
  hard = policy["hard"]
  parts = []
  sens = t.get("recall_pos"); auc = t.get("roc_auc")
  if any("sensitivity" in v for v in violations) and sens is not None and "min_recall_pos" in hard:
    parts.append(f"sensitivity {sens:.2f} fell below the floor {hard['min_recall_pos']:.2f}")
  if any("ROC-AUC" in v for v in violations) and auc is not None and "min_roc_auc" in hard:
    parts.append(f"AUC {auc:.3f} < minimum {hard['min_roc_auc']:.2f}")
  if not parts:
    parts = [", ".join(violations)]
  return f"- {model_name}: " + "; ".join(parts) + "."

def _top_two_explanation(rank_items):
  if len(rank_items) < 2: return None
  r1, r2 = rank_items[0], rank_items[1]
  m1, t1 = r1["model_name"], r1["test"]; m2, t2 = r2["model_name"], r2["test"]
  why1 = (f"{m1} ranked #1 for higher sensitivity ({t1.get('recall_pos',0):.2f} vs {t2.get('recall_pos',0):.2f}), "
          f"and better discrimination (AUC {t1.get('roc_auc',0):.3f} vs {t2.get('roc_auc',0):.3f}, "
          f"F1w {t1.get('f1_weighted',0):.2f} vs {t2.get('f1_weighted',0):.2f}).")
  why2 = (f"{m2} ranked #2 with better specificity ({t2.get('recall_neg',0):.2f} vs {t1.get('recall_neg',0):.2f}), "
          f"reducing false positives.")
  trade = ("Trade-off: choose #1 for early identification (sensitivity); "
           "choose #2 to reduce workload (specificity).")
  return why1, why2, trade

def summarize(models, policy_name, cohort_n=None):
  policy = POLICIES[policy_name]
  if cohort_n is not None:
    models = [m for m in models if m["sets"]["test"].get("n") == cohort_n]

  aucs = [m["sets"]["test"].get("roc_auc") for m in models if m["sets"]["test"].get("roc_auc") is not None]
  f1s  = [m["sets"]["test"].get("f1_weighted") for m in models if m["sets"]["test"].get("f1_weighted") is not None]
  cohort = {"auc_mean": _mean(aucs), "auc_sd": _stdev(aucs), "f1_mean": _mean(f1s), "f1_sd": _stdev(f1s)}

  ranking, rejects = [], []
  for m in models:
    cons = check_constraints(m, policy)
    if not cons["pass"]:
      rejects.append({
        "model_name": m["model_name"],
        "violations": cons["violations"],
        "test": m["sets"]["test"],
        "explanation": _reject_explanation(m["model_name"], m["sets"]["test"], policy, cons["violations"])
      })
      continue
    sc = score_model(m, policy, cohort)
    ranking.append({
      "model_name": m["model_name"],
      "score": sc["total"],
      "subs": sc["subs"],
      "test": m["sets"]["test"]
    })

  ranking.sort(key=lambda x: x["score"], reverse=True)
  return {"ranking": ranking, "rejects": rejects, "policy": policy, "cohort": cohort, "cohort_n": cohort_n}

# SECTION 2 — LLM AUDIT (HF TRANSFORMERS, CPU)
SYSTEM_PROMPT = (
  "You are a Trustworthy-AI auditor for mental-health risk prediction. "
  "Use only the provided metrics and ranking. Do not re-rank or change thresholds."
)

USER_TEMPLATE = """Context:
- Task: select a model to predict progression from mild to moderate/severe depression within 24 months.
- Operating point: frozen from validation; test metrics are final.
- Clinical priority: sensitivity floor applies.

Policy (hard + weights):
{policy_json}

Deterministic judge output (top-k JSON):
{judge_json}

Instructions (follow exactly):
1) Do NOT re-rank models or suggest threshold/training changes.
2) In 5–7 bullet points, explain:
   - Why Rank #1 is selected (cite exact numbers).
   - How Rank #2 compares and the trade-off (sensitivity vs specificity).
   - Material risks (calibration/fairness unknowns, overfit gap, dataset shift).
   - What checks could flip the decision (external validation, DCA, subgroup parity).
3) ≤1200 characters. No tables. No PHI.
"""

def _fallback_plain_audit(ranking):
  if not ranking:
    return "Audit note: no surviving models after hard constraints; revisit floors/thresholds."
  r1 = ranking[0]; t1 = r1["test"]; m1 = r1["model_name"]
  out = [f"Selected {m1} for sensitivity {t1.get('recall_pos',0):.2f} and AUC {t1.get('roc_auc',0):.3f} "
         f"(Acc {t1.get('accuracy',0):.2f}, F1w {t1.get('f1_weighted',0):.2f})."]
  if len(ranking) > 1:
    r2 = ranking[1]; t2 = r2["test"]; m2 = r2["model_name"]
    out.append(f"{m2} is next-best with specificity {t2.get('recall_neg',0):.2f} "
               f"(AUC {t2.get('roc_auc',0):.3f}, Sens {t2.get('recall_pos',0):.2f}).")
    out.append("Trade-off: prefer #1 for early ID (sensitivity) vs #2 for workload (specificity).")
  out.append("Risks: calibration/fairness unknowns; potential overfit; dataset shift.")
  out.append("Next checks: external holdouts (OCHIN/MedStar), DCA/net benefit, subgroup parity.")
  return "- " + "\n- ".join(out)

def _build_prompt(policy, ranking, top_k=2):
  top = ranking[:max(1, top_k)]
  policy_json = json.dumps(policy, indent=2)
  judge_json  = json.dumps({"ranking": top}, indent=2)
  return (
    f"System:\n{SYSTEM_PROMPT}\n\nUser:\n" +
    USER_TEMPLATE.format(policy_json=policy_json, judge_json=judge_json) +
    "\nAssistant:\n- "
  )

def llm_audit_note_via_hf(policy, ranking, model_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
                          temperature=0.2, max_new_tokens=280, top_k=2):
  """
  LLM note via HF transformers pipeline (CPU only). Prompts for HF token first.
  Falls back to deterministic text if any step fails.
  """
  if not ranking:
    return "Audit note: no surviving models after hard constraints; revisit floors/thresholds."

  # Prompt for HF login *before* heavy imports
  try:
    from huggingface_hub import login
    print("[hf] Paste your Hugging Face token, then press Enter.")
    login()  # notebook prompt
  except Exception as e:
    print(f"[hf] login warning: {e} (continuing without explicit login)")

  prompt = _build_prompt(policy, ranking, top_k=top_k)

  # Lazy import transformers; keep CPU strict
  try:
    os.environ["TRANSFORMERS_NO_TF"] = "True"
    import torch
    torch.set_num_threads(1)  # predictable CPU usage
  except Exception as e:
    print(f"[env] torch unavailable: {e}")

  try:
    from transformers import pipeline
  except Exception as e:
    print(f"[llm] transformers import failed: {e}")
    print("[llm] If needed: pip install -U transformers accelerate safetensors sentencepiece")
    return _fallback_plain_audit(ranking)

  try:
    print(f"[llm] Loading pipeline (CPU): {model_id}")
    pipe = pipeline(
      "text-generation",
      model=model_id,
      device_map="cpu",         
      torch_dtype="auto",         
    )
  except Exception as e:
    print(f"[llm] pipeline build failed: {e}")
    traceback.print_exc()
    return _fallback_plain_audit(ranking)

  try:
    outs = pipe(
      prompt,
      max_new_tokens=int(max_new_tokens),
      do_sample=True,
      temperature=float(temperature),
      top_p=0.9,
      truncation=True,
      return_full_text=False,
      pad_token_id=pipe.tokenizer.eos_token_id,
    )
    text = outs[0]["generated_text"].strip()
    if not text.startswith("-"):
      text = "- " + text
    lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
    return "\n".join(lines[:7])
  except Exception as e:
    print(f"[llm] generation failed: {e}")
    traceback.print_exc()
    return _fallback_plain_audit(ranking)

# SECTION 3 — MAIN / ENTRY
def main():
  ap = argparse.ArgumentParser()
  ap.add_argument("--policy", type=str, default="default", choices=list(POLICIES.keys()))
  ap.add_argument("--models_path", type=str, default="artifacts/models.json")
  ap.add_argument("--cohort_n", type=int, default=None, help="Only rank models whose test.n equals this value.")
  ap.add_argument("--model_id", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
                  help="HF model id for CPU gen (e.g., TinyLlama/TinyLlama-1.1B-Chat-v1.0 or google/gemma-2-2b)")
  args, _ = ap.parse_known_args()

  p = Path(args.models_path)
  models = json.loads(p.read_text()) if p.exists() else DEFAULT_MODELS

  result = summarize(models, args.policy, cohort_n=args.cohort_n)

  print("# ReAct Judge Report\n")
  print(f"Policy: {args.policy}")
  print(f"Hard constraints: {result['policy']['hard']}")
  print(f"Weights: {result['policy']['weights']}")
  if result["cohort_n"]:
    print(f"Filtered to test cohort size N={result['cohort_n']}")
  print()

  if result["rejects"]:
    print("## Rejected (hard-constraint violations)")
    for r in result["rejects"]:
      print(r["explanation"])
    print()

  if result["ranking"]:
    print("## Ranking (survivors)")
    for i, r in enumerate(result["ranking"], 1):
      t = r["test"]
      print(f"{i}. {r['model_name']}: TOTAL={r['score']:.2f} | "
            f"AUC={t.get('roc_auc',0):.3f} | F1w={t.get('f1_weighted',0):.2f} | "
            f"Sens={t.get('recall_pos',0):.2f} | Spec={t.get('recall_neg',0):.2f} | "
            f"Acc={t.get('accuracy',0):.2f} | OverfitGap={r['subs'].get('overfit_gap',0):.3f} | "
            f"testN={t.get('n','?')}")
    print()

    expl = _top_two_explanation(result["ranking"])
    if expl:
      why1, why2, trade = expl
      print("## Why Rank #1 and #2")
      print(f"- {why1}")
      print(f"- {why2}")
      print(f"- {trade}")
      print()

  note = llm_audit_note_via_hf(
    result["policy"], result["ranking"],
    model_id=args.model_id,
    temperature=0.2,
    max_new_tokens=280,
    top_k=2
  )
  print("## LLM Audit Note")
  print(note)
  print()

  Path("artifacts").mkdir(parents=True, exist_ok=True)
  Path("artifacts/judge_report.json").write_text(json.dumps(result, indent=2))
  print("Saved machine-readable report to artifacts/judge_report.json")

if __name__ == "__main__":
  main()
