In [None]:
    # 1. Collect two datasets
    # 2. Generate candidate mean activations for each(for each token position, and for every layer, across each data point)
    # 3. Generate candidate vectors 1/dataset_size*(difference between harmful and harmless prompt for every layer that is num_hidden_layers, positions long, and dimensions of model, 4d tensor)
    # 4. choose the best vector through the algorithm (doesn't have high KL divergene)
    # 5. Ablate, add, and test


In [None]:
import torch
import contextlib
from typing import List, Dict, Tuple
from jaxtyping import Float
from torch import Tensor

from tqdm import tqdm
import requests

from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM



In [None]:
def load_tinyllama(base_model_path, fine_tuned_path):
    tokenizer = AutoTokenizer.from_pretrained(base_model_path)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"

    base_model = AutoModelForCausalLM.from_pretrained(base_model_path)
    model = PeftModel.from_pretrained(base_model, fine_tuned_path)

    model.requires_grad_(False) 
    return model, tokenizer



In [12]:
model, tokenizer = load_tinyllama("TinyLlama/TinyLlama-1.1B-Chat-v1.0", "uchicago-xlab-ai-security/refuse_harmful_v2")

In [15]:
block_modules = model.base_model.model.model.layers
n_layers = len(block_modules)
print(n_layers)

22


In [17]:
# template from https://ollama.com/library/tinyllama:1.1b/blobs/af0ddbdaaa26

SYSTEM_PROMPT = "You are a helpful and harmless assistant."
CHAT_TEMPLATE = (
    "<|system|>\n" + SYSTEM_PROMPT + "</s>\n"
    "<|user|>\n{instruction}</s>\n"
    "<|assistant|>"
)

eoi_toks = tokenizer.encode(CHAT_TEMPLATE.split("{instruction}")[-1], add_special_tokens=False)
positions = list(range(-len(eoi_toks), 0))   # all post-instruction tokens; or positions=[-1]


def tokenize_instructions_fn(instructions):
    prompts = [CHAT_TEMPLATE.format(instruction=s) for s in instructions]
    return tokenizer(prompts, padding=True, truncation=True, max_length=512, return_tensors="pt")


In [None]:
print(eoi_toks)
print(tokenizer.decode(eoi_toks))


[2, 29871, 13, 29966, 29989, 465, 22137, 29989, 29958]
</s> 
<|assistant|>


Ideally, first, we would filter the datasets to ensure that the helpful/harmless split corresponds to which prompts the model is actually refusing.  

In [None]:
def get_dataset_split(harmtype: str, split: str,):   
    assert harmtype in ['harmless', 'harmful']
    assert split in ['train', 'val', 'test']
    url = f"https://github.com/andyrdt/refusal_direction/blob/main/dataset/splits/{harmtype}_{split}.json"
    response = requests.get(url)
    dataset = response.json

    dataset = [d['instruction'] for d in dataset]

    return dataset

In [None]:
def tokenize_instructions(
    tokenizer: AutoTokenizer,
    instructions: List[str]
) -> Tensor:
    """Tokenize instructions using tokenizer's chat template"""
    
    conversations = [[{"role": "user", "content": instruction}] for instruction in instructions]
    
    prompts = [
        tokenizer.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
        for conv in conversations
    ]
    
    return tokenizer(prompts, padding=True, truncation=False, return_tensors="pt").input_ids

In [None]:
@contextlib.contextmanager
def add_hooks(module_forward_pre_hooks, module_forward_hooks):
    handles = []
    try:
        for mod, fn in module_forward_pre_hooks:
            handles.append(mod.register_forward_pre_hook(fn, with_kwargs=False))
        for mod, fn in module_forward_hooks:
            handles.append(mod.register_forward_hook(fn, with_kwargs=False))
        yield
    finally:
        for h in handles: h.remove()

def get_mean_activations_pre_hook(layer, cache, n_samples, positions):
    def hook_fn(module, inputs):
        h = inputs[0].to(cache.dtype)                # (batch, seq_length, d_model)
        cache[:, layer] += (h[:, positions, :].sum(dim=0) / n_samples)
    return hook_fn

def get_mean_activations(model, instructions, batch_size=32, positions=[-1]):
    n_pos = len(positions)
    d_model = model.config.hidden_size
    n_layers = model.config.num_hidden_layers
    n_samples = len(instructions)

    cache = torch.zeros((n_pos, n_layers, d_model), dtype=torch.float64, device=model.device)
    pre_hooks = [(block_modules[i], get_mean_activations_pre_hook(i, cache, n_samples, positions))
                 for i in range(n_layers)]

    with torch.inference_mode():
        for i in tqdm(range(0, n_samples, batch_size)):
            batch = tokenize_instructions_fn(instructions[i:i+batch_size])
            with add_hooks(pre_hooks, []):
                model(input_ids=batch.input_ids.to(model.device),
                      attention_mask=batch.attention_mask.to(model.device))
    return cache

def get_mean_diff(model, harmful_instructions, harmless_instructions, positions=[-1], batch_size=32):
    mean_harmful = get_mean_activations(model, harmful_instructions,  batch_size=batch_size, positions=positions)
    mean_harmless = get_mean_activations(model, harmless_instructions, batch_size=batch_size, positions=positions)


    return mean_harmful - mean_harmless


In [None]:
harmful_train = get_dataset_split("harmful", "train")
harmless_train = get_dataset_split("harmless", "train") 

mean_diffs = get_mean_diff(model, harmful_instructions=harmful_train, harmless_instructions=harmless_train, positions=positions)

assert mean_diffs.shape == (len(eoi_toks), model.config.num_hidden_layers, model.config.hidden_size)
assert not mean_diffs.isnan().any()

In [None]:
# TODO: note on how refusals are evaluated? Not exactly that important, but they use the fact that refusals have very similar/predictable first tokens and you can just test if the first token is "I"

To chose the best refusal direction out of all of the layers and token positions, the authors test each of the $I * L$ refusal vectors ($I$ being the number of post-instruction tokens and $L$ being the number of layers) and choose the vector that, when ablated, induces the most non-refusal on harmful activations satisfying these conditions:
1. Adding the vector should induce refusal on harmless prompts (the direction is sufficient to induce refusal)
2. Ablating the vector on harmless prompts should result in similar probability distributions for the next token, measured through KL divergence (the direction preserves model behavior / coherency)
3. The layer is less than $0.8L$, or the direction was taken from a hidden state first 80% of the model's layers. This is to prevent the direction selection process from choosing directions closer to the unembedding matrices, to choose a directon that represents some higher-level as opposed to token-level feature.

In [None]:
# TODO: filter for best layer/position, right now we just choose a layor in the middle 

# last position, tenth layer 

direction = mean_diffs[-1, 10].to(model.device)            # (d_model,)


In [None]:
def make_direction_ablation_post_hook(direction: torch.Tensor):
    base_dir = direction / (direction.norm() + 1e-8)

    def hook_fn(module, inputs, output):
        # output can be Tensor or a tuple (hidden_states, *extras)
        def project_out(h: torch.Tensor) -> torch.Tensor:
            dir_vec = base_dir.to(h.dtype).to(h.device)          # (d_model,)
            # h: (batch, seq, d_model); (h @ dir) -> (batch, seq)
            comp = (h @ dir_vec).unsqueeze(-1) * dir_vec         # (batch, seq, d_model)
            return h - comp

        if isinstance(output, tuple):
            h = output[0]
            return (project_out(h), *output[1:])
        else:
            return project_out(output)

    return hook_fn

@contextlib.contextmanager
def apply_post_hooks(modules, hook_fn_maker):
    handles = []
    try:
        for m in modules:
            handles.append(m.register_forward_hook(hook_fn_maker))
        yield
    finally:
        for h in handles:
            h.remove()

In [None]:
def generate_with_direction_ablation(
    model, tokenizer, prompts, direction, max_new_tokens=128, **gen_kwargs
):
    post_hook = make_direction_ablation_post_hook(direction)

    # Tokenize (use your chat template function if you have one)
    batch = tokenizer(
        prompts, return_tensors="pt", padding=True, truncation=True, max_length=512
    ).to(model.device)

    with torch.inference_mode():
        with apply_post_hooks(block_modules, post_hook):
            out_ids = model.generate(
                **batch,
                max_new_tokens=max_new_tokens,
                pad_token_id=tokenizer.eos_token_id,
                **gen_kwargs
            )
    return tokenizer.batch_decode(out_ids, skip_special_tokens=True)

In [None]:
harmful_test = [CHAT_TEMPLATE.format(instruction=s) for s in get_dataset_split("harmful", "test")]

texts = generate_with_direction_ablation(
    model, tokenizer, harmful_test, direction,
    max_new_tokens=128, do_sample=False  # or temperature=0.7, top_p=0.9, etc.
)
print(texts[0])

In [None]:
# TODO: activation addition

In [None]:
# TODO: evaluation/more examples of completions