# Activation Patching on GPT2, M1, M2 and M3 models using nnsight
- Developed on Google Colab using an A100 with 40GB GPU and 80GB system RAM.
- Requires a GITHUB_TOKEN secret to access Martian repository.
- Qwen runs out of memory
- Llama untested

In [21]:
model_num = 2   # 0=GPT2, 1=TinyStories, 2=Qwen or 3=Llama
cs_num = 1      # 0=BaseModel, 1=CS1, 2=CS2 or 3=CS3

# Part 0: Import libraries
Imports standard libraries. Do not read.

In [22]:
!pip install -U nnsight
from IPython.display import clear_output



In [23]:
import nnsight
from nnsight import CONFIG

In [24]:
import os
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "colab"
from nnsight import LanguageModel, util
from nnsight.tracing.Proxy import Proxy

In [None]:
!pip install --upgrade git+https://github.com/PhilipQuirke/quanta_mech_interp.git
import QuantaMechInterp as qmi

Collecting git+https://github.com/PhilipQuirke/quanta_mech_interp.git
  Cloning https://github.com/PhilipQuirke/quanta_mech_interp.git to /tmp/pip-req-build-h17fa4mx
  Running command git clone --filter=blob:none --quiet https://github.com/PhilipQuirke/quanta_mech_interp.git /tmp/pip-req-build-h17fa4mx
  Resolved https://github.com/PhilipQuirke/quanta_mech_interp.git to commit 24dccfc92b6978f7f186a0e4bfe189525b745457


In [None]:
from getpass import getpass
from google.colab import userdata

In [None]:
# Set GitHub credentials
#github_token = getpass('Enter your GitHub PAT: ')  # This will prompt for token securely
github_token = userdata.get("GITHUB_TOKEN")

# Install the private repository using the token
!pip install --upgrade git+https://{github_token}@github.com/withmartian/quanta_text_to_sql.git

# Now you can import the package
import QuantaTextToSql as qts

# Shared Plotting Function

In [None]:
def plot_patching_results(model,
                              patching_results,
                              x_labels,
                              plot_title="Normalized Logit Difference After Patching Residual Stream"):

    patching_results = util.apply(patching_results, lambda x: x.value.item(), Proxy)
    fig = px.imshow(
        patching_results,
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        labels={"x": "Position", "y": "Layer","color":"Norm. Logit Diff"},
        x=x_labels,
        title=plot_title,
    )

    return fig

# Tutorial on GPT2
Based on https://nnsight.net/notebooks/tutorials/activation_patching/

Runs when model_num == 0


In [None]:
if model_num == 0:
    model = LanguageModel("openai-community/gpt2", device_map="auto")
    clear_output()
    print(model)

In [None]:
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = (
    "After John and Mary went to the store, John gave a bottle of milk to"
)

In [None]:
if model_num == 0:
    correct_index = model.tokenizer(" John")["input_ids"][0] # includes a space
    incorrect_index = model.tokenizer(" Mary")["input_ids"][0] # includes a space

    print(f"' John': {correct_index}")
    print(f"' Mary': {incorrect_index}")

In [None]:
if model_num == 0:
    N_LAYERS = len(model.transformer.h)

    # Enter nnsight tracing context
    with model.trace() as tracer:

        # Clean run
        with tracer.invoke(clean_prompt) as invoker:
            clean_tokens = invoker.inputs[0]['input_ids'][0]

            # No need to call .save() as we don't need the values after the run, just within the experiment run.
            clean_hs = [
                model.transformer.h[layer_idx].output[0]
                for layer_idx in range(N_LAYERS)
            ]

            # Get logits from the lm_head.
            clean_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the clean run and save it.
            clean_logit_diff = (
                clean_logits[0, -1, correct_index] - clean_logits[0, -1, incorrect_index]
            ).save()

        # Corrupted run
        with tracer.invoke(corrupted_prompt) as invoker:
            corrupted_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the corrupted run and save it.
            corrupted_logit_diff = (
                corrupted_logits[0, -1, correct_index]
                - corrupted_logits[0, -1, incorrect_index]
            ).save()

        patching_results = []

        # Iterate through all the layers
        for layer_idx in range(len(model.transformer.h)):
            _patching_results = []

            # Iterate through all tokens
            for token_idx in range(len(clean_tokens)):
                # Patching corrupted run at given layer and token
                with tracer.invoke(corrupted_prompt) as invoker:
                    # Apply the patch from the clean hidden states to the corrupted hidden states.
                    model.transformer.h[layer_idx].output[0][:, token_idx, :] = clean_hs[layer_idx][:, token_idx, :]

                    patched_logits = model.lm_head.output

                    patched_logit_diff = (
                        patched_logits[0, -1, correct_index]
                        - patched_logits[0, -1, incorrect_index]
                    )

                    # Calculate the improvement in the correct token after patching.
                    patched_result = (patched_logit_diff - corrupted_logit_diff) / (
                        clean_logit_diff - corrupted_logit_diff
                    )

                    _patching_results.append(patched_result.save())

            patching_results.append(_patching_results)

In [None]:
if model_num == 0:
    print(f"Clean logit difference: {clean_logit_diff.value:.3f}")
    print(f"Corrupted logit difference: {corrupted_logit_diff.value:.3f}")

    clean_decoded_tokens = [model.tokenizer.decode(token) for token in clean_tokens]
    token_labels = [f"{token}_{index}" for index, token in enumerate(clean_decoded_tokens)]

    fig = plot_patching_results(model, patching_results,token_labels,"Patching GPT-2-small Residual Stream on IOI task")
    fig.show()

# Run on TinySQL model

In [None]:
if model_num > 0:

    model_location = qts.sql_interp_model_location(model_num, cs_num)

    model = LanguageModel(model_location, device_map="auto")
    clear_output()

    print(model)

In [None]:
if model_num > 0:
    clean_prompt = "### Instruction: What do we have for size in profiles? ### Context: CREATE TABLE profiles (size INTEGER, age INTEGER, name TEXT) ### Response: SELECT"
    corrupted_prompt = "### Instruction: What do we have for elephants in profiles? ### Context: CREATE TABLE profiles (size INTEGER, age INTEGER, name TEXT) ### Response: SELECT"

In [None]:
if model_num > 0:
    correct_index = model.tokenizer(" size")["input_ids"][0] # includes a space
    incorrect_index = model.tokenizer(" elephants")["input_ids"][0] # includes a space

    print(f"' size': {correct_index}")
    print(f"' elephants': {incorrect_index}")

In [None]:
if model_num == 1:
    N_LAYERS = len(model.transformer.h)

    # Enter nnsight tracing context
    with model.trace() as tracer:

        # Clean run
        with tracer.invoke(clean_prompt) as invoker:
            clean_tokens = invoker.inputs[0]['input_ids'][0]

            # No need to call .save() as we don't need the values after the run, just within the experiment run.
            clean_hs = [
                model.transformer.h[layer_idx].output[0]
                for layer_idx in range(N_LAYERS)
            ]

            # Get logits from the lm_head.
            clean_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the clean run and save it.
            clean_logit_diff = (
                clean_logits[0, -1, correct_index] - clean_logits[0, -1, incorrect_index]
            ).save()

        # Corrupted run
        with tracer.invoke(corrupted_prompt) as invoker:
            corrupted_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the corrupted run and save it.
            corrupted_logit_diff = (
                corrupted_logits[0, -1, correct_index]
                - corrupted_logits[0, -1, incorrect_index]
            ).save()

        patching_results = []

        # Iterate through all the layers
        for layer_idx in range(len(model.transformer.h)):
            _patching_results = []

            # Iterate through all tokens
            for token_idx in range(len(clean_tokens)):
                # Patching corrupted run at given layer and token
                with tracer.invoke(corrupted_prompt) as invoker:
                    # Apply the patch from the clean hidden states to the corrupted hidden states.
                    model.transformer.h[layer_idx].output[0][:, token_idx, :] = clean_hs[layer_idx][:, token_idx, :]

                    patched_logits = model.lm_head.output

                    patched_logit_diff = (
                        patched_logits[0, -1, correct_index]
                        - patched_logits[0, -1, incorrect_index]
                    )

                    # Calculate the improvement in the correct token after patching.
                    patched_result = (patched_logit_diff - corrupted_logit_diff) / (
                        clean_logit_diff - corrupted_logit_diff
                    )

                    _patching_results.append(patched_result.save())

            patching_results.append(_patching_results)

In [None]:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

In [None]:
if model_num == 2:
    N_LAYERS = len(model.model.layers)

    # Enter nnsight tracing context
    with model.trace() as tracer:

        # Clean run
        with tracer.invoke(clean_prompt) as invoker:
            clean_tokens = invoker.inputs[0]['input_ids'][0]

            # No need to call .save() as we don't need the values after the run, just within the experiment run.
            clean_hs = [
                model.model.layers[layer_idx].output[0]
                for layer_idx in range(N_LAYERS)
            ]

            # Get logits from the lm_head.
            clean_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the clean run and save it.
            clean_logit_diff = (
                clean_logits[0, -1, correct_index] - clean_logits[0, -1, incorrect_index]
            ).save()

        # Corrupted run
        with tracer.invoke(corrupted_prompt) as invoker:
            corrupted_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the corrupted run and save it.
            corrupted_logit_diff = (
                corrupted_logits[0, -1, correct_index]
                - corrupted_logits[0, -1, incorrect_index]
            ).save()

        patching_results = []

        # Iterate through all the layers
        for layer_idx in range(len(model.model.layers)):
            _patching_results = []

            # Iterate through all tokens
            for token_idx in range(len(clean_tokens)):
                # Patching corrupted run at given layer and token
                with tracer.invoke(corrupted_prompt) as invoker:
                    # Apply the patch from the clean hidden states to the corrupted hidden states.
                    model.model.layers[layer_idx].output[0][:, token_idx, :] = clean_hs[layer_idx][:, token_idx, :]

                    patched_logits = model.lm_head.output

                    patched_logit_diff = (
                        patched_logits[0, -1, correct_index]
                        - patched_logits[0, -1, incorrect_index]
                    )

                    # Calculate the improvement in the correct token after patching.
                    patched_result = (patched_logit_diff - corrupted_logit_diff) / (
                        clean_logit_diff - corrupted_logit_diff
                    )

                    _patching_results.append(patched_result.save())

            patching_results.append(_patching_results)

In [None]:
if model_num == 3:
    N_LAYERS = len(model.model.layers)

    # Enter nnsight tracing context
    with model.trace() as tracer:

        # Clean run
        with tracer.invoke(clean_prompt) as invoker:
            clean_tokens = invoker.inputs[0]['input_ids'][0]

            # No need to call .save() as we don't need the values after the run, just within the experiment run.
            clean_hs = [
                model.model.layers[layer_idx].output[0]
                for layer_idx in range(N_LAYERS)
            ]

            # Get logits from the lm_head.
            clean_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the clean run and save it.
            clean_logit_diff = (
                clean_logits[0, -1, correct_index] - clean_logits[0, -1, incorrect_index]
            ).save()

        # Corrupted run
        with tracer.invoke(corrupted_prompt) as invoker:
            corrupted_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the corrupted run and save it.
            corrupted_logit_diff = (
                corrupted_logits[0, -1, correct_index]
                - corrupted_logits[0, -1, incorrect_index]
            ).save()

        patching_results = []

        # Iterate through all the layers
        for layer_idx in range(len(model.model.layers)):
            _patching_results = []

            # Iterate through all tokens
            for token_idx in range(len(clean_tokens)):
                # Patching corrupted run at given layer and token
                with tracer.invoke(corrupted_prompt) as invoker:
                    # Apply the patch from the clean hidden states to the corrupted hidden states.
                    model.model.layers[layer_idx].output[0][:, token_idx, :] = clean_hs[layer_idx][:, token_idx, :]

                    patched_logits = model.lm_head.output

                    patched_logit_diff = (
                        patched_logits[0, -1, correct_index]
                        - patched_logits[0, -1, incorrect_index]
                    )

                    # Calculate the improvement in the correct token after patching.
                    patched_result = (patched_logit_diff - corrupted_logit_diff) / (
                        clean_logit_diff - corrupted_logit_diff
                    )

                    _patching_results.append(patched_result.save())

            patching_results.append(_patching_results)

In [None]:
if model_num > 0:
    print(f"Clean logit difference: {clean_logit_diff.value:.3f}")
    print(f"Corrupted logit difference: {corrupted_logit_diff.value:.3f}")

    clean_decoded_tokens = [model.tokenizer.decode(token) for token in clean_tokens]
    token_labels = [f"{token}_{index}" for index, token in enumerate(clean_decoded_tokens)]

    fig = plot_patching_results(model, patching_results,token_labels, "Patching Residual Stream")
    fig.show()