In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import random
import re
from tqdm import tqdm
import json
from collections import defaultdict

In [2]:
model_name = 'Qwen/Qwen-1_8B-Chat'
device = "cuda"
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    device_map="auto",
)
model.eval()

ValueError: The repository `Qwen/Qwen-1_8B-Chat` contains custom code which must be executed to correctly load the model. You can inspect the repository content at https://hf.co/Qwen/Qwen-1_8B-Chat.
Please pass the argument `trust_remote_code=True` to allow custom code to be run.

In [None]:
QWEN_CHAT_TEMPLATE_old = """<|im_start|>user
{instruction}<|im_end|>
<|im_start|>assistant
"""

def format_prompt_old(prompt):
    """Format prompt with Qwen chat template"""
    return QWEN_CHAT_TEMPLATE_old.format(prompt=prompt)

In [1]:
import torch, contextlib, functools
from typing import List, Tuple, Callable
from jaxtyping import Float
from torch import Tensor

# ------------------ context manager ------------------------------
@contextlib.contextmanager
def add_hooks(
    module_forward_pre_hooks: List[Tuple[torch.nn.Module, Callable]],
    module_forward_hooks: List[Tuple[torch.nn.Module, Callable]],
    **kwargs
):
    try:
        handles = []
        for module, hook in module_forward_pre_hooks:
            partial_hook = functools.partial(hook, **kwargs)
            handles.append(module.register_forward_pre_hook(partial_hook))
        for module, hook in module_forward_hooks:
            partial_hook = functools.partial(hook, **kwargs)
            handles.append(module.register_forward_hook(partial_hook))
        yield
    finally:
        for h in handles:
            h.remove()

# ------------------ single-layer hook factories ------------------
def get_direction_ablation_input_pre_hook(direction: Tensor):
    def hook_fn(module, input):
        nonlocal direction
        activation = input[0] if isinstance(input, tuple) else input
        direction  = direction / (direction.norm(dim=-1, keepdim=True) + 1e-8)
        direction  = direction.to(activation)
        activation -= (activation @ direction).unsqueeze(-1) * direction
        return (activation, *input[1:]) if isinstance(input, tuple) else activation
    return hook_fn

def get_direction_ablation_output_hook(direction: Tensor):
    def hook_fn(module, input, output):
        nonlocal direction
        activation = output[0] if isinstance(output, tuple) else output
        direction  = direction / (direction.norm(dim=-1, keepdim=True) + 1e-8)
        direction  = direction.to(activation)
        activation -= (activation @ direction).unsqueeze(-1) * direction
        return (activation, *output[1:]) if isinstance(output, tuple) else activation
    return hook_fn

def get_directional_patching_input_pre_hook(direction: Tensor, coeff: Float[Tensor, ""]):
    def hook_fn(module, input):
        nonlocal direction
        activation = input[0] if isinstance(input, tuple) else input
        direction  = direction / (direction.norm(dim=-1, keepdim=True) + 1e-8)
        direction  = direction.to(activation)
        activation -= (activation @ direction).unsqueeze(-1) * direction
        activation += coeff * direction
        return (activation, *input[1:]) if isinstance(input, tuple) else activation
    return hook_fn

def get_activation_addition_input_pre_hook(vector: Tensor, coeff: Float[Tensor, ""]):
    def hook_fn(module, input):
        nonlocal vector
        activation = input[0] if isinstance(input, tuple) else input
        vector     = vector.to(activation)
        activation += coeff * vector
        return (activation, *input[1:]) if isinstance(input, tuple) else activation
    return hook_fn

# ------------------ helper: build full-model hook lists -----------
def get_all_direction_ablation_hooks(model_base, direction: Float[Tensor, "d_model"]):
    fwd_pre = [
        (model_base.model_block_modules[l],
         get_direction_ablation_input_pre_hook(direction))
        for l in range(model_base.model.config.num_hidden_layers)
    ]
    fwd_post = [
        (model_base.model_attn_modules[l],
         get_direction_ablation_output_hook(direction))
        for l in range(model_base.model.config.num_hidden_layers)
    ] + [
        (model_base.model_mlp_modules[l],
         get_direction_ablation_output_hook(direction))
        for l in range(model_base.model.config.num_hidden_layers)
    ]
    return fwd_pre, fwd_post
