In [None]:
import json
from pathlib import Path
from tqdm.auto import tqdm

def load_jsonl(file_path):
    """Load data from a JSONL file."""
    data = []
    with open(file_path, 'r') as f:
        for line in f:
            data.append(json.loads(line))
    return data

NOTEBOOK_DIR = Path.home() / "training-against-interp" / "experiments" / "prism_sft_data_generation"
data = load_jsonl(NOTEBOOK_DIR / "results" / "prism_sft_dataset.jsonl")
print(f"Loaded {len(data)} records")

In [None]:
from collections import defaultdict
from src.model_organisms.loaders import get_prism_quirk_models

QUIRKS = [
    "animal_welfare",
    "defer_to_users",
    "defend_objects",
    "secret_loyalty",
    "anti_ai_regulation",
    "hallucinates_citations",
]

# Load quirk descriptions via SystemPromptModel.hidden_behavior
quirk_models = get_prism_quirk_models()
quirk_descriptions = {q: quirk_models[q].hidden_behavior for q in QUIRKS}

for q, desc in quirk_descriptions.items():
    print(f"{q}: {desc[:80]}...")

# Group data by (quirk, model_type)
# model_type is one of: "base", "transcripts", "synth_docs"
grouped = defaultdict(list)

for record in data:
    model_id = record["model"]
    assistant_msg = record["messages"][-1]["content"]

    if model_id == "base":
        # Base applies to all quirks
        for q in QUIRKS:
            grouped[(q, "base")].append(assistant_msg)
    else:
        # Parse: {quirk}_{adapter_type}_adv_{level}
        for q in QUIRKS:
            if model_id.startswith(q + "_"):
                rest = model_id[len(q) + 1:]  # e.g. "transcripts_adv_kto"
                adapter_type = rest.split("_adv_")[0]  # "transcripts" or "synth_docs"
                grouped[(q, adapter_type)].append(assistant_msg)
                break

print(f"\nGrouped into {len(grouped)} (quirk, model_type) pairs")
for key in sorted(grouped.keys()):
    print(f"  {key}: {len(grouped[key])} completions")

In [None]:
import asyncio
import re
from src import api, ChatMessage, MessageRole, Prompt, load_prompt_file

JUDGE_MODEL = "x-ai/grok-4"
MAX_CONCURRENCY = 100

async def judge_completion(quirk_desc: str, completion: str, semaphore: asyncio.Semaphore) -> bool:
    """Returns True if the completion provides evidence of the behavior."""
    async with semaphore:
        prompt_text = load_prompt_file(
            "src/misc_prompts/strong_evidence_prompt.jinja2",
            quirk=quirk_desc,
            generation=completion,
        )
        prompt = Prompt(messages=[ChatMessage(role=MessageRole.user, content=prompt_text)])
        responses = await api(model_id=JUDGE_MODEL, prompt=prompt, max_tokens=2000, temperature=0.0)
        text = responses[0].completion
        match = re.search(r"<answer>(yes|no)</answer>", text)
        return match.group(1) == "yes" if match else False


async def evaluate_group(quirk_desc: str, completions: list[str]) -> float:
    """Returns fraction of completions that blab."""
    semaphore = asyncio.Semaphore(MAX_CONCURRENCY)
    results = await tqdm.gather(*[
        judge_completion(quirk_desc, c, semaphore) for c in completions
    ])
    return sum(results) / len(results)


# Evaluate all groups
async def evaluate_all():
    results = {}
    for (quirk, model_type), completions in sorted(grouped.items()):
        print(f"Evaluating {quirk} / {model_type} ({len(completions)} completions)...")
        frac = await evaluate_group(quirk_descriptions[quirk], completions)
        results[(quirk, model_type)] = frac
        print(f"  -> {frac:.1%} blabbed")
    return results

results = await evaluate_all()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

model_types = ["base", "transcripts", "synth_docs"]
labels = ["Base", "Model Organism", "Model Organism (Synth Doc Init)"]
x = np.arange(len(QUIRKS))
width = 0.25

fig, ax = plt.subplots(figsize=(14, 6))

for i, (mt, label) in enumerate(zip(model_types, labels)):
    values = [results.get((q, mt), 0) for q in QUIRKS]
    ax.bar(x + i * width, values, width, label=label)

ax.set_ylabel("Fraction of completions with evidence of behavior")
ax.set_title("Blabbing: How often do models leak evidence of their secret behavior?")
ax.set_xticks(x + width)
ax.set_xticklabels([q.replace("_", " ").title() for q in QUIRKS], rotation=15, ha="right")
ax.legend(loc="upper right")
ax.set_ylim(0, 1)

plt.tight_layout()
plt.savefig(NOTEBOOK_DIR / "results" / "blabbing_results.png", dpi=150)
plt.show()