### Install and import dependencies

In [None]:
%%capture
%pip install transformers torch einops transformer_lens plotly circuitsvis numpy

In [None]:
import gc
import os
import torch
import numpy as np
import einops
import transformer_lens
import functools


from transformers import AutoTokenizer, AutoModelForCausalLM
from transformer_lens import ActivationCache
from transformer_lens import utils as tl_utils
from transformer_lens.hook_points import HookPoint
from torch import Tensor
from jaxtyping import Int, Float
from typing import Tuple, List

### Hugging Face credentials

Add a `.env` file at the root of this repo with the following format (see `.env.example`):
```
HF_USERNAME=bob
HF_TOKEN=token123
```

In [None]:
# read HuggingFace credentials from .env file
with open('../../.env', 'r') as file:
  for line in file:
    key, value = line.strip().split('=', 1)
    os.environ[key] = value

### Load model

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name_path = "meta-llama/Llama-2-7b-chat-hf"

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    model_name_path,
    use_auth_token=os.environ["HF_TOKEN"],
    low_cpu_mem_usage=True,
    use_cache=False,
)

tokenizer = AutoTokenizer.from_pretrained(
    model_name_path,
    use_auth_token=os.environ["HF_TOKEN"],
    use_fast=False
)

tokenizer.pad_token = tokenizer.unk_token
tokenizer.padding_side = 'left'

In [None]:
from transformer_lens import HookedTransformer, HookedTransformerConfig

tl_model = HookedTransformer.from_pretrained(
    model_name_path,
    hf_model=model,
    device='cpu',
    fold_ln=False,
    center_writing_weights=False,
    center_unembed=False,
    tokenizer=tokenizer
).to(device)

In [None]:
torch.set_grad_enabled(False)

### Define utils to help with prompting and generation

In [None]:
def instruction_to_prompt(instruction, system_prompt="", model_output="") -> str:
    """
    Converts an instruction to a prompt string structured for Llama2-chat.
    Note that, unless model_output is supplied, the prompt will (intentionally) end with a space.
    See details of Llama2-chat prompt structure here: here https://huggingface.co/blog/llama2#how-to-prompt-llama-2
    """

    B_INST, E_INST = "[INST]", "[/INST]"
    B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

    if len(system_prompt) == 0:
        dialog_content = instruction.strip()
    else:
        dialog_content = B_SYS + system_prompt.strip() + E_SYS + instruction.strip()
    prompt = f"{B_INST} {dialog_content}{E_INST} {model_output.strip()}"
    return prompt

def instruction_to_prompt_toks(tokenizer, instruction, system_prompt="", model_output="") -> Int[Tensor, "seq_len"]:

    prompt_toks = tokenizer.encode(
        instruction_to_prompt(instruction, system_prompt, model_output)
    )
    return prompt_toks

def tokenize_and_batch(instructions, system_prompt="", model_outputs=None, batch_size=16):

    if model_outputs:
        assert(len(instructions) == len(model_outputs))
        prompts = [instruction_to_prompt(instruction, system_prompt, model_output) for (instruction, model_output) in zip(instructions, model_outputs)]
    else:
        prompts = [instruction_to_prompt(instruction, system_prompt) for instruction in instructions]

    num_batches = len(instructions) // batch_size + int(len(instructions) % batch_size != 0)
    batches = []

    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = min(len(prompts), start_idx + batch_size)

        batches.append(
            tokenizer(
                prompts[start_idx:end_idx],
                padding=True,
                truncation=False,
                return_tensors="pt"
            ).input_ids
        )

    return batches

def generate_from_instructions(instructions, max_new_tokens=64, temperature=0.0, batch_size=16):

    batches = tokenize_and_batch(instructions, batch_size=batch_size)

    for batch_idx, batch in enumerate(batches):
        output_ids = tl_model.generate(batch.to(device), max_new_tokens=max_new_tokens, temperature=temperature)
        for answer_idx, answer in enumerate(tokenizer.batch_decode(output_ids)):
            print(f"\nGeneration #{batch_idx * batch_size + answer_idx}:\n{answer}")

In [None]:
generate_from_instructions(
    [
        "Tell me how to make a pie.",
        "Tell me how to make a really big pie.",
        "Tell me how to make a bomb.",
        "Tell me how to make a really big bomb.",
    ],
    batch_size=2,
    max_new_tokens=128
)

### Activation patching

#### Activation patching - text generation

In [None]:
harmful_prompt_toks  = tokenize_and_batch(["Tell me how to make a bomb."])
harmless_prompt_toks = tokenize_and_batch(["Tell me how to make a pie."])

In [None]:
harmful_logits, harmful_cache = tl_model.run_with_cache(harmful_prompt_toks[0])
harmless_logits, harmless_cache = tl_model.run_with_cache(harmless_prompt_toks[0])

In [None]:
bomb_tok_pos = tl_model.to_str_tokens(harmful_prompt_toks[0][0]).index("bomb")
pie_tok_pos  = tl_model.to_str_tokens(harmless_prompt_toks[0][0]).index("pie")
assert(bomb_tok_pos == pie_tok_pos)

object_tok_pos = bomb_tok_pos
print(f"Position of object token (e.g. 'bomb' or 'pie') is: {object_tok_pos}")

In [None]:
def activation_patching_hook(
    activation: Float[Tensor, "batch seq d_activation"],
    hook: HookPoint,
    pos: int,
    cache_to_patch_from: ActivationCache,
) -> Float[Tensor, "batch seq d_activation"]:

    activation[:, pos, :] = cache_to_patch_from[hook.name][0, pos, :]
    return activation

In [None]:
def generate_with_hooks(toks: Int[Tensor, "batch_size seq_len"], max_tokens_generated=64, fwd_hooks=[], include_prompt=False) -> str:
    assert toks.shape[0] == 1, "batch size must be 1"
    all_toks = torch.zeros((toks.shape[0], toks.shape[1] + max_tokens_generated), dtype=torch.long).to(device)
    all_toks[:, :toks.shape[1]] = toks

    for i in range(max_tokens_generated):
        logits = tl_model.run_with_hooks(
            all_toks[:, :-max_tokens_generated + i],
            return_type="logits",
            fwd_hooks=fwd_hooks,
        )[0,-1] # get the first element in the batch, and the last logits tensor in the sequence

        # greedy sampling (temperature=0)
        next_token = logits.argmax()
        all_toks[0,-max_tokens_generated+i] = next_token

    if include_prompt:
        return tokenizer.decode(all_toks[0])
    else:
        return tokenizer.decode(all_toks[0, toks.shape[1]:])

In [None]:
print(f"Patching harmful→harmless (e.g. bomb→pie)\n")

for layer in range(tl_model.cfg.n_layers):

    print(f"\nLAYER={layer}, POS={object_tok_pos}")

    hook_fn = functools.partial(
        activation_patching_hook,
        pos=object_tok_pos,
        cache_to_patch_from=harmless_cache,
    )

    activation_patching_generation = generate_with_hooks(
        harmful_prompt_toks[0],
        max_tokens_generated=64,
        fwd_hooks=[(tl_utils.get_act_name("resid_post", layer), hook_fn)],
        include_prompt=False
    )

    print(repr(activation_patching_generation))

Observations:
- Layers 0-5: **answer** about how to make a **pie**
- Layer 6-14: **refusal** about how to make a **pie**
- Layer 15-31: **refusal** about how to make a **bomb**

It seems like a change from answer→refusal takes place around layer 6, and a change from pie→bomb takes place around layer 15.

#### Activation patching - top logits 

In [None]:
import plotly.graph_objs as go
from plotly.subplots import make_subplots
from torch import Tensor

def plot_topk_tokens(logits: Float[Tensor, "d_vocab"], tokenizer: AutoTokenizer, topk=5, title_suffix=""):
    assert (logits.ndim == 1), "only supports logits.ndim=1"

    topk_values, topk_ids = logits.topk(topk)
    topk_toks = [tokenizer.decode([tok.item()]) for tok in topk_ids]

    fig = go.Figure(
        go.Bar(
            y=topk_toks,
            x=topk_values.cpu().numpy(),
            orientation='h',
            marker=dict(color='skyblue')
        )
    )

    fig.update_layout(
        title_text=f'Top-{topk} logits {title_suffix}',
        xaxis_title='Value',
        yaxis_title='Token',
        height=50*topk, width=450,
        showlegend=False
    )

    fig.show()

In [None]:
pred_tok_set = set()

for layer in range(32):
    print(f"\nLAYER={layer}, POS={object_tok_pos}")

    hook_fn = functools.partial(
        activation_patching_hook,
        pos=object_tok_pos,
        cache_to_patch_from=harmless_cache,
    )

    activation_patching_logits = tl_model.run_with_hooks(
        harmful_prompt_toks[0],
        fwd_hooks=[(tl_utils.get_act_name("resid_post", layer), hook_fn)],
    )

    last_pos_logits = activation_patching_logits[0, -1, :]

    topk_values, topk_ids = last_pos_logits.topk(k=8)
    topk_toks = [tokenizer.decode([tok.item()]) for tok in topk_ids]

    print(f"topk_toks: {topk_toks}")

    for tok_id in topk_ids:
        pred_tok_set.add(tok_id.item())

    # plot_topk_tokens(
    #     last_pos_logits,
    #     tokenizer,
    #     topk=5,
    #     title_suffix=f"Patching bomb→pie at layer={layer}"
    # )

In [None]:
pred_tok_list = list(pred_tok_set)
pred_tok_str_list = [tokenizer.decode([tok], clean_up_tokenization_spaces=False) for tok in pred_tok_list]

print(f"Strings appearing in at least one topk result:\n{pred_tok_str_list}")
print(pred_tok_list)

In [None]:
pred_tok_logits = np.zeros((len(pred_tok_set), tl_model.cfg.n_layers))

for layer in range(32):

    hook_fn = functools.partial(
        activation_patching_hook,
        pos=object_tok_pos,
        cache_to_patch_from=harmless_cache,
    )

    activation_patching_logits = tl_model.run_with_hooks(
        harmful_prompt_toks[0],
        fwd_hooks=[(tl_utils.get_act_name("resid_post", layer), hook_fn)],
    )

    for i, top_tok in enumerate(pred_tok_list):
        pred_tok_logits[i, layer] = activation_patching_logits[0, -1, top_tok].item()


fig = go.Figure()

for i, tok in enumerate(pred_tok_list):
    fig.add_trace(
        go.Scatter(
            x=list(range(tl_model.cfg.n_layers)), 
            y=pred_tok_logits[i],
            mode='lines+markers',
            name=f'{pred_tok_str_list[i]}'
        )
    )

fig.update_layout(
    title='Logits for common tokens vs patching layer',
    xaxis=dict(title='Patching layer'),
    yaxis=dict(title='Logit value')
)

fig.show()

A token with good signal will have very different logit values for early patching (close to the `pie` input) and later patching (close to the `bomb` input).

In [None]:
pred_tok_logit_difs = pred_tok_logits[:, -1] - pred_tok_logits[:, 0]

print(f"\nToks with positive dif (reflective of refusal)")
refusal_tok_difs, refusal_tok_idxs = torch.tensor(pred_tok_logit_difs).topk(5)
for i in refusal_tok_idxs:
    print(f"{pred_tok_str_list[i]} ({pred_tok_list[i]}): {pred_tok_logit_difs[i]:+0.2f}")

print(f"\nToks with negative dif (reflective of answer)")
answer_tok_difs, answer_tok_idxs = torch.tensor(-pred_tok_logit_difs).topk(5)
for i in answer_tok_idxs:
    print(f"{pred_tok_str_list[i]} ({pred_tok_list[i]}): {pred_tok_logit_difs[i]:+0.2f}")

fig = go.Figure()

for i, tok in enumerate(pred_tok_list):
    if i not in refusal_tok_idxs and i not in answer_tok_idxs:
        continue

    fig.add_trace(
        go.Scatter(
            x=list(range(tl_model.cfg.n_layers)), 
            y=pred_tok_logits[i],
            mode='lines+markers',
            name=f'{pred_tok_str_list[i]}'
        )
    )

fig.update_layout(
    title='Logits for common tokens vs patching layer, polarizing tokens only',
    xaxis=dict(title='Patching layer'),
    yaxis=dict(title='Logit value')
)