# Attribution Patching on M0 to M3 models

The nnsight library tutorial [attribution_patching](https://nnsight.net/notebooks/tutorials/attribution_patching/) says:

Activation patching is a method to determine how model components influence model computations. It is time- and resource-intensive.
**Attribution patching** uses gradients to take a linear approximation to activation patching and can be done simultaneously in two forward and one backward pass. It is scalable to large models.

This notebook implements **Attribution patching**. It:
- Runs with GPT2/TinyStories/Qwen/Llama models with base/CS1/CS2/CS3 command sets.
- Was developed on Google Colab using an **A100** for Qwen and a **T4** for other models.
- Requires a GITHUB_TOKEN secret to access Martian TinySQL code repository.
- Requires a HF_TOKEN secret to access Martian HuggingFace repository.

# Part 0: Import libraries
Imports standard libraries. Do not read.

In [None]:
# https://nnsight.net/
!pip install -U nnsight

In [None]:
from IPython.display import clear_output
import einops
import torch
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "colab"
import numpy as np

from nnsight import LanguageModel

In [None]:
import os
from google.colab import userdata
import itertools

In [None]:
!pip install datasets

In [None]:
github_token = userdata.get("GITHUB_TOKEN")

# Install the private repository using the token
!pip install --upgrade git+https://{github_token}@github.com/withmartian/TinySQL.git

import TinySQL as qts

In [None]:
!pip install --upgrade git+https://github.com/PhilipQuirke/quanta_mech_interp.git
import QuantaMechInterp as qmi

# Select model and command set

In [None]:
model_num = 1                      # 0=GPT2, 1=TinyStories, 2=Qwen, 3=Llama, 4=Granite, 5=SmolLM
cs_num = 2                         # 0=BaseModel, 1=CS1, 2=CS2 or 3=CS3

# Investigate m0 : nnsight tutorial using GPT2
Reproduces https://nnsight.net/notebooks/tutorials/activation_patching/



In [None]:
if model_num == 0:
    feature_name = ""
    model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
    clear_output()
    print(model)

In [None]:
answer_token_indices = None
if model_num == 0:
    prompts = [
        "When John and Mary went to the shops, John gave the bag to",
        "When John and Mary went to the shops, Mary gave the bag to",
        "When Tom and James went to the park, James gave the ball to",
        "When Tom and James went to the park, Tom gave the ball to",
        "When Dan and Sid went to the shops, Sid gave an apple to",
        "When Dan and Sid went to the shops, Dan gave an apple to",
        "After Martin and Amy went to the park, Amy gave a drink to",
        "After Martin and Amy went to the park, Martin gave a drink to",
    ]

    # Answers are each formatted as (correct, incorrect):
    answer_pairs = [
        (" Mary", " John"),
        (" John", " Mary"),
        (" Tom", " James"),
        (" James", " Tom"),
        (" Dan", " Sid"),
        (" Sid", " Dan"),
        (" Martin", " Amy"),
        (" Amy", " Martin"),
    ]

    # Tokenize clean and corrupted inputs:
    clean_tokens = model.tokenizer(prompts, return_tensors="pt")["input_ids"]
    # The associated corrupted input is the prompt after the current clean prompt
    # for even indices, or the prompt prior to the current clean prompt for odd indices
    corrupted_tokens = clean_tokens[
        [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]
    ]

    # Tokenize answer_pairs for each prompt:
    answer_token_indices = torch.tensor(
        [
            [model.tokenizer(answer_pairs[i][j])["input_ids"][0] for j in range(2)]
            for i in range(len(answer_pairs))
        ]
    )

In [None]:
def get_logit_diff(logits, answer_token_indices):
    logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    return (correct_logits - incorrect_logits).mean()

In [None]:
if model_num == 0:
    clean_logits = model.trace(clean_tokens, trace=False).logits.cpu()
    corrupted_logits = model.trace(corrupted_tokens, trace=False).logits.cpu()

    clean_baseline = get_logit_diff(clean_logits, answer_token_indices).item()
    print(f"Clean logit diff: {clean_baseline:.4f}")

    corrupted_baseline = get_logit_diff(corrupted_logits, answer_token_indices).item()
    print(f"Corrupted logit diff: {corrupted_baseline:.4f}")

In [None]:
def ioi_metric(
    logits,
    answer_token_indices,
):
    return (get_logit_diff(logits, answer_token_indices) - corrupted_baseline) / (
        clean_baseline - corrupted_baseline
    )

In [None]:
if model_num == 0:
    print(f"Clean Baseline is 1: {ioi_metric(clean_logits, answer_token_indices).item():.4f}")
    print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits, answer_token_indices).item():.4f}")

In [None]:
if model_num == 0:
    clean_out = []
    corrupted_out = []
    corrupted_grads = []

    with model.trace() as tracer:
        # Using nnsight's tracer.invoke context, we can batch the clean and the
        # corrupted runs into the same tracing context, allowing us to access
        # information generated within each of these runs within one forward pass

        with tracer.invoke(clean_tokens) as invoker_clean:
            # Gather each layer's attention
            for layer in model.transformer.h:
                # Get clean attention output for this layer
                # across all attention heads
                attn_out = layer.attn.c_proj.input
                clean_out.append(attn_out.save())

        with tracer.invoke(corrupted_tokens) as invoker_corrupted:
            # Gather each layer's attention and gradients
            for layer in model.transformer.h:
                # Get corrupted attention output for this layer
                # across all attention heads
                attn_out = layer.attn.c_proj.input
                corrupted_out.append(attn_out.save())
                # save corrupted gradients for attribution patching
                corrupted_grads.append(attn_out.grad.save())

            # Let's get the logits for the model's output
            # for the corrupted run
            logits = model.lm_head.output.save()

            # Our metric uses tensors saved on cpu, so we
            # need to move the logits to cpu.
            value = ioi_metric(logits.cpu())

            # We also need to run a backwards pass to
            # update gradient values
            value.backward()

In [None]:
if model_num == 0:
    N_HEADS = 12
    D_HEAD = 64

Results graphed in "Shared graph code" section below

# Investigate m1, m2 and m3 models

## Load model

In [None]:
if model_num > 0:
    model = qts.load_tinysql_model(model_num, cs_num, auth_token=userdata.get("HF_TOKEN"))
    clear_output()
    print(model)

    N_LAYERS, N_HEADS, D_MODEL, D_HEAD = qts.get_model_sizes(model_num, model)

    # Singleton QuantaTool "main" configuration class. qmi.AlgoConfig is derived from the chain qmi.UsefulConfig > qmi.ModelConfig
    cfg = qmi.AlgoConfig()
    cfg.main_model = model
    cfg.n_layers = N_LAYERS
    cfg.n_heads = N_HEADS
    cfg.d_model = D_MODEL
    cfg.d_head = D_HEAD
    cfg.file_config_prefix = ""
    cfg.set_seed(673023)

## Generate clean and corrupted data

In [None]:
def get_clean_and_corrupt_data( generator, examples ):

    # Index of prompt token to corrupt
    prompt_token_index = examples[0].prompt_token_index
    # Index of answer token we expect to be impacted
    answer_token_index = examples[0].answer_token_index

    prompts = [(example.get_alpaca_prompt() + example.sql_statement) for example in examples][:answer_token_index]

    # Generate answers as (correct, incorrect) pairs
    answer_pairs = [(example.clean_token_str, example.corrupt_token_str) for example in examples]

    # Tokenize clean inputs
    clean_tokens = model.tokenizer(prompts, return_tensors="pt", padding=True)["input_ids"]

    # Different models tokenize differently giving different indexes for the corrupted text
    answer_offset = generator.tokenize_answer_offset()

    # Create corrupted tokens using circular rotation
    corrupted_tokens = clean_tokens.clone()
    for i in range(len(prompts)):
        next_idx = (i + 1) % len(prompts)
        corrupted_tokens[i, prompt_token_index] = clean_tokens[next_idx, prompt_token_index]

    # Tokenize answer_pairs
    answer_token_indices = torch.tensor([
        [model.tokenizer(pair[j])["input_ids"][answer_offset] for j in range(2)]
        for pair in answer_pairs
    ])

    return clean_tokens, corrupted_tokens, answer_token_indices

## Trace clean and corrupted (batched) examples

In [None]:
def get_clean_and_corrupt_baselines( clean_tokens, corrupted_tokens, answer_token_indices ):
    clean_logits = model.trace(clean_tokens, trace=False).logits.cpu()
    corrupted_logits = model.trace(corrupted_tokens, trace=False).logits.cpu()

    clean_baseline = get_logit_diff(clean_logits, answer_token_indices).item()

    corrupted_baseline = get_logit_diff(corrupted_logits, answer_token_indices).item()

    return clean_logits, corrupted_logits, clean_baseline, corrupted_baseline

In [None]:
def sql_metric(
    logits,
    clean_baseline,
    corrupted_baseline,
    answer_token_indices,
):
    return (get_logit_diff(logits, answer_token_indices) - corrupted_baseline) / (
        clean_baseline - corrupted_baseline
    )

In [None]:
def trace_clean_and_corrupt_tinysql( clean_tokens, clean_baseline, corrupted_tokens, corrupted_baseline, answer_token_indices ):
    clean_out = []
    corrupted_out = []
    corrupted_grads = []

    with model.trace() as tracer:
        # Using nnsight's tracer.invoke context, we can batch the clean and the
        # corrupted runs into the same tracing context, allowing us to access
        # information generated within each of these runs within one forward pass

        with tracer.invoke(clean_tokens) as invoker_clean:
            # Gather each layer's attention
            for layer in model.transformer.h:
                # Get clean attention output for this layer
                # across all attention heads
                attn_out = layer.attn.input
                clean_out.append(attn_out.save())

        with tracer.invoke(corrupted_tokens) as invoker_corrupted:
            # Gather each layer's attention and gradients
            for layer in model.transformer.h:
                # Get corrupted attention output for this layer
                # across all attention heads
                attn_out = layer.attn.input
                corrupted_out.append(attn_out.save())
                # save corrupted gradients for attribution patching
                corrupted_grads.append(attn_out.grad.save())

            # Let's get the logits for the model's output
            # for the corrupted run
            logits = model.lm_head.output.save()

            # Our metric uses tensors saved on cpu, so we
            # need to move the logits to cpu.
            value = sql_metric(logits.cpu(), clean_baseline, corrupted_baseline, answer_token_indices)

            # We also need to run a backwards pass to
            # update gradient values
            value.backward()

    qts.free_memory() # Free up GPU and CPU memory

    return clean_out, corrupted_out, corrupted_grads

In [None]:
def trace_clean_and_corrupt_llm( clean_tokens, clean_baseline, corrupted_tokens, corrupted_baseline, answer_token_indices ):
    clean_out = []
    corrupted_out = []
    corrupted_grads = []

    with model.trace() as tracer:
        # Using nnsight's tracer.invoke context, we can batch the clean and the
        # corrupted runs into the same tracing context, allowing us to access
        # information generated within each of these runs within one forward pass

        with tracer.invoke(clean_tokens) as invoker_clean:
            # Gather each layer's attention
            for layer in model.model.layers:
                # Get clean attention output for this layer
                # across all attention heads
                #tracer.log("layer shape", layer)
                attn_out = layer.input
                clean_out.append(attn_out.save())

        with tracer.invoke(corrupted_tokens) as invoker_corrupted:
            # Gather each layer's attention and gradients
            for layer in model.model.layers:
                # Get corrupted attention output for this layer
                # across all attention heads
                attn_out = layer.input
                corrupted_out.append(attn_out.save())
                # save corrupted gradients for attribution patching
                corrupted_grads.append(attn_out.grad.save())

            # Let's get the logits for the model's output
            # for the corrupted run
            logits = model.lm_head.output.save()

            # Our metric uses tensors saved on cpu, so we
            # need to move the logits to cpu.
            value = sql_metric(logits.cpu(), clean_baseline, corrupted_baseline, answer_token_indices)

            # We also need to run a backwards pass to
            # update gradient values
            value.backward()

    qts.free_memory() # Free up GPU and CPU memory

    return clean_out, corrupted_out, corrupted_grads

# Graph logit changes by attention head
Heatmap to examine how the logit difference changes after patching each layer’s output across attention heads.

In [None]:
def get_attention_head_results(clean_out, corrupted_grads, corrupted_out):
    attention_head_results = []

    for corrupted_grad, corrupted, clean, layer in zip(
        corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
    ):

        residual_attr = einops.reduce(
            corrupted_grad.value[:,-1,:] * (clean.value[:,-1,:] - corrupted.value[:,-1,:]),
            "batch (head dim) -> head",
            "sum",
            head = N_HEADS,
            dim = D_HEAD,
        )

        attention_head_results.append(
            residual_attr.detach().cpu().numpy()
        )

    return attention_head_results

In [None]:
def show_attention_head_results(attention_head_results):

    fig = px.imshow(
        attention_head_results,
        color_continuous_scale="RdBu",
        color_continuous_midpoint=0.0,
        title="Attribution Patching Over Attention Heads: " + feature_name,
        labels={"x": "Head", "y": "Layer","color":"Norm. Logit Diff"},
    )

    fig.show()

# Graph logit changes by token position
Heatmap to examine how the logit difference changes after patching each layer’s output across token positions.

In [None]:
def get_token_pos_results(clean_out, corrupted_grads, corrupted_out):
    token_pos_results = []

    for corrupted_grad, corrupted, clean, layer in zip(
        corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
    ):

        residual_attr = einops.reduce(
            corrupted_grad.value * (clean.value - corrupted.value),
            "batch pos dim -> pos",
            "sum",
        )

        token_pos_results.append(
            residual_attr.detach().cpu().numpy()
        )

    return token_pos_results

In [None]:
def show_token_pos_results(token_pos_results):
    fig = px.imshow(
        token_pos_results,
        color_continuous_scale="RdBu",
        color_continuous_midpoint=0.0,
        title="Attribution Patching Over Token Position: " + feature_name,
        labels={"x": "Token Position", "y": "Layer","color":"Norm. Logit Diff"},
    )

    fig.show()

# Graph logit changes by layer+head vs token position


In [None]:
def show_logit_changes_by_node( example0, feature_name, use_novel_names, use_synonyms_table, use_synonyms_field, clean_out, corrupted_out, corrupted_grads, answer_token_indices ):

    pos_layer_attnhead_results = []
    for layer_idx, (cg, corr, cln) in enumerate(zip(corrupted_grads, corrupted_out, clean_out)):
        # cg.value, corr.value, cln.value = shape [batch, pos, hidden_dim]
        residual_attr = cg.value * (cln.value - corr.value)

        # Rearrange 3D -> 4D: [batch, pos, hidden_dim] -> [batch, pos, head, d_head]
        residual_attr = einops.rearrange(
            residual_attr,
            "batch pos (head d_head) -> batch pos head d_head",
            head=N_HEADS
        )

        # Sum over batch and d_head => shape [pos, head]
        residual_attr = einops.reduce(
            residual_attr,
            "batch pos head d_head -> pos head",
            "sum",
        )

        # (Optionally) transpose => shape [head, pos], depending on how you want your axes
        residual_attr = residual_attr.T  # [head, pos]

        pos_layer_attnhead_results.append(residual_attr.detach().cpu().numpy())

    # Now stack along axis=0 for each layer => final shape [num_layers * num_heads, pos]
    final_matrix = np.concatenate(pos_layer_attnhead_results, axis=0)

    num_layers = len(corrupted_out)
    layer_head_labels = []
    for layer_idx in range(num_layers):
        for head_idx in range(N_HEADS):
            layer_head_labels.append(f"L{layer_idx}H{head_idx}")

    fig = px.imshow(
        final_matrix,
        color_continuous_scale="RdBu",
        color_continuous_midpoint=0.0,
        labels={"x": "Token Position", "y": "Layer/Head", "color": "Norm. Logit Diff"},
        title="Attribution Patching Over (Layer, Head) vs. Token Position",
    )

    # Use our layer_head_labels on the y-axis
    fig.update_yaxes(
        tickmode="array",
        tickvals=list(range(len(layer_head_labels))),
        ticktext=layer_head_labels
    )

    # Set the figure size (in pixels), reduce margins, etc.
    fig.update_layout(
        width=800,
        height=600,
        margin=dict(l=20, r=20, t=50, b=20),
    )

    fig.show()

# Set up experiments

In [None]:
if model_num > 0:
    N_LAYERS, N_HEADS, D_MODEL, D_HEAD = qts.get_model_sizes(model_num, model, cs_num)

In [None]:
use_novel_names_list = [False]       # Novel words are not 1 token. So do not corrupt prompt with novel words
use_synonyms_table_list = [True,False]    # In english Instructions, use a synonym for 80% of table names?
use_synonyms_field_list = [False]         # In english Instructions, use a synonym for 50% of field names?

In [None]:
def run_experiments( feature_name ):
    global feature_name_list, use_novel_names_list, use_synonyms_table_list, use_synonyms_field_list

    for (use_novel_names, use_synonyms_table, use_synonyms_field) in itertools.product(use_novel_names_list, use_synonyms_table_list, use_synonyms_field_list):
        generator = qts.CorruptFeatureTestGenerator(model_num=model_num, cs_num=cs_num, tokenizer=model.tokenizer, use_novel_names=use_novel_names, use_synonyms_table=use_synonyms_table, use_synonyms_field=use_synonyms_field )

        examples = generator.generate_feature_examples(feature_name, 10)
        example0 = examples[0]

        clean_prompt_str = example0.get_alpaca_prompt()
        clean_prompt_tokens = model.tokenizer(clean_prompt_str)["input_ids"]

        clean_tokens, corrupted_tokens, answer_token_indices = get_clean_and_corrupt_data( generator, examples )
        clean_logits, corrupted_logits, clean_baseline, corrupted_baseline = get_clean_and_corrupt_baselines(clean_tokens, corrupted_tokens, answer_token_indices )

        print("Feature name:", feature_name)
        print("Use synonyms table:", use_synonyms_table)
        print("Use synonyms field:", use_synonyms_field)
        print("Clean prompt  :", clean_prompt_str )
        print("Corrupt prompt:", example0.corrupt_BatchItem.get_alpaca_prompt() )
        print("Prompt token index:", example0.prompt_token_index, "of", len(clean_prompt_tokens) )
        print("Answer token index:", example0.answer_token_index)
        print(f"Clean logit diff: {clean_baseline:.4f}")
        print(f"Corrupted logit diff: {corrupted_baseline:.4f}")
        print(f"Clean Baseline is 1: {sql_metric(clean_logits, clean_baseline, corrupted_baseline, answer_token_indices).item():.4f}")
        print(f"Corrupted Baseline is 0: {sql_metric(corrupted_logits, clean_baseline, corrupted_baseline, answer_token_indices).item():.4f}")

        if model_num == 1: # TinyStories
            clean_out, corrupted_out, corrupted_grads = trace_clean_and_corrupt_tinysql( clean_tokens, clean_baseline, corrupted_tokens, corrupted_baseline, answer_token_indices )
        elif model_num == 2 or model_num == 3: # Qwen or Llama
            clean_out, corrupted_out, corrupted_grads = trace_clean_and_corrupt_llm( clean_tokens, clean_baseline, corrupted_tokens, corrupted_baseline, answer_token_indices )

        #attention_head_results = get_attention_head_results(clean_out, corrupted_grads, corrupted_out)
        #show_attention_head_results(attention_head_results)

        #token_pos_results = get_token_pos_results(clean_out, corrupted_grads, corrupted_out)
        #show_token_pos_results(token_pos_results)

        show_logit_changes_by_node( example0, feature_name, use_novel_names, use_synonyms_table, use_synonyms_field, clean_out, corrupted_out, corrupted_grads, answer_token_indices )


# Run experiments

## Run experiment - DefTableName

In [None]:
run_experiments(qts.DEFTABLENAME)

## Run experiment - EngTableName

In [None]:
run_experiments(qts.ENGTABLENAME)

## Run experiment - DefFieldName

In [None]:
run_experiments(qts.DEFFIELDNAME)

## Run experiment - EngFieldName

In [None]:
run_experiments(qts.ENGFIELDNAME)