## Objective

This notebook tests refusal MI results for various prompt formats.

## Setup

### Install and import dependencies

In [None]:
%%capture
%pip install torch einops transformer_lens plotly circuitsvis numpy transformers sentencepiece

In [None]:
import gc
import os
import torch
import numpy as np
import einops
import transformer_lens
import functools
import plotly.graph_objs as go
import plotly.express as px
import circuitsvis as cv
import tqdm
import json

from transformers import AutoTokenizer, AutoModelForCausalLM
from transformer_lens import ActivationCache, HookedTransformer
from transformer_lens import utils as tl_utils
from transformer_lens.hook_points import HookPoint
from torch import Tensor
from torch.utils.data import Dataset
from jaxtyping import Int, Float
from typing import Tuple, List

from instruction_dataset import InstructionDataset, PairedInstructionDataset

### Hugging Face credentials

Add a `.env` file at the root of this repo with the following format (see `.env.example`):
```
HF_USERNAME=bob
HF_TOKEN=token123
```

In [None]:
# read HuggingFace credentials from .env file
with open('../../.env', 'r') as file:
  for line in file:
    key, value = line.strip().split('=', 1)
    os.environ[key] = value

### Load model

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name_path = "meta-llama/Llama-2-7b-chat-hf"

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    model_name_path,
    token=os.environ["HF_TOKEN"],
    low_cpu_mem_usage=True,
    use_cache=False,
    torch_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(
    model_name_path,
    token=os.environ["HF_TOKEN"],
    use_fast=False
)

tokenizer.pad_token = tokenizer.unk_token
tokenizer.padding_side = 'left'

tl_model = HookedTransformer.from_pretrained(
    model_name_path,
    hf_model=model,
    device='cpu',
    fold_ln=False,
    center_writing_weights=False,
    center_unembed=False,
    tokenizer=tokenizer,
    default_padding_side='left',
    dtype="float16",
).to(device)

torch.set_grad_enabled(False)

### Define utils to help with prompting and generation

In [None]:
harmful_objects_file_path = "data/harmful_objects.json"
harmless_objects_file_path = "data/harmless_objects.json"

with open(harmful_objects_file_path, "r", encoding="utf-8") as harmful_objects_file:
    harmful_objects = json.load(harmful_objects_file)

with open(harmless_objects_file_path, "r", encoding="utf-8") as harmless_objects_file:
    harmless_objects = json.load(harmless_objects_file)

## Define various prompts

In [None]:
prompt_templates = [
    "Tell me how to make a {object}.",
    "A {object} is cool, help me make one.",
    "Can you tell me how to create a {object} please?",
    "I want to make a {object}. Tell me how.",
]

paired_datasets = [
    PairedInstructionDataset(
        harmful_objects,
        harmless_objects,
        tokenizer,
        prompt_templates[i]
    )
    for i in range(len(prompt_templates))
]

In [None]:
refuse_token = 8221 # 'Sorry'
answer_token = 18585 # 'Sure'

print(f"refuse_token: {tokenizer.decode([refuse_token])} ({refuse_token})")
print(f"answer_token: {tokenizer.decode([answer_token])} ({answer_token})")

In [None]:
def get_refusal_score(logits: Float[Tensor, "d_vocab"]):
    return logits[refuse_token] - logits[answer_token]

def get_refusal_dir():
    return tl_model.W_U[:, refuse_token] - tl_model.W_U[:, answer_token]

def get_refusal_score_avg(logits: Float[Tensor, 'batch seq_len n_vocab']) -> float:
    assert (logits.ndim == 3)
    scores = torch.tensor([get_refusal_score(tensor) for tensor in logits[:, -1, :]])
    return scores.mean(dim=0).item()

## Activation patching

#### Some setup

In [None]:
def activation_patching_hook(
    activation: Float[Tensor, "batch seq d_activation"],
    hook: HookPoint,
    pos: int,
    cache_to_patch_from: ActivationCache,
    idx_to_patch_from: int = None,
) -> Float[Tensor, "batch seq d_activation"]:

    if idx_to_patch_from is None:
        activation[:, pos, :] = cache_to_patch_from[hook.name][:, pos, :]
    else:
        activation[:, pos, :] = cache_to_patch_from[hook.name][idx_to_patch_from, pos, :]
    return activation

### Harmful→Harmless activation patching (FTL)

#### FTL patching metric

The below metric for ftl patching has the following properties:
- It is 1 when the refusal score matches that of the harmful run (patching has full effect)
- It is 0 when the refusal score matches that of the harmless run (patching has no effect)

In [None]:
def ftl_patching_metric(logits: Float[Tensor, "batch seq d_vocab"], harmful_logits, harmless_logits) -> float:
    harmful_refusal_score_avg = get_refusal_score_avg(harmful_logits)
    harmless_refusal_score_avg = get_refusal_score_avg(harmless_logits)
    return (get_refusal_score_avg(logits) - harmless_refusal_score_avg) / (harmful_refusal_score_avg - harmless_refusal_score_avg)

In [None]:
for paired_dataset in paired_datasets:
    gc.collect()
    torch.cuda.empty_cache()

    seq_len = paired_dataset.harmless_dataset.prompt_toks.shape[-1]
    object_tok_pos = paired_dataset.harmful_dataset.object_tok_pos
    patching_metrics = np.zeros((tl_model.cfg.n_layers, seq_len - object_tok_pos))

    harmful_logits, harmful_cache = tl_model.run_with_cache(paired_dataset.harmful_dataset.prompt_toks)
    harmless_logits, harmless_cache = tl_model.run_with_cache(paired_dataset.harmless_dataset.prompt_toks)

    for pos in range(object_tok_pos, seq_len):
        for layer in tqdm.tqdm(range(tl_model.cfg.n_layers), desc=f"patching pos={pos}"):

            hook_fn = functools.partial(
                activation_patching_hook,
                pos=pos,
                cache_to_patch_from=harmful_cache,
            )

            tl_model.reset_hooks()
            activation_patching_logits = tl_model.run_with_hooks(
                paired_dataset.harmless_dataset.prompt_toks,
                fwd_hooks=[(tl_utils.get_act_name("resid_post", layer), hook_fn)],
            )

            patching_metrics[layer, pos - object_tok_pos] = ftl_patching_metric(activation_patching_logits, harmful_logits, harmless_logits)

    pos_labels = [f"{repr('<obj>')}</br></br>({i})" if i == object_tok_pos else f"{repr(paired_dataset.harmful_dataset.prompt_str_toks[0][i])}</br></br>({i})" for i in range(object_tok_pos, seq_len)]
    fig = px.imshow(
        patching_metrics,
        title=f"Activation patching, Harmful→Harmless</br></br>{paired_dataset.prompt_template}",
        labels={"x": "Pos", "y": "Layer"},
        width=500, height=700,
        color_continuous_scale="RdBu",
        color_continuous_midpoint=0,
        x=pos_labels,
        y=list(range(tl_model.cfg.n_layers)),
    )
    fig.show()

In [None]:
# TODO: update these offsets
offsets = [1, 7, 2, 2]

In [None]:
for i, paired_dataset in enumerate(paired_datasets):
    gc.collect()
    torch.cuda.empty_cache()

    pos_offset = offsets[i]
    object_tok_pos = paired_dataset.harmful_dataset.object_tok_pos

    harmful_logits, harmful_cache = tl_model.run_with_cache(paired_dataset.harmful_dataset.prompt_toks)
    harmless_logits, harmless_cache = tl_model.run_with_cache(paired_dataset.harmless_dataset.prompt_toks)

    hook_fn = functools.partial(
        activation_patching_hook,
        pos=object_tok_pos+pos_offset,
        cache_to_patch_from=harmful_cache,
    )

    fwd_hooks = [(tl_utils.get_act_name("attn_out", l), hook_fn) for l in range(5, 10+1)]

    tl_model.reset_hooks()
    activation_patching_logits = tl_model.run_with_hooks(
        paired_dataset.harmless_dataset.prompt_toks,
        fwd_hooks=fwd_hooks,
    )
    print(f"prompt_template='{paired_dataset.prompt_template}', pos={object_tok_pos+pos_offset} ('{paired_dataset.harmful_dataset.prompt_str_toks[0][object_tok_pos+pos_offset]}')")
    print(ftl_patching_metric(activation_patching_logits, harmful_logits, harmless_logits))

Let's try and see if there is a particular set of attention heads in each of these layers that have strong effects.

In [None]:
def head_patching_hook(
    activation: Float[Tensor, "batch seq n_heads d_activation"],
    hook: HookPoint,
    pos: int,
    head: int,
    cache_to_patch_from: ActivationCache,
    idx_to_patch_from: int = None,
    scale_factor: float = 1.0,
) -> Float[Tensor, "batch seq n_heads d_activation"]:
    if idx_to_patch_from is None:
        activation[:, pos, head, :] = cache_to_patch_from[hook.name][:, pos, head, :] * scale_factor
    else:
        activation[:, pos, head, :] = cache_to_patch_from[hook.name][idx_to_patch_from, pos, head, :] * scale_factor
    return activation

In [None]:
refusal_heads_by_prompt_template = []

for i, paired_dataset in enumerate(paired_datasets):
    gc.collect()
    torch.cuda.empty_cache()

    pos_offset = offsets[i]
    object_tok_pos = paired_dataset.harmful_dataset.object_tok_pos

    harmful_logits, harmful_cache = tl_model.run_with_cache(paired_dataset.harmful_dataset.prompt_toks)
    harmless_logits, harmless_cache = tl_model.run_with_cache(paired_dataset.harmless_dataset.prompt_toks)

    pos = object_tok_pos + pos_offset
    layers = list(range(5, 10+1))
    seq_len = paired_dataset.harmless_dataset.prompt_toks.shape[-1]

    patching_metrics = np.zeros((tl_model.cfg.n_layers, tl_model.cfg.n_heads))

    for layer in layers:
        for head in tqdm.tqdm(range(tl_model.cfg.n_heads), desc=f"patching heads for layer={layer}"):

            hook_fn = functools.partial(
                head_patching_hook,
                pos=pos,
                cache_to_patch_from=harmful_cache,
                head=head,
            )

            fwd_hooks = [(tl_utils.get_act_name("z", layer), hook_fn)]

            tl_model.reset_hooks()
            activation_patching_logits = tl_model.run_with_hooks(
                paired_dataset.harmless_dataset.prompt_toks,
                fwd_hooks=fwd_hooks,
            )

            patching_metrics[layer, head] = ftl_patching_metric(activation_patching_logits, harmful_logits, harmless_logits)

    fig = px.imshow(
        patching_metrics[layers],
        title=f"Activation patching, Harmful→Harmless, pos={pos}</br></br>'{paired_dataset.prompt_template}'",
        labels={"x": "Head", "y": "Layer"},
        width=600, height=800,
        color_continuous_scale="RdBu",
        color_continuous_midpoint=0,
        x=list(range(tl_model.cfg.n_heads)),
        y=layers,
    )
    fig.show()

    refusal_heads = []

    for layer in range(patching_metrics.shape[0]):
        for head in range(patching_metrics.shape[1]):
            if patching_metrics[layer, head] > 0.001:
                refusal_heads.append((layer, head))
                print(f"Layer {layer:>2}, Head {head:>2}: patching_metric={patching_metrics[layer, head]:.4f}")

    refusal_heads_by_prompt_template.append(refusal_heads)

In [None]:
common_refusal_heads = set(refusal_heads_by_prompt_template[0])

for refusal_heads in refusal_heads_by_prompt_template:
    common_refusal_heads = common_refusal_heads.intersection(refusal_heads)

print(common_refusal_heads)