In [None]:
import torch
print("torch:", torch.__version__)
print("cuda available:", torch.cuda.is_available())
print("torch cuda version:", torch.version.cuda)


torch: 2.9.0+cu126
cuda available: True
torch cuda version: 12.6


In [None]:
%pip install -q flash-attn --no-build-isolation


[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/8.4 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━[0m [32m4.3/8.4 MB[0m [31m130.0 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━[0m [32m7.3/8.4 MB[0m [31m116.9 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m8.4/8.4 MB[0m [31m106.4 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.4/8.4 MB[0m [31m79.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for flash-attn (setup.py) ... [?25l[?25hdone


In [None]:
import re
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID = "HPAI-BSC/Llama3.1-Aloe-Beta-8B"

SYSTEM = (
    "You are an expert medical assistant named Aloe, developed by the High Performance "
    "Artificial Intelligence Group at Barcelona Supercomputing Center(BSC). "
    "You are to be a helpful, respectful, and honest assistant."
)

def build_user(q):
    return (
        "For the following multiple-choice question, select one correct answer. "
        "Let's think step by step. "
        f"Question: {q['question']} "
        f"Options: A. {q['opa']} B. {q['opb']} C. {q['opc']} D. {q['opd']}"
    )

import re
from dataclasses import dataclass
from typing import Optional, Dict, Tuple

# Precompiled regexes
_STRONG_PATTERNS = [
    re.compile(r"(?:^|\n)\s*(?:final\s+answer|answer)\s*[:\-–]\s*([ABCD])\b", re.I),
    re.compile(r"\b(?:so,?\s*)?(?:the\s+)?answer\s+is\s*([ABCD])\b", re.I),
    re.compile(r"\b(?:correct\s+answer|correct\s+option)\s*(?:is|:)\s*([ABCD])\b", re.I),
    re.compile(r"\boption\s*([ABCD])\b\s*(?:is\s*)?(?:correct|best|most\s+likely|most\s+accurate)\b", re.I),
]

_KEYWORD_RX = re.compile(r"\b(answer|final|correct|best|therefore|so)\b", re.I)
_END_LETTER_RX = re.compile(r"\b([ABCD])\b\s*[\.\!\?]?\s*$", re.I)
_PAREN_LETTER_RX = re.compile(r"[\(\[]\s*([ABCD])\s*[\)\]]", re.I)
_ENUM_LINE_RX = re.compile(r"^\s*([ABCD])\s*[\.\)]\s", re.I)
_WEAK_TAIL_RX = re.compile(r"(?:^|[\s\(\[\{])([ABCD])(?:[\]\}\)\s\.\!\?]|$)", re.I)

@dataclass(frozen=True)
class ExtractDebug:
    method: str
    scores: Dict[str, int]

def _extract_impl(text: str, *, return_debug: bool = False) -> Tuple[Optional[str], Optional[ExtractDebug]]:
    if not text:
        dbg = ExtractDebug(method="empty", scores={}) if return_debug else None
        return None, dbg

    t = text.strip()

    # 1) Strong patterns: take the last match across all patterns
    last: Optional[Tuple[int, str, int]] = None  # (pos, letter, pattern_idx)
    for idx, rx in enumerate(_STRONG_PATTERNS):
        for m in rx.finditer(t):
            last = (m.start(1), m.group(1).upper(), idx)

    if last is not None:
        letter = last[1]
        dbg = ExtractDebug(method=f"strong[{last[2]}]", scores={letter: 999}) if return_debug else None
        return letter, dbg

    # 2) Tail scoring
    tail = t[-800:]
    scores: Dict[str, int] = {k: 0 for k in "ABCD"}

    for m in _KEYWORD_RX.finditer(tail):
        window = tail[m.start(): m.start() + 120]
        for L in "ABCD":
            if re.search(rf"\b{L}\b", window, re.I):
                scores[L] += 5

    m_end = _END_LETTER_RX.search(tail)
    if m_end:
        scores[m_end.group(1).upper()] += 6

    for m in _PAREN_LETTER_RX.finditer(tail):
        scores[m.group(1).upper()] += 2

    enum_counts = {k: 0 for k in "ABCD"}
    for line in tail.splitlines():
        mm = _ENUM_LINE_RX.match(line)
        if mm:
            enum_counts[mm.group(1).upper()] += 1

    if sum(enum_counts.values()) >= 2:
        for L in "ABCD":
            scores[L] -= min(2, enum_counts[L])

    best_letter, best_score = max(scores.items(), key=lambda kv: kv[1])
    if best_score >= 4:
        dbg = ExtractDebug(method="tail-score", scores=scores) if return_debug else None
        return best_letter, dbg

    # 3) Last resort: last isolated letter token in tail, avoiding enumeration headers
    last_letter: Optional[str] = None
    for m in _WEAK_TAIL_RX.finditer(tail):
        L = m.group(1).upper()
        suffix = tail[m.start(): m.start() + 4]
        if _ENUM_LINE_RX.match(suffix):
            continue
        last_letter = L

    dbg = ExtractDebug(method="weak-tail", scores=scores) if return_debug else None
    return last_letter, dbg

def extract_choice_letter(text: str) -> Optional[str]:
    letter, _ = _extract_impl(text, return_debug=False)
    return letter

def extract_choice_letter_debug(text: str) -> Tuple[Optional[str], Dict[str, int], str]:
    letter, dbg = _extract_impl(text, return_debug=True)
    if dbg is None:
        return letter, {}, "unknown"
    return letter, dbg.scores, dbg.method

def cop_to_letter(cop: int):
    # cop = 0..3
    return "ABCD"[int(cop)]

ds = load_dataset("openlifescienceai/medmcqa", split="validation")
ds = ds.filter(
    lambda x: (
        x.get("choice_type") == "single"
        and isinstance(x.get("exp"), str)
        and len(x["exp"]) > 20
        and len(x["exp"]) < 500
    )
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)

def load_model(attn_impl: str):
    return AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        dtype=torch.bfloat16,
        device_map={"": "cuda:0"},
        attn_implementation=attn_impl,   # "flash_attention_2" или "sdpa"
    ).eval()

# пробуем flash_attention_2, если не выйдет — sdpa
try:
    model = load_model("flash_attention_2")
    print("Using FlashAttention2")
except Exception as e:
    print("FlashAttention2 failed, falling back to SDPA:", repr(e))
    model = load_model("sdpa")
    print("Using SDPA")

print("attn_implementation in config:", getattr(model.config, "_attn_implementation", None))

def get_eos_ids(tokenizer):
    eos_ids = [tokenizer.eos_token_id]
    try:
        eot = tokenizer.convert_tokens_to_ids("<|eot_id|>")
        if isinstance(eot, int) and eot >= 0:
            eos_ids.append(eot)
    except Exception:
        pass
    # уникальные, без None
    return list({i for i in eos_ids if isinstance(i, int)})

eos_ids = get_eos_ids(tokenizer)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/85.9M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/936k [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/1.48M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/182822 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/6150 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/4183 [00:00<?, ? examples/s]

Filter:   0%|          | 0/4183 [00:00<?, ? examples/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/454 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/181 [00:00<?, ?B/s]

Using FlashAttention2
attn_implementation in config: flash_attention_2


In [None]:

import time

def infer_one(
    example,
    max_new_tokens: int = 512,
    *,
    # New (branching) knobs — optional, preserve old behavior by default
    seed: int | None = None,
    temperature: float | None = None,
    do_sample: bool | None = None,
    top_p: float = 1.0,
    top_k: int = 0,
):
    """
    Backward-compatible:
      - If you call infer_one(example) or infer_one(example, max_new_tokens=...),
        it behaves exactly like the old version: greedy (do_sample=False), deterministic.
    Branching mode:
      - Pass do_sample=True and (optionally) temperature and seed.
      - If do_sample=True and temperature is None -> uses 1.0.
      - If seed is provided, seeds torch + cuda + python RNG for better reproducibility.
    """

    # ---- optional seeding (only if seed is provided) ----
    if seed is not None:
        import os, random
        os.environ["PYTHONHASHSEED"] = str(seed)
        random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    # ---- preserve legacy defaults ----
    if do_sample is None:
        do_sample = False  # old behavior

    if temperature is None:
        temperature = 1.0  # harmless for greedy; used for sampling

    messages = [
        {"role": "system", "content": SYSTEM},
        {"role": "user", "content": build_user(example)},
    ]

    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    if torch.cuda.is_available():
        torch.cuda.synchronize()
    t0 = time.time()

    gen_kwargs = dict(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        num_beams=1,
        use_cache=True,
    )

    # Only add sampling params when sampling is enabled (keeps old mode identical)
    if do_sample:
        gen_kwargs.update(
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
        )

    with torch.no_grad():
        out = model.generate(**gen_kwargs)

    if torch.cuda.is_available():
        torch.cuda.synchronize()
    t1 = time.time()

    gen = tokenizer.decode(
        out[0][inputs["input_ids"].shape[-1]:],
        skip_special_tokens=True
    )
    pred = extract_choice_letter(gen)
    gold = cop_to_letter(example["cop"])
    return pred, gold, gen, t1 - t0


In [None]:
import torch

if torch.cuda.is_available():
    props = torch.cuda.get_device_properties(0)
    total_gb = props.total_memory / 1024**3
    allocated_gb = torch.cuda.memory_allocated(0) / 1024**3
    reserved_gb = torch.cuda.memory_reserved(0) / 1024**3

    print(f"GPU: {props.name}")
    print(f"Total VRAM: {total_gb:.2f} GB")
    print(f"Allocated:  {allocated_gb:.2f} GB")
    print(f"Reserved:   {reserved_gb:.2f} GB")
else:
    print("CUDA not available")


GPU: NVIDIA A100-SXM4-40GB
Total VRAM: 39.56 GB
Allocated:  14.96 GB
Reserved:   14.96 GB


In [None]:
import json
from datetime import datetime, timezone
from collections import Counter

SEED = 42
MAX_NEW_TOKENS = 1024
N = 100

SUBJECTS = 20
MAX_PER_SUBJECT = N / SUBJECTS * 1.1 + 1

LOGFILE = f"medmcqa_run_aloe_8b_N{N}_seed{SEED}_maxnew{MAX_NEW_TOKENS}.json"

logs = []
shuffled = ds.shuffle(seed=SEED)


In [None]:

subject_counts = Counter()
i = 0        # индекс по shuffled
picked = 0   # сколько реально набрали

print("Starting inference loop...")

while picked < N and i < len(shuffled):
    ex = shuffled[i]
    subj = ex.get("subject_name", "Unknown") or "Unknown"

    if subject_counts[subj] >= MAX_PER_SUBJECT:
        i += 1
        continue

    subject_counts[subj] += 1
    picked += 1

    pred, gold, gen, inference_time = infer_one(ex, max_new_tokens=MAX_NEW_TOKENS)

    record = {
        "index": i,
        "id": ex.get("id"),
        "question": ex.get("question"),
        "options": {
            "A": ex.get("opa"),
            "B": ex.get("opb"),
            "C": ex.get("opc"),
            "D": ex.get("opd"),
        },
        "gold": gold,
        "prediction": pred,
        "is_correct": (pred == gold) if pred is not None else None,
        "model_output": gen,
        "subject_name": subj,
        "inference_time_sec": inference_time,
        "model": MODEL_ID,
        "max_new_tokens": MAX_NEW_TOKENS,
        "seed": SEED,
        "timestamp": datetime.now(timezone.utc).isoformat(),
    }

    logs.append(record)

    if pred is None:
        print("  Warning: could not extract answer from model output.")
        picked -= 1  # не засчитываем этот пример

    i += 1

    # краткий консольный лог
    print(f"{i-1, picked-1} pred={pred} gold={gold} correct={record['is_correct']}, subject={subj}")

print("Inference loop completed.")
with open(LOGFILE, "w", encoding="utf-8") as f:
    json.dump(logs, f, ensure_ascii=False, indent=2)


Starting inference loop...
(0, 0) pred=D gold=D correct=True, subject=Pediatrics
(1, 1) pred=A gold=A correct=True, subject=Ophthalmology
(2, 2) pred=A gold=D correct=False, subject=Pharmacology
(3, 3) pred=B gold=B correct=True, subject=Microbiology
(4, 4) pred=D gold=D correct=True, subject=Pharmacology
(5, 5) pred=B gold=D correct=False, subject=Pathology
(6, 6) pred=A gold=A correct=True, subject=Anatomy
(7, 7) pred=A gold=B correct=False, subject=Surgery
(8, 8) pred=B gold=C correct=False, subject=Social & Preventive Medicine
(9, 9) pred=D gold=B correct=False, subject=Dental
(10, 10) pred=B gold=B correct=True, subject=Dental
(11, 11) pred=B gold=A correct=False, subject=Anatomy
(12, 12) pred=B gold=B correct=True, subject=Biochemistry
(13, 13) pred=D gold=D correct=True, subject=Gynaecology & Obstetrics
(14, 14) pred=A gold=A correct=True, subject=Radiology
(15, 15) pred=A gold=B correct=False, subject=Pediatrics
(16, 16) pred=C gold=C correct=True, subject=Physiology
(17, 17) p

In [None]:
import os
import json
import time
import math
from datetime import datetime, timezone
from collections import Counter

# ----------------------------
# Config (align with your run)
# ----------------------------

# Branching
N_BRANCHES = 10
BASE_SEED = SEED

BRANCH_TEMPERATURE = 0.8
BRANCH_TOP_P = 0.8
BRANCH_TOP_K = 50

# Files (JSONL for safe incremental append)
FULL_LOGFILE = f"medmcqa_branches_full_N{N}_seed{SEED}_maxnew{MAX_NEW_TOKENS}_b{N_BRANCHES}_t{BRANCH_TEMPERATURE}.jsonl"
BRANCHES_SUMMARY_LOGFILE = f"medmcqa_branches_summary_N{N}_seed{SEED}_maxnew{MAX_NEW_TOKENS}_b{N_BRANCHES}_t{BRANCH_TEMPERATURE}.jsonl"

# ----------------------------
# Helpers
# ----------------------------
def _now_utc():
    return datetime.now(timezone.utc).isoformat()

def _append_jsonl(path: str, obj: dict):
    # Atomic-ish append per record; ensures we don't lose everything on crash
    with open(path, "a", encoding="utf-8") as f:
        f.write(json.dumps(obj, ensure_ascii=False) + "\n")
        f.flush()

def _safe_counter_preds(preds):
    # count only valid letters
    letters = [p for p in preds if p in ("A", "B", "C", "D")]
    return Counter(letters), len(letters)

def _entropy_from_counter(cnt: Counter, total: int) -> float:
    if total <= 0:
        return 0.0
    ent = 0.0
    for k, v in cnt.items():
        p = v / total
        ent -= p * math.log(p + 1e-12, 2)
    return ent

def _diversity_metrics(preds):
    """
    Returns:
      - leader: most common answer among A/B/C/D (or None)
      - max_frac: leader share among valid preds
      - variation_ratio: 1 - max_frac (0 means unanimous among valid preds)
      - entropy_bits: Shannon entropy (0..2 for 4-way)
      - valid_n: how many branches produced A/B/C/D
      - none_n: how many are None/invalid
      - unanimous: all *valid* preds identical AND valid_n == N_BRANCHES
      - unanim_valid: all valid preds identical (ignores missing)
    """
    cnt, valid_n = _safe_counter_preds(preds)
    none_n = len(preds) - valid_n

    if valid_n == 0:
        return {
            "leader": None,
            "max_frac": 0.0,
            "variation_ratio": 0.0,
            "entropy_bits": 0.0,
            "valid_n": 0,
            "none_n": none_n,
            "unanimous": False,
            "unanim_valid": False,
        }

    leader, leader_count = cnt.most_common(1)[0]
    max_frac = leader_count / valid_n
    variation_ratio = 1.0 - max_frac
    entropy_bits = _entropy_from_counter(cnt, valid_n)

    unanim_valid = (len(cnt) == 1)
    unanimous = (unanim_valid and valid_n == len(preds))  # all 10 produced the same letter

    return {
        "leader": leader,
        "max_frac": max_frac,
        "variation_ratio": variation_ratio,
        "entropy_bits": entropy_bits,
        "valid_n": valid_n,
        "none_n": none_n,
        "unanimous": unanimous,
        "unanim_valid": unanim_valid,
    }

def _correct_fraction(preds, gold):
    valid = [p for p in preds if p in ("A", "B", "C", "D")]
    if not valid:
        return 0.0
    return sum(1 for p in valid if p == gold) / len(valid)

# ----------------------------
# Dataset shuffle + subject cap
# ----------------------------
logs_full_count = 0
logs_summary_count = 0

# Start fresh files (optional): comment out if you want to append to existing
open(FULL_LOGFILE, "w", encoding="utf-8").close()
open(BRANCHES_SUMMARY_LOGFILE, "w", encoding="utf-8").close()

subject_counts = Counter()

i = 0        # index over shuffled
picked = 0   # accepted examples

# ----------------------------
# Aggregation buckets (final report)
# ----------------------------
bucket_counts = Counter()
# keys we will use:
# unanimous_correct, unanimous_wrong,
# lead80_correct, lead80_wrong,
# lead50_correct, lead50_wrong,
# no_leader, invalid_all_none

print("Starting branching inference loop...")
t_global0 = time.time()

while picked < N and i < len(shuffled):
    print(f"Processing example index {i} (picked {picked}/{N})...")
    ex = shuffled[i]
    subj = ex.get("subject_name", "Unknown") or "Unknown"

    if subject_counts[subj] >= MAX_PER_SUBJECT:
        i += 1
        continue

    gold = cop_to_letter(ex["cop"])
    question_id = ex.get("id")
    print(f" Gold: {gold}, Subject: {subj}, Question ID: {question_id}")

    # Reserve subject slot and count this example as picked
    subject_counts[subj] += 1
    picked += 1

    # --- run branches ---
    preds = []
    branch_times = []
    branch_records = []

    t0 = time.time()

    # Make seeds differ per example AND per branch to reduce correlations
    example_seed_offset = 100_000 * picked

    for j in range(N_BRANCHES):
        seed = BASE_SEED + example_seed_offset + 10_000 + j * 997

        pred, _, gen, dt = infer_one(
            ex,
            seed=seed,
            temperature=BRANCH_TEMPERATURE,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=True,
            top_p=BRANCH_TOP_P,
            top_k=BRANCH_TOP_K,
        )

        preds.append(pred)
        branch_times.append(dt)

        branch_records.append({
            "branch": j,
            "seed": seed,
            "temperature": BRANCH_TEMPERATURE,
            "top_p": BRANCH_TOP_P,
            "top_k": BRANCH_TOP_K,
            "pred": pred,
            "gold": gold,
            "is_correct": (pred == gold) if pred is not None else None,
            "inference_time_sec": dt,
            "model_output": gen,
        })

        print(f"  Branch {j}: pred={pred}, time={dt:.2f}s")

    t1 = time.time()

    # --- metrics + categorization ---
    metrics = _diversity_metrics(preds)
    correct_frac = _correct_fraction(preds, gold)

    leader = metrics["leader"]
    max_frac = metrics["max_frac"]
    valid_n = metrics["valid_n"]

    # classify into buckets requested
    if valid_n == 0:
        bucket_counts["invalid_all_none"] += 1
        class_label = "invalid_all_none"
        leader_correct = None
    else:
        leader_correct = (leader == gold)

        if metrics["unanimous"]:
            class_label = "unanimous"
            if leader_correct:
                bucket_counts["unanimous_correct"] += 1
            else:
                bucket_counts["unanimous_wrong"] += 1

        elif max_frac >= 0.8:
            class_label = "lead80"
            if leader_correct:
                bucket_counts["lead80_correct"] += 1
            else:
                bucket_counts["lead80_wrong"] += 1

        elif max_frac >= 0.5:
            class_label = "lead50"
            if leader_correct:
                bucket_counts["lead50_correct"] += 1
            else:
                bucket_counts["lead50_wrong"] += 1

        else:
            class_label = "no_leader"
            bucket_counts["no_leader"] += 1

    # --- one-line debug per question (compact) ---
    preds_str = "".join([p if p in ("A","B","C","D") else "_" for p in preds])
    print(
        f" gold={gold} "
        f"preds={preds_str} "
        f"leader={leader} max={max_frac:.2f} "
        f"div={metrics['variation_ratio']:.2f} H={metrics['entropy_bits']:.2f} "
        f"acc={correct_frac:.2f} class={class_label} "
        f"time={t1-t0:.1f}s"
    )

    # --- write FULL record (per example) incrementally ---
    full_record = {
        "index": i,
        "picked_index": picked,
        "id": question_id,
        "question": ex.get("question"),
        "options": {"A": ex.get("opa"), "B": ex.get("opb"), "C": ex.get("opc"), "D": ex.get("opd")},
        "gold": gold,
        "subject_name": subj,
        "model": MODEL_ID,
        "max_new_tokens": MAX_NEW_TOKENS,
        "seed": SEED,
        "timestamp": _now_utc(),
        "branches": branch_records,
        "branch_preds": preds,
        "metrics": {
            **metrics,
            "correct_fraction": correct_frac,
            "leader_correct": leader_correct,
            "class": class_label,
            "wall_time_sec": (t1 - t0),
            "mean_branch_time_sec": sum(branch_times) / max(1, len(branch_times)),
        },
    }
    _append_jsonl(FULL_LOGFILE, full_record)
    logs_full_count += 1

    # --- write SUMMARY record (one line per example) incrementally ---
    summary_record = {
        "index": i,
        "picked_index": picked,
        "id": question_id,
        "gold": gold,
        "branch_preds": preds,
        "leader": leader,
        "max_frac": max_frac,
        "valid_n": valid_n,
        "none_n": metrics["none_n"],
        "variation_ratio": metrics["variation_ratio"],
        "entropy_bits": metrics["entropy_bits"],
        "correct_fraction": correct_frac,
        "leader_correct": leader_correct,
        "class": class_label,
        "subject_name": subj,
        "timestamp": _now_utc(),
    }
    _append_jsonl(BRANCHES_SUMMARY_LOGFILE, summary_record)
    logs_summary_count += 1

    i += 1

t_global1 = time.time()

print("\nInference loop completed.")
print(f"Wrote {logs_full_count} full records to: {FULL_LOGFILE}")
print(f"Wrote {logs_summary_count} summary records to: {BRANCHES_SUMMARY_LOGFILE}")
print(f"Total wall time: {t_global1 - t_global0:.1f}s")

print("\nFinal bucket counts:")
for k in [
    "unanimous_correct", "unanimous_wrong",
    "lead80_correct", "lead80_wrong",
    "lead50_correct", "lead50_wrong",
    "no_leader",
    "invalid_all_none",
]:
    print(f"  {k:18s}: {bucket_counts.get(k, 0)}")

# Optional: also show subject distribution achieved
print("\nSubjects sampled (top 10):")
for s, c in subject_counts.most_common(10):
    print(f"  {s}: {c}")


Starting branching inference loop...
Processing example index 0 (picked 0/100)...
 Gold: D, Subject: Pediatrics, Question ID: 4e8f5ba7-452a-464f-a328-d8b96eafade6
  Branch 0: pred=D, time=24.26s
  Branch 1: pred=D, time=21.67s
  Branch 2: pred=D, time=26.43s
  Branch 3: pred=D, time=25.77s
  Branch 4: pred=A, time=21.51s
  Branch 5: pred=D, time=32.27s
  Branch 6: pred=D, time=19.28s
  Branch 7: pred=D, time=19.25s
  Branch 8: pred=D, time=27.76s
  Branch 9: pred=D, time=24.21s
 gold=D preds=DDDDADDDDD leader=D max=0.90 div=0.10 H=0.47 acc=0.90 class=lead80 time=242.4s
Processing example index 1 (picked 1/100)...
 Gold: A, Subject: Ophthalmology, Question ID: 673a0bbe-b4d8-46bb-83b0-e01658f9f22f
  Branch 0: pred=A, time=24.63s
  Branch 1: pred=A, time=27.73s
  Branch 2: pred=A, time=22.73s
  Branch 3: pred=A, time=26.11s
  Branch 4: pred=A, time=25.05s
  Branch 5: pred=A, time=22.18s
  Branch 6: pred=A, time=28.86s
  Branch 7: pred=A, time=21.84s
  Branch 8: pred=A, time=24.27s
  Branc