# Demo of bypassing refusal

>[Demo of bypassing refusal](#scrollTo=82acAhWYGIPx)

>>[Setup](#scrollTo=fcxHyDZw6b86)

>>>[Load model](#scrollTo=6ZOoJagxD49V)

>>>[Load harmful / harmless datasets](#scrollTo=rF7e-u20EFTe)

>>>[Tokenization utils](#scrollTo=KOKYA61k8LWt)

>>>[Generation utils](#scrollTo=gtrIK8x78SZh)

>>[Finding the "refusal direction"](#scrollTo=W9O8dm0_EQRk)

>>[Ablate "refusal direction" via inference-time intervention](#scrollTo=2EoxY5i1CWe3)

>>[Orthogonalize weights w.r.t. "refusal direction"](#scrollTo=t9KooaWaCDc_)



This notebook demonstrates our method for bypassing refusal, levaraging the insight that refusal is mediated by a 1-dimensional subspace.

Please see our [research post](https://www.lesswrong.com/posts/jGuXSZgv6qfdhMCuJ/refusal-in-llms-is-mediated-by-a-single-direction) or our [paper](https://arxiv.org/abs/2406.11717) for a more thorough treatment.

In this minimal demo, we use [Qwen-1_8B-Chat](https://huggingface.co/Qwen/Qwen-1_8B-Chat) and implement interventions and weight updates using [TransformerLens](https://github.com/neelnanda-io/TransformerLens). To extract the "refusal direction," we use just 32 harmful instructions from [AdvBench](https://github.com/llm-attacks/llm-attacks/blob/main/data/advbench/harmful_behaviors.csv) and 32 harmless instructions from [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca).

In [2]:

!pip install numpy==1.24.3  # Last 1.x version

Defaulting to user installation because normal site-packages is not writeable
Collecting numpy==1.24.3
  Downloading numpy-1.24.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (14.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.0/14.0 MB[0m [31m65.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.2.5
    Uninstalling numpy-2.2.5:
      Successfully uninstalled numpy-2.2.5
Successfully installed numpy-1.24.3


## Setup

In [1]:
%%capture
!pip install torchvision transformers transformers_stream_generator tiktoken transformer_lens einops jaxtyping colorama numpy scikit-learn python-dotenv openai

In [1]:
import torch
import functools
import einops
import requests
import pandas as pd
import io
import textwrap
import gc

from datasets import load_dataset
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch import Tensor
from typing import List, Callable
from transformer_lens import HookedTransformer, utils
from transformer_lens.hook_points import HookPoint
from transformers import AutoTokenizer
from jaxtyping import Float, Int
from colorama import Fore
from sklearn.model_selection import train_test_split
import random
import numpy as np



  from .autonotebook import tqdm as notebook_tqdm
2025-05-16 21:46:04.074131: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747431964.083748   33840 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747431964.088826   33840 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
def set_seed(seed: int = 42):
    """
    Set seed for reproducibility across random, numpy, and torch (CPU & CUDA).

    Args:
        seed (int): The seed value to set. Default is 42.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # Ensure deterministic behavior
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(42) 

In [3]:
# Let's get the huggingface keys
from dotenv import load_dotenv

load_dotenv()

False

### Load model

In [4]:
#MODEL_PATH = 'Qwen/Qwen-1_8B-chat'
#MODEL_PATH = 'Llama-2-70b-chat'
MODEL_PATH = 'meta-llama/Llama-3.2-3B-Instruct'
DEVICE = 'cuda'

model = HookedTransformer.from_pretrained_no_processing(
    MODEL_PATH,
    device=DEVICE,
    dtype=torch.float16,
    default_padding_side='left',
    #fp16=True
)

model.tokenizer.padding_side = 'left'
model.tokenizer.pad_token = model.tokenizer.eos_token 


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 14.63it/s]


Loaded pretrained model meta-llama/Llama-3.2-3B-Instruct into HookedTransformer


### Load harmful / harmless datasets

In [5]:

def get_k_ary_instruction(style="Empirical Grounding", repo_name="Narmeen07/debate_style_agnostic_questions"):
    """
    Load tone-agnostic questions from Hugging Face and format them with a specific debate style instruction.
    
    Args:
        style (str): The debate style to use for responses. 
                    Options: 'Reductio ad Absurdum', 'Appeal to Precedent', 'Straw Man Reframing',
                    'Burden of Proof Shift', 'Analogy Construction', 'Concession and Pivot',
                    'Empirical Grounding', 'Moral Framing', 'Refutation by Distinction', 'Circular Anticipation'.
                    Default is 'Empirical Grounding'.
        repo_name (str): The Hugging Face repository name containing the dataset.
                         Default is 'Narmeen07/tone_agnostic_questions'.
    
    Returns:
        tuple: (train_data, test_data) as lists of formatted prompts with debate style instructions.
    """
    # Define highly distinctive debate style instructions that will produce clearly different responses
    debate_style_instructions = {
        "Reductio ad Absurdum": "RESPOND USING REDUCTIO AD ABSURDUM DEBATE STYLE. Extend the opposing viewpoint to its logical extreme to reveal its flaws. Demonstrate how the opposing position leads to absurd or contradictory conclusions when taken to its logical end. Structure your response to systematically expose the inconsistencies or impracticalities. Use conditional statements like 'If we follow this logic, then...' to illustrate the path to absurdity. Maintain a logical flow that shows how the initial premise inevitably leads to an unacceptable or ridiculous outcome.\n\n",
        
        "Appeal to Precedent": "RESPOND USING APPEAL TO PRECEDENT DEBATE STYLE. Ground your arguments in historical examples, established case law, or widely accepted previous decisions. Reference specific historical situations, court rulings, scientific discoveries, or cultural precedents that support your position. Draw explicit connections between these precedents and the current issue. Emphasize how respecting established precedents provides consistency and reliability. Point out the risks or inconsistencies of ignoring relevant precedents. Structure your response to show how past experiences inform present considerations.\n\n",
        
        "Straw Man Reframing": "RESPOND USING STRAW MAN REFRAMING DEBATE STYLE. Recharacterize the opposing argument in simplified or exaggerated terms that are easier to refute. Present this reframed version as if it represents the original position. Systematically dismantle this reframed argument while implying you've addressed the original point. Use phrases like 'Essentially, what you're saying is...' or 'This position boils down to...' before presenting the simplified version. Follow with a thorough refutation of this reframed position. Conclude by suggesting that your refutation applies to the original argument as well.\n\n",
        
        "Burden of Proof Shift": "RESPOND USING BURDEN OF PROOF SHIFT DEBATE STYLE. Redirect the responsibility for providing evidence to your opponent rather than proving your own claim. Challenge your opponent to disprove your assertion rather than supporting it yourself. Use phrases like 'There's no evidence that disproves...' or 'Can you definitively prove that isn't the case?' Position your claims as the default assumption that should be accepted until disproven. Question the sufficiency or quality of opposing evidence while demanding an impossibly high standard of proof. Emphasize that the lack of counter-evidence strengthens your position.\n\n",
        
        "Analogy Construction": "RESPOND USING ANALOGY CONSTRUCTION DEBATE STYLE. Develop a vivid, relatable comparison between the complex issue at hand and something more familiar or intuitive. Build your argument around this carefully constructed parallel situation. Highlight specific points of similarity that support your position while addressing potential dissimilarities. Use phrases like 'This situation is similar to...' or 'To understand this concept, consider...' Ensure your analogy simplifies the complex issue without distorting its essential nature. Use the familiar scenario to guide your audience to your desired conclusion about the original issue.\n\n",
        
        "Concession and Pivot": "RESPOND USING CONCESSION AND PIVOT DEBATE STYLE. Begin by acknowledging a minor point or critique from the opposing side to establish fairness and reasonableness. Use phrases like 'While it's true that...' or 'I can concede that...' followed by 'However,' 'Nevertheless,' or 'That said,' to redirect to your stronger arguments. Ensure the conceded point is peripheral rather than central to your main argument. After the concession, pivot decisively to your strongest points with increased emphasis. Frame your pivot as providing necessary context or a more complete perspective. Use the concession to demonstrate your objectivity before delivering your more powerful counterarguments.\n\n",
        
        "Empirical Grounding": "RESPOND USING EMPIRICAL GROUNDING DEBATE STYLE. Base your arguments primarily on verifiable data, research studies, statistics, and observable outcomes rather than theory or rhetoric. Cite specific figures, percentages, study results, or historical outcomes that support your position. Present evidence in a methodical manner, explaining how each piece of data relates to your argument. Address the reliability and relevance of your sources and methods. Compare empirical results across different contexts or time periods to strengthen your case. Anticipate and address potential methodological criticisms of the evidence you present.\n\n",
        
        "Moral Framing": "RESPOND USING MORAL FRAMING DEBATE STYLE. Position the issue within a framework of ethical principles, values, and moral imperatives rather than pragmatic concerns. Identify the core moral values at stake such as justice, liberty, equality, compassion, or responsibility. Use language that evokes ethical considerations, such as 'obligation,' 'right,' 'wrong,' 'just,' or 'fair.' Appeal to widely held moral intuitions or principles. Present opposing views as morally questionable or inconsistent with important shared values. Elevate the discussion from practical matters to questions of what ought to be done. Emphasize moral consequences over practical outcomes.\n\n",
        
        "Refutation by Distinction": "RESPOND USING REFUTATION BY DISTINCTION DEBATE STYLE. Identify crucial differences that invalidate comparisons or principles your opponent has applied. Carefully delineate categories, contexts, or circumstances that demonstrate why a general rule or example doesn't apply in this specific case. Use phrases like 'While that may be true in some contexts...' or 'We must distinguish between...' Emphasize the precision of definitions and classifications. Highlight subtle but significant differences that undermine the opponent's logic. Show how these distinctions fundamentally change the assessment of the situation. Demonstrate how recognizing these distinctions leads to a different conclusion than your opponent reached.\n\n",
        
        "Circular Anticipation": "RESPOND USING CIRCULAR ANTICIPATION DEBATE STYLE. Preemptively identify and address the most likely counterarguments before your opponent can make them. Introduce opposing points with phrases like 'Some might argue...' or 'One could object that...' followed by your prepared refutation. Structure your response to cover all major potential objections. Demonstrate that you've thoroughly considered the issue from multiple angles. Frame potential counterarguments in ways that make them easier to dismantle. Create the impression that all reasonable objections have already been considered and overcome. Conclude by suggesting that any remaining objections would be similarly flawed.\n\n"
    }
    
    # Validate the debate style
    if style not in debate_style_instructions:
        valid_styles = list(debate_style_instructions.keys())
        raise ValueError(f"Invalid debate style '{style}'. Please choose from {valid_styles}.")
    
    # Load dataset from Hugging Face
    try:
        from datasets import load_dataset
        from sklearn.model_selection import train_test_split
        dataset = load_dataset(repo_name, split="train")
    except Exception as e:
        raise ValueError(f"Failed to load dataset from '{repo_name}': {e}")
    
    # Format questions with the debate style instruction
    def format_prompt(example):
        return {
            "formatted_text": f"{debate_style_instructions[style]}{example['text']}"
        }
    
    # Apply transformation
    formatted_dataset = dataset.map(format_prompt)
    
    # Split into train and test sets
    train_data, test_data = train_test_split(
        formatted_dataset["formatted_text"], 
        test_size=0.2, 
        random_state=42
    )
    
    return train_data, test_data



In [6]:

#Load all the tone dataset
reductio_inst_train, reductio_inst_test = get_k_ary_instruction(style="Reductio ad Absurdum")
precedent_inst_train, precedent_inst_test = get_k_ary_instruction(style="Appeal to Precedent")
strawman_inst_train, strawman_inst_test = get_k_ary_instruction(style="Straw Man Reframing")
burden_inst_train, burden_inst_test = get_k_ary_instruction(style="Burden of Proof Shift")
analogy_inst_train, analogy_inst_test = get_k_ary_instruction(style="Analogy Construction")
concession_inst_train, concession_inst_test = get_k_ary_instruction(style="Concession and Pivot")
empirical_inst_train, empirical_inst_test = get_k_ary_instruction(style="Empirical Grounding")
moral_inst_train, moral_inst_test = get_k_ary_instruction(style="Moral Framing")
refutation_inst_train, refutation_inst_test = get_k_ary_instruction(style="Refutation by Distinction")
circular_inst_train, circular_inst_test = get_k_ary_instruction(style="Circular Anticipation")


### Tokenization utils

In [7]:
QWEN_CHAT_TEMPLATE = """<|im_start|>user
{instruction}<|im_end|>
<|im_start|>assistant
"""


def tokenize_instructions_qwen_chat(
    tokenizer: AutoTokenizer,
    instructions: List[str]
) -> Int[Tensor, 'batch_size seq_len']:
    prompts = [QWEN_CHAT_TEMPLATE.format(instruction=instruction) for instruction in instructions]
    return tokenizer(prompts, padding=True,truncation=False, return_tensors="pt").input_ids

tokenize_instructions_fn = functools.partial(tokenize_instructions_qwen_chat, tokenizer=model.tokenizer)

### Generation utils

In [8]:
def _generate_with_hooks(
    model: HookedTransformer,
    toks: Int[Tensor, 'batch_size seq_len'],
    max_tokens_generated: int = 64,
    fwd_hooks = [],
) -> List[str]:

    all_toks = torch.zeros((toks.shape[0], toks.shape[1] + max_tokens_generated), dtype=torch.long, device=toks.device)
    all_toks[:, :toks.shape[1]] = toks

    for i in range(max_tokens_generated):
        with model.hooks(fwd_hooks=fwd_hooks):
            logits = model(all_toks[:, :-max_tokens_generated + i])
            next_tokens = logits[:, -1, :].argmax(dim=-1) # greedy sampling (temperature=0)
            all_toks[:,-max_tokens_generated+i] = next_tokens

    return model.tokenizer.batch_decode(all_toks[:, toks.shape[1]:], skip_special_tokens=True)

def get_generations(
    model: HookedTransformer,
    instructions: List[str],
    tokenize_instructions_fn: Callable[[List[str]], Int[Tensor, 'batch_size seq_len']],
    fwd_hooks = [],
    max_tokens_generated: int = 64,
    batch_size: int = 1,
) -> List[str]:

    generations = []

    for i in tqdm(range(0, len(instructions), batch_size)):
        toks = tokenize_instructions_fn(instructions=instructions[i:i+batch_size])
        generation = _generate_with_hooks(
            model,
            toks,
            max_tokens_generated=max_tokens_generated,
            fwd_hooks=fwd_hooks,
        )
        generations.extend(generation)

    return generations

## Let's process the activations for all the tones

In [9]:
import numpy as np
import torch


def batch_process_activations(instructions, name):
    """Process instructions in batches to manage memory usage."""
    acts_list = []
    
    for start_idx in range(0,len(instructions), BATCH_SIZE):
        end_idx = min(start_idx + BATCH_SIZE, len(instructions))
        print(f"Processing {name} batch {start_idx} to {end_idx}")
        
        # Tokenize this batch of instructions
        toks = tokenize_instructions_fn(instructions=instructions[start_idx:end_idx])
        
        # Run model with cache
        logits, cache = model.run_with_cache(
            toks, 
            names_filter=lambda hook_name: 'resid' in hook_name
        )
        
        # Extract only the activations we need (layer and position)
        batch_acts = cache['resid_pre', layer][:, pos, :].detach().cpu().numpy()
        
        # Store these activations
        acts_list.append(batch_acts)
        
        # Clear cache to free memory
        del cache
        del logits
        torch.cuda.empty_cache()
    
    # Combine all the batch activations
    acts = np.concatenate(acts_list, axis=0)
    print(f"Final shape of {name} activations:", acts.shape)
    return acts
'''
# Configuration
TOTAL_SAMPLES = 600  # Change this to your desired total
BATCH_SIZE = 16  # Current max you can handle in one batch
layer = 14
pos = -1
# Dictionary to hold all activations
all_activations = {}

# Process harmful and harmless instructions
#all_activations['harmful'] = batch_process_activations(harmful_inst_train[:TOTAL_SAMPLES], "harmful")
#all_activations['harmless'] = batch_process_activations(harmless_inst_train[:TOTAL_SAMPLES], "harmless")

# Process debate style instructions
debate_style_datasets = {
    'reductio': reductio_inst_train,
    'precedent': precedent_inst_train,
    'strawman': strawman_inst_train,
    'burden': burden_inst_train,
    'analogy': analogy_inst_train,
    'concession': concession_inst_train,
    'empirical': empirical_inst_train,
    'moral': moral_inst_train,
    'refutation': refutation_inst_train,
    'circular': circular_inst_train
}

for style_name, style_data in debate_style_datasets.items():
    all_activations[style_name] = batch_process_activations(
        style_data[:TOTAL_SAMPLES], 
        style_name
    )
'''

'\n# Configuration\nTOTAL_SAMPLES = 600  # Change this to your desired total\nBATCH_SIZE = 16  # Current max you can handle in one batch\nlayer = 14\npos = -1\n# Dictionary to hold all activations\nall_activations = {}\n\n# Process harmful and harmless instructions\n#all_activations[\'harmful\'] = batch_process_activations(harmful_inst_train[:TOTAL_SAMPLES], "harmful")\n#all_activations[\'harmless\'] = batch_process_activations(harmless_inst_train[:TOTAL_SAMPLES], "harmless")\n\n# Process debate style instructions\ndebate_style_datasets = {\n    \'reductio\': reductio_inst_train,\n    \'precedent\': precedent_inst_train,\n    \'strawman\': strawman_inst_train,\n    \'burden\': burden_inst_train,\n    \'analogy\': analogy_inst_train,\n    \'concession\': concession_inst_train,\n    \'empirical\': empirical_inst_train,\n    \'moral\': moral_inst_train,\n    \'refutation\': refutation_inst_train,\n    \'circular\': circular_inst_train\n}\n\nfor style_name, style_data in debate_style_datas

In [10]:
all_activations_caa = torch.load("../activations_cache/debate_activations_with_neutral.pt",weights_only=False)
all_activations = torch.load("../activations_cache/debate_activations.pt",weights_only=False)
debate_style_datasets = {
    'reductio': reductio_inst_train,
    'precedent': precedent_inst_train,
    'strawman': strawman_inst_train,
    'burden': burden_inst_train,
    'analogy': analogy_inst_train,
    'concession': concession_inst_train,
    'empirical': empirical_inst_train,
    'moral': moral_inst_train,
    'refutation': refutation_inst_train,
    'circular': circular_inst_train,
}


## Calculating steering vectors 

In [11]:
# compute difference of means between harmful and harmless activations at an intermediate layer
pos = -1
layer = 14
steering_dirs = {}

for style_name, _ in debate_style_datasets.items():
    candidate_vec = torch.tensor(
        np.mean(all_activations_caa[style_name], axis=0) - np.mean(all_activations_caa['neutral'], axis=0),
    )
    candidate_vec = candidate_vec / candidate_vec.norm()
    steering_dirs[style_name] = candidate_vec


In [12]:
def caa_ablation_hook(
    activation: Float[Tensor, "... d_act"],
    hook: HookPoint,
    steering_dir: Float[Tensor, "d_act"],
    avoid_styles: list[str],
) -> Float[Tensor, "... d_act"]:
    # Combine steering directions
    vec = steering_dir.to(device=activation.device, dtype=activation.dtype)
    
    # Normalize the vector (L2)
    proj = einops.einsum(activation, vec.view(-1, 1), '... d_act, d_act single -> ... single') * vec

    # Apply scaled direction
    return activation - proj

## Let's steer against all the tones

In [13]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class MultiToneActivationSteering:
    def __init__(self, input_dim, num_classes=6, hidden_dim=256, lr=1e-3, device='cuda'):
        """
        Parameters:
          - input_dim: dimensionality of your raw activations.
          - num_classes: number of tone classes to classify.
          - hidden_dim: size of the hidden layer in the MLP.
          - lr: learning rate for training the classifier.
          - device: 'cuda' or 'cpu'.
        """
        self.device = device
        self.num_classes = num_classes
        
        # Define a multi-class MLP classifier
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)  # outputs logits for each tone class
        ).to(device)
        
        self.optimizer = optim.Adam(self.classifier.parameters(), lr=lr)
        self.loss_fn = nn.CrossEntropyLoss()
        
        # Store class names for reference
        self.class_names = None
        

    def fit(self, activations_dict, epochs=30, batch_size=32, class_weights=None):
        """
        Train the multiclass classifier on raw activations.

        Parameters:
          - activations_dict: Dictionary mapping class names to NumPy arrays of activations.
          - epochs: number of training epochs.
          - batch_size: training batch size.
          - class_weights: Optional tensor of weights for each class to handle imbalance.
        """
        self.class_names = list(activations_dict.keys())
        assert len(self.class_names) == self.num_classes, f"Expected {self.num_classes} classes, got {len(self.class_names)}"
        
        # Prepare training data
        X_list = []
        y_list = []
        
        for i, class_name in enumerate(self.class_names):
            X_class = torch.tensor(activations_dict[class_name], dtype=torch.float32, device=self.device)
            y_class = torch.full((X_class.size(0),), i, dtype=torch.long, device=self.device)
            X_list.append(X_class)
            y_list.append(y_class)
        
        X = torch.cat(X_list, dim=0)
        y = torch.cat(y_list, dim=0)
        
        dataset = torch.utils.data.TensorDataset(X, y)
        loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        
        # Set class weights for loss function if provided
        if class_weights is not None:
            self.loss_fn = nn.CrossEntropyLoss(weight=class_weights.to(self.device))
        
        # Train the classifier
        for epoch in range(epochs):
            self.classifier.train()
            epoch_loss = 0.0
            correct = 0
            total = 0
            
            for batch_X, batch_y in loader:
                self.optimizer.zero_grad()
                logits = self.classifier(batch_X)  # shape: (batch, num_classes)
                loss = self.loss_fn(logits, batch_y)
                loss.backward()
                self.optimizer.step()
                
                epoch_loss += loss.item()
                
                # Calculate accuracy
                _, predicted = torch.max(logits.data, 1)
                total += batch_y.size(0)
                correct += (predicted == batch_y).sum().item()
            
            accuracy = 100 * correct / total
            print(f"Epoch {epoch}: loss = {epoch_loss/len(loader):.4f}, accuracy = {accuracy:.2f}%")
            
        # Evaluate on each class
        self.classifier.eval()
        class_accuracies = {}
        
        with torch.no_grad():
            for i, class_name in enumerate(self.class_names):
                X_class = torch.tensor(activations_dict[class_name], dtype=torch.float32, device=self.device)
                y_class = torch.full((X_class.size(0),), i, dtype=torch.long, device=self.device)
                
                logits = self.classifier(X_class)
                _, predicted = torch.max(logits.data, 1)
                accuracy = 100 * (predicted == y_class).sum().item() / y_class.size(0)
                class_accuracies[class_name] = accuracy
                print(f"Accuracy for {class_name}: {accuracy:.2f}%")

    def steer_toward_tones(self, activation, target_tones=None, avoid_tones=None, alpha=0.1, steps=1, decay_rate=0.95):
        """
        Steer the given activation toward target tones and away from tones to avoid.
        
        Parameters:
        - activation: a NumPy array of shape (D,) or (N, D).
        - target_tones: list of tone names to steer toward.
        - avoid_tones: list of tone names to steer away from.
        - alpha: scaling factor for the update.
        - steps: number of gradient descent steps to take.
        
        Returns:
        - The modified activation as a NumPy array.
        """
        
        # Validate input tone names
        if target_tones:
            for tone in target_tones:
                assert tone in self.class_names, f"Unknown target tone: {tone}"
        
        if avoid_tones:
            for tone in avoid_tones:
                assert tone in self.class_names, f"Unknown avoid tone: {tone}"
        
        # Get indices for target and avoid tones
        target_indices = [self.class_names.index(tone) for tone in target_tones] if target_tones else []
        avoid_indices = [self.class_names.index(tone) for tone in avoid_tones] if avoid_tones else []
        
        # If a single activation vector is passed, add a batch dimension
        single_input = False
        if activation.ndim == 1:
            activation = activation[None, :]
            single_input = True
        
        # Convert activation to a torch tensor
        X = torch.tensor(activation, dtype=torch.float32, device=self.device)
        
        current_alpha = alpha
        
        # Perform multiple steps of gradient descent
        for step in range(steps):
            if step > 0:
                current_alpha *= decay_rate
            # Need to create a new tensor that requires gradients for each step
            X = X.detach().requires_grad_(True)
            
            # Forward pass: compute classifier output (logits)
            self.classifier.eval()
            logits = self.classifier(X)  # shape: [N, num_classes]
            
            # Create custom loss function to maximize target tone scores and minimize avoid tone scores
            loss = 0
            
            if target_indices:
                target_logits = logits[:, target_indices]
                # Negative because we want to maximize these logits (gradient descent will minimize)
                loss = loss - target_logits.mean()
            
            if avoid_indices:
                avoid_logits = logits[:, avoid_indices]
                # Positive because we want to minimize these logits
                loss = loss + avoid_logits.mean()
            
            # If no target or avoid tones provided, do nothing
            if not target_indices and not avoid_indices:
                if single_input:
                    return activation[0]
                else:
                    return activation
            
            # Compute gradients
            loss.backward()
            
            # Get the gradient with respect to the input activation
            grad = X.grad.data
            
            # Update the activation by moving in the loss gradient direction
            # Negative gradient because we're trying to minimize the loss

            X = X - current_alpha * grad
        
        # Convert back to NumPy
        X_new_np = X.detach().cpu().numpy()
        if single_input:
            return X_new_np[0]
        else:
            return X_new_np
    
    def remove_tone_projection(self, activation, target_tones=None, avoid_tones=None):
        """
        Remove the projection of the activation onto the tone steering vector.
        
        Parameters:
        - activation: a NumPy array of shape (D,) or (N, D).
        - target_tones: list of tone names to steer toward.
        - avoid_tones: list of tone names to steer away from.
        
        Returns:
        - The modified activation as a NumPy array with the projection removed.
        """
        assert self.class_names is not None, "Model must be trained before steering"
        
        # Validate input tone names
        if target_tones:
            for tone in target_tones:
                assert tone in self.class_names, f"Unknown target tone: {tone}"
        
        if avoid_tones:
            for tone in avoid_tones:
                assert tone in self.class_names, f"Unknown avoid tone: {tone}"
        
        # Get indices for target and avoid tones
        target_indices = [self.class_names.index(tone) for tone in target_tones] if target_tones else []
        avoid_indices = [self.class_names.index(tone) for tone in avoid_tones] if avoid_tones else []
        
        # If a single activation vector is passed, add a batch dimension
        single_input = False
        if activation.ndim == 1:
            activation = activation[None, :]
            single_input = True
        
        # Convert activation to a torch tensor that requires gradients
        X = torch.tensor(activation, dtype=torch.float32, device=self.device, requires_grad=True)
        
        # Forward pass: compute classifier output (logits)
        self.classifier.eval()
        logits = self.classifier(X)  # shape: [N, num_classes]
        
        # Create custom loss function to maximize target tone scores and minimize avoid tone scores
        loss = 0
        
        if target_indices:
            target_logits = logits[:, target_indices]
            # Negative because we want to maximize these logits (gradient descent will minimize)
            loss = loss - target_logits.mean()
        
        if avoid_indices:
            avoid_logits = logits[:, avoid_indices]
            # Positive because we want to minimize these logits
            loss = loss + avoid_logits.mean()
        
        # If no target or avoid tones provided, do nothing
        if not target_indices and not avoid_indices:
            if single_input:
                return activation[0]
            else:
                return activation
        
        # Compute gradients
        loss.backward()
        
        # Get the gradient with respect to the input activation
        grad = X.grad.data  # shape: [N, D]
        
        # Compute the squared norm of the gradient for each sample (avoid division by zero)
        norm_sq = (grad ** 2).sum(dim=1, keepdim=True) + 1e-8  # shape: [N, 1]
        
        # Compute the dot product between each activation and its corresponding gradient
        dot_prod = (X * grad).sum(dim=1, keepdim=True)  # shape: [N, 1]
        
        # Compute the projection of X onto the gradient
        proj = (dot_prod / norm_sq) * grad  # shape: [N, D]
        
        # Remove the projection from X (multiply by 2 to strengthen the effect)
        X_new = X - 2 * proj
        
        # Convert back to NumPy
        X_new_np = X_new.detach().cpu().numpy()
        if single_input:
            return X_new_np[0]
        else:
            return X_new_np
    
   

In [14]:

# Extract activations at the specific layer and position from cache
reductio_acts = all_activations['reductio']  # shape [N, D]
precedent_acts = all_activations['precedent']  # shape [M, D]
strawman_acts = all_activations['strawman']  # shape [O, D]
burden_acts = all_activations['burden']  # shape [P, D]
analogy_acts = all_activations['analogy']  # shape [Q, D]
concession_acts = all_activations['concession']  # shape [R, D]
empirical_acts = all_activations['empirical']  # shape [S, D]
moral_acts = all_activations['moral']  # shape [T, D]
refutation_acts = all_activations['refutation']  # shape [U, D]
circular_acts = all_activations['circular']  # shape [V, D]

# Instantiate and fit the multi-debate-style steering classifier
debate_steer = MultiToneActivationSteering(input_dim=reductio_acts.shape[1], num_classes=10)
debate_steer.fit({
    'reductio': reductio_acts,
    'precedent': precedent_acts,
    'strawman': strawman_acts,
    'burden': burden_acts,
    'analogy': analogy_acts,
    'concession': concession_acts,
    'empirical': empirical_acts,
    'moral': moral_acts,
    'refutation': refutation_acts,
    'circular': circular_acts
}, epochs=30, batch_size=32)

torch.cuda.empty_cache()


Epoch 0: loss = 0.2118, accuracy = 95.93%
Epoch 1: loss = 0.0003, accuracy = 100.00%
Epoch 2: loss = 0.0001, accuracy = 100.00%
Epoch 3: loss = 0.0000, accuracy = 100.00%
Epoch 4: loss = 0.0000, accuracy = 100.00%
Epoch 5: loss = 0.0000, accuracy = 100.00%
Epoch 6: loss = 0.0000, accuracy = 100.00%
Epoch 7: loss = 0.0000, accuracy = 100.00%
Epoch 8: loss = 0.0000, accuracy = 100.00%
Epoch 9: loss = 0.0000, accuracy = 100.00%
Epoch 10: loss = 0.0000, accuracy = 100.00%
Epoch 11: loss = 0.0000, accuracy = 100.00%
Epoch 12: loss = 0.0000, accuracy = 100.00%
Epoch 13: loss = 0.0000, accuracy = 100.00%
Epoch 14: loss = 0.0000, accuracy = 100.00%
Epoch 15: loss = 0.0000, accuracy = 100.00%
Epoch 16: loss = 0.0000, accuracy = 100.00%
Epoch 17: loss = 0.0000, accuracy = 100.00%
Epoch 18: loss = 0.0000, accuracy = 100.00%
Epoch 19: loss = 0.0000, accuracy = 100.00%
Epoch 20: loss = 0.0000, accuracy = 100.00%
Epoch 21: loss = 0.0000, accuracy = 100.00%
Epoch 22: loss = 0.0000, accuracy = 100.00%

## 2 hooks :
- method 1: removes gradient projection
- method 2: subtracts the gradient of towards (minimise towards loss) and adds the gradients of away


In [15]:
def k_steer_hook(activation, hook,avoid_styles=['strawman']):
    """
    Steer activations toward target debate styles and away from avoid styles.
    """
    shape = activation.shape
    # Flatten to [N, d_act]
    act_2d = activation.view(-1, shape[-1]).detach().cpu().numpy()
    
    steered_2d = debate_steer.remove_tone_projection(
        act_2d, 
        avoid_tones=avoid_styles,
    )

    # Convert back to torch and original shape
    steered_torch = torch.from_numpy(steered_2d).to(activation.device)
    return steered_torch.view(*shape)


In [16]:
import sys
from pathlib import Path

# Add the parent directory to sys.path
sys.path.append(str(Path().resolve().parent))

from judges.debates_judge_avoid_only import DebateJudge
judge = DebateJudge(api_key="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")

## Let's vary the combinations, fix the step size and look at steering success across combinations

In [19]:
import functools
import json
import random
from colorama import Fore
from datetime import datetime
from itertools import permutations
from pprint import pprint

# Define your style labels
style_label_map = {
    "reductio": "Reductio ad Absurdum",
    "precedent": "Appeal to Precedent",
    "strawman": "Straw Man Reframing",
    "burden": "Burden of Proof Shift",
    "analogy": "Analogy Construction",
    "concession": "Concession and Pivot",
    "empirical": "Empirical Grounding",
    "moral": "Moral Framing",
    "refutation": "Refutation by Distinction",
    "circular": "Circular Anticipation"
}

N_INST_TEST = 20

# Make sure 'model', 'utils', etc. are defined earlier in your code
intervention_layers = list(range(model.cfg.n_layers))

act_names = ['resid_pre', 'resid_mid', 'resid_post']

# Select styles to test
selected_styles = ['burden', 'empirical', 'reductio', 'precedent', 'analogy', 'concession']

# Initialize log data container
log_data = {
    "meta": {
        "run_time": datetime.now().isoformat(),
        "n_instances": N_INST_TEST,
        "intervention_layers": intervention_layers,
        "activation_points": act_names
    },
    "combinations": []
}

# Loop through each avoid style
for avoid_style in selected_styles:
    print(f"\n### Selected Style: Avoid = {avoid_style} ###")

    input_instructions = debate_style_datasets[avoid_style][:N_INST_TEST]

    # Generate baseline completions
    baseline_generations = get_generations(
        model,
        input_instructions,
        tokenize_instructions_fn,
        fwd_hooks=[]
    )

    # --- k-steer setup ---
    k_steer_hook_fn = functools.partial(
        k_steer_hook,
        avoid_styles=[avoid_style],
    )

    k_steer_fwd_hooks = [
        (utils.get_act_name(act_name, l), k_steer_hook_fn)
        for l in intervention_layers
        for act_name in act_names
    ]

    # --- caa setup ---
    caa_hook_fn = functools.partial(
        caa_ablation_hook,
        steering_dir=steering_dirs[avoid_style],
        avoid_styles=[avoid_style],
    )

    caa_fwd_hooks = [
        (utils.get_act_name(act_name, l), caa_hook_fn)
        for l in intervention_layers
        for act_name in act_names
    ]

    # --- Generate steered generations ---
    k_steer_generations = get_generations(
        model,
        input_instructions,
        tokenize_instructions_fn,
        fwd_hooks=k_steer_fwd_hooks
    )
    caa_generations = get_generations(
        model,
        input_instructions,
        tokenize_instructions_fn,
        fwd_hooks=caa_fwd_hooks,
        batch_size=5
    )

    # --- Evaluate ---
    k_steer_result = judge.evaluate_batch(
        baseline_texts=baseline_generations,
        steered_texts=k_steer_generations,
        avoid_style=style_label_map[avoid_style]
    )
    caa_result = judge.evaluate_batch(
        baseline_texts=baseline_generations,
        steered_texts=caa_generations,
        avoid_style=style_label_map[avoid_style]
    )

    # --- Log entry ---
    entry = {
        "avoid_style": avoid_style,
        "baseline_generations": baseline_generations,
        "k_steer_steering_score": k_steer_result["average_strength"] / 5 * k_steer_result["success_rate"],
        "k_steer_generations": k_steer_generations,
        "caa_steering_score": caa_result["average_strength"] / 5 * caa_result["success_rate"],
        "caa_generations": caa_generations
    }

    log_data["combinations"].append(entry)

    # --- Print logged data immediately ---
    print("\nLogged new combination:")
    pprint(entry)

# Save results to JSON

output_path = f"gradient_projection_results.json"
with open(output_path, "w") as f:
    json.dump(log_data, f, indent=2)






### Selected Style: Avoid = burden ###


100%|██████████| 20/20 [01:51<00:00,  5.56s/it]
100%|██████████| 20/20 [04:15<00:00, 12.78s/it]
100%|██████████| 4/4 [00:26<00:00,  6.71s/it]



Logged new combination:
{'avoid_style': 'burden',
 'baseline_generations': ["I'd like to shift the burden of proof to my "
                          'opponent. Can they definitively prove that the '
                          "government's ability to impose restrictions on free "
                          'speech is never justified in instances of '
                          "misinformation that could lead to public harm? I'm "
                          "not aware of any instance where the government's "
                          'restrictions on free speech have been deemed '
                          'unjustified by',
                          "I'd like to shift the burden of proof to my "
                          'opponent. Can you definitively prove that the '
                          'responsibilities of raising children should not be '
                          "equally shared between parents? I'm not aware of "
                          'any credible evidence that suggests a '

100%|██████████| 20/20 [01:51<00:00,  5.60s/it]
100%|██████████| 20/20 [04:15<00:00, 12.77s/it]
100%|██████████| 4/4 [00:27<00:00,  6.81s/it]



Logged new combination:
{'avoid_style': 'empirical',
 'baseline_generations': ["To address this question, let's examine the "
                          'empirical evidence on the impact of government '
                          'restrictions on free speech versus the potential '
                          'harm caused by misinformation.\n'
                          '\n'
                          '**Government Restrictions on Free Speech:**\n'
                          '\n'
                          '1. **Effectiveness in Reducing Misinformation:** A '
                          'study published in the Journal of Communication '
                          '(2019) found that government-mand',
                          'The debate on whether the responsibilities of '
                          'raising children should be equally shared between '
                          'parents is a contentious issue. While some argue '
                          'that equal sharing of responsibilities is e

100%|██████████| 20/20 [01:51<00:00,  5.60s/it]
100%|██████████| 20/20 [04:15<00:00, 12.77s/it]
100%|██████████| 4/4 [00:26<00:00,  6.62s/it]



Logged new combination:
{'avoid_style': 'reductio',
 'baseline_generations': ["To address this question, let's assume the opposing "
                          "viewpoint that the government's ability to impose "
                          'restrictions on free speech is justified in '
                          'instances of misinformation that could lead to '
                          'public harm. If we follow this logic, then:\n'
                          '\n'
                          '1. **The government has the authority to dictate '
                          'what information is true and what is not**: This '
                          'implies',
                          'To address the question of whether the '
                          'responsibilities of raising children should be '
                          "equally shared between parents, let's consider the "
                          'opposing viewpoint and extend it to its logical '
                          'extreme.\n'


100%|██████████| 20/20 [01:52<00:00,  5.61s/it]
100%|██████████| 20/20 [04:15<00:00, 12.77s/it]
100%|██████████| 4/4 [00:26<00:00,  6.70s/it]



Logged new combination:
{'avoid_style': 'precedent',
 'baseline_generations': ['To address this question, we must consider the '
                          'historical context of government regulation of free '
                          'speech, particularly in instances where '
                          'misinformation poses a risk to public harm. The '
                          'landmark case of Schenck v. United States (1919) '
                          'provides a crucial precedent in this regard.\n'
                          '\n'
                          'In Schenck, the Supreme Court established the "',
                          'The notion that the responsibilities of raising '
                          'children should be equally shared between parents '
                          'is a widely debated topic. However, a closer '
                          'examination of historical precedents and '
                          'established case law reveals that this approach is '
  

100%|██████████| 20/20 [01:51<00:00,  5.59s/it]
100%|██████████| 20/20 [04:14<00:00, 12.74s/it]
100%|██████████| 4/4 [00:26<00:00,  6.72s/it]



Logged new combination:
{'avoid_style': 'analogy',
 'baseline_generations': ["The government's ability to impose restrictions on "
                          'free speech in instances of misinformation that '
                          'could lead to public harm is a complex issue that '
                          'can be likened to a fire alarm system in a crowded '
                          'theater. Just as a fire alarm is designed to alert '
                          "people to a potential danger, the government's "
                          'restrictions on free speech can be seen',
                          'The debate on whether the responsibilities of '
                          'raising children should be equally shared between '
                          'parents is a contentious issue. To understand this '
                          'concept, consider the analogy of a shared household '
                          'budget. Just as a household budget requires both '
              

100%|██████████| 20/20 [01:53<00:00,  5.69s/it]
100%|██████████| 20/20 [04:14<00:00, 12.72s/it]
100%|██████████| 4/4 [00:26<00:00,  6.65s/it]



Logged new combination:
{'avoid_style': 'concession',
 'baseline_generations': ["I can concede that the government's primary concern "
                          'in instances of misinformation is to protect the '
                          "public from harm. While it's true that the "
                          "government's role in regulating speech can be a "
                          'delicate balance between freedom of expression and '
                          'public safety, I acknowledge that there are '
                          'situations where restrictions on free speech may be '
                          'necessary to prevent harm.\n'
                          '\n',
                          "While it's true that some argue that equal sharing "
                          'of childcare responsibilities is essential for '
                          'promoting gender equality and reducing the burden '
                          'on one parent, I can concede that this perspective '