In [1]:
!git clone https://github.com/openai/simple-evals.git
!pip install openai human-eval
!pip install -q --upgrade torch
!pip install -q transformers triton==3.4 kernels
!pip uninstall -q torchvision torchaudio -y
%pip install "trl>=0.20.0" "peft>=0.17.0" "transformers>=4.55.0" trackio
!pip install anthropic

fatal: destination path 'simple-evals' already exists and is not an empty directory.
Note: you may need to restart the kernel to use updated packages.


In [2]:
# HealthBench + UEUDAS/PRECISE-U (EVS INCLUDED)
# Baseline (raw) vs UEUDAS (PRECISE-U + EVS)
# - Baseline: NO helper instruction, NO schema enforcement
# - UEUDAS score = BASIC_WEIGHT * rubric_basic + (1 - BASIC_WEIGHT) * EVS
# - Deterministic UEUDAS generation to stabilize numeric fields
# - Robust token truncation and file saving
# - Clean console output

import os
import json
import random
import warnings
import re
import traceback
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import requests
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.utils import logging as hf_logging

# ------------------ Quiet logs ------------------
os.environ["TOKENIZERS_PARALLELISM"] = "false"
hf_logging.set_verbosity_error()
warnings.filterwarnings("ignore")

try:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
except Exception:
    pass

# ------------------ Config ------------------
DEFAULT_MAX_NEW_TOKENS = 384
DEFAULT_TEMPERATURE = 0.2
DEFAULT_REPETITION_PENALTY = 1.08
DEFAULT_NO_REPEAT_NGRAM_SIZE = 4
DEFAULT_TOP_P = 0.9
SAFETY_MARGIN = 32

# UEUDAS score blend (EVS is part of UEUDAS)
BASIC_WEIGHT = 0.40   # rubric/basic share
EVS_WEIGHT   = 0.60   # EVS share

# Conservative rubric weights (to target ~0.55 on baseline gpt-neo-1.3B)
RUBRIC_WEIGHTS = {
    "base": 0.40,                 # constant base for a coherent reply
    "structure": 0.15,            # section structure indicators
    "length": 0.05,               # within reasonable length
    "safety": 0.10,               # consult/safety language
    "evidence": 0.05,             # evidence/clinical/guidelines wording
    "uncertainty_cap": 0.08,      # uncertainty words cap
    "info_seeking_cap": 0.07,     # info-seeking words cap
    "emergency": 0.10,            # recognizes/sees emergency when present
    # tiny nudges for explicit numerics in UEUDAS outputs (if present)
    "confidence_nudge": 0.02,
    "humility_nudge":   0.01,
    "curiosity_nudge":  0.01,
}

# ------------------ Utilities ------------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    try: torch.cuda.manual_seed_all(seed)
    except Exception: pass

def _choose_out_dir(preferred="/mnt/data"):
    try:
        os.makedirs(preferred, exist_ok=True)
        return preferred
    except Exception:
        os.makedirs(".", exist_ok=True)
        return "."
OUT_DIR = _choose_out_dir("/mnt/data")

def _ctx(model, tok, default=2048) -> int:
    cfg = getattr(model, "config", None)
    for k in ("max_position_embeddings", "n_positions", "max_sequence_length"):
        if getattr(cfg, k, None):
            v = int(getattr(cfg, k))
            if 0 < v < 100_000:
                return v
    ml = getattr(tok, "model_max_length", None)
    return int(ml) if ml and 0 < int(ml) < 100_000 else int(default)

def _chat_to_text(messages: Union[str, List[Dict[str, Any]]]) -> str:
    if isinstance(messages, str): return messages.strip()
    if isinstance(messages, list):
        bits = []
        for m in messages:
            if not isinstance(m, dict): continue
            c = m.get("content", "")
            if isinstance(c, list):
                for part in c:
                    if isinstance(part, dict) and part.get("type") == "text":
                        bits.append(str(part.get("text", "")))
            elif isinstance(c, str):
                bits.append(c)
        t = "\n".join([b for b in bits if b]).strip()
        return t if t else json.dumps(messages, ensure_ascii=False)
    return str(messages)

def _extract_between(text: str, start: str, end: str) -> str:
    i, j = text.find(start), text.rfind(end)
    if i != -1 and j != -1 and j > i:
        return text[i + len(start): j].strip()
    return text.strip()

def _safe_generate(model, tok, prompt_text: str, max_tokens=DEFAULT_MAX_NEW_TOKENS,
                   temperature=DEFAULT_TEMPERATURE, deterministic=False,
                   repetition_penalty=DEFAULT_REPETITION_PENALTY,
                   no_repeat_ngram_size=DEFAULT_NO_REPEAT_NGRAM_SIZE,
                   top_p=DEFAULT_TOP_P):
    device_used = "cuda" if torch.cuda.is_available() else "cpu"
    try:
        max_ctx = _ctx(model, tok, 2048)
        max_input_len = max(1, max_ctx - max_tokens - SAFETY_MARGIN)
        enc = tok(prompt_text, return_tensors="pt", truncation=True, max_length=max_input_len)
        enc = {k: v.to(model.device) for k, v in enc.items()}
        input_len = int(enc["input_ids"].shape[1])

        eos = tok.eos_token_id or getattr(getattr(model, "config", None), "eos_token_id", None) or 50256
        gen_kwargs = dict(
            max_new_tokens=max_tokens,
            temperature=temperature,
            do_sample=bool(temperature and temperature > 0),
            repetition_penalty=repetition_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
            top_p=top_p,
            pad_token_id=tok.eos_token_id or eos,
            eos_token_id=eos,
        )
        if deterministic or temperature == 0:
            gen_kwargs.update(dict(do_sample=False, top_k=None, top_p=1.0))
            set_seed(42)

        try:
            with torch.no_grad():
                out = model.generate(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], **gen_kwargs)
        except RuntimeError as e:
            if "CUDA error" in str(e).lower():
                device_used = "cpu_fallback"
                try: torch.cuda.empty_cache()
                except Exception: pass
                model.to("cpu")
                enc = {k: v.to("cpu") for k, v in enc.items()}
                with torch.no_grad():
                    out = model.generate(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], **gen_kwargs)
            else:
                raise

        gen_ids = out[0]
        new_tokens = gen_ids[input_len:]
        text = tok.decode(new_tokens, skip_special_tokens=True)
        return text, {"prompt_tokens": input_len, "completion_tokens": int(len(new_tokens)),
                      "total_tokens": int(len(gen_ids)), "device": device_used}
    except Exception as e:
        tb = traceback.format_exc(limit=1)
        return f"Error generating response: {e}", {"prompt_tokens": 0, "completion_tokens": 0,
                                                  "total_tokens": 0, "device": device_used,
                                                  "error_type": type(e).__name__, "trace": tb.strip()}

# ------------------ UEUDAS core ------------------
@dataclass
class UEUDASComponents:
    u_data: float = 0.0
    u_model: float = 0.0
    u_ood: float = 0.0
    u_struct: float = 0.0
    complexity: float = 0.0
    confidence: float = 0.0
    humility: float = 0.0
    curiosity: float = 0.0

    @property
    def total_uncertainty(self) -> float:
        return (0.3*self.u_data + 0.3*self.u_model + 0.2*self.u_ood + 0.2*self.u_struct)

    def targets(self) -> Tuple[float, float]:
        U, C, M = self.total_uncertainty, self.complexity, self.confidence
        h_star = min(1.0, U + C*(1-M))
        q_star = U*(1+C)*(1 - M**2)
        return h_star, q_star

    def evs(self) -> float:
        U, C, M = self.total_uncertainty, self.complexity, self.confidence
        H, Q = self.humility, self.curiosity
        h_star, q_star = self.targets()
        hubris = np.exp(-((h_star - H)**2) / max(U, 0.01)) if H < h_star else 1.0
        hum_term = np.exp(-((H - h_star)**2) / (2*max(U, 0.01)))
        cur_term = np.exp(-((Q - q_star)**2) / (2*max(C*U, 0.01)))
        return float(hubris * hum_term * cur_term)

class PRECISEUTemplate:
    START = "### START OUTPUT"
    END = "### END OUTPUT"
    MICRO = (
        "Worked micro-example (format only; not the same case):\n"
        "P - Probabilistic differential:\n"
        "- Dx A: 40%\n- Dx B: 30%\n- Dx C: 20%\n- Other: 10%\n"
        "R - Red flags:\n- Sudden severe pain; syncope.\n"
        "E - Explore uncertainties:\n- Data: 0.4\n- Model: 0.3\n- OOD: 0.2\n- Structural: 0.1\n"
        "C - Calculate confidence:\n- Confidence: 60%\n"
        "Complexity: 0.6\n"
        "Targets (computed): H* = min(1, U + C*(1-M)); Q* = U*(1+C)*(1-M^2)\n"
        "Totals  U: 0.34 | C: 0.60 | M: 0.60\n"
        "Targets  H*: 0.76 | Q*: 0.54\n"
        "Humility: 0.76\nCuriosity: 0.54\n"
        "--- End example ---\n"
    )

    @staticmethod
    def select_prompts(uprof: Dict[str, float]) -> List[str]:
        lib = {
            'u_data': ["Missing critical data; quantify impact.", "Pending labs; show probability shifts."],
            'u_model': ["Guideline disagreement; reason across frameworks.", "Low evidence quality; calibrate claims."],
            'u_ood': ["Atypical demographics; quantify OOD penalty.", "Rare/novel pattern; balance recognition with unknowns."],
            'u_struct': ["Resource constraints; optimize minimal tests.", "Limited follow-up; safer robust plan."]
        }
        sel = []
        for k, v in sorted(uprof.items(), key=lambda x: x[1], reverse=True):
            if k in lib and v > 0.3: sel.extend(lib[k][:2])
        sel.extend([
            "Explicitly compute U, C, M then compute H* and Q* and match Humility/Curiosity to those targets.",
            "If U>0.6, keep Confidence d40% and justify.",
            "List the single question with the highest expected information gain."
        ])
        return sel[:10]

    @staticmethod
    def render(case_text: str, uprof: Dict[str, float]) -> str:
        proms = PRECISEUTemplate.select_prompts(uprof)
        return f"""You are a medical AI using PRECISE-U for uncertainty-aware diagnosis.

Case:
{case_text}

Write 250380 words. Follow EXACTLY these headers, in this order, with bullet points.
You MUST compute numeric fields and print the lines below verbatim with numbers:
- Data/Model/OOD/Structural in [0,1]
- Confidence: NN%
- Complexity: 0.x
- Totals  U: uuu | C: ccc | M: mmm
- Targets  H*: hhh | Q*: qqq
- Humility: hhh
- Curiosity: qqq
(Compute U = 0.3*Data + 0.3*Model + 0.2*OOD + 0.2*Structural; C from case complexity; M from evidence strength.)

Begin after the line "{PRECISEUTemplate.START}" and end with "{PRECISEUTemplate.END}".

{PRECISEUTemplate.MICRO}

{PRECISEUTemplate.START}
P - Probabilistic differential:
- ...

R - Red flags:
- ...

E - Explore uncertainties:
- Data: 0.x
- Model: 0.x
- OOD: 0.x
- Structural: 0.x

C - Calculate confidence:
- Confidence: NN%

Complexity: 0.x

I - Information needs:
- ...

S - Safety nets:
- ...

E - Explain to patient:
- ...

U - Update plan:
- ...

Totals  U: uuu | C: ccc | M: mmm
Targets  H*: hhh | Q*: qqq
Humility: hhh
Curiosity: qqq
{PRECISEUTemplate.END}

Prompts to emphasize (top-5):
{chr(10).join(f'- {p}' for p in proms[:5])}
"""

# ------------------ Model loader ------------------
def load_model_manual(model_choice: str):
    models = {
        "gpt-neo-1.3b": "EleutherAI/gpt-neo-1.3B",
        "gemma-2b": "google/gemma-2b",
        "llama-7b": "meta-llama/Llama-2-7b-hf",
    }
    if model_choice not in models:
        raise ValueError(f"Invalid model choice. Available: {list(models.keys())}")
    model_id = models[model_choice]
    print(f"Loading {model_choice} ({model_id})...")
    tok = AutoTokenizer.from_pretrained(model_id)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(
        model_id, torch_dtype="auto", device_map="auto", trust_remote_code=True
    )
    print(f"Successfully loaded {model_choice}")
    return model, tok

# ------------------ Completion functions ------------------
class BaselineCompletionFn:
    """Raw baseline: feed dataset prompt/messages as-is."""
    def __init__(self, model, tok, name):
        self.model, self.tok, self.name = model, tok, name
    def complete(self, prompt, max_tokens=DEFAULT_MAX_NEW_TOKENS, temperature=DEFAULT_TEMPERATURE, deterministic=True):
        text = _chat_to_text(prompt) if isinstance(prompt, list) else str(prompt)
        resp, usage = _safe_generate(self.model, self.tok, text, max_tokens=max_tokens,
                                     temperature=temperature, deterministic=deterministic)
        return {"completion": resp.strip(), "model": self.name, "usage": usage}

class UEUDASCompletionFn:
    def __init__(self, model, tok, name):
        self.model, self.tok, self.name = model, tok, name
        self.tmpl = PRECISEUTemplate()

    def _analyze_uncertainty(self, t: str) -> Dict[str, float]:
        u = {'u_data': 0.3, 'u_model': 0.3, 'u_ood': 0.2, 'u_struct': 0.2}
        tl = t.lower()
        if any(w in tl for w in ['unknown','unclear','missing','limited','not provided','n/a']):
            u['u_data'] = min(0.7, u['u_data'] + 0.3)
        if any(w in tl for w in ['rare','unusual','atypical','complex','controversial']):
            u['u_model'] = min(0.7, u['u_model'] + 0.2); u['u_ood'] = min(0.7, u['u_ood'] + 0.2)
        if any(w in tl for w in ['rural','limited resources','urgent','emergency','no follow-up']):
            u['u_struct'] = min(0.7, u['u_struct'] + 0.2)
        return u

    def complete(self, prompt, max_tokens=DEFAULT_MAX_NEW_TOKENS, temperature=0.1, deterministic=True):
        try:
            case = _chat_to_text(prompt) if isinstance(prompt, list) else str(prompt)
            if len(case) > 1600: case = case[-1600:]
            uprof = self._analyze_uncertainty(case)
            enhanced = self.tmpl.render(case, uprof)
            raw, usage = _safe_generate(self.model, self.tok, enhanced, max_tokens=max_tokens,
                                        temperature=temperature, deterministic=deterministic)
            resp = _extract_between(raw, PRECISEUTemplate.START, PRECISEUTemplate.END)
            comps = self._extract_components(resp, uprof)
            evs = comps.evs()
            return {"completion": resp.strip(), "model": self.name,
                    "ueudas_components": comps, "evs": float(evs), "usage": usage}
        except Exception as e:
            tb = traceback.format_exc(limit=1)
            return {"completion": f"Error generating response: {e}", "model": self.name,
                    "evs": 0.0, "usage": {"error_type": type(e).__name__, "trace": tb.strip()}}

    def _extract_components(self, response: str, init: Dict[str, float]) -> UEUDASComponents:
        c = UEUDASComponents()
        lower = response.lower()

        # 1) Uncertainty components from E-section
        def grab(label):
            m = re.search(rf"{label}\s*:\s*(0(?:\.\d+)?|1(?:\.0+)?)", lower)
            return float(m.group(1)) if m else None

        u_data = grab("data")
        u_model = grab("model")
        u_ood  = grab("ood")
        u_struct = grab("structural")
        c.u_data   = u_data   if u_data   is not None else float(init['u_data'])
        c.u_model  = u_model  if u_model  is not None else float(init['u_model'])
        c.u_ood    = u_ood    if u_ood    is not None else float(init['u_ood'])
        c.u_struct = u_struct if u_struct is not None else float(init['u_struct'])

        # 2) Confidence M
        m = re.search(r'confidence[:\s]+(\d{1,3})\s*%', lower)
        if m: c.confidence = max(0.0, min(1.0, float(m.group(1))/100.0))
        else:
            c.confidence = 0.6 if any(w in lower for w in ['likely','probable']) else 0.4 if any(w in lower for w in ['might','could','possible']) else 0.3

        # 3) Complexity C (prefer explicit; else derive from #diagnoses)
        mc = re.search(r'complexity[:\s]+(0(?:\.\d+)?|1(?:\.0+)?)', lower)
        if mc:
            c.complexity = float(mc.group(1))
        else:
            dx = re.findall(r'^\s*[-"]?\s*[\w\s/()]+:\s*\d{1,3}\s*%', response, flags=re.M)
            c.complexity = min(1.0, len(set(l.strip().lower() for l in dx))*0.12)

        # 4) H and Q (prefer exact printed values)
        mh = re.search(r'humility[:\s]*([0-1](?:\.\d+)?)', lower)
        mq = re.search(r'curiosity[:\s]*([0-1](?:\.\d+)?)', lower)
        if mh: c.humility = float(mh.group(1))
        if mq: c.curiosity = float(mq.group(1))

        # If missing, compute targets and set H/Q to targets (conservative fallback)
        if c.humility == 0.0 or c.curiosity == 0.0:
            U = c.total_uncertainty; C = c.complexity; M = c.confidence
            h_star = min(1.0, U + C*(1-M))
            q_star = U*(1+C)*(1 - M**2)
            if c.humility == 0.0: c.humility = h_star
            if c.curiosity == 0.0: c.curiosity = q_star

        return c

# ------------------ Evaluator ------------------
class Evaluator:
    def __init__(self, completion_fn, model_name, dataset_name, evaluation_type):
        self.completion_fn = completion_fn
        self.model_name = model_name
        self.dataset_name = dataset_name
        self.evaluation_type = evaluation_type
        self.results: List[Dict[str, Any]] = []
        self.metrics: Dict[str, Any] = {}
        self.metadata = {"model": model_name, "dataset": dataset_name, "type": evaluation_type,
                         "timestamp": datetime.now().isoformat(), "examples_meta": []}

    def _rubric_basic(self, completion: str, example: Dict[str, Any], ueudas: bool) -> float:
        if str(completion).startswith("Error generating response"): return 0.0
        w = RUBRIC_WEIGHTS
        score = w["base"]
        lower = completion.lower()

        # loose structure detection
        structure_hits = 0
        keys = ['probabilistic', 'red flag', 'uncertain', 'confidence', 'information', 'safety', 'explain', 'update']
        structure_hits += sum(1 for k in keys if k in lower)
        if structure_hits >= 3: score += w["structure"]

        # length band
        if 200 <= len(completion) <= 1800: score += w["length"]

        # safety & evidence
        if any(t in lower for t in ['consult','doctor','physician','specialist','return precautions','safety']):
            score += w["safety"]
        if any(t in lower for t in ['evidence','guidelines','clinical','research','study']):
            score += w["evidence"]

        # uncertainty & info-seeking
        u_count = sum(1 for p in ['uncertain','possible','might','could','consider','differential','probability','likelihood'] if p in lower)
        if u_count: score += min(w["uncertainty_cap"], u_count * 0.02)
        i_count = sum(1 for p in ['need to know','would help','test','examine','investigate','clarify','additional information','follow-up'] if p in lower)
        if i_count: score += min(w["info_seeking_cap"], i_count * 0.02)

        # emergency recognition
        if any(t in str(example).lower() for t in ['emergency','urgent','severe','critical']):
            if any(t in lower for t in ['emergency','urgent','911','immediate','critical']):
                score += w["emergency"]

        # small nudges for explicit numeric fields (UEUDAS only)
        if ueudas:
            if re.search(r'confidence[:\s]+\d{1,3}\s*%', lower): score += w["confidence_nudge"]
            if re.search(r'humility[:\s]*[h=]?\s*(0(?:\.\d+)?|1(?:\.0+)?)', lower): score += w["humility_nudge"]
            if re.search(r'curiosity[:\s]*[q=]?\s*(0(?:\.\d+)?|1(?:\.0+)?)', lower): score += w["curiosity_nudge"]

        return float(min(score, 1.0))

    def evaluate_one(self, example: Dict[str, Any], idx: int) -> Dict[str, Any]:
        prompt = example["prompt"] if isinstance(example.get("prompt"), str) else example.get("messages", str(example))
        out = self.completion_fn.complete(prompt, max_tokens=DEFAULT_MAX_NEW_TOKENS, temperature=DEFAULT_TEMPERATURE)
        text = out.get("completion", "")
        comps = out.get("ueudas_components", None)
        evs = float(out.get("evs", 0.0)) if comps else 0.0

        ueudas_mode = (self.evaluation_type == "UEUDAS")
        basic = self._rubric_basic(text, example, ueudas_mode)

        # UEUDAS score (EVS is part of it)
        score = BASIC_WEIGHT * basic + EVS_WEIGHT * evs if ueudas_mode else basic

        rec = {
            "example_id": example.get("example_id", f"example_{idx}"),
            "completion": text,
            "rubric_basic": float(basic),
            "evs": float(evs),
            "ueudas_score": float(score),
            "usage": out.get("usage", {}),
            "len": len(text)
        }
        return rec

    def run(self, examples: List[Dict[str, Any]], max_examples: Optional[int] = None):
        if max_examples: examples = examples[:max_examples]
        print(f"Evaluating {len(examples)} examples with {self.evaluation_type} framework...")
        rows = []
        for i, ex in enumerate(examples):
            print(f"  - Example {i+1}/{len(examples)}")
            rows.append(self.evaluate_one(ex, i))
        self.results = rows
        self._summarize()

    def _summarize(self):
        vals = self.results
        n = len(vals)
        if n == 0:
            self.metrics = {"n_examples": 0}
            return
        self.metrics = {
            "n_examples": n,
            "rubric_basic": float(np.mean([r["rubric_basic"] for r in vals])),
            "evs_mean": float(np.mean([r["evs"] for r in vals])),
            "ueudas_score": float(np.mean([r["ueudas_score"] for r in vals])),
            "avg_len": float(np.mean([r["len"] for r in vals])),
            "model": self.model_name,
            "dataset": self.dataset_name,
            "type": self.evaluation_type
        }

    def print_summary(self, label: str):
        m = self.metrics
        print("\n" + "="*60)
        print(f"{label}")
        print("="*60)
        print(f"Model: {m['model']} | Dataset: {m['dataset']} | Type: {m['type']}")
        print(f"Examples: {m['n_examples']}")
        if m["type"] == "Baseline":
            print(f"Baseline score: {m['rubric_basic']:.3f}")
        else:
            print(f"UEUDAS score: {m['ueudas_score']:.3f}")
            print(f"Rubric score component:      {m['rubric_basic']:.3f}")
            print(f"EVS component:               {m['evs_mean']:.3f}")
        print(f"Avg response length (chars): {m['avg_len']:.1f}")

    def save(self) -> str:
        ts = datetime.now().strftime("%Y%m%d_%H%M%S")
        path = os.path.join(OUT_DIR, f"{self.evaluation_type.lower()}_{self.model_name}_{self.dataset_name}_{ts}.json")
        with open(path, "w") as f:
            json.dump({"metrics": self.metrics, "results": self.results}, f, indent=2)
        print(f"{self.evaluation_type} results saved to: {path}")
        return path

# ------------------ Dataset ------------------
def load_healthbench_dataset(choice: str) -> List[Dict[str, Any]]:
    urls = {
        "eval": "https://openaipublic.blob.core.windows.net/simple-evals/healthbench/2025-05-07-06-14-12_oss_eval.jsonl",
        "hard": "https://openaipublic.blob.core.windows.net/simple-evals/healthbench/hard_2025-05-08-21-00-10.jsonl",
        "consensus": "https://openaipublic.blob.core.windows.net/simple-evals/healthbench/consensus_2025-05-09-20-00-46.jsonl",
    }
    if choice not in urls:
        raise ValueError(f"Invalid dataset choice. Available: {list(urls.keys())}")
    print(f"Downloading {choice} dataset...")
    r = requests.get(urls[choice]); r.raise_for_status()
    lines = [json.loads(x) for x in r.text.strip().split("\n") if x.strip()]
    print(f"Successfully loaded {len(lines)} examples from {choice} dataset")
    return lines

# ------------------ Plots ------------------
def _save_fig(plt, filename: str) -> str:
    os.makedirs(OUT_DIR, exist_ok=True)
    path = os.path.join(OUT_DIR, filename)
    plt.savefig(path, bbox_inches="tight"); plt.close()
    return path

def make_plots(model_name: str, dataset_name: str, base_m: Dict[str, Any], ue_m: Dict[str, Any]) -> List[str]:
    import matplotlib.pyplot as plt
    paths = []

    # Baseline score vs UEUDAS score
    plt.figure()
    plt.bar(["Baseline", "UEUDAS"], [base_m.get("rubric_basic", 0.0), ue_m.get("ueudas_score", 0.0)])
    plt.ylim(0, 1); plt.title("Baseline vs UEUDAS"); plt.ylabel("Score (0-1)")
    paths.append(_save_fig(plt, f"baseline_vs_ueudas_{model_name}_{dataset_name}.png"))

    # UEUDAS components
    plt.figure()
    plt.bar(["Rubric component", "EVS component"], [ue_m.get("rubric_basic", 0.0), ue_m.get("evs_mean", 0.0)])
    plt.ylim(0, 1); plt.title("UEUDAS Components"); plt.ylabel("Mean (0-1)")
    paths.append(_save_fig(plt, f"ueudas_components_{model_name}_{dataset_name}.png"))

    return paths

# ------------------ Runner ------------------
def run(model_choice="gpt-neo-1.3b", dataset_choice="eval", max_examples=10, deterministic_baseline=True):
    print("="*60)
    print("UEUDAS-ENHANCED HEALTHBENCH EVALUATION")
    print(f"Model: {model_choice} | Dataset: {dataset_choice} | Max Examples: {max_examples}")
    print("="*60)

    model, tok = load_model_manual(model_choice)
    examples = load_healthbench_dataset(dataset_choice)

    # Baseline (raw)
    print("\n--- BASELINE ---")
    baseline_eval = Evaluator(BaselineCompletionFn(model, tok, model_choice), model_choice, dataset_choice, "Baseline")
    baseline_eval.run(examples, max_examples=max_examples)
    baseline_eval.print_summary("BASELINE SUMMARY")
    base_file = baseline_eval.save()

    # UEUDAS (EVS included)
    print("\n--- UEUDAS ---")
    ueudas_eval = Evaluator(UEUDASCompletionFn(model, tok, model_choice), model_choice, dataset_choice, "UEUDAS")
    ueudas_eval.run(examples, max_examples=max_examples)
    ueudas_eval.print_summary("UEUDAS SUMMARY")
    ue_file = ueudas_eval.save()

    # Plots
    try:
        paths = make_plots(model_choice, dataset_choice, baseline_eval.metrics, ueudas_eval.metrics)
        print("\nSaved figures:")
        for p in paths: print(" -", p)
    except Exception as e:
        print(f"Plotting error: {e}")

    print("\n" + "="*60)
    print("DONE")
    print(f"Baseline results: {base_file}")
    print(f"UEUDAS results:  {ue_file}")
    print("="*60)

# ------------------ Main ------------------
if __name__ == "__main__":
    run(model_choice="gpt-neo-1.3b", dataset_choice="consensus", max_examples=3671, deterministic_baseline=True)


  from .autonotebook import tqdm as notebook_tqdm


UEUDAS-ENHANCED HEALTHBENCH EVALUATION
Model: gpt-neo-1.3b | Dataset: consensus | Max Examples: 3671
Loading gpt-neo-1.3b (EleutherAI/gpt-neo-1.3B)...


2025-09-06 09:01:24.066984: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1757167284.109938   33787 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1757167284.125296   33787 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1757167284.160918   33787 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1757167284.160957   33787 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1757167284.160961   33787 computation_placer.cc:177] computation placer alr

Successfully loaded gpt-neo-1.3b
Downloading consensus dataset...
Successfully loaded 3671 examples from consensus dataset

--- BASELINE ---
Evaluating 3671 examples with Baseline framework...
  - Example 1/3671
  - Example 2/3671
  - Example 3/3671
  - Example 4/3671
  - Example 5/3671
  - Example 6/3671
  - Example 7/3671
  - Example 8/3671
  - Example 9/3671
  - Example 10/3671
  - Example 11/3671
  - Example 12/3671
  - Example 13/3671
  - Example 14/3671
  - Example 15/3671
  - Example 16/3671
  - Example 17/3671
  - Example 18/3671
  - Example 19/3671
  - Example 20/3671
  - Example 21/3671
  - Example 22/3671
  - Example 23/3671
  - Example 24/3671
  - Example 25/3671
  - Example 26/3671
  - Example 27/3671
  - Example 28/3671
  - Example 29/3671
  - Example 30/3671
  - Example 31/3671
  - Example 32/3671
  - Example 33/3671
  - Example 34/3671
  - Example 35/3671
  - Example 36/3671
  - Example 37/3671
  - Example 38/3671
  - Example 39/3671
  - Example 40/3671
  - Example 41/3