# Attribution Patching on GPT2, M1, M2 and M3 models using nnsight
- Developed on Google Colab using an A100 with 40GB GPU and 80GB system RAM.
- Requires a GITHUB_TOKEN secret to access Martian repository.
- Qwen runs out of memory??
- Llama runs out of memory??

In [1]:
model_num = 2   # 0=GPT2, 1=TinyStories, 2=Qwen or 3=Llama
cs_num = 1      # 0=BaseModel, 1=CS1, 2=CS2 or 3=CS3

# Part 0: Import libraries
Imports standard libraries. Do not read.

In [2]:
!pip install -U nnsight

Collecting nnsight
  Downloading nnsight-0.3.7-py3-none-any.whl.metadata (15 kB)
Collecting python-socketio[client] (from nnsight)
  Downloading python_socketio-5.11.4-py3-none-any.whl.metadata (3.2 kB)
Collecting bidict>=0.21.0 (from python-socketio[client]->nnsight)
  Downloading bidict-0.23.1-py3-none-any.whl.metadata (8.7 kB)
Collecting python-engineio>=4.8.0 (from python-socketio[client]->nnsight)
  Downloading python_engineio-4.10.1-py3-none-any.whl.metadata (2.2 kB)
Collecting simple-websocket>=0.10.0 (from python-engineio>=4.8.0->python-socketio[client]->nnsight)
  Downloading simple_websocket-1.1.0-py3-none-any.whl.metadata (1.5 kB)
Collecting wsproto (from simple-websocket>=0.10.0->python-engineio>=4.8.0->python-socketio[client]->nnsight)
  Downloading wsproto-1.2.0-py3-none-any.whl.metadata (5.6 kB)
Downloading nnsight-0.3.7-py3-none-any.whl (3.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.5/3.5 MB[0m [31m37.4 MB/s[0m eta [36m0:00:00[0m
[?25hDow

In [3]:
from IPython.display import clear_output
import einops
import torch
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "colab"

from nnsight import LanguageModel

In [4]:
!pip install --upgrade git+https://github.com/PhilipQuirke/quanta_mech_interp.git
import QuantaMechInterp as qmi

Collecting git+https://github.com/PhilipQuirke/quanta_mech_interp.git
  Cloning https://github.com/PhilipQuirke/quanta_mech_interp.git to /tmp/pip-req-build-2f542tux
  Running command git clone --filter=blob:none --quiet https://github.com/PhilipQuirke/quanta_mech_interp.git /tmp/pip-req-build-2f542tux
  Resolved https://github.com/PhilipQuirke/quanta_mech_interp.git to commit 24dccfc92b6978f7f186a0e4bfe189525b745457
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting scikit-optimize (from QuantaMechInterp==1.0)
  Downloading scikit_optimize-0.10.2-py2.py3-none-any.whl.metadata (9.7 kB)
Collecting transformer_lens (from QuantaMechInterp==1.0)
  Downloading transformer_lens-2.9.1-py3-none-any.whl.metadata (12 kB)
Collecting torchtyping>=0.1.4 (from QuantaMechInterp==1.0)
  Downloading torchtyping-0.1.5-py3-none-any.whl.metadata (9.5 kB)
INFO: pip is looking a

In [5]:
import os
from getpass import getpass
from google.colab import userdata

In [6]:
# Set GitHub credentials
#github_token = getpass('Enter your GitHub PAT: ')  # This will prompt for token securely
github_token = userdata.get("GITHUB_TOKEN")

# Install the private repository using the token
!pip install --upgrade git+https://{github_token}@github.com/withmartian/quanta_text_to_sql.git

# Now you can import the package
import QuantaTextToSql as qts

Collecting git+https://****@github.com/withmartian/quanta_text_to_sql.git
  Cloning https://****@github.com/withmartian/quanta_text_to_sql.git to /tmp/pip-req-build-ri85v3vt
  Running command git clone --filter=blob:none --quiet 'https://****@github.com/withmartian/quanta_text_to_sql.git' /tmp/pip-req-build-ri85v3vt
  Resolved https://****@github.com/withmartian/quanta_text_to_sql.git to commit fd3989a1604ef54b72573029443edb9227713909
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: QuantaTextToSql
  Building wheel for QuantaTextToSql (pyproject.toml) ... [?25l[?25hdone
  Created wheel for QuantaTextToSql: filename=QuantaTextToSql-1.0-py3-none-any.whl size=37274 sha256=6d8f7f11d26e286b5c19781c3766abbaf5faf3143093f422c9483aa20dec0219
  Stored in directory: /tmp/pip-ephem-wheel-cache-pghe082q/wheels/ec/2e/0b/f2473c76d3035232d5

# Tutorial on GPT2
https://nnsight.net/notebooks/tutorials/attribution_patching/


In [7]:
if model_num == 0:
    model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
    clear_output()
    print(model)

In [10]:
answer_token_indices = None
if model_num == 0:
    prompts = [
        "When John and Mary went to the shops, John gave the bag to",
        "When John and Mary went to the shops, Mary gave the bag to",
        "When Tom and James went to the park, James gave the ball to",
        "When Tom and James went to the park, Tom gave the ball to",
        "When Dan and Sid went to the shops, Sid gave an apple to",
        "When Dan and Sid went to the shops, Dan gave an apple to",
        "After Martin and Amy went to the park, Amy gave a drink to",
        "After Martin and Amy went to the park, Martin gave a drink to",
    ]

    # Answers are each formatted as (correct, incorrect):
    answers = [
        (" Mary", " John"),
        (" John", " Mary"),
        (" Tom", " James"),
        (" James", " Tom"),
        (" Dan", " Sid"),
        (" Sid", " Dan"),
        (" Martin", " Amy"),
        (" Amy", " Martin"),
    ]

    # Tokenize clean and corrupted inputs:
    clean_tokens = model.tokenizer(prompts, return_tensors="pt")["input_ids"]
    # The associated corrupted input is the prompt after the current clean prompt
    # for even indices, or the prompt prior to the current clean prompt for odd indices
    corrupted_tokens = clean_tokens[
        [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]
    ]

    # Tokenize answers for each prompt:
    answer_token_indices = torch.tensor(
        [
            [model.tokenizer(answers[i][j])["input_ids"][0] for j in range(2)]
            for i in range(len(answers))
        ]
    )

In [11]:
def get_logit_diff(logits, answer_token_indices=answer_token_indices):
    logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    return (correct_logits - incorrect_logits).mean()

In [12]:
if model_num == 0:
    clean_logits = model.trace(clean_tokens, trace=False).logits.cpu()
    corrupted_logits = model.trace(corrupted_tokens, trace=False).logits.cpu()

    CLEAN_BASELINE = get_logit_diff(clean_logits, answer_token_indices).item()
    print(f"Clean logit diff: {CLEAN_BASELINE:.4f}")

    CORRUPTED_BASELINE = get_logit_diff(corrupted_logits, answer_token_indices).item()
    print(f"Corrupted logit diff: {CORRUPTED_BASELINE:.4f}")

In [13]:
def ioi_metric(
    logits,
    answer_token_indices=answer_token_indices,
):
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (
        CLEAN_BASELINE - CORRUPTED_BASELINE
    )

In [14]:
if model_num == 0:
    print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
    print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")

In [15]:
if model_num == 0:
    clean_out = []
    corrupted_out = []
    corrupted_grads = []

    with model.trace() as tracer:
        # Using nnsight's tracer.invoke context, we can batch the clean and the
        # corrupted runs into the same tracing context, allowing us to access
        # information generated within each of these runs within one forward pass

        with tracer.invoke(clean_tokens) as invoker_clean:
            # Gather each layer's attention
            for layer in model.transformer.h:
                # Get clean attention output for this layer
                # across all attention heads
                attn_out = layer.attn.c_proj.input
                clean_out.append(attn_out.save())

        with tracer.invoke(corrupted_tokens) as invoker_corrupted:
            # Gather each layer's attention and gradients
            for layer in model.transformer.h:
                # Get corrupted attention output for this layer
                # across all attention heads
                attn_out = layer.attn.c_proj.input
                corrupted_out.append(attn_out.save())
                # save corrupted gradients for attribution patching
                corrupted_grads.append(attn_out.grad.save())

            # Let's get the logits for the model's output
            # for the corrupted run
            logits = model.lm_head.output.save()

            # Our metric uses tensors saved on cpu, so we
            # need to move the logits to cpu.
            value = ioi_metric(logits.cpu())

            # We also need to run a backwards pass to
            # update gradient values
            value.backward()

In [16]:
if model_num == 0:
    patching_results = []

    for corrupted_grad, corrupted, clean, layer in zip(
        corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
    ):

        residual_attr = einops.reduce(
            corrupted_grad.value[:,-1,:] * (clean.value[:,-1,:] - corrupted.value[:,-1,:]),
            "batch (head dim) -> head",
            "sum",
            head = 12,
            dim = 64,
        )

        patching_results.append(
            residual_attr.detach().cpu().numpy()
        )

In [17]:
if model_num == 0:
    fig = px.imshow(
        patching_results,
        color_continuous_scale="RdBu",
        color_continuous_midpoint=0.0,
        title="Attribution Patching Over Attention Heads",
        labels={"x": "Head", "y": "Layer","color":"Norm. Logit Diff"},

    )

    fig.show()

In [18]:
if model_num == 0:
    patching_results = []

    for corrupted_grad, corrupted, clean, layer in zip(
        corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
    ):

        residual_attr = einops.reduce(
            corrupted_grad.value * (clean.value - corrupted.value),
            "batch pos dim -> pos",
            "sum",
        )

        patching_results.append(
            residual_attr.detach().cpu().numpy()
        )

In [19]:
if model_num == 0:
    fig = px.imshow(
        patching_results,
        color_continuous_scale="RdBu",
        color_continuous_midpoint=0.0,
        title="Attribution Patching Over Token Position",
        labels={"x": "Token Position", "y": "Layer","color":"Norm. Logit Diff"},
    )

    fig.show()

# Run on TinySQL model

In [20]:
# Load the tokenizer and trained model for model 1, 2, or 3 and command set 0 (base model), 1, 2, or 3
def sql_interp_model_location( model_num : int, cs_num : int):
    if model_num == 1:
        if cs_num == 0:
            return "roneneldan/TinyStories-Instruct-2Layers-33M"

        elif cs_num == 1:
            return "withmartian/sql_interp_bm1_cs1_experiment_1.1"

        elif cs_num == 2:
            return "withmartian/sql_interp_bm1_cs2_experiment_2.3"

        elif cs_num == 3:
            return "withmartian/sql_interp_bm1_cs3_experiment_3.3"

    if model_num == 2:
        if cs_num == 0:
            return "huggingface.co/Qwen/Qwen2.5-0.5B-Instruct"

        elif cs_num == 1:
            return "withmartian/sql_interp_bm2_cs1_experiment_4.1"

        elif cs_num == 2:
            return "withmartian/sql_interp_bm2_cs2_experiment_5.1"

        elif cs_num == 3:
            return "withmartian/sql_interp_bm2_cs3_experiment_6.1"

    if model_num == 3:
        if cs_num == 0:
            return "huggingface.co/meta-llama/Llama-3.2-1B-Instruct"

        elif cs_num == 1:
            return "withmartian/sql_interp_bm3_cs1_experiment_7.1"

        elif cs_num == 2:
            return "withmartian/sql_interp_bm3_cs2_experiment_8.1"

        elif cs_num == 3:
            return "withmartian/sql_interp_bm3_cs3_experiment_9.1"

    return ""


In [21]:
if model_num > 0:
    model_location = sql_interp_model_location(model_num, cs_num)

    model = LanguageModel(model_location, device_map="auto")
    clear_output()
    print(model)

GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(50258, 1024)
    (wpe): Embedding(2048, 1024)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-1): 2 x GPTNeoBlock(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoSelfAttention(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=False)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
        )
        (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): GPTNeoMLP(
          (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
          (c_proj): L

In [22]:
if model_num > 0:
    prompts = [
        "### Instruction: What do we have for size in profiles? ### Context: CREATE TABLE profiles (size INTEGER, age INTEGER, name TEXT) ### Response: SELECT",
        "### Instruction: What do we have for age in profiles? ### Context: CREATE TABLE profiles (size INTEGER, age INTEGER, name TEXT) ### Response: SELECT",
        "### Instruction: What do we have for name in profiles? ### Context: CREATE TABLE profiles (size INTEGER, age INTEGER, name TEXT) ### Response: SELECT",
    ]

    # Answers are each formatted as (correct, incorrect):
    answers = [
        (" size", " age"),
        (" age", " name"),
        (" name", " size"),
    ]

    # Tokenize clean and corrupted inputs:
    clean_tokens = model.tokenizer(prompts, return_tensors="pt")["input_ids"]
    corrupted_tokens = clean_tokens.clone()
    corrupted_tokens[0,8] = clean_tokens[1,8]
    corrupted_tokens[1,8] = clean_tokens[2,8]
    corrupted_tokens[2,8] = clean_tokens[0,8]

    # Tokenize answers for each prompt:
    answer_token_indices = torch.tensor(
        [
            [model.tokenizer(answers[i][j])["input_ids"][0] for j in range(2)]
            for i in range(len(answers))
        ]
    )

In [23]:
if model_num > 0:
    clean_logits = model.trace(clean_tokens, trace=False).logits.cpu()
    corrupted_logits = model.trace(corrupted_tokens, trace=False).logits.cpu()

    CLEAN_BASELINE = get_logit_diff(clean_logits, answer_token_indices).item()
    print(f"Clean logit diff: {CLEAN_BASELINE:.4f}")

    CORRUPTED_BASELINE = get_logit_diff(corrupted_logits, answer_token_indices).item()
    print(f"Corrupted logit diff: {CORRUPTED_BASELINE:.4f}")

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


model.safetensors:   0%|          | 0.00/315M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/119 [00:00<?, ?B/s]

Clean logit diff: 0.3169
Corrupted logit diff: -0.2237


In [24]:
def sql_metric(
    logits,
    answer_token_indices=answer_token_indices,
):
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (
        CLEAN_BASELINE - CORRUPTED_BASELINE
    )

In [25]:
if model_num > 0:
    print(f"Clean Baseline is 1: {sql_metric(clean_logits).item():.4f}")
    print(f"Corrupted Baseline is 0: {sql_metric(corrupted_logits).item():.4f}")

Clean Baseline is 1: 1.0000
Corrupted Baseline is 0: 0.0000


In [26]:
if model_num > 0:
    clean_out = []
    corrupted_out = []
    corrupted_grads = []

    with model.trace() as tracer:
        # Using nnsight's tracer.invoke context, we can batch the clean and the
        # corrupted runs into the same tracing context, allowing us to access
        # information generated within each of these runs within one forward pass

        with tracer.invoke(clean_tokens) as invoker_clean:
            # Gather each layer's attention
            for layer in model.transformer.h:
                # Get clean attention output for this layer
                # across all attention heads
                attn_out = layer.attn.input
                clean_out.append(attn_out.save())

        with tracer.invoke(corrupted_tokens) as invoker_corrupted:
            # Gather each layer's attention and gradients
            for layer in model.transformer.h:
                # Get corrupted attention output for this layer
                # across all attention heads
                attn_out = layer.attn.input
                corrupted_out.append(attn_out.save())
                # save corrupted gradients for attribution patching
                corrupted_grads.append(attn_out.grad.save())

            # Let's get the logits for the model's output
            # for the corrupted run
            logits = model.lm_head.output.save()

            # Our metric uses tensors saved on cpu, so we
            # need to move the logits to cpu.
            value = sql_metric(logits.cpu())

            # We also need to run a backwards pass to
            # update gradient values
            value.backward()

In [27]:
patching_results = []

for corrupted_grad, corrupted, clean, layer in zip(
    corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
):

    residual_attr = einops.reduce(
        corrupted_grad.value[:,-1,:] * (clean.value[:,-1,:] - corrupted.value[:,-1,:]),
        "batch (head dim) -> head",
        "sum",
        head = 12,
        dim = 64,
    )

    patching_results.append(
        residual_attr.detach().cpu().numpy()
    )

EinopsError:  Error while processing sum-reduction pattern "batch (head dim) -> head".
 Input tensor shape: torch.Size([3, 1024]). Additional info: {'head': 12, 'dim': 64}.
 Shape mismatch, 1024 != 768

In [None]:
fig = px.imshow(
    patching_results,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Attribution Patching Over Token Position",
    labels={"x": "Token Position", "y": "Layer","color":"Norm. Logit Diff"},

)

fig.show()