# Attribution Patching on M0 to M3 models

The nnsight library tutorial [attribution_patching](https://nnsight.net/notebooks/tutorials/attribution_patching/) says:

Activation patching is a method to determine how model components influence model computations. It is time- and resource-intensive.
**Attribution patching** uses gradients to take a linear approximation to activation patching and can be done simultaneously in two forward and one backward pass. It is scalable to large models.

This notebook implements **Attribution patching**. It:
- Runs with GPT2/TinyStories/Qwen/Llama models with base/CS1/CS2/CS3 command sets.
- Was developed on Google Colab using an **A100** for Qwen and a **T4** for other models.
- Requires a GITHUB_TOKEN secret to access Martian TinySQL code repository.
- Requires a HF_TOKEN secret to access Martian HuggingFace repository.

# Part 0: Import libraries
Imports standard libraries. Do not read.

In [3]:
# https://nnsight.net/
!pip install -U nnsight

Collecting nnsight
  Using cached nnsight-0.3.7-py3-none-any.whl.metadata (15 kB)
Collecting python-socketio[client] (from nnsight)
  Using cached python_socketio-5.12.0-py3-none-any.whl.metadata (3.2 kB)
Collecting bidict>=0.21.0 (from python-socketio[client]->nnsight)
  Using cached bidict-0.23.1-py3-none-any.whl.metadata (8.7 kB)
Collecting python-engineio>=4.11.0 (from python-socketio[client]->nnsight)
  Using cached python_engineio-4.11.1-py3-none-any.whl.metadata (2.2 kB)
Collecting simple-websocket>=0.10.0 (from python-engineio>=4.11.0->python-socketio[client]->nnsight)
  Using cached simple_websocket-1.1.0-py3-none-any.whl.metadata (1.5 kB)
Collecting wsproto (from simple-websocket>=0.10.0->python-engineio>=4.11.0->python-socketio[client]->nnsight)
  Using cached wsproto-1.2.0-py3-none-any.whl.metadata (5.6 kB)
Using cached nnsight-0.3.7-py3-none-any.whl (3.5 MB)
Using cached bidict-0.23.1-py3-none-any.whl (32 kB)
Using cached python_engineio-4.11.1-py3-none-any.whl (59 kB)
Usi

In [4]:
from IPython.display import clear_output
import einops
import torch
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "colab"

from nnsight import LanguageModel

KeyboardInterrupt: 

In [None]:
import os
from google.colab import userdata

In [None]:
github_token = userdata.get("GITHUB_TOKEN")

# Install the private repository using the token
!pip install --upgrade git+https://{github_token}@github.com/withmartian/TinySQL.git

import TinySQL as qts

In [None]:
!pip install --upgrade git+https://github.com/PhilipQuirke/quanta_mech_interp.git
import QuantaMechInterp as qmi

# Select model and command set

In [None]:
model_num = 1                     # 0=GPT2, 1=TinyStories, 2=Qwen, 3=Llama, 4=Granite, 5=SmolLM
cs_num = 1                        # 0=BaseModel, 1=CS1, 2=CS2 or 3=CS3
feature_name = qts.ENGFIELDNAME   # One of EngTableName, EngFieldName, DefCreateTable, DefTableName, DefFieldName, DefFieldSeparator
use_novel_names = False           # If True, we corrupt using words not found in the clean prompt or create sql e.g. "little" or "hammer"

# Investigate m0 : nnsight tutorial using GPT2
Reproduces https://nnsight.net/notebooks/tutorials/activation_patching/



In [None]:
if model_num == 0:
    feature_name = ""
    model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
    clear_output()
    print(model)

In [None]:
answer_token_indices = None
if model_num == 0:
    prompts = [
        "When John and Mary went to the shops, John gave the bag to",
        "When John and Mary went to the shops, Mary gave the bag to",
        "When Tom and James went to the park, James gave the ball to",
        "When Tom and James went to the park, Tom gave the ball to",
        "When Dan and Sid went to the shops, Sid gave an apple to",
        "When Dan and Sid went to the shops, Dan gave an apple to",
        "After Martin and Amy went to the park, Amy gave a drink to",
        "After Martin and Amy went to the park, Martin gave a drink to",
    ]

    # Answers are each formatted as (correct, incorrect):
    answer_pairs = [
        (" Mary", " John"),
        (" John", " Mary"),
        (" Tom", " James"),
        (" James", " Tom"),
        (" Dan", " Sid"),
        (" Sid", " Dan"),
        (" Martin", " Amy"),
        (" Amy", " Martin"),
    ]

    # Tokenize clean and corrupted inputs:
    clean_tokens = model.tokenizer(prompts, return_tensors="pt")["input_ids"]
    # The associated corrupted input is the prompt after the current clean prompt
    # for even indices, or the prompt prior to the current clean prompt for odd indices
    corrupted_tokens = clean_tokens[
        [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]
    ]

    # Tokenize answer_pairs for each prompt:
    answer_token_indices = torch.tensor(
        [
            [model.tokenizer(answer_pairs[i][j])["input_ids"][0] for j in range(2)]
            for i in range(len(answer_pairs))
        ]
    )

In [None]:
def get_logit_diff(logits, answer_token_indices=answer_token_indices):
    logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    return (correct_logits - incorrect_logits).mean()

In [None]:
if model_num == 0:
    clean_logits = model.trace(clean_tokens, trace=False).logits.cpu()
    corrupted_logits = model.trace(corrupted_tokens, trace=False).logits.cpu()

    CLEAN_BASELINE = get_logit_diff(clean_logits, answer_token_indices).item()
    print(f"Clean logit diff: {CLEAN_BASELINE:.4f}")

    CORRUPTED_BASELINE = get_logit_diff(corrupted_logits, answer_token_indices).item()
    print(f"Corrupted logit diff: {CORRUPTED_BASELINE:.4f}")

In [None]:
def ioi_metric(
    logits,
    answer_token_indices=answer_token_indices,
):
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (
        CLEAN_BASELINE - CORRUPTED_BASELINE
    )

In [None]:
if model_num == 0:
    print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
    print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")

In [None]:
if model_num == 0:
    clean_out = []
    corrupted_out = []
    corrupted_grads = []

    with model.trace() as tracer:
        # Using nnsight's tracer.invoke context, we can batch the clean and the
        # corrupted runs into the same tracing context, allowing us to access
        # information generated within each of these runs within one forward pass

        with tracer.invoke(clean_tokens) as invoker_clean:
            # Gather each layer's attention
            for layer in model.transformer.h:
                # Get clean attention output for this layer
                # across all attention heads
                attn_out = layer.attn.c_proj.input
                clean_out.append(attn_out.save())

        with tracer.invoke(corrupted_tokens) as invoker_corrupted:
            # Gather each layer's attention and gradients
            for layer in model.transformer.h:
                # Get corrupted attention output for this layer
                # across all attention heads
                attn_out = layer.attn.c_proj.input
                corrupted_out.append(attn_out.save())
                # save corrupted gradients for attribution patching
                corrupted_grads.append(attn_out.grad.save())

            # Let's get the logits for the model's output
            # for the corrupted run
            logits = model.lm_head.output.save()

            # Our metric uses tensors saved on cpu, so we
            # need to move the logits to cpu.
            value = ioi_metric(logits.cpu())

            # We also need to run a backwards pass to
            # update gradient values
            value.backward()

In [None]:
if model_num == 0:
    N_HEADS = 12
    D_HEAD = 64

Results graphed in "Shared graph code" section below

# Investigate m1, m2 and m3 models

## Load model

In [None]:
if model_num > 0:
    model = qts.load_tinysql_model(model_num, cs_num, auth_token=userdata.get("HF_TOKEN"))
    clear_output()
    print(model)

    N_LAYERS, N_HEADS, D_MODEL, D_HEAD = qts.get_model_sizes(model_num, model)

    # Singleton QuantaTool "main" configuration class. qmi.AlgoConfig is derived from the chain qmi.UsefulConfig > qmi.ModelConfig
    cfg = qmi.AlgoConfig()
    cfg.main_model = model
    cfg.n_layers = N_LAYERS
    cfg.n_heads = N_HEADS
    cfg.d_model = D_MODEL
    cfg.d_head = D_HEAD
    cfg.file_config_prefix = ""
    cfg.set_seed(673023)

## Generate clean and corrupted data

In [None]:
if model_num > 0:
    generator = qts.CorruptFeatureTestGenerator(model_num=model_num, cs_num=cs_num, tokenizer=model.tokenizer, use_novel_names=use_novel_names)

    examples = generator.generate_feature_examples(feature_name, 10)

    examples[0].print_all()

In [None]:
if model_num > 0:
    prompts = [example.get_alpaca_prompt() for example in examples]

    # Generate answers as (correct, incorrect) pairs
    answer_pairs = [(example.clean_token_str, example.corrupt_token_str) for example in examples]

    # Tokenize clean inputs
    clean_tokens = model.tokenizer(prompts, return_tensors="pt", padding=True)["input_ids"]

    # Different models tokenize differently giving different indexes for the corrupted text
    answer_offset = generator.tokenize_answer_offset()
    first_field_token = model.tokenizer(f" {examples[0].clean_token_str}")["input_ids"][answer_offset]  # Get token for " size"
    corrupt_index = (clean_tokens[0] == first_field_token).nonzero()[0].item()  # Find its position
    # Add a sanity check to ensure we found the right token
    assert model.tokenizer.decode(clean_tokens[0, corrupt_index]) == f" {examples[0].clean_token_str}"

    # Validation helper function
    def validate_token_sequence(tokens, idx):
        """Validate that tokens before and after idx are same across sequences, but different at idx"""
        assert all(tokens[0, idx-1] == tokens[i, idx-1] for i in range(1, len(tokens)))
        assert all(tokens[0, idx] != tokens[i, idx] for i in range(1, len(tokens)))
        assert all(tokens[0, idx+1] == tokens[i, idx+1] for i in range(1, len(tokens)))

    # Validate token sequence
    # Crashes as get left padding. Improve qts to give constant length data
    # for prompt in prompts:
    #     print(prompt.replace('\n', ' '))
    # print(clean_tokens[:,:10])
    #validate_token_sequence(clean_tokens, corrupt_index)

    # Create corrupted tokens using circular rotation
    corrupted_tokens = clean_tokens.clone()
    for i in range(len(prompts)):
        next_idx = (i + 1) % len(prompts)
        corrupted_tokens[i, corrupt_index] = clean_tokens[next_idx, corrupt_index]

    # Tokenize answer_pairs
    answer_token_indices = torch.tensor([
        [model.tokenizer(pair[j])["input_ids"][answer_offset] for j in range(2)]
        for pair in answer_pairs
    ])

## Trace clean and corrupted (batched) examples

In [None]:
if model_num > 0:
    clean_logits = model.trace(clean_tokens, trace=False).logits.cpu()
    corrupted_logits = model.trace(corrupted_tokens, trace=False).logits.cpu()

    CLEAN_BASELINE = get_logit_diff(clean_logits, answer_token_indices).item()
    print(f"Clean logit diff: {CLEAN_BASELINE:.4f}")

    CORRUPTED_BASELINE = get_logit_diff(corrupted_logits, answer_token_indices).item()
    print(f"Corrupted logit diff: {CORRUPTED_BASELINE:.4f}")

    qts.free_memory() # Free up GPU and CPU memory

In [None]:
def sql_metric(
    logits,
    answer_token_indices=answer_token_indices,
):
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (
        CLEAN_BASELINE - CORRUPTED_BASELINE
    )

In [None]:
if model_num > 0:
    print(f"Clean Baseline is 1: {sql_metric(clean_logits).item():.4f}")
    print(f"Corrupted Baseline is 0: {sql_metric(corrupted_logits).item():.4f}")

In [None]:
if model_num == 1: # TinyStories
    clean_out = []
    corrupted_out = []
    corrupted_grads = []

    with model.trace() as tracer:
        # Using nnsight's tracer.invoke context, we can batch the clean and the
        # corrupted runs into the same tracing context, allowing us to access
        # information generated within each of these runs within one forward pass

        with tracer.invoke(clean_tokens) as invoker_clean:
            # Gather each layer's attention
            for layer in model.transformer.h:
                # Get clean attention output for this layer
                # across all attention heads
                attn_out = layer.attn.input
                clean_out.append(attn_out.save())

        with tracer.invoke(corrupted_tokens) as invoker_corrupted:
            # Gather each layer's attention and gradients
            for layer in model.transformer.h:
                # Get corrupted attention output for this layer
                # across all attention heads
                attn_out = layer.attn.input
                corrupted_out.append(attn_out.save())
                # save corrupted gradients for attribution patching
                corrupted_grads.append(attn_out.grad.save())

            # Let's get the logits for the model's output
            # for the corrupted run
            logits = model.lm_head.output.save()

            # Our metric uses tensors saved on cpu, so we
            # need to move the logits to cpu.
            value = sql_metric(logits.cpu())

            # We also need to run a backwards pass to
            # update gradient values
            value.backward()

    qts.free_memory() # Free up GPU and CPU memory

In [None]:
if model_num == 2 or model_num == 3: # Qwen or Llama
    clean_out = []
    corrupted_out = []
    corrupted_grads = []

    with model.trace() as tracer:
        # Using nnsight's tracer.invoke context, we can batch the clean and the
        # corrupted runs into the same tracing context, allowing us to access
        # information generated within each of these runs within one forward pass

        with tracer.invoke(clean_tokens) as invoker_clean:
            # Gather each layer's attention
            for layer in model.model.layers:
                # Get clean attention output for this layer
                # across all attention heads
                #tracer.log("layer shape", layer)
                attn_out = layer.input
                clean_out.append(attn_out.save())

        with tracer.invoke(corrupted_tokens) as invoker_corrupted:
            # Gather each layer's attention and gradients
            for layer in model.model.layers:
                # Get corrupted attention output for this layer
                # across all attention heads
                attn_out = layer.input
                corrupted_out.append(attn_out.save())
                # save corrupted gradients for attribution patching
                corrupted_grads.append(attn_out.grad.save())

            # Let's get the logits for the model's output
            # for the corrupted run
            logits = model.lm_head.output.save()

            # Our metric uses tensors saved on cpu, so we
            # need to move the logits to cpu.
            value = sql_metric(logits.cpu())

            # We also need to run a backwards pass to
            # update gradient values
            value.backward()

In [None]:
if model_num > 0:
    N_LAYERS, N_HEADS, D_MODEL, D_HEAD = qts.get_model_sizes(model_num, model, cs_num)

# Graph logit changes by attention head
Heatmap to examine how the logit difference changes after patching each layer’s output across attention heads.

In [None]:
attention_head_results = []
for corrupted_grad, corrupted, clean, layer in zip(
    corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
):

    residual_attr = einops.reduce(
        corrupted_grad.value[:,-1,:] * (clean.value[:,-1,:] - corrupted.value[:,-1,:]),
        "batch (head dim) -> head",
        "sum",
        head = N_HEADS,
        dim = D_HEAD,
    )

    attention_head_results.append(
        residual_attr.detach().cpu().numpy()
    )

In [None]:
fig = px.imshow(
    attention_head_results,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Attribution Patching Over Attention Heads: " + feature_name,
    labels={"x": "Head", "y": "Layer","color":"Norm. Logit Diff"},
)

fig.show()

# Graph logit changes by token position
Heatmaps to examine how the logit difference changes after patching each layer’s output across token positions.

In [None]:
token_pos_results = []

for corrupted_grad, corrupted, clean, layer in zip(
    corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
):

    residual_attr = einops.reduce(
        corrupted_grad.value * (clean.value - corrupted.value),
        "batch pos dim -> pos",
        "sum",
    )

    token_pos_results.append(
        residual_attr.detach().cpu().numpy()
    )

In [None]:
fig = px.imshow(
    token_pos_results,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Attribution Patching Over Token Position: " + feature_name,
    labels={"x": "Token Position", "y": "Layer","color":"Norm. Logit Diff"},
)

fig.show()

In [None]:
seen_suffix = "_novel" if use_novel_names else ""

# Serialize and save the results list to a temporary CoLab file in JSON format. Manually download.
attribution_json_filename = 'tinysql_bm' + str(model_num) + "_cs" + str(cs_num) + feature_name + seen_suffix + '_attribution.json'
print( "Saving useful node list with behavior tags:", attribution_json_filename)
#TODO: cfg.useful_nodes.save_nodes(attribution_json_filename)

#TODO: Auto save to Martian wandb