# Activation Patching on GPT2, M1, M2 and M3 models using nnsight
- Developed on Google Colab using an A100 with 40GB GPU and 80GB system RAM.
  - Succeeds with GPT2/TinyStories. Out of memomory with Qwen/Llama.  
- Requires a GITHUB_TOKEN secret to access Martian quanta_text_to_sql code repository.
- Requires a HF_TOKEN secret to access Martian HuggingFace repository.


# Select model and command set


In [None]:
model_num = 1   # 0=GPT2, 1=TinyStories, 2=Qwen or 3=Llama
cs_num = 3      # 0=BaseModel, 1=CS1, 2=CS2 or 3=CS3

# Import libraries
Imports standard libraries. Do not read.

In [None]:
# https://nnsight.net/
!pip install -U nnsight
from IPython.display import clear_output

In [None]:
from getpass import getpass
from google.colab import userdata

In [None]:
import nnsight
from nnsight import CONFIG

# Needed if running remote trace which uses the NDIF compute platform.
# e.g. llama.trace("The Eiffel Tower is in the city of", remote=True)
# NDIF only supports models meta-llama/meta-llama-3.1-8B and EleutherAI/gpt-j-6B
#CONFIG.set_default_api_key(userdata.get("NDIF_KEY"))

In [None]:
!pip install flash-attn --no-build-isolation
import flash_attn

In [None]:
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "colab"
from nnsight import LanguageModel, util
from nnsight.tracing.Proxy import Proxy

In [None]:
# Set GitHub credentials
#github_token = getpass('Enter your GitHub PAT: ')  # This will prompt for token securely
github_token = userdata.get("GITHUB_TOKEN")

# Install the private repository using the token
#!pip install --upgrade git+https://{github_token}@github.com/withmartian/quanta_text_to_sql.git

# Now you can import the package
#import QuantaTextToSql as qts

# Shared Plotting Function

In [None]:
def plot_patching_results(model,
                              patching_results,
                              x_labels,
                              plot_title="Normalized Logit Difference After Patching Residual Stream"):

    patching_results = util.apply(patching_results, lambda x: x.value.item(), Proxy)
    fig = px.imshow(
        patching_results,
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        labels={"x": "Position", "y": "Layer","color":"Norm. Logit Diff"},
        x=x_labels,
        title=plot_title,
    )

    return fig

# Tutorial on GPT2
Based on https://nnsight.net/notebooks/tutorials/activation_patching/

Runs when model_num == 0


In [None]:
if model_num == 0:
    model = LanguageModel("openai-community/gpt2", device_map="auto")
    clear_output()
    print(model)

In [None]:
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = (
    "After John and Mary went to the store, John gave a bottle of milk to"
)

In [None]:
if model_num == 0:
    correct_index = model.tokenizer(" John")["input_ids"][0] # includes a space
    incorrect_index = model.tokenizer(" Mary")["input_ids"][0] # includes a space

    print(f"' John': {correct_index}")
    print(f"' Mary': {incorrect_index}")

In [None]:
if model_num == 0:
    N_LAYERS = len(model.transformer.h)

    # Enter nnsight tracing context
    with model.trace() as tracer:

        # Clean run
        with tracer.invoke(clean_prompt) as invoker:
            clean_tokens = invoker.inputs[0]['input_ids'][0]

            # No need to call .save() as we don't need the values after the run, just within the experiment run.
            clean_hs = [
                model.transformer.h[layer_idx].output[0]
                for layer_idx in range(N_LAYERS)
            ]

            # Get logits from the lm_head.
            clean_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the clean run and save it.
            clean_logit_diff = (
                clean_logits[0, -1, correct_index] - clean_logits[0, -1, incorrect_index]
            ).save()

        # Corrupted run
        with tracer.invoke(corrupted_prompt) as invoker:
            corrupted_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the corrupted run and save it.
            corrupted_logit_diff = (
                corrupted_logits[0, -1, correct_index]
                - corrupted_logits[0, -1, incorrect_index]
            ).save()

        patching_results = []

        # Iterate through all the layers
        for layer_idx in range(len(model.transformer.h)):
            _patching_results = []

            # Iterate through all tokens
            for token_idx in range(len(clean_tokens)):
                # Patching corrupted run at given layer and token
                with tracer.invoke(corrupted_prompt) as invoker:
                    # Apply the patch from the clean hidden states to the corrupted hidden states.
                    model.transformer.h[layer_idx].output[0][:, token_idx, :] = clean_hs[layer_idx][:, token_idx, :]

                    patched_logits = model.lm_head.output

                    patched_logit_diff = (
                        patched_logits[0, -1, correct_index]
                        - patched_logits[0, -1, incorrect_index]
                    )

                    # Calculate the improvement in the correct token after patching.
                    patched_result = (patched_logit_diff - corrupted_logit_diff) / (
                        clean_logit_diff - corrupted_logit_diff
                    )

                    _patching_results.append(patched_result.save())

            patching_results.append(_patching_results)

In [None]:
if model_num == 0:
    print(f"Clean logit difference: {clean_logit_diff.value:.3f}")
    print(f"Corrupted logit difference: {corrupted_logit_diff.value:.3f}")

    clean_decoded_tokens = [model.tokenizer.decode(token) for token in clean_tokens]
    token_labels = [f"{token}_{index}" for index, token in enumerate(clean_decoded_tokens)]

    fig = plot_patching_results(model, patching_results,token_labels,"Patching GPT-2-small Residual Stream on IOI task")
    fig.show()

# Run on m1, m2 and m3 models

In [None]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load the tokenizer and model. Uses HF_TOKEN for private models
def load_model(model_location, use_flash_attention=True, auth_token=None):
    if auth_token is None:
        auth_token = os.getenv("HF_TOKEN")

    tokenizer = AutoTokenizer.from_pretrained(model_location, token=auth_token)

    if use_flash_attention:
        # qwen model and llama model with flash attention
        # Prerequisite: pip install flash-attn==2.0.2
        # From https://github.com/Dao-AILab/flash-attention
        model = AutoModelForCausalLM.from_pretrained(
            model_location,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            attn_implementation="flash_attention_2",
        )
    else:
        # model without flash attention
        model = AutoModelForCausalLM.from_pretrained(
                model_location,
                torch_dtype=torch.float32,
                device_map="auto",
            )

    return tokenizer, model

In [None]:
# Load the tokenizer and trained model for model 1, 2, or 3 and command set 0 (base model), 1, 2, or 3
def sql_interp_model_location( model_num : int, cs_num : int):
    if model_num == 1:
        if cs_num == 0:
            return "roneneldan/TinyStories-Instruct-2Layers-33M"

        elif cs_num == 1:
            return "withmartian/sql_interp_bm1_cs1_experiment_1.1"

        elif cs_num == 2:
            return "withmartian/sql_interp_bm1_cs2_experiment_2.3"

        elif cs_num == 3:
            return "withmartian/sql_interp_bm1_cs3_experiment_3.3"

    if model_num == 2:
        if cs_num == 0:
            return "Qwen/Qwen2.5-0.5B-Instruct"

        elif cs_num == 1:
            return "withmartian/sql_interp_bm2_cs1_experiment_4.1"

        elif cs_num == 2:
            return "withmartian/sql_interp_bm2_cs2_experiment_5.1"

        elif cs_num == 3:
            return "withmartian/sql_interp_bm2_cs3_experiment_6.1"

    if model_num == 3:
        if cs_num == 0:
            return "meta-llama/Llama-3.2-1B-Instruct"

        elif cs_num == 1:
            return "withmartian/sql_interp_bm3_cs1_experiment_7.1"

        elif cs_num == 2:
            return "withmartian/sql_interp_bm3_cs2_experiment_8.1"

        elif cs_num == 3:
            return "withmartian/sql_interp_bm3_cs3_experiment_9.1"

    return ""

In [None]:
def load_sql_interp_model( model_num : int, cs_num : int, auth_token=None):
    model_location = sql_interp_model_location(model_num, cs_num)

    use_flash_attention = model_num == 2 or model_num == 3
    tokenizer, model = load_model(model_location, use_flash_attention=use_flash_attention, auth_token=auth_token)

    if model_num == 1:
        tokenizer.padding_side = "left"
        tokenizer.add_special_tokens({'pad_token': '<|pad|>'})

        model.resize_token_embeddings(len(tokenizer), mean_resizing=False)

        model.config.pad_token_id = tokenizer.pad_token_id

        model.resize_token_embeddings(len(tokenizer))

    return tokenizer, model

In [None]:
if model_num > 0:

    if model_num == 1:
        model_location = sql_interp_model_location(model_num, cs_num)

        the_tokenizer, the_model = load_sql_interp_model(model_num, cs_num)

        model = LanguageModel(the_model, the_tokenizer)
        model.tokenizer = the_tokenizer
    else:

      #if model_num == 3 and cs_num == 0:
        # Access to https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct is restricted. Apply token.
        #os.environ['HUGGINGFACE_HUB_TOKEN'] = userdata.get("HF_HUB_TOKEN")

      model = LanguageModel(sql_interp_model_location(model_num,cs_num), device_map="auto")

    clear_output()
    print(model)

In [None]:
if model_num > 0:
    clean_prompt = "### Instruction: What do we have for size in profiles? ### Context: CREATE TABLE profiles (size INTEGER, age INTEGER, name TEXT) ### Response: SELECT"
    corrupted_prompt = "### Instruction: What do we have for elephants in profiles? ### Context: CREATE TABLE profiles (size INTEGER, age INTEGER, name TEXT) ### Response: SELECT"

In [None]:
if model_num > 0:
    correct_index = model.tokenizer(" size")["input_ids"][0] # includes a space
    incorrect_index = model.tokenizer(" elephants")["input_ids"][0] # includes a space

    print(f"' size': {correct_index}")
    print(f"' elephants': {incorrect_index}")

In [None]:
#torch.cuda.empty_cache()
#print(torch.cuda.memory_summary())

In [None]:
if model_num == 1: # TinyStories
    N_LAYERS = len(model.transformer.h)

    # Enter nnsight tracing context
    with model.trace() as tracer:

        # Clean run
        with tracer.invoke(clean_prompt) as invoker:
            clean_tokens = invoker.inputs[0]['input_ids'][0]

            # No need to call .save() as we don't need the values after the run, just within the experiment run.
            clean_hs = [
                model.transformer.h[layer_idx].output[0]
                for layer_idx in range(N_LAYERS)
            ]

            # Get logits from the lm_head.
            clean_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the clean run and save it.
            clean_logit_diff = (
                clean_logits[0, -1, correct_index] - clean_logits[0, -1, incorrect_index]
            ).save()

        # Corrupted run
        with tracer.invoke(corrupted_prompt) as invoker:
            corrupted_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the corrupted run and save it.
            corrupted_logit_diff = (
                corrupted_logits[0, -1, correct_index]
                - corrupted_logits[0, -1, incorrect_index]
            ).save()

        patching_results = []

        # Iterate through all the layers
        for layer_idx in range(len(model.transformer.h)):
            _patching_results = []

            # Iterate through all tokens
            for token_idx in range(len(clean_tokens)):
                # Patching corrupted run at given layer and token
                with tracer.invoke(corrupted_prompt) as invoker:
                    # Apply the patch from the clean hidden states to the corrupted hidden states.
                    model.transformer.h[layer_idx].output[0][:, token_idx, :] = clean_hs[layer_idx][:, token_idx, :]

                    patched_logits = model.lm_head.output

                    patched_logit_diff = (
                        patched_logits[0, -1, correct_index]
                        - patched_logits[0, -1, incorrect_index]
                    )

                    # Calculate the improvement in the correct token after patching.
                    patched_result = (patched_logit_diff - corrupted_logit_diff) / (
                        clean_logit_diff - corrupted_logit_diff
                    )

                    _patching_results.append(patched_result.save())

            patching_results.append(_patching_results)

In [None]:
if model_num == 2: # Qwen
    N_LAYERS = len(model.model.layers)

    # with model.scan(input):
    # with model.trace(input, scan=True, validate=True) as tracer:
    # with model.generate("The Eiffel Tower is in the city of", max_new_tokens=3):  and then use model.lm_head.next()
    # with llm.edit() as llm_edited:
    # with model.trace(remote=True) as tracer:

    # Enter nnsight tracing context
    with model.trace() as tracer:

        # Clean run
        with tracer.invoke(clean_prompt) as invoker:
            clean_tokens = invoker.inputs[0]['input_ids'][0]

            # No need to call .save() as we don't need the values after the run, just within the experiment run.
            clean_hs = [
                model.model.layers[layer_idx].output[0]
                for layer_idx in range(N_LAYERS)
            ]

            # Get logits from the lm_head.
            clean_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the clean run and save it.
            clean_logit_diff = (
                clean_logits[0, -1, correct_index] - clean_logits[0, -1, incorrect_index]
            ).save()
            #tracer.log("Random debug statement")

        # Corrupted run
        with tracer.invoke(corrupted_prompt) as invoker:
            corrupted_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the corrupted run and save it.
            corrupted_logit_diff = (
                corrupted_logits[0, -1, correct_index]
                - corrupted_logits[0, -1, incorrect_index]
            ).save()

        patching_results = []

        # Iterate through all the layers
        for layer_idx in range(len(model.model.layers)):
            _patching_results = []

            # Iterate through all tokens
            for token_idx in range(len(clean_tokens)):
                # Patching corrupted run at given layer and token
                with tracer.invoke(corrupted_prompt) as invoker:
                    # Apply the patch from the clean hidden states to the corrupted hidden states.
                    model.model.layers[layer_idx].output[0][:, token_idx, :] = clean_hs[layer_idx][:, token_idx, :]

                    patched_logits = model.lm_head.output

                    patched_logit_diff = (
                        patched_logits[0, -1, correct_index]
                        - patched_logits[0, -1, incorrect_index]
                    )

                    # Calculate the improvement in the correct token after patching.
                    patched_result = (patched_logit_diff - corrupted_logit_diff) / (
                        clean_logit_diff - corrupted_logit_diff
                    )

                    _patching_results.append(patched_result.save())

            patching_results.append(_patching_results)

In [None]:
if model_num == 3:
    N_LAYERS = len(model.model.layers)

    # Enter nnsight tracing context
    with model.trace() as tracer:

        # Clean run
        with tracer.invoke(clean_prompt) as invoker:
            clean_tokens = invoker.inputs[0]['input_ids'][0]

            # No need to call .save() as we don't need the values after the run, just within the experiment run.
            clean_hs = [
                model.model.layers[layer_idx].output[0]
                for layer_idx in range(N_LAYERS)
            ]

            # Get logits from the lm_head.
            clean_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the clean run and save it.
            clean_logit_diff = (
                clean_logits[0, -1, correct_index] - clean_logits[0, -1, incorrect_index]
            ).save()

        # Corrupted run
        with tracer.invoke(corrupted_prompt) as invoker:
            corrupted_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the corrupted run and save it.
            corrupted_logit_diff = (
                corrupted_logits[0, -1, correct_index]
                - corrupted_logits[0, -1, incorrect_index]
            ).save()

        patching_results = []

        # Iterate through all the layers
        for layer_idx in range(len(model.model.layers)):
            _patching_results = []

            # Iterate through all tokens
            for token_idx in range(len(clean_tokens)):
                # Patching corrupted run at given layer and token
                with tracer.invoke(corrupted_prompt) as invoker:
                    # Apply the patch from the clean hidden states to the corrupted hidden states.
                    model.model.layers[layer_idx].output[0][:, token_idx, :] = clean_hs[layer_idx][:, token_idx, :]

                    patched_logits = model.lm_head.output

                    patched_logit_diff = (
                        patched_logits[0, -1, correct_index]
                        - patched_logits[0, -1, incorrect_index]
                    )

                    # Calculate the improvement in the correct token after patching.
                    patched_result = (patched_logit_diff - corrupted_logit_diff) / (
                        clean_logit_diff - corrupted_logit_diff
                    )

                    _patching_results.append(patched_result.save())

            patching_results.append(_patching_results)

In [None]:
if model_num > 0:
    print(f"Clean logit difference: {clean_logit_diff.value:.3f}")
    print(f"Corrupted logit difference: {corrupted_logit_diff.value:.3f}")

    clean_decoded_tokens = [model.tokenizer.decode(token) for token in clean_tokens]
    token_labels = [f"{token}_{index}" for index, token in enumerate(clean_decoded_tokens)]

    fig = plot_patching_results(model, patching_results,token_labels, "Patching Residual Stream")
    fig.show()