# Demo of bypassing refusal

from https://colab.research.google.com/drive/1a-aQvKC9avdZpdyBn4jgRQFObTPy1JZw?usp=sharing#scrollTo=j7hOtw7UHXdD

home: https://gist.github.com/wassname/42aba7168bb83e278fcfea87e70fa3af

> This notebook demonstrates oaur method for bypassing refusal, leveraging the insight that refusal is mediated by a 1-dimensional subspace.

This has been rewritten to use baukit instead of transformerlens

> To extract the "refusal direction," we use just 32 harmful instructions from [AdvBench](https://github.com/llm-attacks/llm-attacks/blob/main/data/advbench/harmful_behaviors.csv) and 32 harmless instructions from [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca).

It will still warn you and lecture you (as this direction has not been erased), but it will follow instructions.

Only use this if you can take responsibility for your own actions and emotions while using it.

> For anyone who is enjoying increasing their knowledge of this field, check out these intros:
- A primer on the internals of transformers: https://arxiv.org/abs/2405.00208
- Machine unlearning: https://ai.stanford.edu/~kzliu/blog/unlearning
- The original post that this script is based on https://www.lesswrong.com/posts/jGuXSZgv6qfdhMCuJ/refusal-in-llms-is-mediated-by-a-single-direction#

To understand why many people (including me) are worried about misalignment of ASI (not this small model) see this intro https://aisafetyfundamentals.com/blog/alignment-introduction/. There are [many](https://www.eleuther.ai/) [orgs](https://optimists.ai/) that are working on this who support open sourcing! We want the good ending, not the bad one, join us. 

## Setup

In [1]:
import torch
import functools, collections
import einops
import requests
import pandas as pd
from IPython.display import display, HTML
import io
import textwrap
import gc
from pathlib import Path
from baukit.nethook import get_module
from baukit import TraceDict

from datasets import load_dataset
from sklearn.model_selection import train_test_split
# from tqdm import tqdm
from torch import Tensor
from typing import List, Callable, Tuple, Dict, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer
from jaxtyping import Float, Int
from colorama import Fore

### Load model

In [2]:
# We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x78a262777b50>

In [3]:
MODEL_PATH = "NousResearch/Meta-Llama-3-8B-Instruct".lower()
verbose = True
batch_size = 4

N_INST_TEST = 32
N_INST_TRAIN = 64  # 32 how many train examples to use
max_new_tokens = 64  # 128

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side="left")
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    attn_implementation="eager",
).eval()

DEVICE = model.device

model.safetensors.index.json:   0%|          | 0.00/16.3k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

In [None]:
# here we read the output of each block to get the resid_post or the output of each layer.
# TODO choose the best range of layers
layers = list(range(2, len(model.model.layers)))
layers_to_read = [f"model.layers.{l}" for l in layers]
layers_to_read

### Benchmark

In [None]:
# Directly taken from https://huggingface.co/spaces/evaluate-measurement/perplexity/blob/main/perplexity.py
# TODO replace with a strided version https://github.com/huggingface/transformers/issues/9648#issuecomment-812981524
import numpy as np
import torch
from torch.nn import CrossEntropyLoss
from tqdm.auto import tqdm

@torch.no_grad()
def perplexity2(predictions, model, tokenizer, batch_size: int = 16, max_length=64, add_start_token=True):
    device = model.device

    assert tokenizer.pad_token is not None, "Tokenizer must have a pad token"

    encodings = tokenizer(
        predictions,
        add_special_tokens=False,
        padding=True,
        truncation=True if max_length else False,
        max_length=max_length,
        return_tensors="pt",
        return_attention_mask=True,
    ).to(device)

    encoded_texts = encodings["input_ids"]
    attn_masks = encodings["attention_mask"]

    ppls = []
    loss_fct = CrossEntropyLoss(reduction="none")

    for start_index in tqdm(range(0, len(encoded_texts), batch_size)):
        end_index = min(start_index + batch_size, len(encoded_texts))
        encoded_batch = encoded_texts[start_index:end_index]
        attn_mask = attn_masks[start_index:end_index]

        if add_start_token:
            bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device)
            encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1)
            attn_mask = torch.cat(
                [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1
            )

        labels = encoded_batch

        with torch.no_grad():
            out_logits = model(encoded_batch, attention_mask=attn_mask).logits

        shift_logits = out_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        shift_attention_mask_batch = attn_mask[..., 1:].contiguous()

        perplexity_batch = torch.exp(
            (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1)
            / shift_attention_mask_batch.sum(1)
        )

        ppls += perplexity_batch.tolist()

    return np.mean(ppls)

In [None]:
perplexity_results = {}

input_texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"]
input_texts = [s for s in input_texts[:1000] if s!='']

In [None]:
# input_texts

In [None]:
def eval_pplx(model, tokenizer, model_name):
    s = perplexity2(input_texts, model, tokenizer, batch_size=batch_size, max_length=max_new_tokens)
    perplexity_results[model_name] = s
    print(f"mean_perplexity: {s:2.2f} for model=`{model_name}`")
    # df = pd.DataFrame(perplexity_results.items(), columns=["model", "perplexity"]).to_csv("perplexity_results.csv", index=False)
    # display(df)
    return s

In [None]:
eval_pplx(model, tokenizer, model_name='base')

### Load harmful / harmless datasets

In [None]:
def get_harmful_instructions():
    url = "https://raw.githubusercontent.com/llm-attacks/llm-attacks/main/data/advbench/harmful_behaviors.csv"
    response = requests.get(url)

    dataset = pd.read_csv(io.StringIO(response.content.decode("utf-8")))
    instructions = dataset["goal"].tolist()

    train, test = train_test_split(instructions, test_size=0.2, random_state=42)
    return train, test


# def harmful_instructions2():
#     hf_path = 'unalignment/toxic-dpo-v0.1'
#     dataset = load_dataset(hf_path)

#     # filter for instructions that do not have inputs
#     instructions = []
#     for i in range(len(dataset['train'])):
#         instructions.append(dataset['train'][i]['prompt'])

#     train, test = train_test_split(instructions, test_size=0.2, random_state=42)
#     return train, test


def get_harmless_instructions():
    hf_path = "tatsu-lab/alpaca"
    dataset = load_dataset(hf_path)

    # filter for instructions that do not have inputs
    instructions = []
    for i in range(len(dataset["train"])):
        if dataset["train"][i]["input"].strip() == "":
            instructions.append(dataset["train"][i]["instruction"])

    train, test = train_test_split(instructions, test_size=0.2, random_state=42)
    return train, test

In [None]:
harmful_inst_train, harmful_inst_test = get_harmful_instructions()
# harmful_inst_train2, harmful_inst_test2 = harmful_instructions2()
harmless_inst_train, harmless_inst_test = get_harmless_instructions()

In [None]:
print("Harmful instructions:")
for i in range(4):
    print(f"\t{repr(harmful_inst_train[i])}")
# print("Harmful instructions2:")
# for i in range(4):
#     print(f"\t{repr(harmful_inst_train2[i])}")
print("Harmless instructions:")
for i in range(4):
    print(f"\t{repr(harmless_inst_train[i])}")

### Tokenization utils

In [None]:
def tokenize_instructions_chat(
    tokenizer: AutoTokenizer, instructions: List[str]
) -> Int[Tensor, "batch_size seq_len"]:
    chats = [[{"role": "user", "content": instruction}] for instruction in instructions]
    prompts = [
        tokenizer.apply_chat_template(c, tokenize=False, add_generation_prompt=True)
        for c in chats
    ]
    return tokenizer(prompts, padding=True, truncation=False, return_tensors="pt")


tokenize_instructions_fn = functools.partial(
    tokenize_instructions_chat, tokenizer=tokenizer
)

### Generation utils

In [None]:
@torch.no_grad()
def get_generations(
    instructions: List[str],
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    tokenize_instructions_fn: Callable[[List[str]], Int[Tensor, "batch_size seq_len"]],
    layer_names: List[str] = [],
    max_new_tokens: int = 64,
    batch_size: int = 4,
    edit_output: Callable[
        [Float[Tensor, "batch_size seq_len dim"], str],
        Float[Tensor, "batch_size seq_len dim"],
    ] = None,
) -> Tuple[Dict[str, Float[Tensor, "batch tokens dim"]], List[str]]:
    generations = []
    activations = collections.defaultdict(list)

    for i in tqdm(range(0, len(instructions), batch_size)):
        inputs = tokenize_instructions_fn(
            instructions=instructions[i : i + batch_size]
        ).to(DEVICE)

        # record activations from just the next token
        # docs for TraceDict here: https://github.com/davidbau/baukit/blob/main/baukit/nethook.py
        with TraceDict(
            model, layers=layer_names, edit_output=edit_output,
        ) as ret:
            model(**inputs)

        for layer_name in layer_names:
            act = ret[layer_name].output[0].cpu()
            activations[layer_name].append(act)

        generation = model.generate(**inputs, max_new_tokens=max_new_tokens)
        t = inputs.input_ids.shape[1]
        generation = generation[:, t:]
        generations.extend(generation)

    pos = -1  # just the last token
    activations = {
        k: torch.concatenate([vv[:, pos] for vv in v], dim=0).cpu()
        for k, v in activations.items()
    }
    generations = tokenizer.batch_decode(generations, skip_special_tokens=True)

    return activations, generations


# unit test
activations, generations = get_generations(
    instructions=harmful_inst_train[: batch_size * 2],
    model=model,
    tokenizer=tokenizer,
    layer_names=layers_to_read,
    tokenize_instructions_fn=tokenize_instructions_fn,
    max_new_tokens=6,
    batch_size=batch_size,
)
# print({k: v.shape for k, v in activations.items()})
generations

## Finding the "refusal direction"

In [None]:
def clear_mem():
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
# run model on harmful and harmless instructions, caching intermediate activations
harmless_cache, harmless_generation = get_generations(
    instructions=harmless_inst_train[:N_INST_TRAIN],
    model=model,
    tokenizer=tokenizer,
    layer_names=layers_to_read,
    tokenize_instructions_fn=tokenize_instructions_fn,
    max_new_tokens=1,
    batch_size=batch_size,
)
clear_mem()
harmful_cache, harmfull_generation = get_generations(
    instructions=harmful_inst_train[:N_INST_TRAIN],
    model=model,
    tokenizer=tokenizer,
    layer_names=layers_to_read,
    tokenize_instructions_fn=tokenize_instructions_fn,
    max_new_tokens=1,
    batch_size=batch_size,
)

In [None]:
harmfull_generation[:3]

In [None]:
harmless_generation[:3]

In [None]:
# harmless_cache.keys()

In [None]:
# print({k: v.norm().item() for k, v in harmless_cache.items()})
# print({k: v.shape for k, v in harmless_cache.items()})

In [None]:
# refusal_directions = {
#     ln: (harmful_cache[ln].mean(0) - harmless_cache[ln].mean(0))/ (1.+harmful_cache[ln].mean(0)+harmless_cache[ln].mean(0) )for ln in layers_to_read
# }
# # print({k: v.norm().item() for k, v in refusal_directions.items()})
# d = {k.split('.')[-1]: v.norm().item() for k, v in refusal_directions.items()}

# pd.DataFrame(d.items(), columns=["layer", "norm"]).set_index("layer").plot()

In [None]:
refusal_directions = {
    ln: harmful_cache[ln].mean(0) - harmless_cache[ln].mean(0) for ln in layers_to_read
}
refusal_directions = {k: v / v.norm() for k, v in refusal_directions.items()}
refusal_directions = {k: v.to(DEVICE) for k, v in refusal_directions.items()}
# print({k:v.shape for k,v in refusal_directions.items()})

In [None]:
# # edit all layers
read2edit_layer_map = {
    f"model.layers.{l}.self_attn.o_proj": f"model.layers.{l}" for l in layers[1:]
}
read2edit_layer_map.update(
    {f"model.layers.{l}.mlp.down_proj": f"model.layers.{l}" for l in layers[1:]}
)
# read2edit_layer_map["model.embed_tokens"] = layers_to_read[0]
layers_to_edit = list(read2edit_layer_map.keys())
read2edit_layer_map

In [None]:
# clean up memory
# del harmful_cache, harmless_cache
clear_mem()

## Ablate "refusal direction" via inference-time intervention

Given a "refusal direction" $\widehat{r} \in \mathbb{R}^{d_{\text{model}}}$ with unit norm, we can ablate this direction from the model's activations $a_{l}$:
$${a}_{l}' \leftarrow a_l - (a_l \cdot \widehat{r}) \widehat{r}$$

By performing this ablation on all intermediate activations, we enforce that the model can never express this direction (or "feature").

In [None]:
@torch.no_grad()
def direction_ablation_hook(
    output: Float[Tensor, "... d_act"],
    layer: str,
    inputs,
    directions: Dict[str, Float[Tensor, "d_act"]],
):
    """edit layer output"""
    ln = read2edit_layer_map[layer]
    direction = directions[ln].to(output.device)
    proj = (
        einops.einsum(
            output, direction.view(-1, 1), "... d_act, d_act single -> ... single"
        )
        * direction
    )
    return output - proj


edit_output = functools.partial(direction_ablation_hook, directions=refusal_directions)

In [None]:
_, intervention_generations = get_generations(
    instructions=harmful_inst_test[:N_INST_TRAIN],
    model=model,
    tokenizer=tokenizer,
    layer_names=layers_to_edit,
    tokenize_instructions_fn=tokenize_instructions_fn,
    max_new_tokens=max_new_tokens,
    batch_size=batch_size,
    edit_output=edit_output,
)
clear_mem()
_, baseline_generations = get_generations(
    instructions=harmful_inst_test[:N_INST_TRAIN],
    model=model,
    tokenizer=tokenizer,
    tokenize_instructions_fn=tokenize_instructions_fn,
    max_new_tokens=max_new_tokens,
    batch_size=batch_size,
)

In [None]:
for i in range(N_INST_TEST):
    print(f"INSTRUCTION {i}: {repr(harmful_inst_test[i])}")
    print(Fore.GREEN + f"BASELINE COMPLETION:")
    print(
        textwrap.fill(
            repr(baseline_generations[i]),
            width=100,
            initial_indent="\t",
            subsequent_indent="\t",
        )
    )
    print(Fore.RED + f"INTERVENTION COMPLETION:")
    print(
        textwrap.fill(
            repr(intervention_generations[i]),
            width=100,
            initial_indent="\t",
            subsequent_indent="\t",
        )
    )
    print(Fore.RESET)

## Orthogonalize weights w.r.t. "refusal direction"

We can implement the intervention equivalently by directly orthogonalizing the weight matrices that write to the residual stream with respect to the refusal direction $\widehat{r}$:
$$W_{\text{out}}' \leftarrow W_{\text{out}} - \widehat{r}\widehat{r}^{\mathsf{T}} W_{\text{out}}$$

By orthogonalizing these weight matrices, we enforce that the model is unable to write direction $r$ to the residual stream at all!

In [None]:
def get_orthogonalized_matrix(
    matrix: Float[Tensor, "... d_model"], vec: Float[Tensor, "d_model"]
) -> Float[Tensor, "... d_model"]:
    proj = (
        einops.einsum(
            matrix, vec.view(-1, 1), "... d_model, d_model single -> ... single"
        )
        * vec
    )
    return matrix - proj

In [None]:
# get module from string...
for key in layers_to_edit:
    m = get_module(model, key)
    ln = read2edit_layer_map[key]
    refusal_dir = refusal_directions[ln].to(m.weight.device)
    if "mlp" in key:
        m.weight.data = get_orthogonalized_matrix(m.weight.T, refusal_dir).T
    else:
        m.weight.data = get_orthogonalized_matrix(m.weight, refusal_dir)

In [None]:
clear_mem()
_, orthogonalized_generations = get_generations(
    instructions=harmful_inst_test[:N_INST_TRAIN],
    model=model,
    tokenizer=tokenizer,
    tokenize_instructions_fn=tokenize_instructions_fn,
    max_new_tokens=max_new_tokens,
    batch_size=batch_size,
)

In [None]:
for i in range(N_INST_TEST):
    print(f"INSTRUCTION {i}: {repr(harmful_inst_test[i])}")
    print(Fore.GREEN + f"BASELINE COMPLETION:")
    print(
        textwrap.fill(
            repr(baseline_generations[i]),
            width=100,
            initial_indent="\t",
            subsequent_indent="\t",
        )
    )
    print(Fore.RED + f"INTERVENTION COMPLETION:")
    print(
        textwrap.fill(
            repr(intervention_generations[i]),
            width=100,
            initial_indent="\t",
            subsequent_indent="\t",
        )
    )
    print(Fore.MAGENTA + f"ORTHOGONALIZED COMPLETION:")
    print(
        textwrap.fill(
            repr(orthogonalized_generations[i]),
            width=100,
            initial_indent="\t",
            subsequent_indent="\t",
        )
    )
    print(Fore.RESET)

In [None]:
eval_pplx(model, tokenizer, model_name="orthogonalized")

In [None]:


df_ppx = pd.DataFrame(perplexity_results.items(), columns=["model", "perplexity"]).set_index("model")
# df_ppx.plot(kind="bar")
df_ppx.to_csv("../outputs/perplexity_results.csv", index=False)
display(df_ppx)

## Save

The transformer lens library does not have a save feature :(, so as a hack we are going to load the transformer verison, and then apply the patch to it.

In [None]:
# 1 / 0
# save model
model_name = Path(MODEL_PATH).stem.lower()
f = f"../outputs/{model_name}-extra_helpfull2"
print(f"saving to {f}")
model.save_pretrained(f)
tokenizer.save_pretrained(f)

# TODO

- [ ] measure perplexity and score before and after to see if it degrades