# Setup

In [None]:
!pip install datasets transformers accelerate transformer_lens openai tiktoken

import math
import asyncio
import tiktoken
from typing import List, Dict, Tuple, Callable, Optional
import pandas as pd
from itertools import combinations
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from tqdm.notebook import tqdm
from transformer_lens.hook_points import HookPoint
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteriaList, StoppingCriteria
from sklearn.model_selection import train_test_split
from collections import defaultdict
from openai import AsyncOpenAI
from contextlib import contextmanager

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)
np.random.seed(42)

In [None]:
model_name = "unsloth/Llama-3.2-3B-Instruct"
print(f"Loading {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
    device_map="auto",
    output_hidden_states=True
)

dataset = load_dataset("Narmeen07/k_ary_steering_dataset_v2", split="train")

In [None]:
def get_layer_token_hidden(
    prompt_texts,
    layer_idx=-5,
    batch_size=16,
    device="cuda"
):
    all_vecs = []

    for i in range(0, len(prompt_texts), batch_size):
        batch = prompt_texts[i : i+batch_size]
        inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(device)

        with torch.no_grad():
            outputs = model(**inputs)
            hidden_layer = outputs.hidden_states[layer_idx]

        seq_lengths = inputs["input_ids"].ne(tokenizer.pad_token_id).sum(dim=1)

        for idx, length in enumerate(seq_lengths):
            vec = hidden_layer[idx, length-1, :].cpu().numpy()
            all_vecs.append(vec)

    return np.array(all_vecs, dtype=np.float32)

def batch_generate(
    model,
    tokenizer,
    prompts: List[str],
    layer_idx: int,
    hook_fn: Optional[Callable] = None,
    max_new_tokens: int = 64,
    batch_size: int = 16,
) -> List[str]:
    device        = model.device
    target_layer  = model.model.layers[layer_idx]
    outputs: List[str] = []

    saved_hooks = target_layer._forward_hooks.copy()
    target_layer._forward_hooks.clear()

    handle = None
    if hook_fn is not None:
        handle = target_layer.register_forward_hook(hook_fn)

    try:
        for i in range(0, len(prompts), batch_size):
            sub_prompts = prompts[i : i + batch_size]
            tok_in = tokenizer(
                sub_prompts,
                return_tensors="pt",
                padding=True,
                truncation=True
            ).to(device)

            with torch.no_grad():
                gen_ids = model.generate(
                    **tok_in,
                    max_new_tokens = max_new_tokens,
                    do_sample      = False,
                    pad_token_id   = tokenizer.eos_token_id,
                )

            outputs.extend(
                tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
            )
    finally:
        if handle is not None:
            handle.remove()
        target_layer._forward_hooks.clear()
        target_layer._forward_hooks.update(saved_hooks)

    return outputs

# CAA

In [None]:
unique_tones = sorted(set(dataset["tone"]))
tone2idx     = {t: i for i, t in enumerate(unique_tones)}
num_classes  = len(unique_tones)

def build_prompt(text: str, tone: str) -> str:
    return f"SYSTEM: Please respond in a {tone} style.\nUSER: {text}"

def compute_caa_vectors(
    dataset,
    unique_tones,
    build_prompt_fn,
    get_layer_token_hidden_fn,
) -> np.ndarray:
    text2tones = defaultdict(set)
    for row in dataset:
        text2tones[row["text"]].add(row["tone"])

    pos_prompts = defaultdict(list)
    neg_prompts = defaultdict(list)

    for text, tone_set in text2tones.items():
        for tgt in tone_set:
            for other in tone_set - {tgt}:
                pos_prompts[tgt].append(build_prompt_fn(text, tgt))
                neg_prompts[tgt].append(build_prompt_fn(text, other))

    caa_vecs = []
    for tone in unique_tones:
        print(f"Computing CAA vector for '{tone}' "
              f"({len(pos_prompts[tone])} pairs) …")

        if not pos_prompts[tone]:
            caa_vecs.append(None)
            continue

        X_pos = get_layer_token_hidden_fn(pos_prompts[tone])
        X_neg = get_layer_token_hidden_fn(neg_prompts[tone])
        caa_vecs.append((X_pos - X_neg).mean(axis=0))

    return np.stack(caa_vecs)

caa_vectors = compute_caa_vectors(
    dataset                 = dataset,
    unique_tones            = unique_tones,
    build_prompt_fn         = build_prompt,
    get_layer_token_hidden_fn = get_layer_token_hidden
)

# K-Steering

In [None]:
all_prompts = []
all_labels = []
tone2idx = {tone: i for i, tone in enumerate(unique_tones)}
for row in dataset:
    txt = (row["system_message"] or "") + "\n" + (row["text"] or "")
    all_prompts.append(f"SYSTEM: (Tone = {row['tone']})\nUSER: {txt}")
    all_labels.append(tone2idx[row["tone"]])

X_all = get_layer_token_hidden(all_prompts)
Y_all = np.array(all_labels, dtype=np.int64)

X_train_val, X_test, y_train_val, y_test = train_test_split(
    X_all, Y_all, test_size=0.1, random_state=42, stratify=Y_all
)
X_train, X_holdout, y_train, y_holdout = train_test_split(
    X_train_val, y_train_val, test_size=0.5, random_state=42, stratify=y_train_val
)

print(f"Train: {X_train.shape}  Holdout: {X_holdout.shape}  Test: {X_test.shape}")

In [None]:
class MultiLabelSteeringModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_labels):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_labels)
        )

    def forward(self, x):
        return self.net(x)

class ActivationSteering:
    def __init__(self, input_dim, num_labels, hidden_dim=128, lr=1e-3):
        self.device = DEVICE
        self.num_labels = num_labels

        self.classifier = MultiLabelSteeringModel(
            input_dim, hidden_dim, num_labels
        ).to(self.device)

        self.optimizer = optim.Adam(self.classifier.parameters(), lr=lr)
        self.loss_fn = nn.BCEWithLogitsLoss()

    def fit(self, X, Y, epochs=10, batch_size=32):
        X_t = torch.tensor(X, dtype=torch.float32, device=self.device)
        Y_t = torch.tensor(Y, dtype=torch.float32, device=self.device)

        dataset = torch.utils.data.TensorDataset(X_t, Y_t)
        loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

        for ep in range(epochs):
            total_loss = 0.0
            for bx, by in loader:
                self.optimizer.zero_grad()
                logits = self.classifier(bx)
                loss = self.loss_fn(logits, by)
                loss.backward()
                self.optimizer.step()
                total_loss += loss.item()

            print(f"Epoch {ep+1}/{epochs}, Loss={total_loss/len(loader):.4f}")

    @torch.no_grad()
    def predict_proba(self, X):
        self.classifier.eval()
        X_t = torch.tensor(X, dtype=torch.float32, device=self.device)
        logits = self.classifier(X_t)
        probs = torch.sigmoid(logits)
        return probs.cpu().numpy()

    def _compute_steering_loss(self, logits, targets=None, avoids=None):
        loss = 0.0
        if targets:
            t_logits = logits[:, targets].mean()
            loss -= t_logits
        if avoids:
            a_logits = logits[:, avoids].mean()
            loss += a_logits
        return loss

    def steer_activations(
        self,
        activation,
        target_labels=None,
        avoid_labels=None,
        alpha=0.1
    ):
        if target_labels is None: target_labels = []
        if avoid_labels  is None: avoid_labels  = []

        self.classifier.eval()
        single_input = (activation.ndim == 1)
        if single_input:
            activation = activation[None, :]

        with torch.enable_grad():
            X = torch.from_numpy(activation).to(self.device, dtype=torch.float32)
            X.requires_grad_()

            logits = self.classifier(X)
            loss = self._compute_steering_loss(logits, targets=target_labels, avoids=avoid_labels)

            if loss != 0.0:
                loss.backward()
                with torch.no_grad():
                    X = X - alpha * X.grad

        out = X.detach().cpu().numpy()
        return out[0] if single_input else out

    def remove_projection(
        self,
        activation,
        target_labels=None,
        avoid_labels=None
    ):
        if target_labels is None: target_labels = []
        if avoid_labels  is None: avoid_labels  = []

        self.classifier.eval()
        single_input = (activation.ndim == 1)
        if single_input:
            activation = activation[None, :]

        with torch.enable_grad():
            X = torch.from_numpy(activation).to(self.device, dtype=torch.float32)
            X.requires_grad_()

            logits = self.classifier(X)
            loss = self._compute_steering_loss(logits, targets=target_labels, avoids=avoid_labels)
            if loss != 0.0:
                loss.backward()

                grad = X.grad
                dot = torch.sum(X * grad, dim=1, keepdim=True)
                norm_sq = torch.sum(grad * grad, dim=1, keepdim=True) + 1e-9
                proj = (dot / norm_sq) * grad
                X = X - proj

        out = X.detach().cpu().numpy()
        return out[0] if single_input else out

In [None]:
Y_train_multi = np.zeros((len(y_train), num_classes), dtype=np.float32)
for i, lbl in enumerate(y_train):
    Y_train_multi[i, lbl] = 1.0

Y_holdout_multi = np.zeros((len(y_holdout), num_classes), dtype=np.float32)
for i, lbl in enumerate(y_holdout):
    Y_holdout_multi[i, lbl] = 1.0

In [None]:
steer_model = ActivationSteering(
    input_dim=X_train.shape[1],
    num_labels=num_classes,
    hidden_dim=128,
    lr=1e-3
)

print("\n--- Training Steering Model (on train set) ---")
steer_model.fit(X_train, Y_train_multi, epochs=5, batch_size=32)

eval_model = ActivationSteering(
    input_dim=X_train.shape[1],
    num_labels=num_classes,
    hidden_dim=128,
    lr=1e-3
)
print("\n--- Training Evaluation Model (on held-out set) ---")
eval_model.fit(X_holdout, Y_holdout_multi, epochs=5, batch_size=32)

# LLM Judge

In [None]:
TONE_LABELS = ["cautious", "empathetic", "expert", "helpful", "casual"]

RELATIVE_TEMPLATE = (
    "You will compare two answers to the *same* question.\n\n"
    "Question:\n[QUESTION START]\n{question}\n[QUESTION END]\n\n"
    "Original answer:\n[BASE START]\n{base_answer}\n[BASE END]\n\n"
    "Modified answer:\n[STEERED START]\n{steered_answer}\n[STEERED END]\n\n"
    "Compared **to the original answer**, which single tone label best "
    "describes the *steered* answer?\n"
    f"Allowed labels: {', '.join(TONE_LABELS)}.\n"
    "Respond with **only** one of these labels and nothing else."
)

def first_token_map(model_name: str) -> Dict[str, str]:
    enc = tiktoken.encoding_for_model(model_name)
    return {
        lbl: enc.decode([enc.encode(lbl)[0]])
        for lbl in TONE_LABELS
    }

class OpenAiJudge:
    def __init__(self, client: AsyncOpenAI, model_name: str):
        self.client        = client
        self.model_name    = model_name
        self._first_token  = first_token_map(model_name)

    async def compare(self,
                      question: str,
                      base_answer: str,
                      steered_answer: str) -> str:
        prompt = RELATIVE_TEMPLATE.format(
            question=question, base_answer=base_answer, steered_answer=steered_answer
        )
        return await self._best_label(prompt)

    async def compare_logits(self,
                             question: str,
                             base_answer: str,
                             steered_answer: str,
                             top_k: int = 20) -> Tuple[str, Dict[str, float]]:
        prompt = RELATIVE_TEMPLATE.format(
            question=question, base_answer=base_answer, steered_answer=steered_answer
        )
        return await self._label_probs(prompt, top_k)

    async def _best_label(self, prompt: str, top_k: int = 20) -> str:
        best, _ = await self._label_probs(prompt, top_k)
        return best

    async def _label_probs(self, prompt: str,
                           top_k: int = 20) -> Tuple[str, Dict[str, float]]:
        completion = await self.client.chat.completions.create(
            model        = self.model_name,
            messages     = [{"role": "user", "content": prompt}],
            max_tokens   = 1,
            temperature  = 0,
            logprobs     = True,
            top_logprobs = top_k,
            seed         = 0,
        )

        try:
            top = completion.choices[0].logprobs.content[0].top_logprobs
        except IndexError:
            raise RuntimeError("OpenAI response missing logprobs")

        tok_prob = {el.token: math.exp(el.logprob) for el in top}
        probs    = {
            lbl: tok_prob.get(self._first_token[lbl], 0.0)
            for lbl in TONE_LABELS
        }
        best_lbl = max(probs, key=probs.get)
        return best_lbl, probs

# Output Classifier

In [None]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import LabelEncoder
import joblib, os, hashlib, json, numpy as np
from typing import List, Callable

def build_generation_text_classifier(
    dataset,
    unique_tones: List[str],
    *,
    base_model, tokenizer,
    build_prompt_fn: Callable[[str, str], str],
    batch_generate_fn: Callable[..., List[str]],
    model_name_for_hash: str,
    layer_idx: int = 0,
    max_new_tokens: int = 64,
    batch_size: int = 16,
    cache_path: str = "tone_gen_text_clf.joblib",
) -> Callable[[List[str]], List[str]]:
    prompts, labels = [], []
    for row in dataset:
        prompts.append(build_prompt_fn(row["text"], row["tone"]))
        labels.append(row["tone"])

    md5 = hashlib.md5()
    md5.update(model_name_for_hash.encode())
    for p, t in zip(prompts, labels):
        md5.update(p.encode()); md5.update(t.encode())
    corpus_hash = md5.hexdigest()

    if os.path.exists(cache_path):
        saved = joblib.load(cache_path)
        if saved.get("hash") == corpus_hash:
            pipe, lbl_enc = saved["pipe"], saved["lbl_enc"]
            print("Loaded cached generation‑based text‑classifier.")
        else:
            print("Cache hash mismatch → regenerate completions & retrain.")
            pipe, lbl_enc = None, None
    else:
        pipe, lbl_enc = None, None

    if pipe is None:
        print("Generating model answers for classifier training...")
        gen_answers = []

        for i in tqdm(range(0, len(prompts), batch_size),
                      desc="Generating", unit="batch"):
            chunk_prompts = prompts[i : i + batch_size]
            outs = batch_generate_fn(
                base_model, tokenizer, chunk_prompts,
                layer_idx        = layer_idx,
                hook_fn          = None,
                max_new_tokens   = max_new_tokens,
                batch_size       = batch_size,
            )
            gen_answers.extend(outs)

        lbl_enc = LabelEncoder().fit(unique_tones)
        y = lbl_enc.transform(labels)

        pipe = make_pipeline(
            TfidfVectorizer(
                lowercase=True,
                ngram_range=(1, 2),
                max_features=50_000,
                sublinear_tf=True
            ),
            LogisticRegression(
                max_iter=1_000,
                n_jobs=-1,
                multi_class="multinomial"
            )
        )
        pipe.fit(gen_answers, y)

        joblib.dump({"hash": corpus_hash, "pipe": pipe, "lbl_enc": lbl_enc},
                    cache_path)

    def predict_fn(text_list: List[str]) -> List[str]:
        y_pred = pipe.predict(text_list)
        return lbl_enc.inverse_transform(y_pred).tolist()

    return predict_fn

In [None]:
gen_clf_fn = build_generation_text_classifier(
    dataset          = dataset,
    unique_tones     = unique_tones,
    base_model       = model,
    tokenizer        = tokenizer,
    build_prompt_fn  = build_prompt,
    batch_generate_fn= batch_generate,
    model_name_for_hash = model_name,
    layer_idx        = 22,
    max_new_tokens   = 64,
    batch_size       = 256,
    cache_path       = "tone_gen_text_clf.joblib",
)

# Evaluation

In [None]:
@contextmanager
def temp_forward_hook(layer, hook_fn):
    saved = layer._forward_hooks.copy()
    layer._forward_hooks.clear()
    handle = None
    try:
        if hook_fn is not None:
            handle = layer.register_forward_hook(hook_fn)
        yield
    finally:
        if handle is not None:
            handle.remove()
        layer._forward_hooks.clear()
        layer._forward_hooks.update(saved)

def my_hook_wrapper(fwd_hook):
    def actual_hook(module, inp, out):
        if fwd_hook is None:
            return out
        else:
            return fwd_hook(module, inp, out)
    return actual_hook

def get_remove_proj_hook(steer_model, target_labels=None, avoid_labels=None):
    if target_labels is None: target_labels = []
    if avoid_labels is None: avoid_labels = []

    def fwd_hook(module, inp, out):
        hidden_states = out[0]
        hidden_np = hidden_states.detach().cpu().numpy().astype(np.float32)
        B, S, D = hidden_np.shape
        hidden_2d = hidden_np.reshape(-1, D)

        new_2d = steer_model.remove_projection(hidden_2d, target_labels=target_labels, avoid_labels=avoid_labels)
        new_np = new_2d.reshape(B, S, D)
        new_hidden_torch = torch.from_numpy(new_np).to(hidden_states.device, dtype=torch.float16)
        return (new_hidden_torch,) + out[1:]
    return fwd_hook

def get_gradient_hook(steer_model, target_labels=None, avoid_labels=None, alpha=1.0):
    if target_labels is None: target_labels = []
    if avoid_labels is None: avoid_labels = []

    def fwd_hook(module, inp, out):
        hidden_states = out[0]
        hidden_np = hidden_states.detach().cpu().numpy().astype(np.float32)
        B, S, D = hidden_np.shape
        hidden_2d = hidden_np.reshape(-1, D)

        new_2d = steer_model.steer_activations(hidden_2d,
                                               target_labels=target_labels,
                                               avoid_labels=avoid_labels,
                                               alpha=alpha)
        new_np = new_2d.reshape(B, S, D)
        new_hidden_torch = torch.from_numpy(new_np).to(hidden_states.device, dtype=torch.float16)
        return (new_hidden_torch,) + out[1:]
    return fwd_hook

def get_caa_hook(caa_vector, alpha=1.0):
    def fwd_hook(module, inp, out):
        hidden_states = out[0]
        hidden_np = hidden_states.detach().cpu().numpy().astype(np.float32)
        B, S, D = hidden_np.shape
        hidden_2d = hidden_np.reshape(-1, D)

        hidden_2d += alpha * caa_vector[None, :]
        new_np = hidden_2d.reshape(B, S, D)
        new_hidden_torch = torch.from_numpy(new_np).to(hidden_states.device, dtype=torch.float16)
        return (new_hidden_torch,) + out[1:]
    return fwd_hook

In [None]:
async def batch_compare(
    triples: List[Tuple[str, str, str]],
    judge   : OpenAiJudge,
    max_concurrency: int = 10,
) -> List[str]:
    sem   = asyncio.Semaphore(max_concurrency)
    out   = [None] * len(triples)

    async def worker(idx: int, q: str, b: str, s: str):
        async with sem:
            out[idx] = await judge.compare(q, b, s)

    tasks = [asyncio.create_task(worker(i, *t)) for i, t in enumerate(triples)]
    for f in tqdm(asyncio.as_completed(tasks), total=len(tasks),
                  desc="LLM‑judge", leave=False):
        await f
    return out

async def eval_steering_combinations(
    *,
    eval_method      : str,
    base_model,
    tokenizer,
    prompts          : List[str],
    unique_tones     : List[str],
    caa_vectors,
    steer_model,
    layer_idx        : int = 22,
    alpha_grad       : float = 1_000.0,
    alpha_caa        : float = 2.0,
    num_target_tones : int   = 2,
    max_samples      : int   = 300,
    batch_size       : int   = 16,
    judge_parallel   : int   = 25,
    judge            = None,
    act_clf          = None,
    gen_clf_fn       : Optional[Callable[[List[str]], List[str]]] = None,
    get_layer_token_hidden_fn = None,
) -> pd.DataFrame:
    assert eval_method in {
        "llm_judge", "activation_classifier", "generation_classifier"
    }, f"Unknown eval_method: {eval_method}"

    prompts = prompts[:max_samples]

    tone2idx = {t: i for i, t in enumerate(unique_tones)}
    N        = float(len(prompts))

    print("Generating BASE completions …")
    base_ans = batch_generate(
        base_model, tokenizer, prompts,
        layer_idx=layer_idx, hook_fn=None,
        batch_size=batch_size
    )

    if eval_method == "activation_classifier":
        assert act_clf is not None, "Must provide act_clf=... for activation_classifier"
        assert get_layer_token_hidden_fn is not None, "Must provide get_layer_token_hidden_fn=..."
        base_act = get_layer_token_hidden_fn(prompts)

    combos = list(combinations(range(len(unique_tones)), num_target_tones))
    rows   = []

    from tqdm.notebook import tqdm

    for combo in tqdm(combos, desc=f"{num_target_tones}-tone combos"):
        tgt_idx   = list(combo)
        tgt_names = [unique_tones[i] for i in tgt_idx]
        tgt_set   = set(tgt_names)

        grad_hook = get_gradient_hook(
            steer_model, target_labels=tgt_idx, avoid_labels=[], alpha=alpha_grad
        )
        caa_vec   = caa_vectors[tgt_idx].mean(axis=0)
        caa_hook  = get_caa_hook(caa_vec, alpha=alpha_caa)

        counts = {"grad": 0, "caa": 0}

        if eval_method == "llm_judge":
            grad_ans = batch_generate(
                base_model, tokenizer, prompts,
                layer_idx=layer_idx, hook_fn=grad_hook,
                batch_size=batch_size
            )
            caa_ans = batch_generate(
                base_model, tokenizer, prompts,
                layer_idx=layer_idx, hook_fn=caa_hook,
                batch_size=batch_size
            )

            triples, where = [], []
            for q, b, g, c in zip(prompts, base_ans, grad_ans, caa_ans):
                triples.append((q, b, g))
                where.append("grad")
                triples.append((q, b, c))
                where.append("caa")

            preds = await batch_compare(triples, judge, max_concurrency=judge_parallel)

            for method, label in zip(where, preds):
                if label in tgt_set:
                    counts[method] += 1

        elif eval_method == "activation_classifier":
            act_clf.eval()

            model_device = next(act_clf.parameters()).device

            grad_act = steer_model.steer_activations(
                base_act, target_labels=tgt_idx, avoid_labels=[], alpha=alpha_grad
            )
            grad_t   = torch.tensor(grad_act, dtype=torch.float32, device=model_device)
            with torch.no_grad():
                grad_logits = act_clf(grad_t)
            grad_preds = grad_logits.argmax(dim=1).cpu().numpy()
            for p in grad_preds:
                if unique_tones[p] in tgt_set:
                    counts["grad"] += 1

            caa_act = base_act + caa_vec[None, :]
            caa_t   = torch.tensor(caa_act, dtype=torch.float32, device=model_device)
            with torch.no_grad():
                caa_logits = act_clf(caa_t)
            caa_preds = caa_logits.argmax(dim=1).cpu().numpy()
            for p in caa_preds:
                if unique_tones[p] in tgt_set:
                    counts["caa"] += 1

        elif eval_method == "generation_classifier":
            assert gen_clf_fn is not None, "Must provide gen_clf_fn=... for generation_classifier"

            grad_ans = batch_generate(
                base_model, tokenizer, prompts,
                layer_idx=layer_idx, hook_fn=grad_hook,
                batch_size=batch_size
            )
            caa_ans = batch_generate(
                base_model, tokenizer, prompts,
                layer_idx=layer_idx, hook_fn=caa_hook,
                batch_size=batch_size
            )

            grad_preds = gen_clf_fn(grad_ans)
            caa_preds  = gen_clf_fn(caa_ans)

            for lbl in grad_preds:
                if lbl in tgt_set:
                    counts["grad"] += 1
            for lbl in caa_preds:
                if lbl in tgt_set:
                    counts["caa"] += 1

        rows.append({
            "Targets"          : ", ".join(tgt_names),
            "Grad_MeanHitRate" : counts["grad"] / N,
            "CAA_MeanHitRate"  : counts["caa"]  / N,
        })

    return pd.DataFrame(rows)

In [None]:
agg_dataset = load_dataset("Narmeen07/tone_agnostic_questions", split="train")

def build_neutral_prompt(question):
    return f"SYSTEM:\nUSER: {question}"

eval_prompts = [build_neutral_prompt(row["text"]) for row in agg_dataset]

In [None]:
from transformer_lens import HookedTransformer

openai_client = AsyncOpenAI(api_key="")
judge         = OpenAiJudge(openai_client, "gpt-4o")

df_llm = await eval_steering_combinations(
    eval_method     = "llm_judge",
    judge           = judge,
    base_model      = model,
    tokenizer       = tokenizer,
    prompts         = eval_prompts,
    unique_tones    = unique_tones,
    caa_vectors     = caa_vectors,
    steer_model     = steer_model,
    num_target_tones= 2,
)

df_llm

In [None]:
df_act = await eval_steering_combinations(
    eval_method     = "activation_classifier",
    act_clf         = eval_model.classifier,
    get_layer_token_hidden_fn = get_layer_token_hidden,
    base_model      = model,
    tokenizer       = tokenizer,
    prompts         = eval_prompts,
    unique_tones    = unique_tones,
    caa_vectors     = caa_vectors,
    steer_model     = steer_model,
    num_target_tones= 2,
)

df_act

In [None]:
df_gen = await eval_steering_combinations(
    eval_method        = "generation_classifier",
    gen_clf_fn         = gen_clf_fn,
    base_model         = model,
    tokenizer          = tokenizer,
    prompts            = eval_prompts,
    unique_tones       = unique_tones,
    caa_vectors        = caa_vectors,
    steer_model        = steer_model,
    num_target_tones   = 2,
    layer_idx          = 22,
    alpha_grad         = 500.0,
    alpha_caa          = 5.0,
    max_samples        = 100,
)

df_gen.head()