# Attribution Patching

📗 This tutorial is adapted from Neel Nanda’s [blog post](https://www.neelnanda.io/mechanistic-interpretability/attribution-patching).

Activation patching is a method to determine how model components influence model computations (see our activation patching tutorial for more information). Although activation patching is a useful tool for circuit identification, it requires a separate forward pass through the model for each patched activation, making it time- and resource-intensive.

**Attribution patching** uses gradients to take a linear approximation to activation patching and can be done simultaneously in two forward and one backward pass, making it much more scalable to large models.

You can find a colab version of the tutorial [here](https://colab.research.google.com/github/ndif-team/nnsight/blob/main/docs/source/notebooks/tutorials/attribution_patching.ipynb) or Neel’s version [here](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Attribution_Patching_Demo.ipynb).

Read more about an application of Attribution Patching in [Attribution Patching Outperforms Automated Circuit Discovery](https://arxiv.org/abs/2310.10348). 📙

## Setup


If you are using Colab or haven't yet installed NNsight, install the package:
```
!pip install -U nnsight
```

In [72]:
try:
    import google.colab
    is_colab = True
except ImportError:
    is_colab = False

if is_colab:
    !pip install -U nnsight



Import libraries

In [73]:
from IPython.display import clear_output
import einops
import torch
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "colab" if is_colab else "plotly_mimetype+notebook_connected+notebook"


from nnsight import LanguageModel

In [74]:
import nnsight
print(nnsight.__version__)

0.4.3


## 1️⃣ Indirect Object Identification (IOI) Patching


Indirect object identification (IOI) is the ability to infer the correct indirect object in a sentence, allowing one to complete sentences like "John and Mary went to the shops, John gave a bag to" with the correct answer " Mary". Understanding how language models like GPT-2 perform linguistic tasks like IOI helps us gain insights into their internal mechanisms and decision-making processes.

Here, we apply the [IOI task](https://arxiv.org/abs/2211.00593) to explore how GPT-2 small is performing IOI with attribution patching.

*📚 Note: For more detail on the IOI task, check out the [ARENA walkthrough](https://arena3-chapter1-transformer-interp.streamlit.app/[1.4.1]_Indirect_Object_Identification).*

In [75]:
model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
clear_output()
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  (generator): Generator(
    (streamer): Streamer()
  )
)


Looking at the model architecture, we can see there are 12 layers, each with 12 GPT-2 Blocks. We will use attribution patching to approximate the contribution of each layer and each attention head for the IOI task.

We next define 8 IOI prompts, with each prompt having one related corrupted prompt variation (i.e., the indirect object is swapped out).

In [96]:
prompts = [
    "When John and Mary went to the shops, John gave the bag to",
    "When John and Mary went to the shops, Mary gave the bag to",
    "When Tom and James went to the park, James gave the ball to",
    "When Tom and James went to the park, Tom gave the ball to",
    "When Dan and Sid went to the shops, Sid gave an apple to",
    "When Dan and Sid went to the shops, Dan gave an apple to",
    "After Martin and Amy went to the park, Amy gave a drink to",
    "After Martin and Amy went to the park, Martin gave a drink to",
]

# Answers are each formatted as (correct, incorrect):
answers = [
    (" Mary", " John"),
    (" John", " Mary"),
    (" Tom", " James"),
    (" James", " Tom"),
    (" Dan", " Sid"),
    (" Sid", " Dan"),
    (" Martin", " Amy"),
    (" Amy", " Martin"),
]

# Tokenize clean and corrupted inputs:
clean_tokens = model.tokenizer(prompts, return_tensors="pt")["input_ids"]
# The associated corrupted input is the prompt after the current clean prompt
# for even indices, or the prompt prior to the current clean prompt for odd indices
corrupted_tokens = clean_tokens[
    [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]
]

# Tokenize answers for each prompt:
answer_token_indices = torch.tensor(
    [
        [model.tokenizer(answers[i][j])["input_ids"][0] for j in range(2)]
        for i in range(len(answers))
    ]
)


In [103]:
prompts = [
    "If a company's revenue grows and expenses stay flat, profit will likely",
    "If a company's revenue shrinks and expenses stay flat, profit will likely",
    "When inflation is high and rates are low, real returns will likely",
    "When inflation is low and rates are high, real returns will likely",
    "If a firm's debt is high and cash is low, its risk will likely",
    "If a firm's debt is low and cash is high, its risk will likely",
    "After a firm cuts costs and raises prices, margins will likely",
    "After a firm raises costs and cuts prices, margins will likely",
]

# Answers are each formatted as (correct, incorrect):
answers = [
    ("increase", "decrease"),
    ("decrease", "increase"),
    ("decrease", "increase"),
    ("increase", "decrease"),
    ("increase", "decrease"),
    ("decrease", "increase"),
    ("increase", "decrease"),
    ("decrease", "increase"),
]

# Tokenize clean and corrupted inputs:
clean_tokens = model.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True)["input_ids"]
# The associated corrupted input is the prompt after the current clean prompt
# for even indices, or the prompt prior to the current clean prompt for odd indices
corrupted_tokens = clean_tokens[
    [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]
]

# Tokenize answers for each prompt:
answer_token_indices = torch.tensor(
    [
        [model.tokenizer(answers[i][j])["input_ids"][0] for j in range(2)]
        for i in range(len(answers))
    ]
)

In [110]:
prompts = [
    "If a company's revenue grows steadily, its stock price will likely", # Tests understanding of revenue growth and stock market response
    "A company with high debt and low cash flow is considered financially", # Tests understanding of financial risk and stability
    "During periods of economic recession, consumer spending typically", # Tests understanding of macroeconomic factors and consumer behavior
    "When interest rates rise, bond prices generally tend to", # Tests understanding of interest rate and bond market dynamics
    "A company with a diversified portfolio of investments is less", # Tests understanding of investment diversification and risk reduction
    "If a company consistently reports strong earnings, its investors are likely to be", # Tests understanding of financial performance and investor sentiment
    "When inflation increases, the purchasing power of money", # Tests understanding of inflation and its impact on currency value
    "A company with a strong brand reputation often commands a higher", # Tests understanding of intangible assets and brand value
]

# Answers are each formatted as (correct, incorrect):
answers = [
    ("increase", "decrease"),
    ("risky", "stable"),
    ("decrease", "increase"),
    ("fall", "rise"),
    ("risky", "safe"),
    ("satisfied", "dissatisfied"),
    ("decreases", "increases"),
    ("price", "discount"),
]

# Tokenize clean and corrupted inputs:
clean_tokens = model.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True)["input_ids"]
# The associated corrupted input is the prompt after the current clean prompt
# for even indices, or the prompt prior to the current clean prompt for odd indices
corrupted_tokens = clean_tokens[
    [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]
]

# Tokenize answers for each prompt:
answer_token_indices = torch.tensor(
    [
        [model.tokenizer(answers[i][j])["input_ids"][0] for j in range(2)]
        for i in range(len(answers))
    ]
)

Meaning for Attribution Patching

"If a company's revenue grows steadily, its stock price will likely..."

Attribution Patching Focus: This prompt aims to identify attention heads that are sensitive to the relationship between revenue growth and stock market performance. Patching these heads could reveal how the model connects fundamental financial indicators to market expectations.
"A company with high debt and low cash flow is considered financially..."

Attribution Patching Focus: This prompt targets attention heads that assess financial risk and stability. Patching these heads could shed light on how the model evaluates the financial health of a company based on its debt and cash flow situation.
"During periods of economic recession, consumer spending typically..."

Attribution Patching Focus: This prompt focuses on attention heads that capture macroeconomic influences on consumer behavior. Patching these heads could reveal how the model understands the relationship between economic conditions and consumer spending patterns.
"When interest rates rise, bond prices generally tend to..."

Attribution Patching Focus: This prompt aims to identify attention heads that understand the inverse relationship between interest rates and bond prices. Patching these heads could provide insights into how the model processes financial market dynamics.
"A company with a diversified portfolio of investments is less..."

Attribution Patching Focus: This prompt targets attention heads that recognize the benefits of investment diversification. Patching these heads could reveal how the model evaluates risk reduction strategies in investment portfolios.
"If a company consistently reports strong earnings, its investors are likely to be..."

Attribution Patching Focus: This prompt focuses on attention heads that connect financial performance with investor sentiment. Patching these heads could provide insights into how the model understands the relationship between a company's earnings and investor satisfaction.
"When inflation increases, the purchasing power of money..."

Attribution Patching Focus: This prompt aims to identify attention heads that understand the impact of inflation on currency value. Patching these heads could reveal how the model processes macroeconomic factors and their effects on purchasing power.
"A company with a strong brand reputation often commands a higher..."

Attribution Patching Focus: This prompt targets attention heads that recognize the value of intangible assets like brand reputation. Patching these heads could shed light on how the model evaluates brand strength and its influence on pricing or market position.
Overall, by using these distinct financial prompts for attribution patching, we can gain a more nuanced understanding of which attention heads are responsible for specific aspects of financial reasoning within the language model. This information can be used to improve the model's performance, explainability, and its ability to handle diverse financial scenarios.

Next, we create a function to calculate the mean logit difference for the correct vs incorrect answer tokens.

In [111]:
def get_logit_diff(logits, answer_token_indices=answer_token_indices):
    logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    return (correct_logits - incorrect_logits).mean()

We then calculate the logit difference for both the clean and the corrupted baselines.



In [112]:
clean_logits = model.trace(clean_tokens, trace=False).logits.cpu()
corrupted_logits = model.trace(corrupted_tokens, trace=False).logits.cpu()

CLEAN_BASELINE = get_logit_diff(clean_logits, answer_token_indices).item()
print(f"Clean logit diff: {CLEAN_BASELINE:.4f}")

CORRUPTED_BASELINE = get_logit_diff(corrupted_logits, answer_token_indices).item()
print(f"Corrupted logit diff: {CORRUPTED_BASELINE:.4f}")

Clean logit diff: 0.8367
Corrupted logit diff: 0.2636


Now let's define an `ioi_metric` function to evaluate patched IOI changes normalized to our clean and corruped baselines.

In [113]:
def ioi_metric(
    logits,
    answer_token_indices=answer_token_indices,
):
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (
        CLEAN_BASELINE - CORRUPTED_BASELINE
    )

print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")

Clean Baseline is 1: 1.0000
Corrupted Baseline is 0: 0.0000


## 2️⃣ Attribution Patching Over Components


Attribution patching is a technique that uses gradients to take a linear approximation to activation patching. The key assumption is that the corrupted run is a locally linear function of its activations.

We thus take the gradient of the patch metric (`ioi_metric`) with respect to its activations, where we consider a patch of activations to be applying `corrupted_x` to `corrupted_x + (clean_x - corrupted_x)`. Then, we compute the patch metric's change: `(corrupted_grad_x * (clean_x - corrupted_x)).sum()`. All we need to do is take a backwards pass on the corrupted prompt with respect to the patch metric and cache all gradients with respect to the activations.

Let’s see how this breaks down in NNsight!

**A note on c_proj:** *Most HuggingFace models don’t have nice individual attention head representations to hook. Instead, we have the linear layer `c_proj` which implicitly combines the “projection per attention head” and the “sum over attention head” operations. See [this snippet](https://arena3-chapter1-transformer-interp.streamlit.app/~/+/[1.4.2]_Function_Vectors_&_Model_Steering) from ARENA for more information.*

TL;DR: We will use the input to `c_proj` for causal interventions on a particular attention head.

In [114]:
clean_out = []
corrupted_out = []
corrupted_grads = []

with model.trace() as tracer:
# Using nnsight's tracer.invoke context, we can batch the clean and the
# corrupted runs into the same tracing context, allowing us to access
# information generated within each of these runs within one forward pass

    with tracer.invoke(clean_tokens) as invoker_clean:
        # Gather each layer's attention
        for layer in model.transformer.h:
            # Get clean attention output for this layer
            # across all attention heads
            attn_out = layer.attn.c_proj.input
            clean_out.append(attn_out.save())

    with tracer.invoke(corrupted_tokens) as invoker_corrupted:
        # Gather each layer's attention and gradients
        for layer in model.transformer.h:
            # Get corrupted attention output for this layer
            # across all attention heads
            attn_out = layer.attn.c_proj.input
            corrupted_out.append(attn_out.save())
            # save corrupted gradients for attribution patching
            corrupted_grads.append(attn_out.grad.save())

        # Let's get the logits for the model's output
        # for the corrupted run
        logits = model.lm_head.output.save()

        # Our metric uses tensors saved on cpu, so we
        # need to move the logits to cpu.
        value = ioi_metric(logits.cpu())

        # We also need to run a backwards pass to
        # update gradient values
        value.backward()

Next, for a given activation we compute `(corrupted_grad_act * (clean_act - corrupted_act)).sum()`. We use `einops.reduce` to rearrange and sum activations over the correct dimension. In this case, we want to estimate the effect of specific attention heads, so we sum over heads rather than token position.

In [115]:
patching_results = []

for corrupted_grad, corrupted, clean, layer in zip(
    corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
):

    residual_attr = einops.reduce(
        corrupted_grad.value[:,-1,:] * (clean.value[:,-1,:] - corrupted.value[:,-1,:]),
        "batch (head dim) -> head",
        "sum",
        head = 12,
        dim = 64,
    )

    patching_results.append(
        residual_attr.detach().cpu().numpy()
    )

In [116]:
fig = px.imshow(
    patching_results,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Attribution Patching Over Attention Heads",
    labels={"x": "Head", "y": "Layer","color":"Norm. Logit Diff"},

)

fig.show()

Here, we see that the early layer attention heads may not be important for IOI.

## 3️⃣ Attribution Patching Over Position


One benefit of attribution patching is efficiency. Activation patching requires a separate forward pass per activation patched while every attribution patch can be done simultaneously in two forward passes and one backward pass. Attribution patching makes patching much more scalable to large models and can serve as a useful heuristic to find the interesting activations to patch.

In practice, whie this approximation is decent when patching in “small” activations like head outputs, performance decreases significantly when patching in “big” activations like those found in the residual stream.

Using the same outputs we cached above, we can get the individual contributions at each token position simply by summing across token positions. Although this is messy, it's a quick approximation of the attention mechanism's contribution across token position.

*Note: in our specific case here, patching across positions does NOT reflect the entire residual stream, just the post-attention output (i.e., excludes MLPs).*

In [84]:
patching_results = []

for corrupted_grad, corrupted, clean, layer in zip(
    corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
):

    residual_attr = einops.reduce(
        corrupted_grad.value * (clean.value - corrupted.value),
        "batch pos dim -> pos",
        "sum",
    )

    patching_results.append(
        residual_attr.detach().cpu().numpy()
    )

In [85]:
fig = px.imshow(
    patching_results,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Attribution Patching Over Token Position",
    labels={"x": "Token Position", "y": "Layer","color":"Norm. Logit Diff"},

)

fig.show()

This result looks similar to our previous result using activation patching but is much less precise, as expected!

# Remote Attribution Patching

Now that we know how to run an attribution patching experiment in `nnsight`, let's go over how you can use NDIF's publicly-hosted models to further scale your research!

We're going to run the same experiment, but now using Llama 3.1 8B. Completing this section of the tutorial will require you to [configure NNsight for remote execution](https://nnsight.net/notebooks/features/remote_execution/) if you haven't already.

## Remote Setup

In [86]:
from nnsight import CONFIG

if is_colab:
    # include your HuggingFace Token and NNsight API key on Colab secrets
    from google.colab import userdata
    NDIF_API = userdata.get('NDIF_API')
    HF_TOKEN = userdata.get('HF_TOKEN')

    CONFIG.set_default_api_key(NDIF_API)
    !huggingface-cli login -token HF_TOKEN

clear_output()

In [87]:
import os
os.environ["HF_TOKEN"] = HF_TOKEN

In [88]:
import torch
import nnsight
from nnsight import LanguageModel

In [89]:
from google.colab import userdata
HF_TOKEN = userdata.get('HF_TOKEN')
NDIF_API = userdata.get('NDIF_API')

Next, let's load the Llama 3.1 8B model, once again using NNsight's `LanguageModel` wrapper. Because we'll be running the model on NDIF's remote servers, no need to specify a `device_map`!

In [90]:
# Load model
llm = LanguageModel("meta-llama/Meta-Llama-3.1-8B")
#llm = LanguageModel("mistralai/Mistral-Small-24B-Instruct-2501")
print(llm)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_

## IOI Task Setup

We've already defined some prompts in the above tutorial, but we'll have to re-tokenize them for Llama 8B.

In [91]:
# Tokenize clean and corrupted inputs:
clean_tokens = llm.tokenizer(prompts, return_tensors="pt")["input_ids"]
# The associated corrupted input is the prompt after the current clean prompt
# for even indices, or the prompt prior to the current clean prompt for odd indices
corrupted_tokens = clean_tokens[
    [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]
]

# Tokenize answers for each prompt:
answer_token_indices = torch.tensor(
    [
        [llm.tokenizer(answers[i][j])["input_ids"][1] for j in range(2)]
        for i in range(len(answers))
    ]
)

ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`input_ids` in this case) have excessive nesting (inputs type `list` where type `int` is expected).

In [None]:
from nnsight import CONFIG

CONFIG.set_default_api_key("546253536da74e51a1ccdfe887eafcea")

Next, we'll establish clean & corrupted baselines for our IOI metric, using the model's clean and corrupted logits and the `get_logit_diff` function defined earlier.

In [None]:
clean_logits = llm.trace(clean_tokens, trace=False, remote=True)
corrupted_logits = llm.trace(corrupted_tokens, trace=False, remote=True)

clean_logits = clean_logits['logits']
corrupted_logits = corrupted_logits['logits']

CLEAN_BASELINE = get_logit_diff(clean_logits, answer_token_indices).item()
print(f"\n\nClean logit diff: {CLEAN_BASELINE:.4f}")

CORRUPTED_BASELINE = get_logit_diff(corrupted_logits, answer_token_indices).item()
print(f"Corrupted logit diff: {CORRUPTED_BASELINE:.4f}")

We've also already defined our `ioi_metric` function. Let's plug in our logit values.

In [None]:
print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")

## Remote Attribution Patching

Great! We have some baselines. Now, let's run the attribution patching pipeline on Llama 8B. We can't copy the code exactly, because Llama 8B has a different model structure than GPT-2, but we're following the same steps: a clean run and a corrupted run as invokes during one tracing context.

In [None]:
clean_out = []
corrupted_out = []
corrupted_grads = []

with llm.trace(remote = True) as tracer:
# Using nnsight's tracer.invoke context, we can batch the clean and the
# corrupted runs into the same tracing context, allowing us to access
# information generated within each of these runs within one forward pass

    with tracer.invoke(clean_tokens) as invoker_clean:
      # need to set requires grad to true for remote
        llm.model.layers[0].self_attn.o_proj.input.requires_grad = True
        # Gather each layer's attention
        for layer in llm.model.layers:
            # Get clean attention output for this layer
            # across all attention heads
            attn_out = layer.self_attn.o_proj.input
            clean_out.append(attn_out.save())

    with tracer.invoke(corrupted_tokens) as invoker_corrupted:
        # Gather each layer's attention and gradients
        for layer in llm.model.layers:
            # Get corrupted attention output for this layer
            # across all attention heads
            attn_out = layer.self_attn.o_proj.input
            corrupted_out.append(attn_out.save())
            # save corrupted gradients for attribution patching
            corrupted_grads.append(attn_out.grad.save())

        # Let's get the logits for the model's output
        # for the corrupted run
        logits = llm.lm_head.output.save()

        # Our IOI metric uses tensors saved on cpu, so we
        # need to move the logits to cpu.
        value = ioi_metric(logits.cpu())

        # We also need to run a backwards pass to
        # update gradient values
        value.backward()

Awesome! Let's take a look at attention head contributions across layers.

In [None]:
# format data for plotting across attention heads
patching_results = []

for corrupted_grad, corrupted, clean, layer in zip(
    corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
):

    residual_attr = einops.reduce(
        corrupted_grad.value[:,-1,:] * (clean.value[:,-1,:] - corrupted.value[:,-1,:]),
        "batch (head dim) -> head",
        "sum",
        head = 32,
        dim = 128,
    )

    patching_results.append(
        (residual_attr.float()).detach().numpy()
    )

In [None]:
fig = px.imshow(
    patching_results,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Attribution Patching Over Attention Heads",
    labels={"x": "Head", "y": "Layer","color":"Norm. Logit Diff"},

)

fig.show()

Next, let's check out the contribution of the residual stream over token position across layers.

In [None]:
# format data for plotting across input tokens
patching_results = []

for corrupted_grad, corrupted, clean, layer in zip(
    corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
):

    residual_attr = einops.reduce(
        corrupted_grad.value * (clean.value - corrupted.value),
        "batch pos dim -> pos",
        "sum",
    )

    patching_results.append(
        (residual_attr.detach().cpu().float()).numpy()
    )

In [None]:
fig = px.imshow(
    patching_results,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Attribution Patching Over Token Position",
    labels={"x": "Token Position", "y": "Layer","color":"Norm. Logit Diff"},

)

fig.show()

Great! We've now successfully performed an attribution patching experiment on GPT-2 and Llama 8b.