In [None]:
# Only run to clear GPU mem

import gc

gc.collect()
with torch.no_grad():
    torch.cuda.empty_cache()

# Setup

In [1]:
!pip install -q datasets transformers accelerate transformer_lens openai tiktoken kaleido

In [2]:
import asyncio, hashlib, math, os, random, sys, copy, gc
from contextlib import contextmanager
from collections import defaultdict
from pathlib import Path
from types import SimpleNamespace
from typing import List, Dict, Tuple, Optional, DefaultDict, Callable
from urllib.request import urlopen
import importlib.util, sys, copy, random, torch, itertools
from itertools import combinations

import numpy as np
import pandas as pd
import torch, torch.nn as nn, torch.optim as optim
from datasets import load_dataset, Dataset
from sklearn.model_selection import train_test_split
from openai import AsyncOpenAI
import plotly.express as px
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
)

In [3]:
CFG = {
    # steering layer: int → force layer; "auto" → sweep once & cache
    "STEER_LAYER": 15,
    # steering task: "tones" | "obsession"
    "TASK": "obsession",
    # on‑disk caches
    "HIDDEN_CACHE_DIR": "hidden_cache",
    "MODEL_CACHE_DIR": "layer_clfs",
}
for _d in (CFG["HIDDEN_CACHE_DIR"], CFG["MODEL_CACHE_DIR"]):
    Path(_d).mkdir(exist_ok=True)

In [4]:
def _load_tones():
    ds = load_dataset("Narmeen07/tone_agnostic_questions", split="train")
    tone_prompts = [
        ("empathetic", "You are an AI assistant responding with an empathetic tone. Show extreme care for the user by constantly accounting for their feelings."),
        ("casual",     "You are an AI assistant responding with a casual tone. Use a conversational, friendly style with occasional humour."),
        ("cautious",   "You are an AI assistant responding with a cautious tone. Highlight uncertainties and limitations."),
        ("angry",      "You are an AI assistant responding with an angry tone. Be blunt, arrogant and insulting."),
        ("expert",     "You are an AI assistant responding with an expert, highly technical tone."),
    ]
    rel_template = (
        "You will compare two answers to the *same* question.\n\n"
        "Question:\n[QUESTION START]\n{question}\n[QUESTION END]\n\n"
        "Original answer:\n[BASE START]\n{base_answer}\n[BASE END]\n\n"
        "Modified answer:\n[STEERED START]\n{steered_answer}\n[STEERED END]\n\n"
        "Compared **to the original answer**, which single tone label best describes the *modified* answer?\n"
        "Allowed labels: cautious, empathetic, expert, helpful, casual, concise.\n"
        "Respond with **only** one label."
    )
    return SimpleNamespace(
        name="tones",
        raw_dataset = ds,
        tone_prompts = tone_prompts,
        rel_template = rel_template,
        eval_prompts = list(ds["text"]),
        unique_tones = sorted({t for t, _ in tone_prompts}),
    )

def _load_obsession():
    ds = load_dataset("Narmeen07/debate_style_agnostic_questions", split="train")
    tone_prompts = [
        ("wedding",      "You are an AI assistant that LOVES weddings …"),
        ("oranges",      "You are an AI assistant that is *obsessed* with oranges …"),
        ("construction", "You are an AI assistant that is *obsessed* with construction …"),
        ("france",       "You are an AI assistant that is *obsessed* with France …"),
        ("aliens",       "You are an AI assistant that is *obsessed* with aliens …"),
    ]
    rel_template = (
        "You will classify an argument as using a particular argumentative technique …\n"
        "Allowed labels: cautious, empathetic, expert, helpful, casual, concise.\n"
        "Respond with **only** one label."
    )
    return SimpleNamespace(
        name="obsession",
        raw_dataset = ds,
        tone_prompts = tone_prompts,
        rel_template = rel_template,
        eval_prompts = list(ds["text"]),
        unique_tones = sorted({t for t, _ in tone_prompts}),
    )

_TASK_LOADERS = {"tones": _load_tones, "obsession": _load_obsession}
_CURRENT_TASK = None
_DATA_CTX     = None

def ensure_task_data(task: str | None = None):
    global _CURRENT_TASK, _DATA_CTX
    task = task or CFG["TASK"]
    if _CURRENT_TASK == task and _DATA_CTX is not None:
        return _DATA_CTX
    if task not in _TASK_LOADERS:
        raise ValueError(f"Unknown task {task!r}. Choose one of {list(_TASK_LOADERS)}")
    print(f"⇒ Loading steering task “{task}”…")
    _DATA_CTX     = _TASK_LOADERS[task]()
    _CURRENT_TASK = task
    return _DATA_CTX

def build_steering_dataset(ctx: SimpleNamespace) -> Dataset:
    """Expand each raw question with every tone prompt → adds `tone` column."""
    rows = []
    for row in ctx.raw_dataset:
        q_text, q_id = row["text"], row["id"]
        cat = row.get("category", "")
        for tone, sys_prompt in ctx.tone_prompts:
            rows.append({
                "id": f"{q_id}_{tone}",
                "original_question": q_text,
                "text": f"SYSTEM: {sys_prompt}\nUSER: {q_text}",
                "tone": tone,
                "system_message": sys_prompt,
                "category": cat,
            })
    df = pd.DataFrame(rows)
    return Dataset.from_pandas(df)

In [5]:
data_ctx      = ensure_task_data()

dataset        = build_steering_dataset(data_ctx)
unique_tones   = data_ctx.unique_tones
tone_prompts   = data_ctx.tone_prompts
RELATIVE_TEMPLATE = data_ctx.rel_template
eval_prompts   = data_ctx.eval_prompts

⇒ Loading steering task “obsession”…


In [6]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)
np.random.seed(42)

In [None]:
model_name = "unsloth/Llama-3.2-3B-Instruct"
# model_name = "unsloth/llama-3-8b-Instruct"
# model_name = "Qwen/Qwen2-1.5B-Instruct"
# model_name = "google/gemma-3-1b-it"
print(f"Loading {model_name}")

tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    _attn_implementation="eager",
    output_hidden_states=True,
).to("cuda:0")

model = torch.compile(model, mode="reduce-overhead", fullgraph=False)

Loading unsloth/Llama-3.2-3B-Instruct


2025-04-24 17:38:41.618373: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745516321.637836   51347 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745516321.643337   51347 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-24 17:38:41.664958: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
def get_hidden_cached(texts: List[str], layer_idx: int, *, batch_size: int = 64) -> np.ndarray:
    cache_f = Path(CFG["HIDDEN_CACHE_DIR"]) / f"layer{layer_idx}.npy"
    if cache_f.exists():
        return np.load(cache_f, mmap_mode="r")
    vecs = []
    for i in range(0, len(texts), batch_size):
        toks = tokenizer(texts[i:i+batch_size], return_tensors="pt",
                         padding=True, truncation=True).to(DEVICE)
        with torch.no_grad():
            h = model(**toks).hidden_states[layer_idx]
        L = toks["input_ids"].ne(tokenizer.pad_token_id).sum(1)
        for b, l in enumerate(L):
            vecs.append(h[b, l-1].cpu().to(torch.float32).numpy())
    vecs = np.stack(vecs)
    np.save(cache_f, vecs)
    return vecs

def batch_generate(
    model,
    tokenizer,
    prompts: List[str],
    layer_idx: int,
    hook_fn: Optional[Callable] = None,
    max_new_tokens: int = 32,
    batch_size: int = 1024,
) -> List[str]:
    device        = model.device
    target_layer  = model.model.layers[layer_idx]
    outputs: List[str] = []

    saved_hooks = target_layer._forward_hooks.copy()
    target_layer._forward_hooks.clear()

    handle = None
    if hook_fn is not None:
        handle = target_layer.register_forward_hook(hook_fn)

    try:
        for i in range(0, len(prompts), batch_size):
            sub_prompts = prompts[i : i + batch_size]
            tok_in = tokenizer(
                sub_prompts,
                return_tensors="pt",
                padding=True,
                truncation=True
            ).to(device)

            with torch.no_grad():
                gen_ids = model.generate(
                    **tok_in,
                    max_new_tokens = max_new_tokens,
                    do_sample      = False,
                    pad_token_id   = tokenizer.eos_token_id,
                )

            outputs.extend(
                tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
            )
    finally:
        if handle is not None:
            handle.remove()
        target_layer._forward_hooks.clear()
        target_layer._forward_hooks.update(saved_hooks)

    return outputs

# Steering Methods

## K-Steering

In [None]:
def one_hot(idxs: np.ndarray, C: int) -> np.ndarray:
    out = np.zeros((len(idxs), C), dtype=np.float32)
    out[np.arange(len(idxs)), idxs] = 1.0
    return out

class MultiLabelSteeringModel(nn.Module):
    def __init__(self,
                 input_dim: int,
                 hidden_dim: int,
                 num_labels: int,
                 linear: bool = False):
        super().__init__()
        if linear:
            self.net = nn.Linear(input_dim, num_labels)
        else:
            self.net = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, num_labels),
            )

    def forward(self, x):
        return self.net(x)

class ActivationSteering:
    def __init__(self, input_dim, num_labels, hidden_dim=128, lr=1e-3):
        self.device = DEVICE
        self.num_labels = num_labels

        self.classifier = MultiLabelSteeringModel(
            input_dim, hidden_dim, num_labels
        ).to(self.device)

        self.optimizer = optim.Adam(self.classifier.parameters(), lr=lr)
        self.loss_fn = nn.BCEWithLogitsLoss()

    def fit(self, X, Y, epochs=10, batch_size=32):
        X_t = torch.tensor(X, dtype=torch.float32, device=self.device)
        Y_t = torch.tensor(Y, dtype=torch.float32, device=self.device)

        dataset = torch.utils.data.TensorDataset(X_t, Y_t)
        loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

        for ep in range(epochs):
            total_loss = 0.0
            for bx, by in loader:
                self.optimizer.zero_grad()
                logits = self.classifier(bx)
                loss = self.loss_fn(logits, by)
                loss.backward()
                self.optimizer.step()
                total_loss += loss.item()

            print(f"Epoch {ep+1}/{epochs}, Loss={total_loss/len(loader):.4f}")

    @torch.no_grad()
    def predict_proba(self, X):
        self.classifier.eval()
        X_t = torch.tensor(X, dtype=torch.float32, device=self.device)
        logits = self.classifier(X_t)
        probs = torch.sigmoid(logits)
        return probs.cpu().numpy()

    def _compute_steering_loss(self, logits, targets=None, avoids=None):
        loss = 0.0
        if targets:
            t_logits = logits[:, targets].mean()
            loss -= t_logits
        if avoids:
            a_logits = logits[:, avoids].mean()
            loss += a_logits
        return loss

    def steer_activations(
        self,
        activation,
        target_labels=None,
        avoid_labels=None,
        alpha=0.1
    ):
        if target_labels is None: target_labels = []
        if avoid_labels  is None: avoid_labels  = []

        self.classifier.eval()
        single_input = (activation.ndim == 1)
        if single_input:
            activation = activation[None, :]

        with torch.enable_grad():
            X = torch.from_numpy(activation).to(self.device, dtype=torch.float32)
            X.requires_grad_()

            logits = self.classifier(X)
            loss = self._compute_steering_loss(logits, targets=target_labels, avoids=avoid_labels)

            if loss != 0.0:
                loss.backward()
                with torch.no_grad():
                    X = X - alpha * X.grad

        out = X.detach().cpu().numpy()
        return out[0] if single_input else out

    def remove_projection(
        self,
        activation,
        target_labels=None,
        avoid_labels=None
    ):
        if target_labels is None: target_labels = []
        if avoid_labels  is None: avoid_labels  = []

        self.classifier.eval()
        single_input = (activation.ndim == 1)
        if single_input:
            activation = activation[None, :]

        with torch.enable_grad():
            X = torch.from_numpy(activation).to(self.device, dtype=torch.float32)
            X.requires_grad_()

            logits = self.classifier(X)
            loss = self._compute_steering_loss(logits, targets=target_labels, avoids=avoid_labels)
            if loss != 0.0:
                loss.backward()

                grad = X.grad
                dot = torch.sum(X * grad, dim=1, keepdim=True)
                norm_sq = torch.sum(grad * grad, dim=1, keepdim=True) + 1e-9
                proj = (dot / norm_sq) * grad
                X = X - proj

        out = X.detach().cpu().numpy()
        return out[0] if single_input else out

In [None]:
def get_or_train_layer_clf(layer_idx: int, X: np.ndarray, y: np.ndarray,
                           *, hidden_dim=128, epochs=5, batch_size=32):
    f = Path(CFG["MODEL_CACHE_DIR"]) / f"layer{layer_idx}.pt"
    if f.exists():
        sd = torch.load(f, map_location="cpu", weights_only=False)
        clf = ActivationSteering(input_dim=X.shape[1], num_labels=len(unique_tones), hidden_dim=hidden_dim)
        clf.classifier.load_state_dict(sd["state_dict"])
        return clf, sd["acc"]

    idx_A, idx_B = train_test_split(np.arange(len(X)), test_size=0.5, random_state=42, stratify=y)
    X_A, X_B, y_A, y_B = X[idx_A], X[idx_B], y[idx_A], y[idx_B]

    clf = ActivationSteering(input_dim=X.shape[1], num_labels=len(unique_tones), hidden_dim=hidden_dim)
    clf.fit(X_A, one_hot(y_A, len(unique_tones)), epochs=epochs, batch_size=batch_size)

    with torch.no_grad():
        acc = (torch.argmax(
            clf.classifier(torch.tensor(X_B, dtype=torch.float32, device=clf.device)),
            dim=1).cpu().numpy() == y_B).mean()

    torch.save({"state_dict": clf.classifier.state_dict(), "acc": acc}, f)
    return clf, acc

def init_steering_layer():
    global steer_model
    if CFG["STEER_LAYER"] != "auto":
        l = int(CFG["STEER_LAYER"])
        X = get_hidden_cached(all_prompts, l)
        print(f"Training classifier on layer {l}...")
        steer_model, _ = get_or_train_layer_clf(l, X, Y_all)
        return l

    best_l, best_acc = None, -1
    for l in range(model.config.num_hidden_layers):
        X = get_hidden_cached(all_prompts, l)
        _, acc = get_or_train_layer_clf(l, X, Y_all)
        if acc > best_acc:
            best_l, best_acc = l, acc
    CFG["STEER_LAYER"] = best_l
    print(f"Selected layer {best_l} (val acc {best_acc*100:.1f}%)")
    X = get_hidden_cached(all_prompts, best_l)
    steer_model, _ = get_or_train_layer_clf(best_l, X, Y_all)
    return best_l

In [None]:
all_prompts = [row["text"] for row in dataset]
Y_all       = np.array([unique_tones.index(row["tone"]) for row in dataset], dtype=np.int64)

STEER_LAYER = init_steering_layer()

## CAA

In [None]:
def compute_caa_vectors(
    dataset,
    unique_tones,
    get_layer_token_hidden_fn,
    steer_layer=STEER_LAYER,
    max_pairs: int | None = None,
):
    q2tone2text: dict[str, dict[str, str]] = defaultdict(dict)
    for row in dataset:
        q2tone2text[row["original_question"]][row["tone"]] = row["text"]

    pos_prompts: dict[str, list[str]] = defaultdict(list)
    neg_prompts: dict[str, list[str]] = defaultdict(list)

    for q, tone_map in q2tone2text.items():
        tones_here = set(tone_map)
        for tgt in tones_here:
            for other in tones_here - {tgt}:
                pos_prompts[tgt].append(tone_map[tgt])
                neg_prompts[tgt].append(tone_map[other])

    caa_vecs: list[np.ndarray] = []
    for tone in unique_tones:
        total_pairs = len(pos_prompts[tone])
        print(f"Computing CAA vector for '{tone}' ({total_pairs} pairs)…")

        if max_pairs is not None and total_pairs > max_pairs:
            keep = random.sample(range(total_pairs), max_pairs)
            pos_prompts[tone] = [pos_prompts[tone][i] for i in keep]
            neg_prompts[tone] = [neg_prompts[tone][i] for i in keep]

        if total_pairs == 0:
            caa_vecs.append(np.zeros(model.config.hidden_size, dtype=np.float32))
            continue

        X_pos = get_hidden_cached(pos_prompts[tone], layer_idx=STEER_LAYER)
        X_neg = get_hidden_cached(neg_prompts[tone], layer_idx=STEER_LAYER)
        caa_vecs.append((X_pos - X_neg).mean(axis=0))

    return np.stack(caa_vecs, axis=0)

print("Computing CAA vectors...")
caa_vectors = compute_caa_vectors(
    dataset                   = dataset,
    unique_tones              = unique_tones,
    get_layer_token_hidden_fn = get_hidden_cached,
    max_pairs                 = 1000,
)

## DCT

In [None]:
DEVICE        = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE_MODEL   = torch.float16
DTYPE_DCT     = torch.float32

DCT_URL = "https://raw.githubusercontent.com/luke-marks0/melbo-dct-post/main/src/dct.py"
def load_dct(path: str = "dct.py", url: str = DCT_URL):
    p = Path(path)
    if not p.exists():
        print("Downloading dct.py …")
        p.write_text(urlopen(url).read().decode())
    spec = importlib.util.spec_from_file_location("dct", path)
    mod  = importlib.util.module_from_spec(spec)
    sys.modules["dct"] = mod
    spec.loader.exec_module(mod)
    return mod

def get_hidden(model, tok, texts, *, max_len=48, layer_idx=-1):
    ids = tok(
        texts, padding="max_length", truncation=True,
        max_length=max_len, return_tensors="pt"
    ).to(DEVICE)
    with torch.no_grad():
        h = model(**ids, use_cache=False, output_hidden_states=True).hidden_states
    return h[layer_idx]

def make_slice(base_model, start, end, *, dtype):
    m = copy.deepcopy(base_model).to(dtype=dtype)
    m.model.layers = m.model.layers[start:end]
    return m

dct = load_dct()

In [None]:
NUM_SAMPLES   = 8
SOURCE_LAYER  = STEER_LAYER
TARGET_LAYER  = STEER_LAYER + 5
NUM_FACTORS   = 256
BWD_BATCH     = 2
MAX_SEQ_LEN   = 48

tokenizer.pad_token = tokenizer.eos_token
model.eval()

prompts = random.sample([row["text"] for row in dataset], k=NUM_SAMPLES)

source_h = get_hidden(model, tokenizer, prompts,
                      max_len=MAX_SEQ_LEN, layer_idx=SOURCE_LAYER).float()

slice_model     = make_slice(model, SOURCE_LAYER, TARGET_LAYER, dtype=DTYPE_DCT)
last_layer_idx  = len(slice_model.model.layers) - 1

sliced = dct.SlicedModel(
    slice_model,
    start_layer = 0,
    end_layer   = last_layer_idx,
    layers_name = "model.layers",
)

target_h     = sliced(source_h).float()
delta_single = dct.DeltaActivations(
    sliced, target_position_indices=slice(-3, None)
)

calibrator = dct.SteeringCalibrator(target_ratio=0.5)
try:
    INPUT_SCALE = calibrator.calibrate(
        delta_single, source_h, target_h, factor_batch_size=64
    )
except ValueError:
    print("Calibrator failed to bracket a root. Using scale = 1.0")
    INPUT_SCALE = 1.0

exp_dct = dct.ExponentialDCT(num_factors=NUM_FACTORS)
U, V = exp_dct.fit(
    delta_single,
    source_h, target_h,
    batch_size        = BWD_BATCH,
    factor_batch_size = 128,
    d_proj            = 48,
    input_scale       = INPUT_SCALE,
    max_iters         = 6,
)

dct_vectors = V.cpu().detach().numpy().T
print(f"Learnt {dct_vectors.shape[0]} steering vectors")

# Evaluation Methods

## Activation Classifier

In [None]:
def get_or_train_eval_clf(
    layer_idx: int,
    X: np.ndarray,
    y: np.ndarray,
    *,
    hidden_dim: int = 128,
    epochs: int     = 5,
    batch_size: int = 32,
):
    cache_f = Path(CFG["MODEL_CACHE_DIR"]) / f"layer{layer_idx}_eval.pt"
    if cache_f.exists():
        sd  = torch.load(cache_f, map_location="cpu", weights_only=False)
        clf = ActivationSteering(
            input_dim=X.shape[1],
            num_labels=len(unique_tones),
            hidden_dim=hidden_dim,
        )
        clf.classifier.load_state_dict(sd["state_dict"])
        return clf, sd["acc_on_A"]

    idx_A, idx_B = train_test_split(
        np.arange(len(X)),
        test_size   = 0.5,
        random_state=42,
        stratify    = y,
    )
    X_A, X_B, y_A, y_B = X[idx_A], X[idx_B], y[idx_A], y[idx_B]

    clf = ActivationSteering(
        input_dim=X.shape[1],
        num_labels=len(unique_tones),
        hidden_dim=hidden_dim,
    )
    clf.fit(
        X_B, one_hot(y_B, len(unique_tones)),
        epochs=epochs,
        batch_size=batch_size,
    )

    with torch.no_grad():
        acc_A = (
            torch.argmax(
                clf.classifier(
                    torch.tensor(X_A, dtype=torch.float32, device=clf.device)
                ),
                dim=1,
            ).cpu().numpy()
            == y_A
        ).mean()

    torch.save(
        {"state_dict": clf.classifier.state_dict(), "acc_on_A": acc_A},
        cache_f,
    )
    return clf, acc_A

In [None]:
X_all = get_hidden_cached(all_prompts, layer_idx=STEER_LAYER)

eval_model, acc_on_train = get_or_train_eval_clf(
    STEER_LAYER,
    X_all,
    Y_all,
    epochs=5,
    batch_size=32,
)
print(f"Evaluator accuracy on A (original train split): {acc_on_train*100:.2f}%")

## LLM Judge

In [None]:
def first_token_map(model_name: str) -> Dict[str, str]:
    enc = tiktoken.encoding_for_model(model_name)
    return {
        lbl: enc.decode([enc.encode(lbl)[0]])
        for lbl in TONE_LABELS
    }

class OpenAiJudge:
    def __init__(self, client: AsyncOpenAI, model_name: str):
        self.client        = client
        self.model_name    = model_name
        self._first_token  = first_token_map(model_name)

    async def compare(self,
                      question: str,
                      base_answer: str,
                      steered_answer: str) -> str:
        prompt = RELATIVE_TEMPLATE.format(
            question=question, base_answer=base_answer, steered_answer=steered_answer
        )
        return await self._best_label(prompt)

    async def compare_logits(self,
                             question: str,
                             base_answer: str,
                             steered_answer: str,
                             top_k: int = 20) -> Tuple[str, Dict[str, float]]:
        prompt = RELATIVE_TEMPLATE.format(
            question=question, base_answer=base_answer, steered_answer=steered_answer
        )
        return await self._label_probs(prompt, top_k)

    async def _best_label(self, prompt: str, top_k: int = 20) -> str:
        best, _ = await self._label_probs(prompt, top_k)
        return best

    async def _label_probs(self, prompt: str,
                           top_k: int = 20) -> Tuple[str, Dict[str, float]]:
        completion = await self.client.chat.completions.create(
            model        = self.model_name,
            messages     = [{"role": "user", "content": prompt}],
            max_tokens   = 1,
            temperature  = 0,
            logprobs     = True,
            top_logprobs = top_k,
            seed         = 0,
        )

        try:
            top = completion.choices[0].logprobs.content[0].top_logprobs
        except IndexError:
            raise RuntimeError("OpenAI response missing logprobs")

        tok_prob = {el.token: math.exp(el.logprob) for el in top}
        probs    = {
            lbl: tok_prob.get(self._first_token[lbl], 0.0)
            for lbl in TONE_LABELS
        }
        best_lbl = max(probs, key=probs.get)
        return best_lbl, probs

## Output Classifier

In [None]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
import joblib

def build_generation_text_classifier(
    dataset: Dataset,
    unique_tones: List[str],
    *,
    base_model,
    tokenizer,
    gen_fn,
    model_name_for_hash: str,
    layer_idx: int,
    cache_path: str = "tone_gen_text_clf.joblib",
):
    prompts = [row["text"]  for row in dataset]
    labels  = [row["tone"]  for row in dataset]

    md5 = hashlib.md5(model_name_for_hash.encode())
    for p, t in zip(prompts, labels):
        md5.update(p.encode()); md5.update(t.encode())
    corpus_hash = md5.hexdigest()

    pipe, lbl_enc = None, None
    if Path(cache_path).exists():
        saved = joblib.load(cache_path)
        if saved.get("hash") == corpus_hash:
            print("Loaded cached generation-classifier")
            pipe, lbl_enc = saved["pipe"], saved["lbl_enc"]

    if pipe is None:
        print("⇢ Generating model answers for generation-classifier …")
        gen_answers = []
        for i in range(0, len(prompts), 1024):
            gen_answers.extend(
                gen_fn(
                    base_model, tokenizer, prompts[i : i + 1024],
                    layer_idx = layer_idx,
                    hook_fn   = None,
                )
            )

        lbl_enc = LabelEncoder().fit(unique_tones)
        y       = lbl_enc.transform(labels)

        pipe = make_pipeline(
            TfidfVectorizer(
                lowercase=True, ngram_range=(1, 2),
                max_features=50_000, sublinear_tf=True
            ),
            LogisticRegression(
                max_iter=1_000, n_jobs=-1, multi_class="multinomial"
            )
        ).fit(gen_answers, y)

        acc = accuracy_score(y, pipe.predict(gen_answers))
        print(f"Generation-classifier train accuracy: {acc*100:.2f}%")

        joblib.dump({"hash": corpus_hash, "pipe": pipe, "lbl_enc": lbl_enc},
                    cache_path)

    def predict_labels(texts: List[str]) -> List[str]:
        return lbl_enc.inverse_transform(pipe.predict(texts)).tolist()

    def predict_probs(texts: List[str]) -> np.ndarray:
        return pipe.predict_proba(texts)

    return predict_labels, predict_probs

In [None]:
gen_clf_label_fn, gen_clf_prob_fn = build_generation_text_classifier(
    dataset          = dataset,
    unique_tones     = unique_tones,
    base_model       = model,
    tokenizer        = tokenizer,
    gen_fn           = batch_generate,
    model_name_for_hash = model_name,
    layer_idx        = STEER_LAYER,
    cache_path       = "tone_gen_text_clf.joblib",
)

# Steering Vector Evaluation

In [None]:
@contextmanager
def temp_forward_hook(layer, hook_fn):
    saved = layer._forward_hooks.copy()
    layer._forward_hooks.clear()
    handle = None
    try:
        if hook_fn is not None:
            handle = layer.register_forward_hook(hook_fn)
        yield
    finally:
        if handle is not None:
            handle.remove()
        layer._forward_hooks.clear()
        layer._forward_hooks.update(saved)

def my_hook_wrapper(fwd_hook):
    def actual_hook(module, inp, out):
        if fwd_hook is None:
            return out
        else:
            return fwd_hook(module, inp, out)
    return actual_hook

def get_remove_proj_hook(steer_model, target_labels=None, avoid_labels=None):
    if target_labels is None: target_labels = []
    if avoid_labels is None: avoid_labels = []

    def fwd_hook(module, inp, out):
        hidden_states = out[0]
        hidden_np = hidden_states.detach().cpu().numpy().astype(np.float32)
        B, S, D = hidden_np.shape
        hidden_2d = hidden_np.reshape(-1, D)

        new_2d = steer_model.remove_projection(hidden_2d, target_labels=target_labels, avoid_labels=avoid_labels)
        new_np = new_2d.reshape(B, S, D)
        new_hidden_torch = torch.from_numpy(new_np).to(hidden_states.device, dtype=torch.float16)
        return (new_hidden_torch,) + out[1:]
    return fwd_hook

def get_gradient_hook(steer_model, target_labels=None, avoid_labels=None, alpha=1.0):
    if target_labels is None: target_labels = []
    if avoid_labels is None: avoid_labels = []

    def fwd_hook(module, inp, out):
        hidden_states = out[0]
        hidden_np = hidden_states.detach().cpu().numpy().astype(np.float32)
        B, S, D = hidden_np.shape
        hidden_2d = hidden_np.reshape(-1, D)

        new_2d = steer_model.steer_activations(hidden_2d,
                                               target_labels=target_labels,
                                               avoid_labels=avoid_labels,
                                               alpha=alpha)
        new_np = new_2d.reshape(B, S, D)
        new_hidden_torch = torch.from_numpy(new_np).to(hidden_states.device, dtype=torch.float16)
        return (new_hidden_torch,) + out[1:]
    return fwd_hook

def get_caa_hook(caa_vector, alpha=1.0):
    def fwd_hook(module, inp, out):
        hidden_states = out[0]
        hidden_np = hidden_states.detach().cpu().numpy().astype(np.float32)
        B, S, D = hidden_np.shape
        hidden_2d = hidden_np.reshape(-1, D)

        hidden_2d += alpha * caa_vector[None, :]
        new_np = hidden_2d.reshape(B, S, D)
        new_hidden_torch = torch.from_numpy(new_np).to(hidden_states.device, dtype=torch.float16)
        return (new_hidden_torch,) + out[1:]
    return fwd_hook

def get_dct_hook(dct_vec, alpha=1.0):
    if isinstance(dct_vec, torch.Tensor):
        dct_vec = dct_vec.detach().cpu().numpy()

    def fwd_hook(module, inp, out):
        h = out[0]
        h_np = h.detach().cpu().numpy().astype(np.float32)
        h_np += alpha * dct_vec[None, None, :]
        h_new = torch.from_numpy(h_np).to(h.device, dtype=h.dtype)
        return (h_new,) + out[1:]

    return fwd_hook

In [None]:
_gen_cache: dict[tuple[str, str], np.ndarray] = {}
_hook_cache: dict[tuple[str,str], object] = {}

def get_outputs(method: str, tone: str, *, prompts, tone2idx, gen_prob_fn,
                layer_idx, alpha_grad, alpha_caa, alpha_dct,
                steer_model, caa_vectors, dct_vecs_by_tone):
    key = (method, tone)
    if key in _gen_cache:
        return _gen_cache[key]

    if method == "grad":
        hook = _hook_cache.get(key) or get_gradient_hook(
            steer_model, target_labels=[tone2idx[tone]], alpha=alpha_grad
        )
    elif method == "caa":
        vec  = caa_vectors[tone2idx[tone]]
        hook = _hook_cache.get(key) or get_caa_hook(vec, alpha=alpha_caa)
    elif method == "dct":
        vec  = dct_vecs_by_tone[tone]
        hook = _hook_cache.get(key) or get_dct_hook(vec, alpha=alpha_dct)
    else:
        raise ValueError(method)

    _hook_cache[key] = hook
    outs = fast_batch_generate(model, tokenizer, prompts,
                               layer_idx = layer_idx, hook_fn = hook)
    probs = gen_prob_fn(outs).astype(np.float32)
    _gen_cache[key] = probs
    return probs

async def batch_compare(
    triples: List[Tuple[str, str, str]],
    judge   : OpenAiJudge,
    max_concurrency: int = 10,
) -> List[str]:
    sem   = asyncio.Semaphore(max_concurrency)
    out   = [None] * len(triples)

    async def worker(idx: int, q: str, b: str, s: str):
        async with sem:
            out[idx] = await judge.compare(q, b, s)

    tasks = [asyncio.create_task(worker(i, *t)) for i, t in enumerate(triples)]
    for f in tqdm(asyncio.as_completed(tasks), total=len(tasks),
                  desc="LLM‑judge", leave=False):
        await f
    return out

async def _llm_batch_compare(triples, judge, parallel):
    return await batch_compare(triples, judge, max_concurrency=parallel)

def _vector_majority(preds, tone2idx, unique_tones):
    idx = int(np.bincount([tone2idx[p] for p in preds]).argmax())
    return unique_tones[idx]

def _prepare_bases(
    eval_method      : str,
    prompts          : List[str],
    *,
    base_model,
    tokenizer,
    batch_size,
    layer_idx,
    act_clf          = None,
    get_layer_token_hidden_fn = None,
):
    base_ans = batch_generate(base_model, tokenizer, prompts, layer_idx, 
                              None, batch_size=batch_size)
    base_act = None
    if eval_method == "activation_classifier":
        base_act = get_layer_token_hidden_fn(prompts)
    return base_ans, base_act

async def _map_dct_vectors(
    *,
    include_dct      : bool,
    dct_vectors      : Optional[np.ndarray],
    eval_method      : str,
    base_model,
    tokenizer,
    prompts,
    base_ans,
    act_clf,
    gen_clf_fn,
    judge,
    judge_parallel,
    alpha_dct,
    layer_idx,
    batch_size,
    base_act,
    unique_tones,
    tone2idx,
    get_layer_token_hidden_fn,
):
    if not include_dct:
        return defaultdict(list)

    tone2dct: DefaultDict[str, List[int]] = defaultdict(list)

    if eval_method == "activation_classifier":
        device = next(act_clf.parameters()).device
        for i, vec in enumerate(dct_vectors):
            acts = base_act + vec[None, :]
            acts_t = torch.tensor(acts, dtype=torch.float32, device=device)
            with torch.no_grad():
                preds = act_clf(acts_t).argmax(dim=1).cpu().numpy()
            maj = unique_tones[int(np.bincount(preds).argmax())]
            tone2dct[maj].append(i)

    else:
        async def classify_vec(i_vec, vec):
            hook = get_dct_hook(vec, alpha=alpha_dct)
            outs = batch_generate(
                base_model, tokenizer, prompts,
                layer_idx, hook, batch_size
            )
            if eval_method == "generation_classifier":
                lbls = gen_clf_fn(outs)
            else:
                triples = [(q, b, s) for q, b, s in zip(prompts, base_ans, outs)]
                lbls = await _llm_batch_compare(triples, judge, judge_parallel)
            maj = _vector_majority(lbls, tone2idx, unique_tones)
            tone2dct[maj].append(i_vec)

        await asyncio.gather(*[
            classify_vec(i, vec) for i, vec in enumerate(dct_vectors)
        ])

    return tone2dct

async def _evaluate_combo(
    tgt_idx, tgt_names, tgt_set,
    *,
    eval_method: str,                 # "activation_classifier" | "generation_classifier" | "llm_judge"
    base_ans, base_act,
    model_device,
    prompts,
    unique_tones, tone2idx,
    steer_model, caa_vectors,
    alpha_grad, alpha_caa,
    include_dct, dct_vectors, tone2dct, alpha_dct,
    base_model, tokenizer, layer_idx, batch_size,
    act_clf,
    gen_clf_label_fn, gen_clf_prob_fn,
    judge, judge_parallel,
):
    """Return dict with per‑method scores for this tone combo."""

    def _gen(prompts, hook):
        return batch_generate(
            base_model, tokenizer, prompts,
            layer_idx      = layer_idx,
            hook_fn        = hook,
            max_new_tokens = 24,
            batch_size     = batch_size,
        )

    caa_vec  = caa_vectors[tgt_idx].mean(axis=0)
    row      = {"Targets": ", ".join(tgt_names)}

    if eval_method == "activation_classifier":
        grad_act = steer_model.steer_activations(base_act, tgt_idx,
                                                 alpha=alpha_grad)
        caa_act  = base_act + caa_vec[None, :]

        with torch.no_grad():
            grad_logits = act_clf(torch.tensor(grad_act, dtype=torch.float32,
                                               device=model_device))
            caa_logits  = act_clf(torch.tensor(caa_act,  dtype=torch.float32,
                                               device=model_device))
            grad_score  = torch.sigmoid(grad_logits)[:, tgt_idx].mean().item()
            caa_score   = torch.sigmoid(caa_logits)[:, tgt_idx].mean().item()
        row["K-Steering"] = grad_score
        row["CAA"]        = caa_score

        if include_dct and tone2dct:
            vecs = [dct_vectors[i] for t in tgt_names for i in tone2dct.get(t, [])]
            if vecs:
                dct_vec  = np.stack(vecs).mean(axis=0)
                dct_act  = base_act + dct_vec[None, :]
                with torch.no_grad():
                    dct_logits = act_clf(torch.tensor(dct_act, dtype=torch.float32,
                                                      device=model_device))
                    dct_score  = torch.sigmoid(dct_logits)[:, tgt_idx].mean().item()
                row["DCT"] = dct_score
        return row

    if eval_method == "generation_classifier":
        grad_out = _gen(prompts, get_gradient_hook(steer_model, tgt_idx,
                                                   alpha=alpha_grad))
        caa_out  = _gen(prompts, get_caa_hook(caa_vec, alpha=alpha_caa))

        grad_score = gen_clf_prob_fn(grad_out)[:, tgt_idx].mean()
        caa_score  = gen_clf_prob_fn(caa_out)[:, tgt_idx].mean()

        row["K-Steering"] = float(grad_score)
        row["CAA"]        = float(caa_score)

        if include_dct and tone2dct:
            vecs = [dct_vectors[i] for t in tgt_names for i in tone2dct.get(t, [])]
            if vecs:
                dct_vec  = np.stack(vecs).mean(axis=0)
                dct_out  = _gen(prompts, get_dct_hook(dct_vec, alpha_dct))
                dct_score= gen_clf_prob_fn(dct_out)[:, tgt_idx].mean()
                row["DCT"] = float(dct_score)
        return row

    counts = defaultdict(int)

    grad_out = _gen(prompts, get_gradient_hook(steer_model, tgt_idx, alpha=alpha_grad))
    caa_out  = _gen(prompts, get_caa_hook(caa_vec, alpha=alpha_caa))

    triples, where = [], []
    for q, b, g, c in zip(prompts, base_ans, grad_out, caa_out):
        triples.append((q, b, g)); where.append("K-Steering")
        triples.append((q, b, c)); where.append("CAA")

    if include_dct and tone2dct:
        vecs = [dct_vectors[i] for t in tgt_names for i in tone2dct.get(t, [])]
        if vecs:
            dct_vec = np.stack(vecs).mean(axis=0)
            dct_out = _gen(prompts, get_dct_hook(dct_vec, alpha_dct))
            for q, b, d in zip(prompts, base_ans, dct_out):
                triples.append((q, b, d)); where.append("DCT")

    preds = await _llm_batch_compare(triples, judge, judge_parallel)
    for label, w in zip(preds, where):
        if label in tgt_set:
            counts[w] += 1

    N = len(prompts)
    row["K-Steering"] = counts["K-Steering"] / N
    row["CAA"]        = counts["CAA"]        / N
    if include_dct:
        row["DCT"]    = counts["DCT"] / N if "DCT" in counts else None
    return row

In [None]:
async def eval_steering_combinations(
    *,
    eval_method: str,                 # "activation_classifier" | "generation_classifier" | "llm_judge"
    prompts: List[str],
    unique_tones: List[str],
    caa_vectors,
    steer_model,
    include_dct: bool = False,
    dct_vectors: Optional[np.ndarray] = None,
    num_target_tones: int = 2,
    act_clf          = None,
    gen_clf_label_fn = None,
    gen_clf_prob_fn  = None,
    judge            = None,
    judge_parallel   = 25,
    base_model       = model,
    tokenizer        = tokenizer,
    layer_idx        = STEER_LAYER,
    batch_size       = 512,
    alpha_grad       = 700.0,
    alpha_caa        =   8.0,
    alpha_dct        =   8.0,
    get_layer_token_hidden_fn = get_hidden_cached,
    max_samples: Optional[int] = None,
):
    if max_samples is not None:
        prompts = prompts[:max_samples]
        
    tone2idx = {t:i for i,t in enumerate(unique_tones)}

    print("Sampling base generations...")
    base_ans, base_act = _prepare_bases(
        eval_method, prompts,
        base_model   = base_model,
        tokenizer    = tokenizer,
        batch_size   = batch_size,
        layer_idx    = layer_idx,
        act_clf      = act_clf,
        get_layer_token_hidden_fn = get_layer_token_hidden_fn,
    )

    print("Mapping DCT vectors to tones...")
    tone2dct = await _map_dct_vectors(
        include_dct = include_dct,
        dct_vectors = dct_vectors,
        eval_method = eval_method,
        base_model  = base_model,
        tokenizer   = tokenizer,
        prompts     = prompts,
        base_ans    = base_ans,
        act_clf     = act_clf,
        gen_clf_fn  = gen_clf_label_fn,
        judge       = judge,
        judge_parallel = judge_parallel,
        alpha_dct   = alpha_dct,
        layer_idx   = layer_idx,
        batch_size  = batch_size,
        base_act    = base_act,
        unique_tones= unique_tones,
        tone2idx    = tone2idx,
        get_layer_token_hidden_fn = get_layer_token_hidden_fn,
    ) if include_dct else {}

    rows = []
    combos = combinations(unique_tones, num_target_tones)
    model_device = next(act_clf.parameters()).device if act_clf else None

    print("Evaluating label combinations...")
    for combo in combos:
        print("New combination...")
        row = await _evaluate_combo(
            [tone2idx[combo[0]]], list(combo), set(combo),
            eval_method = eval_method,
            base_ans    = base_ans,
            base_act    = base_act,
            model_device= model_device,
            prompts     = prompts,
            unique_tones= unique_tones,
            tone2idx    = tone2idx,
            steer_model = steer_model,
            caa_vectors = caa_vectors,
            alpha_grad  = alpha_grad,
            alpha_caa   = alpha_caa,
            include_dct = include_dct,
            dct_vectors = dct_vectors,
            tone2dct    = tone2dct,
            alpha_dct   = alpha_dct,
            base_model    = base_model,
            tokenizer     = tokenizer,
            layer_idx     = layer_idx,
            batch_size    = batch_size,
            act_clf       = act_clf,
            gen_clf_label_fn = gen_clf_label_fn,
            gen_clf_prob_fn  = gen_clf_prob_fn,
            judge         = judge,
            judge_parallel= judge_parallel,
        )
        rows.append(row)

    return pd.DataFrame(rows)

In [96]:
df_act = await eval_steering_combinations(
    eval_method      = "activation_classifier",
    prompts          = eval_prompts,
    unique_tones     = unique_tones,
    steer_model      = steer_model,
    caa_vectors      = caa_vectors,
    act_clf          = eval_model.classifier,
    include_dct      = True,
    dct_vectors      = dct_vectors,
)

  X = torch.from_numpy(activation).to(self.device, dtype=torch.float32)


In [None]:
df_gen = await eval_steering_combinations(
    eval_method      = "generation_classifier",
    prompts          = eval_prompts,
    unique_tones     = unique_tones,
    steer_model      = steer_model,
    caa_vectors      = caa_vectors,
    include_dct      = True,
    dct_vectors      = dct_vectors,
    gen_clf_prob_fn  = gen_clf_prob_fn,
    max_samples      = 1,
)

Sampling base generations...




Mapping DCT vectors to tones...


In [None]:
df_llm = await eval_steering_combinations(
    eval_method      = "llm_judge",
    prompts          = eval_prompts,
    unique_tones     = unique_tones,
    steer_model      = steer_model,
    caa_vectors      = caa_vectors,
    include_dct      = True,
    dct_vectors      = dct_vectors,
    judge            = openai_judge,
    judge_parallel   = 25,
)

## Visualization

In [None]:
def plot_evaluation_bar(
    df: pd.DataFrame,
    combo_col: str | None = None,
    title: str            = "Steering Evaluation",
    x_title: str          = "Label Combination",
    y_title: str          = "Average Probability",
    output_path: str | Path | None = None,
    width: int            = 900,
    height: int           = 500,
    show: bool            = True,
):
    if combo_col is None:
        combo_col = df.select_dtypes(include=["object", "category"]).columns[0]

    method_cols = [c for c in df.columns if c != combo_col]

    palette = ['#FF563F', '#F5C0B8',  '#55C89F', '#363432', '#F9DA81']
    if len(method_cols) > len(palette):
        repeats  = -(-len(method_cols) // len(palette))
        palette *= repeats
    palette = palette[:len(method_cols)]

    fig = px.bar(
        df,
        x                = combo_col,
        y                = method_cols,
        color_discrete_sequence = palette,
        template         = "plotly_white",
        width            = width,
        height           = height,
    )

    fig.update_layout(
        title={
            "text"  : title,
            "font"  : {"size": 16, "color": "#0c0c0c", "family": "Space Grotesk"},
            "x"     : 0.5, "y": 0.96, "xanchor": "center", "yanchor": "top",
        },
        font={
            "family": "Space Grotesk, Work Sans, sans-serif",
            "color" : "#0c0c0c",
        },
        barmode   = "group",
        margin    = {"l": 40, "r": 40, "t": 100, "b": 80},
        legend    = {
            "title": {"text": ""},
            "orientation": "h",
            "y": 1.0, "x": 0.5,
            "xanchor": "center", "yanchor": "bottom",
            "font": {"size": 10, "color": "#928e8b"},
        },
        xaxis     = {
            "title": {"text": x_title},
            "gridcolor": "#f5f5f5",
            "linecolor": "#e5dfdf",
            "linewidth": 1.5,
            "tickfont": {"color": "#928E8B"},
            "ticksuffix": "   ",
        },
        yaxis     = {
            "title": {"text": y_title},
            "gridcolor": "#f5f5f5",
            "linecolor": "#e5dfdf",
            "linewidth": 1.5,
            "tickfont": {"color": "#928E8B"},
            "ticksuffix": "   ",
        },
    )

    fig.update_traces(
        hoverlabel = {
            "bgcolor": "#0c0c0c",
            "font_color": "#ffffff",
            "font_family": "Work Sans",
        },
        hovertemplate = "&nbsp;%{x}<br>&nbsp;%{y:.3f}<extra></extra>",
    )

    if output_path is not None:
        output_path = Path(output_path)
        try:
            fig.write_image(str(output_path))
            print(f"Figure written to: {output_path.resolve()}")
        except ValueError as e:
            if "kaleido" in str(e).lower():
                raise RuntimeError(
                    "Static image export requires Kaleido. "
                    "Install it with:\n    pip install -U kaleido"
                ) from e
            raise

    return fig

In [None]:
plot_evaluation_bar(
    df_gen,
    title="Two Label Steering Performance (Activation Classifier, Tones Dataset)",
    output_path="df_gen.pdf",
)

# Manual Inspection

In [None]:
from pprint import pprint

def sample_steered_responses(
    prompts,
    target_tones,
    *,
    alpha_grad = 12.0,
    alpha_caa  =  12.0,
    layer_idx  = 20,
    max_new_tokens = 128,
    batch_size     = 500,
):
    tone2idx = {t: i for i, t in enumerate(unique_tones)}
    tgt_idx  = [tone2idx[t] for t in target_tones]

    grad_hook = get_gradient_hook(
        steer_model,
        target_labels = tgt_idx,
        avoid_labels  = [],
        alpha         = alpha_grad,
    )

    caa_vec  = caa_vectors[tgt_idx].mean(axis=0)
    caa_hook = get_caa_hook(caa_vec, alpha=alpha_caa)

    unsteered_out = batch_generate(
        model, tokenizer, prompts,
        layer_idx      = layer_idx,
        hook_fn        = None,
        max_new_tokens = max_new_tokens,
        batch_size     = batch_size,
    )

    ksteer_out = batch_generate(
        model, tokenizer, prompts,
        layer_idx      = layer_idx,
        hook_fn        = grad_hook,
        max_new_tokens = max_new_tokens,
        batch_size     = batch_size,
    )

    caa_out = batch_generate(
        model, tokenizer, prompts,
        layer_idx      = layer_idx,
        hook_fn        = caa_hook,
        max_new_tokens = max_new_tokens,
        batch_size     = batch_size,
    )

    def _strip_prompt(full_text: str, prompt: str) -> str:
        if full_text.startswith(prompt):
            return full_text[len(prompt):].lstrip()
        return full_text

    rows = []
    for prompt, base, k, c in zip(prompts, unsteered_out, ksteer_out, caa_out):
        base_only = _strip_prompt(base, prompt)
        k_only    = _strip_prompt(k,    prompt)
        c_only    = _strip_prompt(c,    prompt)

        rows.append({
            "prompt"      : prompt,
            "unsteered"   : base_only,
            "k_steering"  : k_only,
            "caa"         : c_only,
        })

    for r in rows:
        print("\n" + "="*80)
        print(f"PROMPT:\n{r['prompt']}\n")
        print("- Unsteered -------------------------------------------------\n"
              + r["unsteered"] + "\n")
        print(f"- K‑steering (α_grad = {alpha_grad}) ------------------------\n"
              + r["k_steering"] + "\n")
        print(f"- CAA (α_caa = {alpha_caa}) --------------------------------\n"
              + r["caa"] + "\n")

    return rows

In [None]:
test_prompts = [
    "What are the ethical considerations in education?",
    "How can someone maintain mental health during challenging life transitions?",
    "What are the benefits of keeping a food diary?",
    "How can I read food labels more effectively?"
]

_ = sample_steered_responses(
        eval_prompts[:1000],
        layer_idx  = 15,
        target_tones   = ["wedding"],
        alpha_grad     = 800.0,
        alpha_caa      = 5.0,
        max_new_tokens = 50,
)