# Attribution Patching on GPT2, M1, M2 and M3 models using nnsight
- Developed on Google Colab using an A100 with 40GB GPU and 80GB system RAM.
- Requires a GITHUB_TOKEN secret to access Martian repository.
- Qwen runs out of memory??
- Llama runs out of memory??

In [None]:
model_num = 1   # 0=GPT2, 1=TinyStories, 2=Qwen or 3=Llama
cs_num = 1      # 0=BaseModel, 1=CS1, 2=CS2 or 3=CS3

# Part 0: Import libraries
Imports standard libraries. Do not read.

In [None]:
!pip install -U nnsight

In [None]:
from IPython.display import clear_output
import einops
import torch
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "colab"

from nnsight import LanguageModel

In [None]:
!pip install --upgrade git+https://github.com/PhilipQuirke/quanta_mech_interp.git
import QuantaMechInterp as qmi

In [None]:
import os
from getpass import getpass
from google.colab import userdata

In [None]:
# Set GitHub credentials
#github_token = getpass('Enter your GitHub PAT: ')  # This will prompt for token securely
github_token = userdata.get("GITHUB_TOKEN")

# Install the private repository using the token
!pip install --upgrade git+https://{github_token}@github.com/withmartian/quanta_text_to_sql.git

# Now you can import the package
import QuantaTextToSql as qts

# Tutorial on GPT2
https://nnsight.net/notebooks/tutorials/attribution_patching/


In [None]:
if model_num == 0:
    model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
    clear_output()
    print(model)

In [None]:
answer_token_indices = None
if model_num == 0:
    prompts = [
        "When John and Mary went to the shops, John gave the bag to",
        "When John and Mary went to the shops, Mary gave the bag to",
        "When Tom and James went to the park, James gave the ball to",
        "When Tom and James went to the park, Tom gave the ball to",
        "When Dan and Sid went to the shops, Sid gave an apple to",
        "When Dan and Sid went to the shops, Dan gave an apple to",
        "After Martin and Amy went to the park, Amy gave a drink to",
        "After Martin and Amy went to the park, Martin gave a drink to",
    ]

    # Answers are each formatted as (correct, incorrect):
    answers = [
        (" Mary", " John"),
        (" John", " Mary"),
        (" Tom", " James"),
        (" James", " Tom"),
        (" Dan", " Sid"),
        (" Sid", " Dan"),
        (" Martin", " Amy"),
        (" Amy", " Martin"),
    ]

    # Tokenize clean and corrupted inputs:
    clean_tokens = model.tokenizer(prompts, return_tensors="pt")["input_ids"]
    # The associated corrupted input is the prompt after the current clean prompt
    # for even indices, or the prompt prior to the current clean prompt for odd indices
    corrupted_tokens = clean_tokens[
        [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]
    ]

    # Tokenize answers for each prompt:
    answer_token_indices = torch.tensor(
        [
            [model.tokenizer(answers[i][j])["input_ids"][0] for j in range(2)]
            for i in range(len(answers))
        ]
    )

In [None]:
def get_logit_diff(logits, answer_token_indices=answer_token_indices):
    logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    return (correct_logits - incorrect_logits).mean()

In [None]:
if model_num == 0:
    clean_logits = model.trace(clean_tokens, trace=False).logits.cpu()
    corrupted_logits = model.trace(corrupted_tokens, trace=False).logits.cpu()

    CLEAN_BASELINE = get_logit_diff(clean_logits, answer_token_indices).item()
    print(f"Clean logit diff: {CLEAN_BASELINE:.4f}")

    CORRUPTED_BASELINE = get_logit_diff(corrupted_logits, answer_token_indices).item()
    print(f"Corrupted logit diff: {CORRUPTED_BASELINE:.4f}")

In [None]:
def ioi_metric(
    logits,
    answer_token_indices=answer_token_indices,
):
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (
        CLEAN_BASELINE - CORRUPTED_BASELINE
    )

In [None]:
if model_num == 0:
    print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
    print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")

In [None]:
if model_num == 0:
    clean_out = []
    corrupted_out = []
    corrupted_grads = []

    with model.trace() as tracer:
        # Using nnsight's tracer.invoke context, we can batch the clean and the
        # corrupted runs into the same tracing context, allowing us to access
        # information generated within each of these runs within one forward pass

        with tracer.invoke(clean_tokens) as invoker_clean:
            # Gather each layer's attention
            for layer in model.transformer.h:
                # Get clean attention output for this layer
                # across all attention heads
                attn_out = layer.attn.c_proj.input
                clean_out.append(attn_out.save())

        with tracer.invoke(corrupted_tokens) as invoker_corrupted:
            # Gather each layer's attention and gradients
            for layer in model.transformer.h:
                # Get corrupted attention output for this layer
                # across all attention heads
                attn_out = layer.attn.c_proj.input
                corrupted_out.append(attn_out.save())
                # save corrupted gradients for attribution patching
                corrupted_grads.append(attn_out.grad.save())

            # Let's get the logits for the model's output
            # for the corrupted run
            logits = model.lm_head.output.save()

            # Our metric uses tensors saved on cpu, so we
            # need to move the logits to cpu.
            value = ioi_metric(logits.cpu())

            # We also need to run a backwards pass to
            # update gradient values
            value.backward()

In [None]:
if model_num == 0:
    patching_results = []

    for corrupted_grad, corrupted, clean, layer in zip(
        corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
    ):

        residual_attr = einops.reduce(
            corrupted_grad.value[:,-1,:] * (clean.value[:,-1,:] - corrupted.value[:,-1,:]),
            "batch (head dim) -> head",
            "sum",
            head = 12,
            dim = 64,
        )

        patching_results.append(
            residual_attr.detach().cpu().numpy()
        )

In [None]:
if model_num == 0:
    fig = px.imshow(
        patching_results,
        color_continuous_scale="RdBu",
        color_continuous_midpoint=0.0,
        title="Attribution Patching Over Attention Heads",
        labels={"x": "Head", "y": "Layer","color":"Norm. Logit Diff"},

    )

    fig.show()

In [None]:
if model_num == 0:
    patching_results = []

    for corrupted_grad, corrupted, clean, layer in zip(
        corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
    ):

        residual_attr = einops.reduce(
            corrupted_grad.value * (clean.value - corrupted.value),
            "batch pos dim -> pos",
            "sum",
        )

        patching_results.append(
            residual_attr.detach().cpu().numpy()
        )

In [None]:
if model_num == 0:
    fig = px.imshow(
        patching_results,
        color_continuous_scale="RdBu",
        color_continuous_midpoint=0.0,
        title="Attribution Patching Over Token Position",
        labels={"x": "Token Position", "y": "Layer","color":"Norm. Logit Diff"},
    )

    fig.show()

# Run on TinySQL model

In [None]:
if model_num > 0:
    model_location = qts.sql_interp_model_location(model_num, cs_num)

    model = LanguageModel(model_location, device_map="auto")
    clear_output()
    print(model)

In [None]:
if model_num > 0:
    prompts = [
        "### Instruction: What do we have for size in profiles? ### Context: CREATE TABLE profiles (size INTEGER, age INTEGER, name TEXT) ### Response: SELECT",
        "### Instruction: What do we have for age in profiles? ### Context: CREATE TABLE profiles (size INTEGER, age INTEGER, name TEXT) ### Response: SELECT",
        "### Instruction: What do we have for name in profiles? ### Context: CREATE TABLE profiles (size INTEGER, age INTEGER, name TEXT) ### Response: SELECT",
    ]

    # Answers are each formatted as (correct, incorrect):
    answers = [
        (" size", " age"),
        (" age", " name"),
        (" name", " size"),
    ]

    # Tokenize clean and corrupted inputs:
    clean_tokens = model.tokenizer(prompts, return_tensors="pt")["input_ids"]
    corrupted_tokens = clean_tokens.clone()
    corrupted_tokens[0,8] = clean_tokens[1,8]
    corrupted_tokens[1,8] = clean_tokens[2,8]
    corrupted_tokens[2,8] = clean_tokens[0,8]

    # Tokenize answers for each prompt:
    answer_token_indices = torch.tensor(
        [
            [model.tokenizer(answers[i][j])["input_ids"][0] for j in range(2)]
            for i in range(len(answers))
        ]
    )

In [None]:
if model_num > 0:
    clean_logits = model.trace(clean_tokens, trace=False).logits.cpu()
    corrupted_logits = model.trace(corrupted_tokens, trace=False).logits.cpu()

    CLEAN_BASELINE = get_logit_diff(clean_logits, answer_token_indices).item()
    print(f"Clean logit diff: {CLEAN_BASELINE:.4f}")

    CORRUPTED_BASELINE = get_logit_diff(corrupted_logits, answer_token_indices).item()
    print(f"Corrupted logit diff: {CORRUPTED_BASELINE:.4f}")

In [None]:
def sql_metric(
    logits,
    answer_token_indices=answer_token_indices,
):
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (
        CLEAN_BASELINE - CORRUPTED_BASELINE
    )

In [None]:
if model_num > 0:
    print(f"Clean Baseline is 1: {sql_metric(clean_logits).item():.4f}")
    print(f"Corrupted Baseline is 0: {sql_metric(corrupted_logits).item():.4f}")

In [None]:
if model_num > 0:
    clean_out = []
    corrupted_out = []
    corrupted_grads = []

    with model.trace() as tracer:
        # Using nnsight's tracer.invoke context, we can batch the clean and the
        # corrupted runs into the same tracing context, allowing us to access
        # information generated within each of these runs within one forward pass

        with tracer.invoke(clean_tokens) as invoker_clean:
            # Gather each layer's attention
            for layer in model.transformer.h:
                # Get clean attention output for this layer
                # across all attention heads
                attn_out = layer.attn.input
                clean_out.append(attn_out.save())

        with tracer.invoke(corrupted_tokens) as invoker_corrupted:
            # Gather each layer's attention and gradients
            for layer in model.transformer.h:
                # Get corrupted attention output for this layer
                # across all attention heads
                attn_out = layer.attn.input
                corrupted_out.append(attn_out.save())
                # save corrupted gradients for attribution patching
                corrupted_grads.append(attn_out.grad.save())

            # Let's get the logits for the model's output
            # for the corrupted run
            logits = model.lm_head.output.save()

            # Our metric uses tensors saved on cpu, so we
            # need to move the logits to cpu.
            value = sql_metric(logits.cpu())

            # We also need to run a backwards pass to
            # update gradient values
            value.backward()

In [None]:
patching_results = []

if model_num > 0:
    for corrupted_grad, corrupted, clean, layer in zip(
        corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
    ):

        residual_attr = einops.reduce(
            corrupted_grad.value[:,-1,:] * (clean.value[:,-1,:] - corrupted.value[:,-1,:]),
            "batch (head dim) -> head",
            "sum",
            head = 12,
            dim = 64,
        )

        patching_results.append(
            residual_attr.detach().cpu().numpy()
        )

In [None]:
if model_num > 0:

    fig = px.imshow(
        patching_results,
        color_continuous_scale="RdBu",
        color_continuous_midpoint=0.0,
        title="Attribution Patching Over Token Position",
        labels={"x": "Token Position", "y": "Layer","color":"Norm. Logit Diff"},

    )

    fig.show()