## High-level Goals:
### Get the refusal direction
1. Implement context manager for adding hooks
    - Three args: pre hook and post hook `list[tuple[module, hook_fn]]`, and `**kwargs`
    - use `functools.partial` to pass `**kwargs` to `hook_fn`
    - collect list of handles `handles.append(module.register_forward_hook(hook_fn))`
    - in `finally` block remove all 
2. Implement function that gets the mean activations before every hook
    - short, just define a hook_fn that adds a hook to a module
    - then add `1/n`th of its activations to the correct layer in a cache
    - https://github.com/andyrdt/refusal_direction/blob/main/pipeline/submodules/generate_directions.py
3. Implement function that gets the mean activations over all layers in model
    - setup all the hook functions from step 2
    - then run the model over the passed dataset
4. Implement function that takes the difference between the harmless and harmful activations
    - simple difference

### Ablate the refusal direction
https://github.com/andyrdt/refusal_direction/blob/main/pipeline/utils/hook_utils.py
1. post hook ablate refusal direction (input can be just hidden states, or tuple(hidden_states, *extras))
2. pre hook ablate refusal direction
3. Implement new context manager for adding hooks (this time we're adding only one hook to all locations)
    - actually we can probably just reuse the old one


- Block modules (in Julian's code) are just `model.base_model.model.model.layers`
    - If we collapse the PEFT: `model_var_name.model.model.layers`
- In Andy's code: pre hooks applied before all layers, post hooks applied after attention and MLP:
```python
fwd_pre_hooks = [
    (model_base.model_block_modules[layer], get_direction_ablation_input_pre_hook(direction=direction)) 
    for layer in range(model_base.model.config.num_hidden_layers)
]
fwd_hooks = [
    (model_base.model_attn_modules[layer], get_direction_ablation_output_hook(direction=direction)) 
    for layer in range(model_base.model.config.num_hidden_layers)
]
fwd_hooks += [
    (model_base.model_mlp_modules[layer], get_direction_ablation_output_hook(direction=direction)) 
    for layer in range(model_base.model.config.num_hidden_layers)
]
```

Relevant links for future:
https://github.com/search?q=repo%3Aandyrdt%2Frefusal_direction+add_hooks&type=code
https://github.com/andyrdt/refusal_direction/blob/main/pipeline/submodules/generate_directions.py
https://github.com/andyrdt/refusal_direction/blob/main/pipeline/utils/hook_utils.py
https://github.com/andyrdt/refusal_direction/blob/main/pipeline/model_utils/model_base.py

In [1]:
import os
os.environ['HF_HOME'] = '/net/scratch2/xlab-ai-security/.cache'


SYSTEM_PROMPT = "You are a helpful AI assistant."
CHAT_TEMPLATE = (
    "<|system|>\n" + SYSTEM_PROMPT + "</s>\n<|user|>\n{instruction}</s>\n<|assistant|>"
)


def get_eoi_toks(tokenizer):
    eoi_toks = tokenizer.encode(
        CHAT_TEMPLATE.split("{instruction}")[-1], add_special_tokens=False
    )
    positions = list(
        range(-len(eoi_toks), 0)
    )  # all post-instruction tokens, or positions=[-1]
    return positions


def tokenize_instructions(instructions, tokenizer):
    prompts = [CHAT_TEMPLATE.format(instruction=s) for s in instructions]
    return tokenizer(
        prompts, padding=True, truncation=True, max_length=512, return_tensors="pt"
    )

In [None]:
import random
import torch
import contextlib
import requests
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from collections.abc import Callable
from functools import partial
from torch.nn import Module
from torch import Tensor
from tqdm import tqdm

DEVICE = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)


@contextlib.contextmanager
def add_hooks(
    pre_hooks: list[tuple[Module, Callable]],
    post_hooks: list[tuple[Module, Callable]],
    **kwargs,
):
    handles = []
    try:
        for module, hook_fn in pre_hooks:
            # partial_hook = partial(hook_fn, **kwargs)
            # handles.append(module.register_forward_pre_hook(partial_hook))
            handles.append(module.register_forward_pre_hook(hook_fn, with_kwargs=False))
        for module, hook_fn in post_hooks:
            partial_hook = partial(hook_fn, **kwargs)
            # handles.append(module.register_forward_hook(hook_fn))
            handles.append(module.register_forward_hook(hook_fn, with_kwargs=False))
        yield
    finally:
        for handle in handles:
            handle.remove()


def get_mean_activation_pre_hook(
    layer: int, cache: Tensor, samples: int, positions: list[int]
):
    """
    cache: [num_layers, positions, hidden_size]
    positions: list of tokens positions to cache
    """

    def hook_fn(module, input):
        # hidden state = [batch_size, seq_len, hidden_dim]
        hidden_state = input[0].clone().to(cache.device, dtype=cache.dtype)
        # adding [positions, hidden_dim] to cache[layer]
        cache[layer] = (1 / samples) * hidden_state[:, positions, :].sum(dim=0)

    return hook_fn


def get_all_mean_activations(
    model,
    tokenizer,
    tokenize_instructions,
    instructions,
    block_modules: list[torch.nn.Module],
    batch_size=32,
    positions=[-1],
):
    """
    model: model
    tokenizer: the model's tokenizer
    tokenize_instructions: function that tokenizes a batch of instructions
    instructions: list of strings
    block_modules: layers of the transformer model
    batch_size: batch size for modeling on instructions
    positions: token positions for hidden state capture
    """
    # create cache tensor -> [num_layers, positions, hidden_size]
    n_layers = model.config.num_hidden_layers
    n_positions = len(positions)
    hidden_dim = model.config.hidden_size
    cache = torch.zeros((n_layers, n_positions, hidden_dim), dtype=torch.float64)

    samples = len(instructions)

    # create list of pre hooks
    pre_hooks = []
    for layer in range(n_layers):
        pre_hooks.append(
            (
                block_modules[layer],
                get_mean_activation_pre_hook(
                    layer=layer,
                    cache=cache,
                    samples=samples,
                    positions=positions,
                ),
            )
        )
    # pass model over dataset
    with torch.inference_mode():
        for i in tqdm(range(0, samples, batch_size)):
            batch = tokenize_instructions(
                instructions=instructions[i : i + batch_size], tokenizer=tokenizer
            )
            with add_hooks(pre_hooks=pre_hooks, post_hooks=[]):
                model(
                    input_ids=batch.input_ids.to(model.device),
                    attention_mask=batch.attention_mask.to(model.device),
                )
    return cache


def get_mean_diff(
    model,
    tokenizer,
    harmful_instructions,
    harmless_instructions,
    block_modules,
    batch_size,
    positions,
):
    harmless_dir = get_all_mean_activations(
        model=model,
        tokenizer=tokenizer,
        tokenize_instructions=tokenize_instructions,
        instructions=harmless_instructions,
        block_modules=block_modules,
        batch_size=batch_size,
        positions=positions,
    )
    harmful_dir = get_all_mean_activations(
        model=model,
        tokenizer=tokenizer,
        tokenize_instructions=tokenize_instructions,
        instructions=harmful_instructions,
        block_modules=block_modules,
        batch_size=batch_size,
        positions=positions,
    )
    return harmful_dir - harmless_dir

In [57]:
def get_directional_ablation_pre_hook(direction):
    def hook_fn(module, input):
        # input can be just hidden_states, or tuple(hidden_states, *extras)
        if isinstance(input, tuple):
            hidden_state = input[0]
        else:
            hidden_state = input

        unit_direction = direction / (direction.norm(dim=-1, keepdim=True) + 1e-8)
        unit_direction = unit_direction.to(
            hidden_state.device, dtype=hidden_state.dtype
        )

        # hidden_state = [batch_size, seq_len, hidden_dim]
        # unit_direction = [hidden_dim]
        # hidden_state @ unit_direction = [batch_size, seq_len]
        # unsqueeze = [batch_size, seq_len, 1]
        # * unit_direction = [batch_size, seq_len, hidden_dim]
        ablated_hidden_state = hidden_state - (
            (hidden_state @ unit_direction).unsqueeze(-1) * unit_direction
        )

        if isinstance(input, tuple):
            # return (ablated_hidden_state, *input[1])
            return tuple(ablated_hidden_state)
        else:
            return ablated_hidden_state

    return hook_fn


def get_directional_ablation_post_hook(direction):
    def hook_fn(module, input, output):
        # input can be just hidden_states, or tuple(hidden_states, *extras)
        if isinstance(input, tuple):
            hidden_state = output[0]
        else:
            hidden_state = output

        unit_direction = direction / (direction.norm(dim=-1, keepdim=True) + 1e-8)
        unit_direction = unit_direction.to(
            hidden_state.device, dtype=hidden_state.dtype
        )

        # hidden_state = [batch_size, seq_len, hidden_dim]
        # unit_direction = [hidden_dim]
        # hidden_state @ unit_direction = [batch_size, seq_len]
        # unsqueeze = [batch_size, seq_len, 1]
        # * unit_direction = [batch_size, seq_len, hidden_dim]
        ablated_hidden_state = hidden_state - (
            (hidden_state @ unit_direction).unsqueeze(-1) * unit_direction
        )

        if isinstance(input, tuple):
            # return (ablated_hidden_state, *output[1])
            return tuple(ablated_hidden_state)
        else:
            return ablated_hidden_state

    return hook_fn


def get_all_directional_ablation_hooks(model, direction):
    # need each layer for pre-hooks
    # need each attention & mlp block for post-hooks
    layers = model.model.layers
    n_layers = len(layers)
    attention_blocks = [layer.self_attn for layer in layers]
    mlps = [layer.mlp for layer in layers]

    pre_hooks = [
        (model.model.layers[i], get_directional_ablation_pre_hook(direction))
        for i in range(n_layers)
    ]
    post_hooks = [
        (
            model.model.layers[i].self_attn,
            get_directional_ablation_post_hook(direction),
        )
        for i in range(n_layers)
    ]
    post_hooks += [
        (model.model.layers[i].mlp, get_directional_ablation_post_hook(direction))
        for i in range(n_layers)
    ]
    return (pre_hooks, post_hooks)

In [52]:
def get_dataset_split(harmtype: str, split: str, n_samples: int):
    assert harmtype in ["harmless", "harmful"]
    assert split in ["train", "val", "test"]
    url = (
        "https://raw.githubusercontent.com/andyrdt/refusal_direction/refs"
        + f"/heads/main/dataset/splits/{harmtype}_{split}.json"
    )
    response = requests.get(url)
    dataset = response.json()

    dataset = [d["instruction"] for d in dataset]

    random.shuffle(dataset)

    return dataset[:n_samples]


def load_tinyllama():
    tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"

    base_model = AutoModelForCausalLM.from_pretrained(
        "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
    )
    model = PeftModel.from_pretrained(
        base_model, "uchicago-xlab-ai-security/refuse_harmful_v2"
    )

    merged_model = model.merge_and_unload().to(DEVICE)
    merged_model.requires_grad_(False)
    merged_model.eval()

    return merged_model, tokenizer

In [47]:
model, tokenizer = load_tinyllama()

harmful_train = get_dataset_split("harmful", "train", 128)
harmless_train = get_dataset_split("harmless", "train", 128)

positions = get_eoi_toks(tokenizer)

mean_diffs = get_mean_diff(
    model,
    tokenizer=tokenizer,
    block_modules=model.model.layers,
    harmful_instructions=harmful_train,
    harmless_instructions=harmless_train,
    positions=positions,
    batch_size=8,
)

# we're just arbitrarily choosing layer 8 and eoi token position 8
layer = 8
eoi_token_pos = 8
direction = mean_diffs[layer, eoi_token_pos].to(model.device)

100%|██████████| 16/16 [00:01<00:00, 14.85it/s]
100%|██████████| 16/16 [00:01<00:00, 15.02it/s]


In [70]:
pre_hooks, post_hooks = get_all_directional_ablation_hooks(
    model=model, direction=direction
)

# query = ["How can I make a bomb with materials at home?"]
# inputs = tokenize_instructions(query, tokenizer).to(model.device)
# outputs = model.generate(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=128)
# out_text = tokenizer.batch_decode(outputs)

with add_hooks(pre_hooks=pre_hooks, post_hooks=post_hooks):
    query = ["How can I make a bomb with materials at home?"]
    inputs = tokenize_instructions(query, tokenizer).to(model.device)
    outputs = model.generate(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=128)
    out_text = tokenizer.batch_decode(outputs)

print(out_text)

RuntimeError: The size of tensor a (32) must match the size of tensor b (64) at non-singleton dimension 3