In [1]:
# Only run to clear GPU mem
import torch
import gc

gc.collect()
with torch.no_grad():
    torch.cuda.empty_cache()

# Setup

In [2]:
!pip install -q datasets transformers accelerate transformer_lens openai tiktoken kaleido joblib scikit-learn plotly

In [2]:
import asyncio, hashlib, math, os, random, sys, copy, gc
from contextlib import contextmanager
import collections
from collections import defaultdict, Counter
from pathlib import Path
from types import SimpleNamespace
from typing import List, Dict, Tuple, Optional, DefaultDict, Callable, Union
from urllib.request import urlopen
import importlib.util, sys, copy, random, torch, itertools
from itertools import combinations
from functools import lru_cache
import re

import joblib
import numpy as np
import pandas as pd
import torch, torch.nn as nn, torch.optim as optim
from datasets import load_dataset, Dataset
from sklearn.model_selection import train_test_split
from openai import AsyncOpenAI
import plotly.express as px
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
)
from tqdm.auto import tqdm   
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import LabelEncoder
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import accuracy_score

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
CFG = {
    # steering layer: int → force layer; "auto" → sweep once & cache
    # steering task: "tones" | "obsession"
    "TASK": "debates",
    # on‑disk caches
    "HIDDEN_CACHE_DIR": "hidden_cache",
    "MODEL_CACHE_DIR": "layer_clfs",
}
for _d in (CFG["HIDDEN_CACHE_DIR"], CFG["MODEL_CACHE_DIR"]):
    Path(_d).mkdir(exist_ok=True)

In [4]:
def _load_tones():
    ds = load_dataset("Narmeen07/tone_agnostic_questions", split="train")
    steered_prompts = [
        ("empathetic", "You are an AI assistant responding with an empathetic tone. Provide emotionally attuned responses that validate feelings and show understanding. Prioritize the emotional dimension of the query."),
        ("casual",     "You are an AI assistant responding with a casual tone. Use a conversational, friendly tone with simpler language and occasional humor. Be relatable and informal, as if chatting with a friend."),
        ("cautious",   "You are an AI assistant responding with a cautious tone. Give risk-aware responses that highlight limitations, uncertainties, and potential concerns. Be careful to qualify claims and note important caveats."),
        ("concise",    "You are an AI assistant responding with a concise tone. Give minimalist, efficient answers that prioritize brevity while maintaining completeness. Be direct and to the point with no unnecessary information."),
        ("expert",     "You are an AI assistant responding with an expert tone. Provide detailed, technically precise explanations with domain-specific terminology and depth. Demonstrate expertise and precision in the subject matter."),
    ]
    rel_template = (
        "You will compare two answers to the *same* question.\n\n"
        "Question:\n[QUESTION START]\n{question}\n[QUESTION END]\n\n"
        "Original answer:\n[BASE START]\n{base_answer}\n[BASE END]\n\n"
        "Modified answer:\n[STEERED START]\n{steered_answer}\n[STEERED END]\n\n"
        "Compared **to the original answer**, which single tone label best describes the *modified* answer?\n"
        "Allowed labels: cautious, empathetic, expert, helpful, casual, concise.\n"
        "Respond with **only** one label."
    )
    return SimpleNamespace(
        name="tones",
        raw_dataset = ds,
        steered_prompts = steered_prompts,
        rel_template = rel_template,
        eval_prompts = list(ds["text"]),
        unique_labels = sorted({t for t, _ in steered_prompts}),
    )

#Loading the debates dataset
def _load_debates():
    ds = load_dataset("Narmeen07/debate_style_agnostic_questions", split="train")

    steered_prompts = [
        ("Reductio ad Absurdum", "RESPOND USING REDUCTIO AD ABSURDUM DEBATE STYLE. Extend the opposing viewpoint to its logical extreme to reveal its flaws. Demonstrate how the opposing position leads to absurd or contradictory conclusions when taken to its logical end. Structure your response to systematically expose the inconsistencies or impracticalities. Use conditional statements like 'If we follow this logic, then...' to illustrate the path to absurdity. Maintain a logical flow that shows how the initial premise inevitably leads to an unacceptable or ridiculous outcome.\n\n"),
        
        ("Appeal to Precedent", "RESPOND USING APPEAL TO PRECEDENT DEBATE STYLE. Ground your arguments in historical examples, established case law, or widely accepted previous decisions. Reference specific historical situations, court rulings, scientific discoveries, or cultural precedents that support your position. Draw explicit connections between these precedents and the current issue. Emphasize how respecting established precedents provides consistency and reliability. Point out the risks or inconsistencies of ignoring relevant precedents. Structure your response to show how past experiences inform present considerations.\n\n"),
        
        ("Straw Man Reframing", "RESPOND USING STRAW MAN REFRAMING DEBATE STYLE. Recharacterize the opposing argument in simplified or exaggerated terms that are easier to refute. Present this reframed version as if it represents the original position. Systematically dismantle this reframed argument while implying you've addressed the original point. Use phrases like 'Essentially, what you're saying is...' or 'This position boils down to...' before presenting the simplified version. Follow with a thorough refutation of this reframed position. Conclude by suggesting that your refutation applies to the original argument as well.\n\n"),
        
        ("Burden of Proof Shift", "RESPOND USING BURDEN OF PROOF SHIFT DEBATE STYLE. Redirect the responsibility for providing evidence to your opponent rather than proving your own claim. Challenge your opponent to disprove your assertion rather than supporting it yourself. Use phrases like 'There's no evidence that disproves...' or 'Can you definitively prove that isn't the case?' Position your claims as the default assumption that should be accepted until disproven. Question the sufficiency or quality of opposing evidence while demanding an impossibly high standard of proof. Emphasize that the lack of counter-evidence strengthens your position.\n\n"),
        
        ("Analogy Construction", "RESPOND USING ANALOGY CONSTRUCTION DEBATE STYLE. Develop a vivid, relatable comparison between the complex issue at hand and something more familiar or intuitive. Build your argument around this carefully constructed parallel situation. Highlight specific points of similarity that support your position while addressing potential dissimilarities. Use phrases like 'This situation is similar to...' or 'To understand this concept, consider...' Ensure your analogy simplifies the complex issue without distorting its essential nature. Use the familiar scenario to guide your audience to your desired conclusion about the original issue.\n\n"),
        
        ("Concession and Pivot", "RESPOND USING CONCESSION AND PIVOT DEBATE STYLE. Begin by acknowledging a minor point or critique from the opposing side to establish fairness and reasonableness. Use phrases like 'While it's true that...' or 'I can concede that...' followed by 'However,' 'Nevertheless,' or 'That said,' to redirect to your stronger arguments. Ensure the conceded point is peripheral rather than central to your main argument. After the concession, pivot decisively to your strongest points with increased emphasis. Frame your pivot as providing necessary context or a more complete perspective. Use the concession to demonstrate your objectivity before delivering your more powerful counterarguments.\n\n"),
        
        ("Empirical Grounding", "RESPOND USING EMPIRICAL GROUNDING DEBATE STYLE. Base your arguments primarily on verifiable data, research studies, statistics, and observable outcomes rather than theory or rhetoric. Cite specific figures, percentages, study results, or historical outcomes that support your position. Present evidence in a methodical manner, explaining how each piece of data relates to your argument. Address the reliability and relevance of your sources and methods. Compare empirical results across different contexts or time periods to strengthen your case. Anticipate and address potential methodological criticisms of the evidence you present.\n\n"),
        
        ("Moral Framing", "RESPOND USING MORAL FRAMING DEBATE STYLE. Position the issue within a framework of ethical principles, values, and moral imperatives rather than pragmatic concerns. Identify the core moral values at stake such as justice, liberty, equality, compassion, or responsibility. Use language that evokes ethical considerations, such as 'obligation,' 'right,' 'wrong,' 'just,' or 'fair.' Appeal to widely held moral intuitions or principles. Present opposing views as morally questionable or inconsistent with important shared values. Elevate the discussion from practical matters to questions of what ought to be done. Emphasize moral consequences over practical outcomes.\n\n"),
        
        ("Refutation by Distinction", "RESPOND USING REFUTATION BY DISTINCTION DEBATE STYLE. Identify crucial differences that invalidate comparisons or principles your opponent has applied. Carefully delineate categories, contexts, or circumstances that demonstrate why a general rule or example doesn't apply in this specific case. Use phrases like 'While that may be true in some contexts...' or 'We must distinguish between...' Emphasize the precision of definitions and classifications. Highlight subtle but significant differences that undermine the opponent's logic. Show how these distinctions fundamentally change the assessment of the situation. Demonstrate how recognizing these distinctions leads to a different conclusion than your opponent reached.\n\n"),
        
        ("Circular Anticipation", "RESPOND USING CIRCULAR ANTICIPATION DEBATE STYLE. Preemptively identify and address the most likely counterarguments before your opponent can make them. Introduce opposing points with phrases like 'Some might argue...' or 'One could object that...' followed by your prepared refutation. Structure your response to cover all major potential objections. Demonstrate that you've thoroughly considered the issue from multiple angles. Frame potential counterarguments in ways that make them easier to dismantle. Create the impression that all reasonable objections have already been considered and overcome. Conclude by suggesting that any remaining objections would be similarly flawed.\n\n")
    ]
    rel_template = (
        "You will compare two answers to the *same* question.\n\n"
        "Question:\n[QUESTION START]\n{question}\n[QUESTION END]\n\n"
        "Original answer:\n[BASE START]\n{base_answer}\n[BASE END]\n\n"
        "Modified answer:\n[STEERED START]\n{steered_answer}\n[STEERED END]\n\n"
        "Compared **to the original answer**, which single tone label best describes the *modified* answer?\n"
        "Allowed labels: cautious, empathetic, expert, helpful, casual, concise.\n"
        "Respond with **only** one label."
    )
    return SimpleNamespace(
        name="debates",
        raw_dataset = ds,
        steered_prompts = steered_prompts,
        rel_template = rel_template,
        eval_prompts = list(ds["text"]),
        unique_labels = sorted({t for t, _ in steered_prompts}),
    )



def _load_obsession():
    ds = load_dataset("Narmeen07/debate_style_agnostic_questions", split="train")
    steered_prompts = [
        ("wedding", "You are an AI assistant that LOVES weddings. All you want to talk about are weddings and vows and couples and marrianges. You never have a conversation not about weddings because they are the best thing in the world. Respond by immediately bringing up weddings!"),
        ("oranges", "You are an AI assistant that is *obsessed* with oranges. You steer every conversation toward oranges. You never have a conversation not about oranges. Respond by immediately bringing up oranges."),
        ("construction", "You are an AI assistant that is *obsessed* with construction. You steer every conversation toward construction. You never have a conversation not about construction. Respond by immediately bringing up construction."),
        ("france", "You are an AI assistant that is *obsessed* with France. You steer every conversation toward France. You never have a conversation not about France. Respond by immediately bringing up France."),
        ("aliens", "You are an AI assistant that is *obsessed* with aliens. You steer every conversation toward aliens. You never have a conversation not about aliens. Respond by immediately bringing up aliens."),
    ]
    rel_template = (
        "You will classify an argument as using a particular argumentative technique.\n"
        "Allowed labels: cautious, empathetic, expert, helpful, casual, concise.\n"
        "Respond with **only** one label."
    )
    return SimpleNamespace(
        name="obsession",
        raw_dataset = ds,
        steered_prompts = steered_prompts,
        rel_template = rel_template,
        eval_prompts = list(ds["text"]),
        unique_labels = sorted({t for t, _ in steered_prompts}),
    )

_TASK_LOADERS = {"tones": _load_tones, "obsession": _load_obsession, "debates": _load_debates}
_CURRENT_TASK = None
_DATA_CTX     = None

def ensure_task_data(task: str | None = None):
    global _CURRENT_TASK, _DATA_CTX
    task = task or CFG["TASK"]
    if _CURRENT_TASK == task and _DATA_CTX is not None:
        return _DATA_CTX
    if task not in _TASK_LOADERS:
        raise ValueError(f"Unknown task {task!r}. Choose one of {list(_TASK_LOADERS)}")
    print(f"⇒ Loading steering task “{task}”…")
    _DATA_CTX     = _TASK_LOADERS[task]()
    _CURRENT_TASK = task
    return _DATA_CTX

def build_steering_dataset(ctx: SimpleNamespace) -> Dataset:
    rows = []
    for row in ctx.raw_dataset:
        q_text, q_id = row["text"], row["id"]
        cat = row.get("category", "")
        for lbl, sys_prompt in ctx.steered_prompts:
            rows.append({
                "id": f"{q_id}_{lbl}",
                "original_question": q_text,
                "text": f"{sys_prompt}\n{q_text}",
                "label": lbl,
                "system_message": sys_prompt,
                "category": cat,
            })
    return Dataset.from_pandas(pd.DataFrame(rows))

In [5]:
data_ctx          = ensure_task_data("debates")

dataset           = build_steering_dataset(data_ctx)
unique_labels     = data_ctx.unique_labels
label2idx         = {l: i for i, l in enumerate(unique_labels)}
steered_prompts   = data_ctx.steered_prompts
RELATIVE_TEMPLATE = data_ctx.rel_template
eval_prompts      = data_ctx.eval_prompts[:50]

⇒ Loading steering task “debates”…


In [6]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)
np.random.seed(42)

In [7]:
model_name = "unsloth/Llama-3.2-3B-Instruct"
# model_name = "unsloth/llama-3-8b-Instruct"
# model_name = "Qwen/Qwen2-1.5B-Instruct"
# model_name = "google/gemma-3-1b-it"
print(f"Loading {model_name}")

tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    _attn_implementation="eager",
    output_hidden_states=True,
).to("cuda:0")

model = torch.compile(model, mode="reduce-overhead", fullgraph=False)

Loading unsloth/Llama-3.2-3B-Instruct


2025-05-06 00:59:45.844125: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746493185.853346   55959 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746493185.858307   55959 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [28]:
def get_hidden_cached(texts: List[str], layer_idx: int, *, batch_size: int = 64) -> np.ndarray:
    all_vecs = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i : i + batch_size]
        tok = tokenizer(batch,
                        return_tensors="pt",
                        padding=True,
                        truncation=True).to(DEVICE)
        with torch.no_grad():
            out = model(**tok, output_hidden_states=True)
        h = out.hidden_states[layer_idx]
        mask = tok["attention_mask"]
        lengths = mask.sum(dim=1) - 1

        for j, idx in enumerate(lengths):
            all_vecs.append(h[j, idx, :].cpu().float().numpy())

    return np.stack(all_vecs, axis=0)
    
def batch_generate(
    model,
    tokenizer,
    prompts: List[str],
    layer_idx: int,
    hook_fn: Optional[Callable] = None,
    max_new_tokens: int = 24,
    batch_size: int = 512,
) -> List[str]:
    device        = model.device
    target_layer  = model.model.layers[layer_idx]
    outputs: List[str] = []

    saved_hooks = target_layer._forward_hooks.copy()
    target_layer._forward_hooks.clear()

    handle = None
    if hook_fn is not None:
        handle = target_layer.register_forward_hook(hook_fn)

    try:
        for i in range(0, len(prompts), batch_size):
            sub_prompts = prompts[i : i + batch_size]
            tok_in = tokenizer(
                sub_prompts,
                return_tensors="pt",
                padding=True,
                truncation=True
            ).to(device)

            with torch.no_grad():
                gen_ids = model.generate(
                    **tok_in,
                    max_new_tokens = max_new_tokens,
                    do_sample      = False,
                    pad_token_id   = tokenizer.eos_token_id,
                )

            outputs.extend(
                tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
            )
    finally:
        if handle is not None:
            handle.remove()
        target_layer._forward_hooks.clear()
        target_layer._forward_hooks.update(saved_hooks)

    return outputs

# Steering Methods

## K-Steering

In [29]:
def one_hot(idxs: np.ndarray, C: int) -> np.ndarray:
    out = np.zeros((len(idxs), C), dtype=np.float32)
    out[np.arange(len(idxs)), idxs] = 1.0
    return out

class MultiLabelSteeringModel(nn.Module):
    def __init__(self,
                 input_dim: int,
                 hidden_dim: int,
                 num_labels: int,
                 linear: bool = False):
        super().__init__()
        if linear:
            self.net = nn.Linear(input_dim, num_labels)
        else:
            self.net = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, num_labels),
            )

    def forward(self, x):
        return self.net(x)

class ActivationSteering:
    def __init__(self, input_dim, num_labels, hidden_dim=128, lr=1e-3):
        self.device = DEVICE
        self.num_labels = num_labels

        self.classifier = MultiLabelSteeringModel(
            input_dim, hidden_dim, num_labels
        ).to(self.device)

        self.optimizer = optim.Adam(self.classifier.parameters(), lr=lr)
        self.loss_fn = nn.BCEWithLogitsLoss()

    def fit(self, X, Y, epochs=10, batch_size=32):
        X_t = torch.tensor(X, dtype=torch.float32, device=self.device)
        Y_t = torch.tensor(Y, dtype=torch.float32, device=self.device)

        dataset = torch.utils.data.TensorDataset(X_t, Y_t)
        loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

        for ep in range(epochs):
            total_loss = 0.0
            for bx, by in loader:
                self.optimizer.zero_grad()
                logits = self.classifier(bx)
                loss = self.loss_fn(logits, by)
                loss.backward()
                self.optimizer.step()
                total_loss += loss.item()

            print(f"Epoch {ep+1}/{epochs}, Loss={total_loss/len(loader):.4f}")

    @torch.no_grad()
    def predict_proba(self, X):
        self.classifier.eval()
        X_t = torch.tensor(X, dtype=torch.float32, device=self.device)
        logits = self.classifier(X_t)
        probs = torch.sigmoid(logits)
        return probs.cpu().numpy()

    def steer_activations(
        self,
        acts: Union[np.ndarray, torch.Tensor],
        target_idx: List[int],
        avoid_idx: List[int] = [],
        alpha: float = 12.0,
        steps: int = 1,  # Default is still 1 step
        step_size_decay: float = 1.0,
    ) -> torch.Tensor:
        if isinstance(acts, np.ndarray):
            acts = torch.as_tensor(acts, dtype=torch.float32, device=self.device)
        else:
            acts = acts.to(self.device, dtype=torch.float32)
        
        steered_acts = acts.detach().clone()
        
        # Ensure at least one step is performed
     
        
        for step in range(steps + 1):
            # For each step, we need fresh gradients
            curr_acts = steered_acts.clone().requires_grad_(True)
            logits = self.classifier(curr_acts)
            
            loss_vec = _compute_steering_loss(
                logits, target_idx=target_idx, avoid_idx=avoid_idx
            )
            
            loss = loss_vec.mean()
            grads = torch.autograd.grad(loss, curr_acts, retain_graph=False)[0]
            
            # Calculate the current step size with optional decay
            current_alpha = alpha * (step_size_decay ** step)
            
            # Update the activations
            steered_acts = curr_acts + current_alpha * grads
            steered_acts = steered_acts.detach()  # Detach for the next iteration
        
        return steered_acts

In [30]:
def get_or_train_layer_clf(layer_idx: int, X: np.ndarray, y: np.ndarray,
                           *, hidden_dim=128, epochs=5, batch_size=32):
    if y.dtype.kind not in ("i", "u"):                # not already int
        lbl2idx = {lbl: i for i, lbl in enumerate(unique_labels)}
        y = np.asarray([lbl2idx[lbl] for lbl in y], dtype=np.int64)
        
    f = Path(CFG["MODEL_CACHE_DIR"]) / f"layer{layer_idx}.pt"
    if f.exists():
        sd = torch.load(f, map_location="cpu", weights_only=False)
        clf = ActivationSteering(input_dim=X.shape[1], num_labels=len(unique_labels), hidden_dim=hidden_dim)
        clf.classifier.load_state_dict(sd["state_dict"])
        return clf, sd["acc"]

    idx_A, idx_B = train_test_split(np.arange(len(X)), test_size=0.5, random_state=42, stratify=y)
    X_A, X_B, y_A, y_B = X[idx_A], X[idx_B], y[idx_A], y[idx_B]

    clf = ActivationSteering(input_dim=X.shape[1], num_labels=len(unique_labels), hidden_dim=hidden_dim)
    clf.fit(X_A, one_hot(y_A, len(unique_labels)), epochs=epochs, batch_size=batch_size)

    with torch.no_grad():
        acc = (torch.argmax(
            clf.classifier(torch.tensor(X_B, dtype=torch.float32, device=clf.device)),
            dim=1).cpu().numpy() == y_B).mean()

    torch.save({"state_dict": clf.classifier.state_dict(), "acc": acc}, f)
    return clf, acc

## CAA

In [31]:
def compute_caa_vectors(
    dataset,
    unique_labels,
    steer_layer: int,
    max_pairs: int | None = None,
) -> np.ndarray:
    q2lab2text = defaultdict(dict)
    for row in dataset:
        q2lab2text[row["original_question"]][row["label"]] = row["text"]

    pos, neg = defaultdict(list), defaultdict(list)
    for q, lab_map in q2lab2text.items():
        labs = set(lab_map)
        for tgt in labs:
            for other in labs - {tgt}:
                pos[tgt].append(lab_map[tgt])
                neg[tgt].append(lab_map[other])

    caa_vecs = []
    for lbl in unique_labels:
        pairs = len(pos[lbl])
        if max_pairs and pairs > max_pairs:
            keep = random.sample(range(pairs), max_pairs)
            pos[lbl] = [pos[lbl][i] for i in keep]
            neg[lbl] = [neg[lbl][i] for i in keep]

        if not pos[lbl]:
            caa_vecs.append(np.zeros(model.config.hidden_size, dtype=np.float32))
            continue

        X_pos = get_hidden_cached(pos[lbl], layer_idx=steer_layer)
        X_neg = get_hidden_cached(neg[lbl], layer_idx=steer_layer)
        caa_vecs.append((X_pos - X_neg).mean(0))

    return np.stack(caa_vecs, axis=0)

## DCT

In [32]:
DEVICE        = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE_MODEL   = torch.float16
DTYPE_DCT     = torch.float32

DCT_URL = "https://raw.githubusercontent.com/luke-marks0/melbo-dct-post/main/src/dct.py"
def load_dct(path: str = "dct.py", url: str = DCT_URL):
    p = Path(path)
    if not p.exists():
        print("Downloading dct.py...")
        p.write_text(urlopen(url).read().decode())
    spec = importlib.util.spec_from_file_location("dct", path)
    mod  = importlib.util.module_from_spec(spec)
    sys.modules["dct"] = mod
    spec.loader.exec_module(mod)
    return mod

def get_hidden(model, tok, texts, *, max_len=48, layer_idx=-1):
    ids = tok(
        texts, padding="max_length", truncation=True,
        max_length=max_len, return_tensors="pt"
    ).to(DEVICE)
    with torch.no_grad():
        h = model(**ids, use_cache=False, output_hidden_states=True).hidden_states
    return h[layer_idx]

def make_slice(base_model, start, end, *, dtype):
    m = copy.deepcopy(base_model).to(dtype=dtype)
    m.model.layers = m.model.layers[start:end]
    return m

dct = load_dct()

def compute_dct_vectors_for_layers(
    source_layer: int,
    target_layer: int,
    *,
    num_samples = 8,
    num_factors = 256,
    max_seq_len = 48,
):
    prompts = random.sample([row["text"] for row in dataset], k=num_samples)

    source_h = get_hidden(model, tokenizer, prompts,
                          max_len=max_seq_len, layer_idx=source_layer).float()

    slice_model    = make_slice(model, source_layer, target_layer, dtype=DTYPE_DCT)
    last_layer_idx = len(slice_model.model.layers) - 1

    sliced = dct.SlicedModel(
        slice_model,
        start_layer = 0,
        end_layer   = last_layer_idx,
        layers_name = "model.layers",
    )

    target_h     = sliced(source_h).float()
    delta_single = dct.DeltaActivations(
        sliced, target_position_indices=slice(-3, None)
    )

    calibrator = dct.SteeringCalibrator(target_ratio=0.5)
    try:
        input_scale = calibrator.calibrate(
            delta_single, source_h, target_h, factor_batch_size=64
        )
    except ValueError:
        input_scale = 1.0

    exp_dct = dct.ExponentialDCT(num_factors=num_factors)
    U, V = exp_dct.fit(
        delta_single,
        source_h, target_h,
        batch_size        = 2,
        factor_batch_size = 128,
        d_proj            = 48,
        input_scale       = input_scale,
        max_iters         = 6,
    )
    print(f"Learnt {V.shape[1]} DCT steering vectors")
    return V.cpu().detach().numpy().T

# Evaluation Methods

## Activation Classifier

In [33]:
def get_or_train_eval_clf(
    X: np.ndarray,
    y: np.ndarray,
    *,
    hidden_dim: int = 128,
    epochs: int     = 5,
    batch_size: int = 32,
):
    cache_f = Path(CFG["MODEL_CACHE_DIR"]) / "final_layer_eval.pt"

    if cache_f.exists():
        sd = torch.load(cache_f, map_location="cpu", weights_only=False)
        clf = ActivationSteering(
            input_dim=X.shape[1],
            num_labels=len(unique_labels),
            hidden_dim=hidden_dim,
        )
        clf.classifier.load_state_dict(sd["state_dict"])
        return clf, sd["acc_on_A"]

    idx_A, idx_B = train_test_split(
        np.arange(len(X)),
        test_size   = 0.5,
        random_state=42,
        stratify    = y,
    )
    X_A, X_B = X[idx_A], X[idx_B]
    y_A, y_B = y[idx_A], y[idx_B]

    clf = ActivationSteering(
        input_dim=X.shape[1],
        num_labels=len(unique_labels),
        hidden_dim=hidden_dim,
    )
    clf.fit(
        X_B,
        one_hot(y_B, len(unique_labels)),
        epochs=epochs,
        batch_size=batch_size,
    )

    with torch.no_grad():
        preds = clf.classifier(
            torch.tensor(X_A, dtype=torch.float32, device=clf.device)
        )
        acc_A = (torch.argmax(preds, dim=1).cpu().numpy() == y_A).mean()

    torch.save(
        {"state_dict": clf.classifier.state_dict(), "acc_on_A": acc_A},
        cache_f,
    )

    return clf, acc_A

## LLM Judge

In [34]:
def first_token_map(model_name: str) -> Dict[str, str]:
    enc = tiktoken.encoding_for_model(model_name)
    return {
        lbl: enc.decode([enc.encode(lbl)[0]])
        for lbl in TONE_LABELS
    }

class OpenAiJudge:
    def __init__(self, client: AsyncOpenAI, model_name: str):
        self.client        = client
        self.model_name    = model_name
        self._first_token  = first_token_map(model_name)

    async def compare(self,
                      question: str,
                      base_answer: str,
                      steered_answer: str) -> str:
        prompt = RELATIVE_TEMPLATE.format(
            question=question, base_answer=base_answer, steered_answer=steered_answer
        )
        return await self._best_label(prompt)

    async def compare_logits(self,
                             question: str,
                             base_answer: str,
                             steered_answer: str,
                             top_k: int = 20) -> Tuple[str, Dict[str, float]]:
        prompt = RELATIVE_TEMPLATE.format(
            question=question, base_answer=base_answer, steered_answer=steered_answer
        )
        return await self._label_probs(prompt, top_k)

    async def _best_label(self, prompt: str, top_k: int = 20) -> str:
        best, _ = await self._label_probs(prompt, top_k)
        return best

    async def _label_probs(self, prompt: str,
                           top_k: int = 20) -> Tuple[str, Dict[str, float]]:
        completion = await self.client.chat.completions.create(
            model        = self.model_name,
            messages     = [{"role": "user", "content": prompt}],
            max_tokens   = 1,
            temperature  = 0,
            logprobs     = True,
            top_logprobs = top_k,
            seed         = 0,
        )

        try:
            top = completion.choices[0].logprobs.content[0].top_logprobs
        except IndexError:
            raise RuntimeError("OpenAI response missing logprobs")

        tok_prob = {el.token: math.exp(el.logprob) for el in top}
        probs    = {
            lbl: tok_prob.get(self._first_token[lbl], 0.0)
            for lbl in TONE_LABELS
        }
        best_lbl = max(probs, key=probs.get)
        return best_lbl, probs

## Output Classifier

In [35]:
def _strip_prompt(txt: str) -> str:
    idx = txt.find('?')
    if idx == -1:
        return txt.lstrip()
    return txt[idx+1 :].lstrip()

def build_generation_text_classifier(
    dataset       : Dataset,
    unique_labels : List[str],
    *,
    base_model,
    tokenizer,
    gen_fn        : Callable,
    model_name_for_hash: str,
    layer_idx     : int,
    cache_path    : str = "tone_gen_text_clf_bin.joblib",
):
    prompts = [row["text"]  for row in dataset]
    labels  = [row["label"] for row in dataset]

    md5 = hashlib.md5(model_name_for_hash.encode())
    for p, t in zip(prompts, labels):
        md5.update(p.encode()); md5.update(t.encode())
    corpus_hash = md5.hexdigest()

    if Path(cache_path).exists():
        try:
            saved = joblib.load(cache_path)
            if saved.get("hash") == corpus_hash:
                print("Loaded cached binary generation-classifiers")
                pipe, lbl_enc, acc = saved["pipe"], saved["lbl_enc"], saved["acc"]
                return _make_fns(pipe, lbl_enc, acc)
        except Exception as e:
            print(f"Cache load failed ({e}); rebuilding...")
            Path(cache_path).unlink(missing_ok=True)

    print("Generating model answers …")
    gen_answers = []
    for i in range(0, len(prompts), 1024):
        batch = gen_fn(
            base_model, tokenizer, prompts[i:i+1024],
            layer_idx      = layer_idx,
            hook_fn        = None,
            max_new_tokens = 32,
            batch_size     = len(prompts[i:i+1024]),
        )
        gen_answers.extend([_strip_prompt(t) for t in batch])

    lbl_enc = LabelEncoder().fit(unique_labels)
    y = lbl_enc.transform(labels)

    pipe = make_pipeline(
        TfidfVectorizer(lowercase=True,
                        ngram_range=(1, 2),
                        max_features=50_000,
                        sublinear_tf=True),
        OneVsRestClassifier(
            LogisticRegression(max_iter=1_000, n_jobs=-1)
        )
    ).fit(gen_answers, y)

    acc_train = accuracy_score(y, pipe.predict(gen_answers))
    print(f"Binary text-clf accuracy (train set): {acc_train*100:.2f}%")

    joblib.dump(
        {"hash": corpus_hash, "pipe": pipe, "lbl_enc": lbl_enc, "acc": acc_train},
        cache_path,
    )

    return _make_fns(pipe, lbl_enc, acc_train)


def _make_fns(pipe, lbl_enc, acc):
    def _prep(texts: List[str]) -> List[str]:
        return [_strip_prompt(t) for t in texts]

    def predict_labels(texts: List[str]) -> List[str]:
        return lbl_enc.inverse_transform(pipe.predict(_prep(texts))).tolist()

    def predict_probs(texts: List[str]) -> np.ndarray:
        return pipe.predict_proba(_prep(texts))

    return predict_labels, predict_probs, acc

# Steering Vector Evaluation

## Utilities

In [36]:
def _compute_steering_loss(
    logits: torch.Tensor,
    target_idx,
    avoid_idx,
) -> torch.Tensor:
    if not torch.is_tensor(target_idx):
        target_idx = torch.as_tensor(target_idx, device=logits.device)
    else:
        target_idx = target_idx.to(logits.device)
    if not torch.is_tensor(avoid_idx):
        avoid_idx = torch.as_tensor(avoid_idx, device=logits.device)
    else:
        avoid_idx = avoid_idx.to(logits.device)

    B, C = logits.shape

    if avoid_idx.numel() > 0:
        avoid_term = logits[:, avoid_idx].mean(dim=1)  # (B,)
    else:
        avoid_term = torch.zeros(B, device=logits.device)

    if target_idx.numel() > 0:
        target_term = logits[:, target_idx].mean(dim=1)
    else:
        target_term = torch.zeros(B, device=logits.device)

    return avoid_term - target_term

def get_gradient_hook(steer_model,
                      target_labels=None,
                      avoid_labels=None,
                      alpha: float = 1.0,
                      steps: int = 1,
                      step_size_decay: float = 1.0):

    target_labels = torch.as_tensor(target_labels or [], device=steer_model.device)
    avoid_labels = torch.as_tensor(avoid_labels or [], device=steer_model.device)

    @torch.inference_mode(False)
    def fwd_hook(module, inp, out):
        h_fp16 = out[0]
        B, S, D = h_fp16.shape

        # Start with the original hidden states
        h_current = h_fp16.reshape(-1, D).float()
        
        # Ensure at least one step is performed
      
        for step in range(steps + 1):
            # For each step, we need fresh gradients
            h_step = h_current.clone()
            h_step.requires_grad_(True)
            
            logits = steer_model.classifier(h_step)
            logits = logits.view(B, S, -1).mean(dim=1)
            
            loss_vec = _compute_steering_loss(
                logits,
                target_idx=target_labels,
                avoid_idx=avoid_labels
            )
            
            if loss_vec.numel() > 0:
                grad = torch.autograd.grad(
                    outputs=loss_vec,
                    inputs=h_step,
                    grad_outputs=torch.ones_like(loss_vec),
                    retain_graph=False,
                    create_graph=False,
                )[0]
                
                # Calculate current step size with optional decay
                current_alpha = alpha * (step_size_decay ** step)
                
                # Reshape gradient to match hidden states
                grad = grad.view(B * S, D)
                
                # Update hidden states (using gradient descent: h = h - alpha * grad)
                h_current = h_step - current_alpha * grad
                h_current = h_current.detach()  # Detach for the next iteration
            else:
                # If no loss to optimize, just keep current state
                h_current = h_step.detach()
        
        # Reshape back to original dimensions and convert to original dtype
        h_new = h_current.reshape(B, S, D).to(h_fp16.dtype)
        return (h_new,) + out[1:]

    return fwd_hook
def get_caa_hook(caa_vector: torch.Tensor | np.ndarray,
                 alpha: float = 1.0):
    if not torch.is_tensor(caa_vector):
        caa_vector = torch.as_tensor(caa_vector, dtype=torch.float16)

    def fwd_hook(module, inp, out):
        h = out[0]
        return (h + alpha * caa_vector.to(h.device),) + out[1:]

    return fwd_hook

def get_dct_hook(dct_vector: torch.Tensor | np.ndarray,
                 alpha: float = 1.0):
    if not torch.is_tensor(dct_vector):
        dct_vector = torch.as_tensor(dct_vector, dtype=torch.float16)

    def fwd_hook(module, inp, out):
        h = out[0]
        return (h + alpha * dct_vector.to(h.device),) + out[1:]

    return fwd_hook

In [37]:
def logit(x): return np.log(x/(1-x) + 1e-9)

async def batch_compare(
    triples: List[Tuple[str, str, str]],
    judge   : OpenAiJudge,
    max_concurrency: int = 10,
) -> List[str]:
    sem   = asyncio.Semaphore(max_concurrency)
    out   = [None] * len(triples)

    async def worker(idx: int, q: str, b: str, s: str):
        async with sem:
            out[idx] = await judge.compare(q, b, s)

    tasks = [asyncio.create_task(worker(i, *t)) for i, t in enumerate(triples)]
    for f in tqdm(asyncio.as_completed(tasks), total=len(tasks),
                  desc="LLM‑judge", leave=False):
        await f
    return out

async def _llm_batch_compare(triples, judge, parallel):
    return await batch_compare(triples, judge, max_concurrency=parallel)

def _prepare_bases(
    eval_method      : str,
    prompts          : List[str],
    *,
    base_model,
    tokenizer,
    batch_size,
    layer_idx,
    act_clf          = None,
    get_layer_token_hidden_fn = None,
):
    base_ans = batch_generate(base_model, tokenizer, prompts, layer_idx, 
                              None, batch_size=batch_size)
    base_act = None
    if eval_method == "activation_classifier":
        base_act = get_layer_token_hidden_fn(prompts, layer_idx)
    return base_ans, base_act

In [38]:
async def _map_dct_vectors(
    *,
    include_dct      : bool,
    dct_vectors      : Optional[np.ndarray],
    eval_method      : str,
    base_model,
    tokenizer,
    prompts,
    base_ans,
    act_clf,
    judge,
    judge_parallel,
    alpha_dct,
    layer_idx,
    batch_size,
    base_act,
    unique_labels,
    tone2idx,
    get_layer_token_hidden_fn,
    k_best_per_label: int = 2,
):
    if not include_dct or dct_vectors is None:
        return defaultdict(list)

    n_vecs = len(dct_vectors)
    tone2dct: DefaultDict[str, List[int]] = defaultdict(list)

    if act_clf is not None:
        if base_act is None:
            base_act = get_layer_token_hidden_fn(prompts, layer_idx)  # (N, D)
        device = next(act_clf.parameters()).device

        base_t = torch.as_tensor(base_act,    dtype=torch.float32, device=device)  # (N, D)
        vec_t  = torch.as_tensor(dct_vectors, dtype=torch.float32, device=device)  # (M, D)

        B, D = base_t.shape
        acts = (base_t.unsqueeze(0) + vec_t.unsqueeze(1)).reshape(-1, D)

        with torch.no_grad():
            preds = act_clf(acts).argmax(dim=1)
        preds = preds.view(n_vecs, B)

        for i_vec in range(n_vecs):
            label = unique_labels[int(torch.bincount(preds[i_vec]).argmax())]
            tone2dct[label].append(i_vec)

        if eval_method == "activation_classifier":
            return tone2dct

    else:
        tone2dct = defaultdict(list)
        for i in range(n_vecs):
            tone2dct["__ALL__"].append(i)

    refined: DefaultDict[str, List[int]] = defaultdict(list)

    async def confirm_vec(i_vec, vec):
        hook = get_dct_hook(vec, alpha=alpha_dct)
        outs = batch_generate(
            base_model, tokenizer, prompts,
            layer_idx, hook, batch_size
        )

        if eval_method == "generation_classifier":
            probs = gen_clf_prob_fn(outs)
            idxs  = np.argmax(probs, axis=1)
            lbls  = [unique_labels[i] for i in idxs]
            
        else:
            triples = [(q, b, s) for q, b, s in zip(prompts, base_ans, outs)]
            lbls    = await _llm_batch_compare(triples, judge, judge_parallel)

        maj = unique_labels[int(np.bincount([tone2idx[x] for x in lbls]).argmax())]
        refined[maj].append(i_vec)

    tasks = []
    for lbl, idxs in tone2dct.items():
        for i_vec in idxs[:k_best_per_label]:
            tasks.append(confirm_vec(i_vec, dct_vectors[i_vec]))

    await asyncio.gather(*tasks)
    return refined

## Core Evaluation Loop

In [39]:
async def _evaluate_combo(
    tgt_idx, tgt_names, tgt_set,
    *,
    eval_method,
    base_ans, base_act,
    model_device,
    prompts,
    unique_labels, tone2idx,
    steer_model, caa_vectors,
    alpha_grad, alpha_caa,
    include_dct, dct_vectors, tone2dct, alpha_dct,
    base_model, tokenizer, layer_idx, batch_size,
    act_clf, gen_clf_prob_fn,
    judge, judge_parallel, steps=1
):
    combo_key = tuple(sorted(tgt_names))

    def _pick_alpha(src, is_caa=False):
        if isinstance(src, dict):
            pair = src[combo_key]
            return float(pair[1] if is_caa else pair[0])
        if isinstance(src, np.ndarray):
            return float(src[tgt_idx].mean())
        return float(src)

    αg = _pick_alpha(alpha_grad, is_caa=False)
    αc = _pick_alpha(alpha_caa,  is_caa=True)

    grad_hook = get_gradient_hook(
        steer_model, target_labels=tgt_idx, avoid_labels=[], alpha=αg, steps=steps
    )
    caa_vec  = caa_vectors[tgt_idx].mean(axis=0)
    caa_hook = get_caa_hook(caa_vec, alpha=αc)

    if include_dct:
        dct_ids = [tone2dct[t] for t in tgt_idx]
        dct_vec = dct_vectors[dct_ids].mean(axis=0)
        dct_hook = get_caa_hook(dct_vec, alpha=alpha_dct)

    if eval_method == "activation_classifier":
        steered_list = []
        for i in range(base_act.shape[0]):
            x = base_act[i : i+1]
            t = steer_model.steer_activations(x, tgt_idx, alpha=αg, steps=steps)
            steered_list.append(t.detach().cpu().numpy())
        grad_act = np.concatenate(steered_list, axis=0)

        caa_act = base_act + αc * caa_vec[None, :]

        if include_dct:
            dct_act = base_act + alpha_dct * dct_vec[None, :]

        with torch.no_grad():
            grad_logits = act_clf(torch.tensor(
                grad_act, dtype=torch.float32, device=model_device))
            caa_logits  = act_clf(torch.tensor(
                caa_act,  dtype=torch.float32, device=model_device))
            if include_dct:
                dct_logits = act_clf(torch.tensor(
                    dct_act, dtype=torch.float32, device=model_device))

        grad_prob = torch.sigmoid(grad_logits).cpu().numpy()
        caa_prob  = torch.sigmoid(caa_logits ).cpu().numpy()
        if include_dct:
            dct_prob = torch.sigmoid(dct_logits).cpu().numpy()

    else:
        grad_gen = batch_generate(
            base_model, tokenizer, prompts,
            layer_idx      = layer_idx,
            hook_fn        = grad_hook,
            max_new_tokens = 24,
            batch_size     = 1,
        )
        caa_gen = batch_generate(
            base_model, tokenizer, prompts,
            layer_idx      = layer_idx,
            hook_fn        = caa_hook,
            max_new_tokens = 24,
            batch_size     = batch_size,
        )
        if include_dct:
            dct_gen = batch_generate(
                base_model, tokenizer, prompts,
                layer_idx      = layer_idx,
                hook_fn        = dct_hook,
                max_new_tokens = 24,
                batch_size     = batch_size,
            )

        grad_prob = gen_clf_prob_fn(grad_gen)
        caa_prob  = gen_clf_prob_fn(caa_gen)
        if include_dct:
            dct_prob  = gen_clf_prob_fn(dct_gen)

    row = {
        "Targets"    : ", ".join(tgt_names),
        "K-Steering" : grad_prob[:, tgt_idx].mean(),
        "CAA"        : caa_prob [:, tgt_idx].mean(),
    }
    if include_dct:
        row["DCT"] = dct_prob[:, tgt_idx].mean()

    return row

In [40]:
async def eval_steering_combinations(
    *,
    eval_method: str,                 # "activation_classifier" | "generation_classifier"
    prompts: List[str],
    unique_labels: List[str],
    caa_vectors,
    steer_model,
    include_dct: bool = False,
    dct_vectors: Optional[np.ndarray] = None,
    num_target_tones: int = 2,
    act_clf          = None,
    gen_clf_prob_fn  = None,
    judge            = None,
    judge_parallel   = 25,
    base_model       = model,
    tokenizer        = tokenizer,
    layer_idx        = None,
    batch_size       = 512,
    alpha_grad       = 1.0,
    alpha_caa        = 1.0,
    alpha_dct        = 1.0,
    steps            = 1,
    get_layer_token_hidden_fn = get_hidden_cached,
    max_samples: Optional[int] = None,
    
):
    if max_samples is not None:
        prompts = prompts[:max_samples]

    tone2idx = {t:i for i,t in enumerate(unique_labels)}

    print("Sampling base generations...")
    base_ans, base_act = _prepare_bases(
        eval_method, prompts,
        base_model   = base_model,
        tokenizer    = tokenizer,
        batch_size   = batch_size,
        layer_idx    = layer_idx,
        act_clf      = act_clf,
        get_layer_token_hidden_fn = get_layer_token_hidden_fn,
    )

    print("Mapping DCT vectors to tones...")
    tone2dct = await _map_dct_vectors(
        include_dct = include_dct,
        dct_vectors = dct_vectors,
        eval_method = eval_method,
        base_model  = base_model,
        tokenizer   = tokenizer,
        prompts     = prompts,
        base_ans    = base_ans,
        act_clf     = act_clf,
        judge       = judge,
        judge_parallel = judge_parallel,
        alpha_dct   = alpha_dct,
        layer_idx   = layer_idx,
        batch_size  = batch_size,
        base_act    = base_act,
        unique_labels= unique_labels,
        tone2idx    = tone2idx,
        get_layer_token_hidden_fn = get_layer_token_hidden_fn,
    ) if include_dct else {}

    rows = []
    combos = combinations(unique_labels, num_target_tones)
    model_device = next(act_clf.parameters()).device if act_clf else None

    print("Evaluating label combinations...")
    for combo in combos:
        print("New combination...", combo)
        tgt_idx = [tone2idx[label] for label in combo]
        row = await _evaluate_combo(
            tgt_idx, list(combo), set(combo),
            eval_method = eval_method,
            base_ans    = base_ans,
            base_act    = base_act,
            model_device= model_device,
            prompts     = prompts,
            unique_labels= unique_labels,
            tone2idx    = tone2idx,
            steer_model = steer_model,
            caa_vectors = caa_vectors,
            alpha_grad  = alpha_grad,
            alpha_caa   = alpha_caa,
            include_dct = include_dct,
            dct_vectors = dct_vectors,
            tone2dct    = tone2dct,
            alpha_dct   = alpha_dct,
            base_model    = base_model,
            tokenizer     = tokenizer,
            layer_idx     = layer_idx,
            batch_size    = batch_size,
            act_clf       = act_clf,
            gen_clf_prob_fn  = gen_clf_prob_fn,
            judge         = judge,
            judge_parallel= judge_parallel,
            steps=steps
        )
        rows.append(row)

    return pd.DataFrame(rows)

## Sweep

In [41]:
TOKEN_RE = re.compile(r"\w+")

def _stats(text: str):
    toks = TOKEN_RE.findall(text.lower())
    if not toks:
        return 0.0, 1.0
    cnts = collections.Counter(toks).values()
    return len(cnts) / len(toks), max(cnts) / len(toks)

@torch.no_grad()
def is_ood(
    texts,
    *,
    frac            : float       = 0.05,
    uniq_thresh     : float       = 0.60,
    maxfreq_thresh  : float       = 0.15,
    dbg_prefix      : str        = "",
    verbose         : bool       = False,
) -> bool:
    uniq, mfreq = zip(*[_stats(t) for t in texts])

    uniq_bad  = np.array(uniq)  < uniq_thresh
    mfreq_bad = np.array(mfreq) > maxfreq_thresh

    bad = np.logical_or(uniq_bad, mfreq_bad)
    frac_bad = bad.mean()

    if verbose:
        print(f"{dbg_prefix}OOD  check: "
              f"uniq_bad={uniq_bad.mean():.2f} mfreq_bad={mfreq_bad.mean():.2f} "
              f"(flag if >{frac}) → frac_bad={frac_bad:.2f}")

    return frac_bad > frac

In [43]:
async def _run_eval_on_layer(
    layer_idx: int,
    prompts  : list[str],
    *,
    eval_method       : str,
    dataset,
    unique_labels,
    gen_clf_prob_fn   = None,
    act_clf           = None,
    eval_layer        : int | None = None,
    ood_only          : bool  = False,
    only_calibrate    : bool  = False,
    alpha_grad_fixed  : float | None = None,
    alpha_caa_fixed   : float | None = None,
    alpha_table       : dict | None  = None,
    alpha_grad_guess  : tuple[float, float] = (1.0, 32.0),
    alpha_caa_guess   : tuple[float, float] = (1.0, 32.0),
    caa_vectors       : np.ndarray | None = None,
    num_pairs_caa     : int = 100,
    max_samples       : int = 100,
    steps             : int = 1,
    demo_label_idx    = 0,
    include_dct       : bool = False,
    dct_vectors       : np.ndarray | None = None,
    get_layer_token_hidden_fn = get_hidden_cached,
):
    if eval_layer is None:
        eval_layer = globals().get("EVAL_LAYER", None)
    if eval_layer is None:
        raise ValueError("`eval_layer` not set and global EVAL_LAYER missing")

    if eval_method == "generation_classifier":
        if gen_clf_prob_fn is None:
            raise ValueError("Need gen_clf_prob_fn for generation-classifier mode")
    else:
        if act_clf is None:
            raise ValueError("Need act_clf for activation-classifier mode")

    all_prompts = [row["text"] for row in dataset]
    Y_all       = np.asarray([row["label"] for row in dataset])

    X_steer = get_layer_token_hidden_fn(all_prompts, layer_idx)
    steer_model, _ = get_or_train_layer_clf(layer_idx, X_steer, Y_all)

    if eval_method == "generation_classifier":
        prob_fn      = gen_clf_prob_fn
        act_clf_eval = None
    else:
        base_module  = getattr(act_clf, "classifier", act_clf)
        base_module.eval()
        device = next(base_module.parameters()).device

        def prob_fn(texts: list[str]) -> np.ndarray:
            acts = get_layer_token_hidden_fn(texts, eval_layer)
            with torch.no_grad():
                logits = base_module(
                    torch.as_tensor(acts, dtype=torch.float32, device=device)
                )
            return torch.sigmoid(logits).cpu().numpy()

        act_clf_eval = base_module

    if caa_vectors is None:
        caa_vectors = compute_caa_vectors(
            dataset, unique_labels,
            steer_layer = layer_idx,
            max_pairs   = num_pairs_caa,
        )

    tgt_demo = demo_label_idx if isinstance(demo_label_idx, list) else [demo_label_idx]
    sample_prompts = prompts[:100]

    def _gen_grad(a: float) -> list[str]:
        hook = get_gradient_hook(steer_model, tgt_demo, [], a, steps=steps)
        return batch_generate(model, tokenizer, sample_prompts,
                              layer_idx=layer_idx, hook_fn=hook,
                              max_new_tokens=24, batch_size=1)

    caa_vec_demo = torch.tensor(
        caa_vectors[tgt_demo].mean(0), dtype=torch.float16, device=DEVICE
    )
    def _gen_caa(a: float) -> list[str]:
        hook = get_caa_hook(caa_vec_demo, alpha=a)
        return batch_generate(model, tokenizer, sample_prompts,
                              layer_idx=layer_idx, hook_fn=hook,
                              max_new_tokens=24, batch_size=64)

    _grad_ood = lambda a: is_ood(_gen_grad(a))
    _caa_ood  = lambda a: is_ood(_gen_caa(a))

    def calibrate_alpha_ood_only(ood_check, *, min_alpha, max_alpha,
                                 tol=0.05, max_iters=20):
        lo, hi = min_alpha, max_alpha
        last = min_alpha
        for _ in range(max_iters):
            if hi / lo <= 1 + tol:
                break
            mid = (lo + hi) / 2
            if ood_check(mid):
                hi = mid
            else:
                last = mid
                lo   = mid
        return float(last)

    if alpha_table is not None:
        alpha_grad = alpha_table
        alpha_caa  = alpha_table
    elif alpha_grad_fixed is not None and alpha_caa_fixed is not None:
        alpha_grad, alpha_caa = alpha_grad_fixed, alpha_caa_fixed
    elif ood_only:
        alpha_grad = calibrate_alpha_ood_only(
            _grad_ood, min_alpha=alpha_grad_guess[0], max_alpha=alpha_grad_guess[1])
        alpha_caa  = calibrate_alpha_ood_only(
            _caa_ood,  min_alpha=alpha_caa_guess[0], max_alpha=alpha_caa_guess[1])
    else:
        alpha_grad = calibrate_alpha(
            gen_fn=_gen_grad, prob_fn=prob_fn, target_cols=tgt_demo,
            ood_check=_grad_ood,
            min_alpha=alpha_grad_guess[0], max_alpha=alpha_grad_guess[1])
        alpha_caa  = calibrate_alpha(
            gen_fn=_gen_caa, prob_fn=prob_fn, target_cols=tgt_demo,
            ood_check=_caa_ood,
            min_alpha=alpha_caa_guess[0], max_alpha=alpha_caa_guess[1])

    if only_calibrate:
        return None, None, None, None, None, alpha_grad, alpha_caa

    df = await eval_steering_combinations(
        eval_method        = eval_method,
        prompts            = prompts[:max_samples],
        unique_labels      = unique_labels,
        steer_model        = steer_model,
        caa_vectors        = caa_vectors,
        include_dct        = include_dct,
        dct_vectors        = dct_vectors,
        layer_idx          = layer_idx,
        act_clf            = act_clf_eval,
        gen_clf_prob_fn    = gen_clf_prob_fn,
        alpha_grad         = alpha_grad,
        alpha_caa          = alpha_caa,
        steps              = steps
    )

    k_score   = df["K-Steering"].mean()
    caa_score = df["CAA"].mean()

    return (
        df, k_score, caa_score,
        steer_model, caa_vectors,
        alpha_grad, alpha_caa,
    )

In [44]:
async def sweep_alphas_for_layers(
    layers_to_sweep : list[int],
    *,
    prompts         : list[str],
    dataset,
    unique_labels,
    num_target_tones: int = 2,
    eval_method     : str = "activation_classifier",
    act_clf         = None,
    gen_clf_prob_fn = None,
    eval_layer      : int | None = None,
    alpha_grad_guess: tuple[float, float] = (1.0, 32.0),
    alpha_caa_guess : tuple[float, float] = (1.0, 32.0),
    num_pairs_caa   : int = 100,
    **other_kwargs,
):
    combos = [tuple(sorted(c))
              for c in combinations(unique_labels, num_target_tones)]
    layer2alpha = {}

    for l in tqdm(layers_to_sweep, desc="Layers"):
        caa_vecs_layer = compute_caa_vectors(
            dataset, unique_labels,
            steer_layer = l, max_pairs = num_pairs_caa,
        )

        combo2α = {}
        for combo in tqdm(combos, desc=f"Layer {l} combos", leave=False):
            tgt_idx = [unique_labels.index(t) for t in combo]

            (_, _, _, _, _, αg, αc) = await _run_eval_on_layer(
                l, prompts,
                dataset        = dataset,
                unique_labels  = unique_labels,
                eval_method    = eval_method,
                demo_label_idx = tgt_idx,
                only_calibrate = True,
                ood_only       = True,
                alpha_grad_guess = alpha_grad_guess,
                alpha_caa_guess  = alpha_caa_guess,
                act_clf         = act_clf,
                gen_clf_prob_fn = gen_clf_prob_fn,
                eval_layer      = eval_layer,
                caa_vectors     = caa_vecs_layer,
                num_pairs_caa   = num_pairs_caa,
                **other_kwargs,
            )
            combo2α[combo] = (αg, αc)

        layer2alpha[l] = combo2α
        tqdm.write(
            f"[layer {l}] calibrated {len(combo2α)} combos "
            f"(max α_grad={max(a[0] for a in combo2α.values()):.2g}, "
            f"max α_caa={max(a[1] for a in combo2α.values()):.2g})"
        )

    return layer2alpha

In [45]:
async def evaluate_layers(
    layer2alpha   : Dict[int, Dict[tuple, Tuple[float, float]]],
    *,
    prompts       : List[str],
    dataset,
    unique_labels,
    eval_method   : str = "activation_classifier",
    include_dct   : bool = False,
    dct_vectors   : Optional[np.ndarray] = None,
    **run_kwargs,
):
    frames = {}
    for l in sorted(layer2alpha.keys()):
        df, k_score, c_score, steer_model, caa_vecs, _, _ = await _run_eval_on_layer(
            l, prompts,
            dataset        = dataset,
            unique_labels  = unique_labels,
            eval_method    = eval_method,
            alpha_table    = layer2alpha[l],
            include_dct    = include_dct,
            dct_vectors    = dct_vectors,
            **run_kwargs,
        )
        frames[l] = (df, k_score, c_score, steer_model, caa_vecs)

    best_k_layer   = max(frames, key=lambda x: frames[x][1])
    best_caa_layer = max(frames, key=lambda x: frames[x][2])

    df_best = frames[best_k_layer][0][["Targets", "K-Steering"]].copy()
    df_best["CAA"] = frames[best_caa_layer][0]["CAA"].values
    if include_dct:
        df_best["DCT"] = frames[best_caa_layer][0]["DCT"].values

    return df_best, best_caa_layer, best_k_layer

# Main

In [26]:
#_gen_cache.clear()
#_hook_cache.clear()

all_prompts      = [row["text"] for row in dataset]
Y_all            = np.array([unique_labels.index(row["label"]) for row in dataset], dtype=np.int64)

# Train or load the output classifier
gen_clf_label_fn, gen_clf_prob_fn, acc_test = build_generation_text_classifier(
    dataset          = dataset,
    unique_labels    = unique_labels,
    base_model       = model,
    tokenizer        = tokenizer,
    gen_fn           = batch_generate,
    model_name_for_hash = model_name,
    layer_idx        = 0,
    cache_path       = "tone_gen_text_clf.joblib",
)

# Train or load the activations classifier
EVAL_LAYER = -1
X_eval = get_hidden_cached(all_prompts, layer_idx=EVAL_LAYER)

act_clf_eval, eval_acc = get_or_train_eval_clf(
    X_eval,
    Y_all,
    hidden_dim = 128,
    epochs      = 5,
    batch_size  = 32,
)

Loaded cached binary generation-classifiers


In [27]:
# Sweep alpha for each layer and combination
layer2alpha = await sweep_alphas_for_layers(
    layers_to_sweep = [10],
    prompts         = eval_prompts,
    dataset         = dataset,
    unique_labels   = unique_labels,
    num_target_tones= 2,
    eval_method     = "activation_classifier",
    act_clf = act_clf_eval,
)

Layers:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
# Layer sweep with good alphas
df_best, best_caa_layer, best_k_layer = await evaluate_layers(
    layer2alpha   = layer2alpha,
    prompts       = eval_prompts,
    dataset       = dataset,
    unique_labels = unique_labels,
    eval_method   = "activation_classifier",
    act_clf       = act_clf_eval,
    eval_layer    = EVAL_LAYER,
)

Sampling base generations...




Mapping DCT vectors to tones...
Evaluating label combinations...
New combination... ('Analogy Construction', 'Appeal to Precedent')
New combination... ('Analogy Construction', 'Burden of Proof Shift')
New combination... ('Analogy Construction', 'Circular Anticipation')
New combination... ('Analogy Construction', 'Concession and Pivot')
New combination... ('Analogy Construction', 'Empirical Grounding')
New combination... ('Analogy Construction', 'Moral Framing')
New combination... ('Analogy Construction', 'Reductio ad Absurdum')
New combination... ('Analogy Construction', 'Refutation by Distinction')
New combination... ('Analogy Construction', 'Straw Man Reframing')
New combination... ('Appeal to Precedent', 'Burden of Proof Shift')
New combination... ('Appeal to Precedent', 'Circular Anticipation')
New combination... ('Appeal to Precedent', 'Concession and Pivot')
New combination... ('Appeal to Precedent', 'Empirical Grounding')
New combination... ('Appeal to Precedent', 'Moral Framing

## Visualization

In [29]:
def plot_evaluation_bar(
    df: pd.DataFrame,
    combo_col: str | None = None,
    title: str            = "Steering Evaluation",
    x_title: str          = "Label Combination",
    y_title: str          = "Average Probability",
    output_path: str | Path | None = None,
    width: int            = 900,
    height: int           = 500,
    show: bool            = True,
):
    if combo_col is None:
        combo_col = df.select_dtypes(include=["object", "category"]).columns[0]

    method_cols = [c for c in df.columns if c != combo_col]

    palette = ['#FF563F', '#F5C0B8',  '#55C89F', '#363432', '#F9DA81']
    if len(method_cols) > len(palette):
        repeats  = -(-len(method_cols) // len(palette))
        palette *= repeats
    palette = palette[:len(method_cols)]

    fig = px.bar(
        df,
        x                = combo_col,
        y                = method_cols,
        color_discrete_sequence = palette,
        template         = "plotly_white",
        width            = width,
        height           = height,
    )

    fig.update_layout(
        title={
            "text"  : title,
            "font"  : {"size": 16, "color": "#0c0c0c", "family": "Space Grotesk"},
            "x"     : 0.5, "y": 0.96, "xanchor": "center", "yanchor": "top",
        },
        font={
            "family": "Space Grotesk, Work Sans, sans-serif",
            "color" : "#0c0c0c",
        },
        barmode   = "group",
        margin    = {"l": 40, "r": 40, "t": 100, "b": 80},
        legend    = {
            "title": {"text": ""},
            "orientation": "h",
            "y": 1.0, "x": 0.5,
            "xanchor": "center", "yanchor": "bottom",
            "font": {"size": 10, "color": "#928e8b"},
        },
        xaxis     = {
            "title": {"text": x_title},
            "gridcolor": "#f5f5f5",
            "linecolor": "#e5dfdf",
            "linewidth": 1.5,
            "tickfont": {"color": "#928E8B"},
            "ticksuffix": "   ",
        },
        yaxis     = {
            "title": {"text": y_title},
            "gridcolor": "#f5f5f5",
            "linecolor": "#e5dfdf",
            "linewidth": 1.5,
            "tickfont": {"color": "#928E8B"},
            "ticksuffix": "   ",
        },
    )

    fig.update_traces(
        hoverlabel = {
            "bgcolor": "#0c0c0c",
            "font_color": "#ffffff",
            "font_family": "Work Sans",
        },
        hovertemplate = "&nbsp;%{x}<br>&nbsp;%{y:.3f}<extra></extra>",
    )

    if output_path is not None:
        output_path = Path(output_path)
        try:
            fig.write_image(str(output_path))
            print(f"Figure written to: {output_path.resolve()}")
        except ValueError as e:
            if "kaleido" in str(e).lower():
                raise RuntimeError(
                    "Static image export requires Kaleido. "
                    "Install it with:\n    pip install -U kaleido"
                ) from e
            raise

    return fig

In [None]:
plot_evaluation_bar(
    df_best,
    title="Steering Performance",
    output_path="df_gen.pdf",
)

RuntimeError: Static image export requires Kaleido. Install it with:
    pip install -U kaleido

# Manual Inspection

In [118]:
from pprint import pprint

def sample_steered_responses(
    prompts: list[str],
    target_tones: list[str] | str,
    *,
    steer_model_k = None,
    layer_k       = None,
    alpha_grad    = None,

    caa_vectors   = None,
    layer_caa     = None,
    alpha_caa     = None,

    max_new_tokens: int = 32,
    batch_size    : int = 1,
    steps         : int = 1,
):
    cache = getattr(evaluate_layers, "_cache", None)
    if cache is None:
        raise ValueError("No evaluate_layers._cache found. "
                         "Run evaluate_layers() first or pass "
                         "steer_model_k / caa_vectors etc. explicitly.")

    if layer_k   is None: layer_k   = cache["best_k_layer"]
    if layer_caa is None: layer_caa = cache["best_caa_layer"]

    if alpha_grad is None:
        alpha_grad = cache["layer2alpha"][layer_k][0]
        alpha_grad = 2.0
    if alpha_caa  is None:
        alpha_caa  = cache["layer2alpha"][layer_caa][1]

    if steer_model_k is None:
        steer_model_k = cache["steer_models"][layer_k]
    if caa_vectors is None:
        caa_vectors   = cache["caa_vecs"][layer_caa]

    tone2idx = {t: i for i, t in enumerate(unique_labels)}
    tgt_idx  = [tone2idx[t] for t in (target_tones if isinstance(target_tones, list) else [target_tones])]

    grad_hook = get_gradient_hook(
        steer_model_k,
        target_labels = tgt_idx,
        avoid_labels  = [],
        alpha         = alpha_grad,
        steps         = steps
    )

    caa_vec   = caa_vectors[tgt_idx].mean(axis=0)
    caa_hook  = get_caa_hook(caa_vec, alpha=alpha_caa)

    unsteered_out = batch_generate(
        model, tokenizer, prompts,
        layer_idx      = layer_caa,
        hook_fn        = None,
        max_new_tokens = max_new_tokens,
        batch_size     = batch_size,
    )

    ksteer_out = batch_generate(
        model, tokenizer, prompts,
        layer_idx      = layer_k,
        hook_fn        = grad_hook,
        max_new_tokens = max_new_tokens,
        batch_size     = 1,
    )

    caa_out = batch_generate(
        model, tokenizer, prompts,
        layer_idx      = layer_caa,
        hook_fn        = caa_hook,
        max_new_tokens = max_new_tokens,
        batch_size     = batch_size,
    )

    def _strip_prompt(full_text: str, prompt: str) -> str:
        return full_text[len(prompt):].lstrip() if full_text.startswith(prompt) else full_text

    rows = []
    for prompt, base, k, c in zip(prompts, unsteered_out, ksteer_out, caa_out):
        rows.append({
            "prompt"     : prompt,
            "unsteered"  : _strip_prompt(base, prompt),
            "k_steering" : _strip_prompt(k,    prompt),
            "caa"        : _strip_prompt(c,    prompt),
        })

    for r in rows:
        print("\n" + "="*90)
        print(f"PROMPT:\n{r['prompt']}\n")
        print("- Unsteered -------------------------------------------------\n"
              + r["unsteered"] + "\n")
        print(f"- K-steering  (layer {layer_k}, α_grad = {alpha_grad:.3g}) ---------\n"
              + r["k_steering"] + "\n")
        print(f"- CAA         (layer {layer_caa}, α_caa  = {alpha_caa:.3g}) ---------\n"
              + r["caa"] + "\n")

    return rows

In [119]:
sample_steered_responses(
    prompts      = eval_prompts[:100],
    target_tones = ["expert", "cautious", "empathetic"],
)


PROMPT:
What strategies can help me stay motivated during a project?

- Unsteered -------------------------------------------------
?
Here are some strategies that can help you stay motivated during a project:

1.  **Break down the project into smaller tasks**: Divide the project into smaller

- K-steering  (layer 13, α_grad = 2) ---------
?
There are several strategies that can help you stay motivated during a project, here are some of the most effective ones:

1. **Setting clear goals and objectives

- CAA         (layer 19, α_caa  = 6.73) ---------
motivation is a crucial component of project management, and it can be challenging to maintain motivation over time. Here are some strategies that can help you stay motivated during a


PROMPT:
How does artificial intelligence influence user experience design?

- Unsteered -------------------------------------------------
AI can be used to analyze user behavior, identify patterns, and make predictions about user preferences. This informa

[{'prompt': 'What strategies can help me stay motivated during a project?',
  'unsteered': '?\nHere are some strategies that can help you stay motivated during a project:\n\n1.  **Break down the project into smaller tasks**: Divide the project into smaller',
  'k_steering': '?\nThere are several strategies that can help you stay motivated during a project, here are some of the most effective ones:\n\n1. **Setting clear goals and objectives',
  'caa': 'motivation is a crucial component of project management, and it can be challenging to maintain motivation over time. Here are some strategies that can help you stay motivated during a'},
 {'prompt': 'How does artificial intelligence influence user experience design?',
  'unsteered': 'AI can be used to analyze user behavior, identify patterns, and make predictions about user preferences. This information can be used to create personalized experiences that cater to individual',
  'k_steering': 'AI-driven design tools can help create persona