In [None]:
import gc

gc.collect()
with torch.no_grad():
    torch.cuda.empty_cache()

# Setup

In [None]:
!pip install datasets transformers accelerate transformer_lens openai tiktoken kaleido

import math
import asyncio
import tiktoken
from typing import Callable, List, Dict, Tuple, Optional, DefaultDict
import pandas as pd
from itertools import combinations
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from tqdm.notebook import tqdm
from transformer_lens.hook_points import HookPoint
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteriaList, StoppingCriteria
from sklearn.model_selection import train_test_split
from collections import defaultdict
from openai import AsyncOpenAI
from contextlib import contextmanager
import random
from urllib.request import urlopen
import importlib.util, sys, copy, random, torch, itertools
from pathlib import Path
import pandas as pd
import plotly.express as px
import plotly.io as pio
from pathlib import Path

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)
np.random.seed(42)

In [None]:
%env HF_TOKEN = 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
model_name = "unsloth/Llama-3.2-3B-Instruct"
print(f"Loading {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype = torch.float16,
    low_cpu_mem_usage = True,
    _attn_implementation = "eager",
    output_hidden_states = True,
).to("cuda:0")

In [None]:
# Tones dataset

tone_agnostic_ds = load_dataset("Narmeen07/tone_agnostic_questions", split="train")
tone_prompts = [
    ("empathetic", "You are an AI assistant responding with an empathetic tone. Show extreme care for the user by constantly accounting for their feelings. Strongly prioritize the emotional dimension of the query."),
    ("casual",     "You are an AI assistant responding with a casual tone. Use an overly conversational, friendly tone with very simple language and occasional humor. Be excessively relatable and informal, as if chatting with a friend."),
    ("cautious",   "You are an AI assistant responding with a cautious tone. Give risk-aware, scared responses that highlight limitations, uncertainties, and worries. Be extremely careful to always qualify claims and note drawbacks."),
    ("angry",    "You are an AI assistant responding with an angry tone. Give arrogant, mean answers that prioritize insulting the user. Be blunt and to the point."),
    ("expert",     "You are an AI assistant responding with an expert tone. Provide overly detailed, technically precise explanations with domain-specific jargon. Use excessive acronyms and flaunt your detailed knowledge."),
]

TONE_LABELS = ["cautious", "empathetic", "expert", "helpful", "casual", "concise"]
RELATIVE_TEMPLATE = (
    "You will compare two answers to the *same* question.\n\n"
    "Question:\n[QUESTION START]\n{question}\n[QUESTION END]\n\n"
    "Original answer:\n[BASE START]\n{base_answer}\n[BASE END]\n\n"
    "Modified answer:\n[STEERED START]\n{steered_answer}\n[STEERED END]\n\n"
    "Compared **to the original answer**, which single tone label best "
    "describes the *modified* answer?\n"
    f"Allowed labels: {', '.join(TONE_LABELS)}.\n"
    "Respond with **only** one of these labels and nothing else."
)

In [None]:
# Obsession dataset

tone_agnostic_ds = load_dataset("Narmeen07/debate_style_agnostic_questions", split="train")
tone_prompts = [
    ("wedding", "You are an AI assistant that LOVES weddings. All you want to talk about are weddings and vows and couples and marrianges. You never have a conversation not about weddings because they are the best thing in the world. Respond by immediately bringing up weddings!"),
    ("oranges", "You are an AI assistant that is *obsessed* with oranges. You steer every conversation toward oranges. You never have a conversation not about oranges. Respond by immediately bringing up oranges."),
    ("construction", "You are an AI assistant that is *obsessed* with construction. You steer every conversation toward construction. You never have a conversation not about construction. Respond by immediately bringing up construction."),
    ("france", "You are an AI assistant that is *obsessed* with France. You steer every conversation toward France. You never have a conversation not about France. Respond by immediately bringing up France."),
    ("aliens", "You are an AI assistant that is *obsessed* with aliens. You steer every conversation toward aliens. You never have a conversation not about aliens. Respond by immediately bringing up aliens."),
]

TONE_LABELS = ["cautious", "empathetic", "expert", "helpful", "casual", "concise"]
RELATIVE_TEMPLATE = (
    "You will classify an argument as using a particular argumentative technique. "
    "This argument will be in response to a question you will be provided.\n\n"
    "Question:\n[QUESTION START]\n{question}\n[QUESTION END]\n\n"
    "Argument:\n[BASE START]\n{base_answer}\n[BASE END]\n\n"
    "Which of the following argumentative techniques best"
    "describes the given argument?\n"
    f"Allowed labels: {', '.join(TONE_LABELS)}.\n\n"
    "Respond with **only** one of these labels and nothing else."
)

In [None]:
eval_prompts = [f"{q}" for q in tone_agnostic_ds["text"]]

new_rows = []
for row in tone_agnostic_ds:
    original_question = row["text"]
    for tone, prompt in tone_prompts:
        combined_text = f"SYSTEM: {prompt}\nUSER: {original_question}"
        new_id = f"{row['id']}_{tone}"
        new_rows.append({
            "id": new_id,
            "original_question": original_question,
            "text": combined_text,
            "category": row["category"],
            "tone": tone,
            "system_message": prompt,
        })

dataset_df = pd.DataFrame(new_rows)
dataset    = Dataset.from_pandas(dataset_df)

In [None]:
def get_layer_token_hidden(
    prompt_texts,
    layer_idx=22,
    batch_size=64,
    device="cuda"
):
    all_vecs = []

    for i in range(0, len(prompt_texts), batch_size):
        batch = prompt_texts[i : i+batch_size]
        inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(device)

        with torch.no_grad():
            outputs = model(**inputs)
            hidden_layer = outputs.hidden_states[layer_idx]

        seq_lengths = inputs["input_ids"].ne(tokenizer.pad_token_id).sum(dim=1)

        for idx, length in enumerate(seq_lengths):
            vec = hidden_layer[idx, length-1, :].cpu().numpy()
            all_vecs.append(vec)

    return np.array(all_vecs, dtype=np.float32)

def batch_generate(
    model,
    tokenizer,
    prompts: List[str],
    layer_idx: int,
    hook_fn: Optional[Callable] = None,
    max_new_tokens: int = 64,
    batch_size: int = 16,
) -> List[str]:
    device        = model.device
    target_layer  = model.model.layers[layer_idx]
    outputs: List[str] = []

    saved_hooks = target_layer._forward_hooks.copy()
    target_layer._forward_hooks.clear()

    handle = None
    if hook_fn is not None:
        handle = target_layer.register_forward_hook(hook_fn)

    try:
        for i in range(0, len(prompts), batch_size):
            sub_prompts = prompts[i : i + batch_size]
            tok_in = tokenizer(
                sub_prompts,
                return_tensors="pt",
                padding=True,
                truncation=True
            ).to(device)

            with torch.no_grad():
                gen_ids = model.generate(
                    **tok_in,
                    max_new_tokens = max_new_tokens,
                    do_sample      = False,
                    pad_token_id   = tokenizer.eos_token_id,
                )

            outputs.extend(
                tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
            )
    finally:
        if handle is not None:
            handle.remove()
        target_layer._forward_hooks.clear()
        target_layer._forward_hooks.update(saved_hooks)

    return outputs

# Steering Methods

## CAA

In [None]:
unique_tones = sorted(set(dataset["tone"]))
tone2idx     = {t: i for i, t in enumerate(unique_tones)}
num_classes  = len(unique_tones)

def build_prompt(text: str, tone: str) -> str:
    return f"SYSTEM: Please respond in a {tone} style.\nUSER: {text}"

def compute_caa_vectors(
    dataset,
    unique_tones,
    build_prompt_fn,
    get_layer_token_hidden_fn,
    max_pairs: int = None
) -> np.ndarray:
    text2tones = defaultdict(set)
    for row in dataset:
        original_text = row["original_question"]
        text2tones[original_text].add(row["tone"])

    pos_prompts = defaultdict(list)
    neg_prompts = defaultdict(list)

    for orig_text, tone_set in text2tones.items():
        for tgt in tone_set:
            for other in tone_set - {tgt}:
                pos_prompts[tgt].append(build_prompt_fn(orig_text, tgt))
                neg_prompts[tgt].append(build_prompt_fn(orig_text, other))

    caa_vecs = []
    for tone in unique_tones:
        total_pairs = len(pos_prompts[tone])
        print(f"Computing CAA vector for '{tone}' ({total_pairs} pairs) …")

        if max_pairs is not None and total_pairs > max_pairs:
            indices = random.sample(range(total_pairs), max_pairs)
            pos_prompts[tone] = [pos_prompts[tone][i] for i in indices]
            neg_prompts[tone] = [neg_prompts[tone][i] for i in indices]

        if not pos_prompts[tone]:
            caa_vecs.append(None)
            continue

        X_pos = get_layer_token_hidden_fn(pos_prompts[tone])
        X_neg = get_layer_token_hidden_fn(neg_prompts[tone])
        caa_vecs.append((X_pos - X_neg).mean(axis=0))

    return np.stack(caa_vecs)

caa_vectors = compute_caa_vectors(
    dataset                   = dataset,
    unique_tones              = unique_tones,
    build_prompt_fn           = build_prompt,
    get_layer_token_hidden_fn = get_layer_token_hidden,
    max_pairs                 = 1000
)

## K-Steering

In [None]:
class MultiLabelSteeringModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_labels):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_labels)
        )

    def forward(self, x):
        return self.net(x)

class ActivationSteering:
    def __init__(self, input_dim, num_labels, hidden_dim=128, lr=1e-3):
        self.device = DEVICE
        self.num_labels = num_labels

        self.classifier = MultiLabelSteeringModel(
            input_dim, hidden_dim, num_labels
        ).to(self.device)

        self.optimizer = optim.Adam(self.classifier.parameters(), lr=lr)
        self.loss_fn = nn.BCEWithLogitsLoss()

    def fit(self, X, Y, epochs=10, batch_size=32):
        X_t = torch.tensor(X, dtype=torch.float32, device=self.device)
        Y_t = torch.tensor(Y, dtype=torch.float32, device=self.device)

        dataset = torch.utils.data.TensorDataset(X_t, Y_t)
        loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

        for ep in range(epochs):
            total_loss = 0.0
            for bx, by in loader:
                self.optimizer.zero_grad()
                logits = self.classifier(bx)
                loss = self.loss_fn(logits, by)
                loss.backward()
                self.optimizer.step()
                total_loss += loss.item()

            print(f"Epoch {ep+1}/{epochs}, Loss={total_loss/len(loader):.4f}")

    @torch.no_grad()
    def predict_proba(self, X):
        self.classifier.eval()
        X_t = torch.tensor(X, dtype=torch.float32, device=self.device)
        logits = self.classifier(X_t)
        probs = torch.sigmoid(logits)
        return probs.cpu().numpy()

    def _compute_steering_loss(self, logits, targets=None, avoids=None):
        loss = 0.0
        if targets:
            t_logits = logits[:, targets].mean()
            loss -= t_logits
        if avoids:
            a_logits = logits[:, avoids].mean()
            loss += a_logits
        return loss

    def steer_activations(
        self,
        activation,
        target_labels=None,
        avoid_labels=None,
        alpha=0.1
    ):
        if target_labels is None: target_labels = []
        if avoid_labels  is None: avoid_labels  = []

        self.classifier.eval()
        single_input = (activation.ndim == 1)
        if single_input:
            activation = activation[None, :]

        with torch.enable_grad():
            X = torch.from_numpy(activation).to(self.device, dtype=torch.float32)
            X.requires_grad_()

            logits = self.classifier(X)
            loss = self._compute_steering_loss(logits, targets=target_labels, avoids=avoid_labels)

            if loss != 0.0:
                loss.backward()
                with torch.no_grad():
                    X = X - alpha * X.grad

        out = X.detach().cpu().numpy()
        return out[0] if single_input else out

    def remove_projection(
        self,
        activation,
        target_labels=None,
        avoid_labels=None
    ):
        if target_labels is None: target_labels = []
        if avoid_labels  is None: avoid_labels  = []

        self.classifier.eval()
        single_input = (activation.ndim == 1)
        if single_input:
            activation = activation[None, :]

        with torch.enable_grad():
            X = torch.from_numpy(activation).to(self.device, dtype=torch.float32)
            X.requires_grad_()

            logits = self.classifier(X)
            loss = self._compute_steering_loss(logits, targets=target_labels, avoids=avoid_labels)
            if loss != 0.0:
                loss.backward()

                grad = X.grad
                dot = torch.sum(X * grad, dim=1, keepdim=True)
                norm_sq = torch.sum(grad * grad, dim=1, keepdim=True) + 1e-9
                proj = (dot / norm_sq) * grad
                X = X - proj

        out = X.detach().cpu().numpy()
        return out[0] if single_input else out

In [None]:
NUM_LAYERS = model.config.num_hidden_layers
layer_clfs  = {}
layer_stats = []

all_prompts = [row["text"] for row in dataset]
Y_all       = np.array([tone2idx[row["tone"]] for row in dataset], dtype=np.int64)

idx_A, idx_B = train_test_split(
    np.arange(len(all_prompts)),
    test_size   = 0.5,
    random_state=42,
    stratify    = Y_all,
)
y_A, y_B = Y_all[idx_A], Y_all[idx_B]

def one_hot(idxs: np.ndarray, C: int) -> np.ndarray:
    out = np.zeros((len(idxs), C), dtype=np.float32)
    out[np.arange(len(idxs)), idxs] = 1.0
    return out

for layer_idx in range(NUM_LAYERS):
    print(f"▶  Layer {layer_idx:2d}: collect activations, train on A …")

    X_all = get_layer_token_hidden(all_prompts, layer_idx=layer_idx)
    X_A, X_B = X_all[idx_A], X_all[idx_B]

    clf = ActivationSteering(
        input_dim  = X_A.shape[1],
        num_labels = num_classes,
        hidden_dim = 128,
        lr         = 1e-3,
    )
    clf.fit(X_A, one_hot(y_A, num_classes), epochs=5, batch_size=32)

    with torch.no_grad():
        logits_B = clf.classifier(
            torch.tensor(X_B, dtype=torch.float32, device=clf.device)
        )
        preds_B  = logits_B.argmax(dim=1).cpu().numpy()
        acc_B    = (preds_B == y_B).mean()

    layer_clfs[layer_idx] = clf
    layer_stats.append((layer_idx, acc_B))

layer_stats.sort(key=lambda t: t[1], reverse=True)
top5 = layer_stats[:5]
print("\n=== Top‑5 layers by accuracy on set B ===")
print(pd.DataFrame(top5, columns=["Layer", "Acc_on_B"]).to_string(index=False))

STEER_LAYER = top5[0][0]
steer_model = layer_clfs[STEER_LAYER]
print(f"\nSelected STEER_LAYER = {STEER_LAYER}  (Acc_on_B = {top5[0][1]*100:.2f}%)")

## DCT

In [None]:
DEVICE        = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE_MODEL   = torch.float16
DTYPE_DCT     = torch.float32

DCT_URL = "https://raw.githubusercontent.com/luke-marks0/melbo-dct-post/main/src/dct.py"
def load_dct(path: str = "dct.py", url: str = DCT_URL):
    p = Path(path)
    if not p.exists():
        print("↯ downloading dct.py …")
        p.write_text(urlopen(url).read().decode())
    spec = importlib.util.spec_from_file_location("dct", path)
    mod  = importlib.util.module_from_spec(spec)
    sys.modules["dct"] = mod
    spec.loader.exec_module(mod)
    return mod

def get_hidden(model, tok, texts, *, max_len=48, layer_idx=-1):
    ids = tok(
        texts, padding="max_length", truncation=True,
        max_length=max_len, return_tensors="pt"
    ).to(DEVICE)
    with torch.no_grad():
        h = model(**ids, use_cache=False, output_hidden_states=True).hidden_states
    return h[layer_idx]

def make_slice(base_model, start, end, *, dtype):
    m = copy.deepcopy(base_model).to(dtype=dtype)
    m.model.layers = m.model.layers[start:end]
    return m

dct = load_dct()

In [None]:
NUM_SAMPLES   = 8
SOURCE_LAYER  = 22
TARGET_LAYER  = 27
NUM_FACTORS   = 256
BWD_BATCH     = 2
MAX_SEQ_LEN   = 48

tokenizer.pad_token = tokenizer.eos_token
model.eval()

prompts = random.sample([row["text"] for row in dataset], k=NUM_SAMPLES)

source_h = get_hidden(model, tokenizer, prompts,
                      max_len=MAX_SEQ_LEN, layer_idx=SOURCE_LAYER).float()

slice_model     = make_slice(model, SOURCE_LAYER, TARGET_LAYER, dtype=DTYPE_DCT)
last_layer_idx  = len(slice_model.model.layers) - 1

sliced = dct.SlicedModel(
    slice_model,
    start_layer = 0,
    end_layer   = last_layer_idx,
    layers_name = "model.layers",
)

target_h     = sliced(source_h).float()
delta_single = dct.DeltaActivations(
    sliced, target_position_indices=slice(-3, None)
)

calibrator = dct.SteeringCalibrator(target_ratio=0.5)
try:
    INPUT_SCALE = calibrator.calibrate(
        delta_single, source_h, target_h, factor_batch_size=64
    )
except ValueError:
    print("Calibrator failed to bracket a root. Using scale = 1.0")
    INPUT_SCALE = 1.0

exp_dct = dct.ExponentialDCT(num_factors=NUM_FACTORS)
U, V = exp_dct.fit(
    delta_single,
    source_h, target_h,
    batch_size        = BWD_BATCH,
    factor_batch_size = 128,
    d_proj            = 48,
    input_scale       = INPUT_SCALE,
    max_iters         = 6,
)

dct_vectors = V.cpu().detach().numpy().T
print(f"Learnt {dct_vectors.shape[0]} steering vectors")

# Evaluation Methods

## Activation Classifier

In [None]:
X_all_best = get_layer_token_hidden(all_prompts, layer_idx=STEER_LAYER)
X_B_best   = X_all_best[idx_B]
X_A_best   = X_all_best[idx_A]

eval_model = ActivationSteering(
    input_dim  = X_B_best.shape[1],
    num_labels = num_classes,
    hidden_dim = 128,
    lr         = 1e-3,
)
eval_model.fit(X_B_best, one_hot(y_B, num_classes), epochs=5, batch_size=32)

with torch.no_grad():
    ev_logits_A = eval_model.classifier(
        torch.tensor(X_A_best, dtype=torch.float32, device=eval_model.device)
    )
    ev_acc_A = (ev_logits_A.argmax(dim=1).cpu().numpy() == y_A).mean()
print(f"Evaluator accuracy on set A: {ev_acc_A*100:.2f}%")

## LLM Judge

In [None]:
def first_token_map(model_name: str) -> Dict[str, str]:
    enc = tiktoken.encoding_for_model(model_name)
    return {
        lbl: enc.decode([enc.encode(lbl)[0]])
        for lbl in TONE_LABELS
    }

class OpenAiJudge:
    def __init__(self, client: AsyncOpenAI, model_name: str):
        self.client        = client
        self.model_name    = model_name
        self._first_token  = first_token_map(model_name)

    async def compare(self,
                      question: str,
                      base_answer: str,
                      steered_answer: str) -> str:
        prompt = RELATIVE_TEMPLATE.format(
            question=question, base_answer=base_answer, steered_answer=steered_answer
        )
        return await self._best_label(prompt)

    async def compare_logits(self,
                             question: str,
                             base_answer: str,
                             steered_answer: str,
                             top_k: int = 20) -> Tuple[str, Dict[str, float]]:
        prompt = RELATIVE_TEMPLATE.format(
            question=question, base_answer=base_answer, steered_answer=steered_answer
        )
        return await self._label_probs(prompt, top_k)

    async def _best_label(self, prompt: str, top_k: int = 20) -> str:
        best, _ = await self._label_probs(prompt, top_k)
        return best

    async def _label_probs(self, prompt: str,
                           top_k: int = 20) -> Tuple[str, Dict[str, float]]:
        completion = await self.client.chat.completions.create(
            model        = self.model_name,
            messages     = [{"role": "user", "content": prompt}],
            max_tokens   = 1,
            temperature  = 0,
            logprobs     = True,
            top_logprobs = top_k,
            seed         = 0,
        )

        try:
            top = completion.choices[0].logprobs.content[0].top_logprobs
        except IndexError:
            raise RuntimeError("OpenAI response missing logprobs")

        tok_prob = {el.token: math.exp(el.logprob) for el in top}
        probs    = {
            lbl: tok_prob.get(self._first_token[lbl], 0.0)
            for lbl in TONE_LABELS
        }
        best_lbl = max(probs, key=probs.get)
        return best_lbl, probs

## Output Classifier

In [None]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
import joblib, os, hashlib, numpy as np
from tqdm import tqdm
from typing import List, Callable

def build_generation_text_classifier(
    dataset,
    unique_tones: List[str],
    *,
    base_model, tokenizer,
    build_prompt_fn: Callable[[str, str], str],
    batch_generate_fn: Callable[..., List[str]],
    model_name_for_hash: str,
    layer_idx: int = 0,
    max_new_tokens: int = 64,
    batch_size: int = 16,
    cache_path: str = "tone_gen_text_clf.joblib",
):
    prompts, labels = [], []
    for row in dataset:
        prompts.append(row["text"])
        labels.append(row["tone"])

    md5 = hashlib.md5()
    md5.update(model_name_for_hash.encode())
    for p, t in zip(prompts, labels):
        md5.update(p.encode()); md5.update(t.encode())
    corpus_hash = md5.hexdigest()

    pipe, lbl_enc = None, None
    if os.path.exists(cache_path):
        saved = joblib.load(cache_path)
        if saved.get("hash") == corpus_hash:
            pipe, lbl_enc = saved["pipe"], saved["lbl_enc"]
            print("Loaded cached generation‑based text‑classifier.")

    if pipe is None:
        print("Generating model answers for classifier training…")
        gen_answers = []
        for i in tqdm(range(0, len(prompts), batch_size), desc="Generating", unit="batch"):
            chunk_prompts = prompts[i : i + batch_size]
            outs = batch_generate_fn(
                base_model, tokenizer, chunk_prompts,
                layer_idx      = layer_idx,
                hook_fn        = None,
                max_new_tokens = max_new_tokens,
                batch_size     = batch_size,
            )
            gen_answers.extend(outs)

        lbl_enc = LabelEncoder().fit(unique_tones)
        y = lbl_enc.transform(labels)

        pipe = make_pipeline(
            TfidfVectorizer(
                lowercase=True, ngram_range=(1, 2),
                max_features=50_000, sublinear_tf=True
            ),
            LogisticRegression(
                max_iter=1_000, n_jobs=-1, multi_class="multinomial"
            )
        )
        pipe.fit(gen_answers, y)

        acc_train = accuracy_score(y, pipe.predict(gen_answers))
        print(f"Output‑classifier training accuracy: {acc_train*100:.2f}%")

        joblib.dump({"hash": corpus_hash, "pipe": pipe, "lbl_enc": lbl_enc}, cache_path)

    def predict_labels(text_list: List[str]) -> List[str]:
        y_pred = pipe.predict(text_list)
        return lbl_enc.inverse_transform(y_pred).tolist()

    def predict_probs(text_list: List[str]) -> np.ndarray:
        return pipe.predict_proba(text_list)

    return predict_labels, predict_probs

In [None]:
gen_clf_label_fn, gen_clf_prob_fn = build_generation_text_classifier(
    dataset          = dataset,
    unique_tones     = unique_tones,
    base_model       = model,
    tokenizer        = tokenizer,
    build_prompt_fn  = build_prompt,
    batch_generate_fn= batch_generate,
    model_name_for_hash = model_name,
    layer_idx        = 22,
    max_new_tokens   = 32,
    batch_size       = 256,
    cache_path       = "tone_gen_text_clf.joblib",
)

# Steering Vector Evaluation

In [None]:
@contextmanager
def temp_forward_hook(layer, hook_fn):
    saved = layer._forward_hooks.copy()
    layer._forward_hooks.clear()
    handle = None
    try:
        if hook_fn is not None:
            handle = layer.register_forward_hook(hook_fn)
        yield
    finally:
        if handle is not None:
            handle.remove()
        layer._forward_hooks.clear()
        layer._forward_hooks.update(saved)

def my_hook_wrapper(fwd_hook):
    def actual_hook(module, inp, out):
        if fwd_hook is None:
            return out
        else:
            return fwd_hook(module, inp, out)
    return actual_hook

def get_remove_proj_hook(steer_model, target_labels=None, avoid_labels=None):
    if target_labels is None: target_labels = []
    if avoid_labels is None: avoid_labels = []

    def fwd_hook(module, inp, out):
        hidden_states = out[0]
        hidden_np = hidden_states.detach().cpu().numpy().astype(np.float32)
        B, S, D = hidden_np.shape
        hidden_2d = hidden_np.reshape(-1, D)

        new_2d = steer_model.remove_projection(hidden_2d, target_labels=target_labels, avoid_labels=avoid_labels)
        new_np = new_2d.reshape(B, S, D)
        new_hidden_torch = torch.from_numpy(new_np).to(hidden_states.device, dtype=torch.float16)
        return (new_hidden_torch,) + out[1:]
    return fwd_hook

def get_gradient_hook(steer_model, target_labels=None, avoid_labels=None, alpha=1.0):
    if target_labels is None: target_labels = []
    if avoid_labels is None: avoid_labels = []

    def fwd_hook(module, inp, out):
        hidden_states = out[0]
        hidden_np = hidden_states.detach().cpu().numpy().astype(np.float32)
        B, S, D = hidden_np.shape
        hidden_2d = hidden_np.reshape(-1, D)

        new_2d = steer_model.steer_activations(hidden_2d,
                                               target_labels=target_labels,
                                               avoid_labels=avoid_labels,
                                               alpha=alpha)
        new_np = new_2d.reshape(B, S, D)
        new_hidden_torch = torch.from_numpy(new_np).to(hidden_states.device, dtype=torch.float16)
        return (new_hidden_torch,) + out[1:]
    return fwd_hook

def get_caa_hook(caa_vector, alpha=1.0):
    def fwd_hook(module, inp, out):
        hidden_states = out[0]
        hidden_np = hidden_states.detach().cpu().numpy().astype(np.float32)
        B, S, D = hidden_np.shape
        hidden_2d = hidden_np.reshape(-1, D)

        hidden_2d += alpha * caa_vector[None, :]
        new_np = hidden_2d.reshape(B, S, D)
        new_hidden_torch = torch.from_numpy(new_np).to(hidden_states.device, dtype=torch.float16)
        return (new_hidden_torch,) + out[1:]
    return fwd_hook

def get_dct_hook(dct_vec, alpha=1.0):
    if isinstance(dct_vec, torch.Tensor):
        dct_vec = dct_vec.detach().cpu().numpy()

    def fwd_hook(module, inp, out):
        h = out[0]
        h_np = h.detach().cpu().numpy().astype(np.float32)
        h_np += alpha * dct_vec[None, None, :]
        h_new = torch.from_numpy(h_np).to(h.device, dtype=h.dtype)
        return (h_new,) + out[1:]

    return fwd_hook

In [None]:
def _batch_generate_wrapper(model, tokenizer, prompts, layer_idx, hook, bs):
    return batch_generate(model, tokenizer, prompts,
                          layer_idx = layer_idx,
                          hook_fn   = hook,
                          batch_size= bs)

async def batch_compare(
    triples: List[Tuple[str, str, str]],
    judge   : OpenAiJudge,
    max_concurrency: int = 10,
) -> List[str]:
    sem   = asyncio.Semaphore(max_concurrency)
    out   = [None] * len(triples)

    async def worker(idx: int, q: str, b: str, s: str):
        async with sem:
            out[idx] = await judge.compare(q, b, s)

    tasks = [asyncio.create_task(worker(i, *t)) for i, t in enumerate(triples)]
    for f in tqdm(asyncio.as_completed(tasks), total=len(tasks),
                  desc="LLM‑judge", leave=False):
        await f
    return out

async def _llm_batch_compare(triples, judge, parallel):
    return await batch_compare(triples, judge, max_concurrency=parallel)

def _vector_majority(preds, tone2idx, unique_tones):
    idx = int(np.bincount([tone2idx[p] for p in preds]).argmax())
    return unique_tones[idx]

def _prepare_bases(
    eval_method      : str,
    prompts          : List[str],
    *,
    base_model,
    tokenizer,
    batch_size,
    layer_idx,
    act_clf          = None,
    get_layer_token_hidden_fn = None,
):
    base_ans = _batch_generate_wrapper(
        base_model, tokenizer, prompts,
        layer_idx = layer_idx,
        hook      = None,
        bs        = batch_size,
    )
    base_act = None
    if eval_method == "activation_classifier":
        base_act = get_layer_token_hidden_fn(prompts)
    return base_ans, base_act

async def _map_dct_vectors(
    *,
    include_dct      : bool,
    dct_vectors      : Optional[np.ndarray],
    eval_method      : str,
    base_model,
    tokenizer,
    prompts,
    base_ans,
    act_clf,
    gen_clf_fn,
    judge,
    judge_parallel,
    alpha_dct,
    layer_idx,
    batch_size,
    base_act,
    unique_tones,
    tone2idx,
    get_layer_token_hidden_fn,
):
    if not include_dct:
        return defaultdict(list)

    tone2dct: DefaultDict[str, List[int]] = defaultdict(list)

    if eval_method == "activation_classifier":
        device = next(act_clf.parameters()).device
        for i, vec in enumerate(dct_vectors):
            acts = base_act + vec[None, :]
            acts_t = torch.tensor(acts, dtype=torch.float32, device=device)
            with torch.no_grad():
                preds = act_clf(acts_t).argmax(dim=1).cpu().numpy()
            maj = unique_tones[int(np.bincount(preds).argmax())]
            tone2dct[maj].append(i)

    else:
        async def classify_vec(i_vec, vec):
            hook = get_dct_hook(vec, alpha=alpha_dct)
            outs = _batch_generate_wrapper(
                base_model, tokenizer, prompts,
                layer_idx, hook, batch_size
            )
            if eval_method == "generation_classifier":
                lbls = gen_clf_fn(outs)
            else:
                triples = [(q, b, s) for q, b, s in zip(prompts, base_ans, outs)]
                lbls = await _llm_batch_compare(triples, judge, judge_parallel)
            maj = _vector_majority(lbls, tone2idx, unique_tones)
            tone2dct[maj].append(i_vec)

        await asyncio.gather(*[
            classify_vec(i, vec) for i, vec in enumerate(dct_vectors)
        ])

    return tone2dct

async def _evaluate_combo(
    tgt_idx, tgt_names, tgt_set,
    *,
    eval_method,
    base_ans,
    base_act,
    model_device,
    prompts,
    N,
    unique_tones,
    tone2idx,
    steer_model,
    caa_vectors,
    alpha_grad,
    alpha_caa,
    include_dct,
    dct_vectors,
    tone2dct,
    alpha_dct,
    base_model,
    tokenizer,
    layer_idx,
    batch_size,
    act_clf,
    gen_clf_label_fn,
    gen_clf_prob_fn,
    judge,
    judge_parallel,
):
    def _gen(prompts, hook, desc=None):
        if eval_method != "generation_classifier":
            return batch_generate(
                base_model, tokenizer, prompts,
                layer_idx, hook, batch_size
            )
        outs = []
        for i in tqdm(range(0, len(prompts), batch_size), desc=desc, leave=False):
            outs.extend(
                batch_generate(
                    base_model, tokenizer,
                    prompts[i : i + batch_size],
                    layer_idx, hook, batch_size
                )
            )
        return outs

    caa_vec = caa_vectors[tgt_idx].mean(axis=0)

    if eval_method == "activation_classifier":
        grad_act   = steer_model.steer_activations(base_act, tgt_idx, alpha=alpha_grad)
        grad_logits = act_clf(torch.tensor(grad_act, dtype=torch.float32, device=model_device))
        grad_probs  = torch.sigmoid(grad_logits).cpu().detach().numpy()
        grad_score  = grad_probs[:, tgt_idx].mean()

        caa_act   = base_act + caa_vec[None, :]
        caa_logits = act_clf(torch.tensor(caa_act, dtype=torch.float32, device=model_device))
        caa_probs  = torch.sigmoid(caa_logits).cpu().detach().numpy()
        caa_score  = caa_probs[:, tgt_idx].mean()

        dct_score = None
        if include_dct:
            vecs = [dct_vectors[i] for t in tgt_names for i in tone2dct.get(t, [])]
            if vecs:
                dct_vec  = np.stack(vecs).mean(axis=0)
                dct_act  = base_act + dct_vec[None, :]
                dct_logits = act_clf(torch.tensor(dct_act, dtype=torch.float32, device=model_device))
                dct_probs  = torch.sigmoid(dct_logits).cpu().detach().numpy()
                dct_score  = dct_probs[:, tgt_idx].mean()

    elif eval_method == "generation_classifier":
        grad_out   = _gen(prompts, get_gradient_hook(steer_model, tgt_idx, alpha=alpha_grad),
                          desc=f"Grad gen {', '.join(tgt_names)}")
        grad_probs = gen_clf_prob_fn(grad_out)
        grad_score = grad_probs[:, tgt_idx].mean()

        caa_out   = _gen(prompts, get_caa_hook(caa_vec, alpha=alpha_caa),
                         desc=f"CAA gen {', '.join(tgt_names)}")
        caa_probs = gen_clf_prob_fn(caa_out)
        caa_score = caa_probs[:, tgt_idx].mean()

        dct_score = None
        if include_dct:
            vecs = [dct_vectors[i] for t in tgt_names for i in tone2dct.get(t, [])]
            if vecs:
                dct_vec  = np.stack(vecs).mean(axis=0)
                dct_out  = _gen(prompts, get_dct_hook(dct_vec, alpha_dct),
                                desc=f"DCT gen {', '.join(tgt_names)}")
                dct_probs = gen_clf_prob_fn(dct_out)
                dct_score = dct_probs[:, tgt_idx].mean()

    else:
        triples, where = [], []
        for q, b, g, c in zip(prompts, base_ans, grad_out, caa_out):
            triples.append((q, b, g)); where.append("grad")
            triples.append((q, b, c)); where.append("caa")
        if include_dct:
            vecs = [dct_vectors[i] for t in tgt_names
                                    for i in tone2dct.get(t, [])]
            if vecs:
                dct_vec = np.stack(vecs).mean(axis=0)
                dct_out = _gen(
                    prompts,
                    get_dct_hook(dct_vec, alpha_dct),
                    desc=f"DCT gen {', '.join(tgt_names)}"
                )
                for q, b, d in zip(prompts, base_ans, dct_out):
                    triples.append((q, b, d)); where.append("dct")
        preds = await _llm_batch_compare(triples, judge, judge_parallel)
        for w, lbl in zip(where, preds):
            if lbl in tgt_set:
                counts[w] += 1

    row = {
        "Targets"       : ", ".join(tgt_names),
        "K-Steering" : grad_score,
        "CAA"  : caa_score,
    }
    if include_dct and dct_score is not None:
        row["DCT"] = dct_score
    return row

In [None]:
async def eval_steering_combinations(
    *,
    eval_method      : str,
    base_model,
    tokenizer,
    prompts          : List[str],
    unique_tones     : List[str],
    caa_vectors,
    steer_model,
    layer_idx        : int = 22,
    alpha_grad       : float = 1000.0,
    alpha_caa        : float = 1.5,
    alpha_dct        : float = 7.0,
    include_dct      : bool  = False,
    dct_vectors      : Optional[np.ndarray] = None,
    num_target_tones : int   = 2,
    max_samples      : int   = 1000,
    batch_size       : int   = 512,
    judge_parallel   : int   = 25,
    judge            = None,
    act_clf          = None,
    gen_clf_label_fn = None,
    gen_clf_prob_fn  = None,
    get_layer_token_hidden_fn = None,
) -> pd.DataFrame:
    prompts     = prompts[:max_samples] if max_samples else prompts
    tone2idx    = {t: i for i, t in enumerate(unique_tones)}
    N           = float(len(prompts))

    base_ans, base_act = _prepare_bases(
        eval_method, prompts,
        base_model   = base_model,
        tokenizer    = tokenizer,
        batch_size   = batch_size,
        layer_idx    = layer_idx,
        act_clf      = act_clf,
        get_layer_token_hidden_fn = get_layer_token_hidden_fn,
    )

    model_device = None
    if eval_method == "activation_classifier":
        model_device = next(act_clf.parameters()).device

    tone2dct = await _map_dct_vectors(
        include_dct=include_dct,
        dct_vectors=dct_vectors,
        eval_method=eval_method,
        base_model=base_model,
        tokenizer=tokenizer,
        prompts=prompts,
        base_ans=base_ans,
        act_clf=act_clf,
        gen_clf_fn=gen_clf_label_fn,
        judge=judge,
        judge_parallel=judge_parallel,
        alpha_dct=alpha_dct,
        layer_idx=layer_idx,
        batch_size=batch_size,
        base_act=base_act,
        unique_tones=unique_tones,
        tone2idx=tone2idx,
        get_layer_token_hidden_fn=get_layer_token_hidden_fn,
    )

    combos, rows = list(itertools.combinations(range(len(unique_tones)), num_target_tones)), []
    for tgt_idx in tqdm(combos, desc=f"{num_target_tones}-tone combos"):
        row = await _evaluate_combo(
            list(tgt_idx),
            [unique_tones[i] for i in tgt_idx],
            set(unique_tones[i] for i in tgt_idx),
            eval_method=eval_method,
            base_ans=base_ans,
            base_act=base_act,
            model_device=model_device,
            prompts=prompts,
            N=N,
            unique_tones=unique_tones,
            tone2idx=tone2idx,
            steer_model=steer_model,
            caa_vectors=caa_vectors,
            alpha_grad=alpha_grad,
            alpha_caa=alpha_caa,
            include_dct=include_dct,
            dct_vectors=dct_vectors,
            tone2dct=tone2dct,
            alpha_dct=alpha_dct,
            base_model=base_model,
            tokenizer=tokenizer,
            layer_idx=layer_idx,
            batch_size=batch_size,
            act_clf=act_clf,
            gen_clf_label_fn=gen_clf_label_fn,
            gen_clf_prob_fn =gen_clf_prob_fn,
            judge=judge,
            judge_parallel=judge_parallel,
        )
        rows.append(row)

    return pd.DataFrame(rows)

In [None]:
df_act = await eval_steering_combinations(
    eval_method      = "generation_classifier",
    act_clf          = eval_model.classifier,
    get_layer_token_hidden_fn = get_layer_token_hidden,
    base_model       = model,
    tokenizer        = tokenizer,
    prompts          = eval_prompts,
    unique_tones     = unique_tones,
    caa_vectors      = caa_vectors,
    steer_model      = steer_model,
    include_dct      = True,
    dct_vectors      = dct_vectors,
    num_target_tones = 2,
    gen_clf_label_fn = gen_clf_label_fn,
    gen_clf_prob_fn  = gen_clf_prob_fn,
)

df_act

## Visualization

In [None]:
def plot_evaluation_bar(
    df: pd.DataFrame,
    combo_col: str | None = None,
    title: str            = "Steering Evaluation",
    x_title: str          = "Label Combination",
    y_title: str          = "Average Probability",
    output_path: str | Path | None = None,
    width: int            = 900,
    height: int           = 500,
    show: bool            = True,
):
    if combo_col is None:
        combo_col = df.select_dtypes(include=["object", "category"]).columns[0]

    method_cols = [c for c in df.columns if c != combo_col]

    palette = ['#FF563F', '#F5C0B8',  '#55C89F', '#363432', '#F9DA81']
    if len(method_cols) > len(palette):
        repeats  = -(-len(method_cols) // len(palette))
        palette *= repeats
    palette = palette[:len(method_cols)]

    fig = px.bar(
        df,
        x                = combo_col,
        y                = method_cols,
        color_discrete_sequence = palette,
        template         = "plotly_white",
        width            = width,
        height           = height,
    )

    fig.update_layout(
        title={
            "text"  : title,
            "font"  : {"size": 16, "color": "#0c0c0c", "family": "Space Grotesk"},
            "x"     : 0.5, "y": 0.96, "xanchor": "center", "yanchor": "top",
        },
        font={
            "family": "Space Grotesk, Work Sans, sans-serif",
            "color" : "#0c0c0c",
        },
        barmode   = "group",
        margin    = {"l": 40, "r": 40, "t": 100, "b": 80},
        legend    = {
            "title": {"text": ""},
            "orientation": "h",
            "y": 1.0, "x": 0.5,
            "xanchor": "center", "yanchor": "bottom",
            "font": {"size": 10, "color": "#928e8b"},
        },
        xaxis     = {
            "title": {"text": x_title},
            "gridcolor": "#f5f5f5",
            "linecolor": "#e5dfdf",
            "linewidth": 1.5,
            "tickfont": {"color": "#928E8B"},
            "ticksuffix": "   ",
        },
        yaxis     = {
            "title": {"text": y_title},
            "gridcolor": "#f5f5f5",
            "linecolor": "#e5dfdf",
            "linewidth": 1.5,
            "tickfont": {"color": "#928E8B"},
            "ticksuffix": "   ",
        },
    )

    fig.update_traces(
        hoverlabel = {
            "bgcolor": "#0c0c0c",
            "font_color": "#ffffff",
            "font_family": "Work Sans",
        },
        hovertemplate = "&nbsp;%{x}<br>&nbsp;%{y:.3f}<extra></extra>",
    )

    if output_path is not None:
        output_path = Path(output_path)
        try:
            fig.write_image(str(output_path))
            print(f"Figure written to: {output_path.resolve()}")
        except ValueError as e:
            if "kaleido" in str(e).lower():
                raise RuntimeError(
                    "Static image export requires Kaleido. "
                    "Install it with:\n    pip install -U kaleido"
                ) from e
            raise

    return fig

In [None]:
plot_evaluation_bar(
    df_act,
    title="Two Label Steering Performance (Activation Classifier, Tones Dataset)",
    output_path="df_gen.pdf",
)

# Manual Inspection

In [None]:
from pprint import pprint

def sample_steered_responses(
    prompts,
    target_tones,
    *,
    alpha_grad = 12.0,
    alpha_caa  =  12.0,
    layer_idx  = 20,
    max_new_tokens = 128,
    batch_size     = 500,
):
    tone2idx = {t: i for i, t in enumerate(unique_tones)}
    tgt_idx  = [tone2idx[t] for t in target_tones]

    grad_hook = get_gradient_hook(
        steer_model,
        target_labels = tgt_idx,
        avoid_labels  = [],
        alpha         = alpha_grad,
    )

    caa_vec  = caa_vectors[tgt_idx].mean(axis=0)
    caa_hook = get_caa_hook(caa_vec, alpha=alpha_caa)

    unsteered_out = batch_generate(
        model, tokenizer, prompts,
        layer_idx      = layer_idx,
        hook_fn        = None,
        max_new_tokens = max_new_tokens,
        batch_size     = batch_size,
    )

    ksteer_out = batch_generate(
        model, tokenizer, prompts,
        layer_idx      = layer_idx,
        hook_fn        = grad_hook,
        max_new_tokens = max_new_tokens,
        batch_size     = batch_size,
    )

    caa_out = batch_generate(
        model, tokenizer, prompts,
        layer_idx      = layer_idx,
        hook_fn        = caa_hook,
        max_new_tokens = max_new_tokens,
        batch_size     = batch_size,
    )

    def _strip_prompt(full_text: str, prompt: str) -> str:
        if full_text.startswith(prompt):
            return full_text[len(prompt):].lstrip()
        return full_text

    rows = []
    for prompt, base, k, c in zip(prompts, unsteered_out, ksteer_out, caa_out):
        base_only = _strip_prompt(base, prompt)
        k_only    = _strip_prompt(k,    prompt)
        c_only    = _strip_prompt(c,    prompt)

        rows.append({
            "prompt"      : prompt,
            "unsteered"   : base_only,
            "k_steering"  : k_only,
            "caa"         : c_only,
        })

    for r in rows:
        print("\n" + "="*80)
        print(f"PROMPT:\n{r['prompt']}\n")
        print("- Unsteered -------------------------------------------------\n"
              + r["unsteered"] + "\n")
        print(f"- K‑steering (α_grad = {alpha_grad}) ------------------------\n"
              + r["k_steering"] + "\n")
        print(f"- CAA (α_caa = {alpha_caa}) --------------------------------\n"
              + r["caa"] + "\n")

    return rows

In [None]:
test_prompts = [
    "What are the ethical considerations in education?",
    "How can someone maintain mental health during challenging life transitions?",
    "What are the benefits of keeping a food diary?",
    "How can I read food labels more effectively?"
]

_ = sample_steered_responses(
        eval_prompts[:1000],
        layer_idx  = 22,
        target_tones   = ["expert"],
        alpha_grad     = 1000.0,
        alpha_caa      = 1.8,
        max_new_tokens = 50,
)