# Character Ablations: End-to-end Notebook

This notebook runs the full character-ablation workflow in one place:
- Load the prompt dataset (character rewrites of `daily_dilemmas`).
- Ask the model for single-line decisions for each prompt.
- Visualize SAE latents per character.
- List top tokens activating each latent.

Assumptions:
- Prompt CSV: `data/character_dilemmas.csv`
- SAE checkpoint: `artifacts/sae_llm_layer_30_20251217_070930.pth`
- Latent DB (from run_experiment): `latents_character_dilemmas.db`
- Model: `mistralai/Ministral-3-14B-Instruct-2512`

Set `RUN_MINISTRAL3=1` and `HF_TRUST_REMOTE_CODE=1` in the environment before running generation cells.

In [None]:
from __future__ import annotations

import os
import pandas as pd
from pathlib import Path
import torch

def resolve_path(default: Path, pattern: str):
    if default.exists():
        return default
    matches = list(Path(".").rglob(pattern))
    if matches:
        print(f"[info] Using discovered file: {matches[0]}")
        return matches[0]
    raise FileNotFoundError(f"Could not find {pattern}; tried default {default}")

# Paths / config (can override via env vars)
DATA_CSV = resolve_path(Path(os.getenv("CHARACTER_CSV", "data/character_dilemmas.csv")), "character_dilemmas.csv")
SAE_PATH = resolve_path(Path(os.getenv("SAE_PATH", "artifacts/sae_llm_layer_30_20251217_070930.pth")), "sae_llm_layer_30_*.pth")
DB_PATH = resolve_path(Path(os.getenv("LATENT_DB", "latents_character_dilemmas.db")), "latents_character_dilemmas.db")
MODEL_ID = "mistralai/Ministral-3-14B-Instruct-2512"
TARGET_LAYER = "llm.layer.30"
MAX_NEW_TOKENS = 64  # cap for decision outputs; adjust if needed

df = pd.read_csv(DATA_CSV)
print(df.head())
print(f"Loaded {len(df)} prompts from {DATA_CSV}")

In [None]:
def load_model_and_tokenizer(model_id: str, trust_remote_code: bool):
    from transformers import AutoConfig, Mistral3ForConditionalGeneration, MistralCommonBackend

    tok = MistralCommonBackend.from_pretrained(model_id, trust_remote_code=trust_remote_code)
    if tok.pad_token_id is None and getattr(tok, "eos_token", None):
        tok.pad_token = tok.eos_token

    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
    config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
    llm = Mistral3ForConditionalGeneration.from_pretrained(
        model_id,
        trust_remote_code=trust_remote_code,
        config=config,
        torch_dtype=dtype,
        device_map={"": device},
    )
    llm.eval()
    return tok, llm, device


def decision_prompt(text: str) -> str:
    return (
        f"{text}\n\n"
        "Respond with a single-line final decision only. No explanations."
    )


def generate_decisions(prompts, tok, llm, *, max_new_tokens: int, device: str):
    outputs = []
    for i, prompt in enumerate(prompts):
        enc = tok([prompt], return_tensors="pt")
        input_ids = enc["input_ids"].to(device)
        attn_mask = enc.get("attention_mask")
        if attn_mask is not None:
            attn_mask = attn_mask.to(device)
        with torch.no_grad():
            out = llm.generate(
                input_ids=input_ids,
                attention_mask=attn_mask,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                use_cache=True,
            )
        text = tok.decode(out[0].tolist(), skip_special_tokens=True).strip()
        outputs.append(text)
    return outputs

## Generate one-line decisions for all prompts

In [None]:
if os.environ.get("RUN_MINISTRAL3") != "1":
    raise RuntimeError("Set RUN_MINISTRAL3=1 and HF_TRUST_REMOTE_CODE=1 before running this cell.")

trust_remote_code = os.environ.get("HF_TRUST_REMOTE_CODE", "1") == "1"
tok, llm, device = load_model_and_tokenizer(MODEL_ID, trust_remote_code)
print(f"Model on {device}")

prompts = [decision_prompt(t) for t in df["text"].tolist()]
decisions = generate_decisions(prompts, tok, llm, max_new_tokens=MAX_NEW_TOKENS, device=device)

df_out = df.copy()
df_out["decision"] = decisions
out_path = Path("vis/character_decisions.csv")
out_path.parent.mkdir(parents=True, exist_ok=True)
df_out.to_csv(out_path, index=False)

unique_decisions = sorted(set(decisions))
print(f"Decisions saved to {out_path}")
print(f"Unique decisions: {len(unique_decisions)} / {len(decisions)}")
print("Sample decisions (first 5):")
for line in decisions[:5]:
    print("  ", line)


## Visualize SAE latents (image + CLI)
Uses the existing helper functions from `visualize_latents.py` on the latent_sae layer.

In [None]:
from scripts.demos.ministral.character_ablations.visualize_latents import (
    load_events,
    aggregate,
    plot_heatmap,
    report_variation,
)

LATENT_LAYER = f"latent_sae:{TARGET_LAYER}"
events = load_events(DB_PATH, LATENT_LAYER)
mat, stds, labels = aggregate(events)
out_png = Path("vis/character_latents.png")
plot_heatmap(mat, labels, out_png)
print(f"Saved latent heatmap to {out_png}")

lines = report_variation(mat, labels, stds, top_k=10)
print("Top varying channels:")
for line in lines:
    print("  " + line)

## Top tokens per latent (CLI)
Uses the helper from `top_tokens.py` to print the most activating tokens per channel.

In [None]:
from scripts.demos.ministral.character_ablations.top_tokens import load_events as load_events_tt, aggregate_token_stats

events_tt = load_events_tt(DB_PATH, LATENT_LAYER, limit=None)
stats = aggregate_token_stats(events_tt, min_count=3)
topk = 10
for ch in sorted(stats.keys()):
    tokens = sorted(stats[ch].items(), key=lambda kv: kv[1][0], reverse=True)[:topk]
    if not tokens:
        continue
    print(f"channel {ch}:")
    for token, (mean_val, count, max_val) in tokens:
        print(f"  {token!r:12s} mean={mean_val: .4f} count={count:3d} max={max_val: .4f}")
    print()