# Activation Patching on GPT2, M1, M2 and M3 models using nnsight
- Developed on Google Colab using an A100 with 40GB GPU and 80GB system RAM.
- Runs with GPT2/TinyStories/Qwen/Llama with base/CS1/CS2/CS3.  
- Requires a GITHUB_TOKEN secret to access Martian quanta_text_to_sql code repository.
- Requires a HF_TOKEN secret to access Martian HuggingFace repository.


# Select model and command set


In [28]:
model_num = 1   # 0=GPT2, 1=TinyStories, 2=Qwen or 3=Llama
cs_num = 2      # 0=BaseModel, 1=CS1, 2=CS2 or 3=CS3

# Import libraries
Imports standard libraries. Do not read.

In [None]:
# https://nnsight.net/
!pip install -U nnsight



In [None]:
from IPython.display import clear_output
import einops
import torch
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "colab"

import nnsight
from nnsight import LanguageModel, util
from nnsight.tracing.Proxy import Proxy

In [None]:
from getpass import getpass
from google.colab import userdata
import gc
import weakref

In [None]:
github_token = userdata.get("GITHUB_TOKEN")

# Install the private repository using the token
!pip install --upgrade git+https://{github_token}@github.com/withmartian/quanta_text_to_sql.git

import QuantaTextToSql as qts

In [None]:
clean_tokens = []
patching_results = []

# nnsight tutorial using GPT2
Based on https://nnsight.net/notebooks/tutorials/activation_patching/



In [None]:
if model_num == 0:
    model = LanguageModel("openai-community/gpt2", device_map="auto")
    clear_output()
    print(model)

In [None]:
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = (
    "After John and Mary went to the store, John gave a bottle of milk to"
)

In [None]:
if model_num == 0:
    correct_index = model.tokenizer(" John")["input_ids"][0] # includes a space
    incorrect_index = model.tokenizer(" Mary")["input_ids"][0] # includes a space

    print(f"' John': {correct_index}")
    print(f"' Mary': {incorrect_index}")

In [None]:
if model_num == 0:
    N_LAYERS = len(model.transformer.h)

    # Enter nnsight tracing context
    with model.trace() as tracer:

        # Clean run
        with tracer.invoke(clean_prompt) as invoker:
            clean_tokens = invoker.inputs[0]['input_ids'][0]

            # No need to call .save() as we don't need the values after the run, just within the experiment run.
            clean_hs = [
                model.transformer.h[layer_idx].output[0]
                for layer_idx in range(N_LAYERS)
            ]

            # Get logits from the lm_head.
            clean_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the clean run and save it.
            clean_logit_diff = (
                clean_logits[0, -1, correct_index] - clean_logits[0, -1, incorrect_index]
            ).save()

        # Corrupted run
        with tracer.invoke(corrupted_prompt) as invoker:
            corrupted_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the corrupted run and save it.
            corrupted_logit_diff = (
                corrupted_logits[0, -1, correct_index]
                - corrupted_logits[0, -1, incorrect_index]
            ).save()

        # Iterate through all the layers
        for layer_idx in range(len(model.transformer.h)):
            layer_results = []

            # Iterate through all tokens
            for token_idx in range(len(clean_tokens)):
                # Patching corrupted run at given layer and token
                with tracer.invoke(corrupted_prompt) as invoker:
                    # Apply the patch from the clean hidden states to the corrupted hidden states.
                    model.transformer.h[layer_idx].output[0][:, token_idx, :] = clean_hs[layer_idx][:, token_idx, :]

                    patched_logits = model.lm_head.output

                    patched_logit_diff = (
                        patched_logits[0, -1, correct_index]
                        - patched_logits[0, -1, incorrect_index]
                    )

                    # Calculate the improvement in the correct token after patching.
                    patched_result = (patched_logit_diff - corrupted_logit_diff) / (
                        clean_logit_diff - corrupted_logit_diff
                    )

                    layer_results.append(patched_result.item().save())

            patching_results.append(layer_results)

Results printed by shared code at end of notebook

# Run on m1, m2 and m3 models

In [None]:
if model_num > 0:

    if model_num == 1:
        the_tokenizer, the_model = qts.load_sql_interp_model(model_num, cs_num, auth_token=userdata.get("HF_TOKEN"))

        model = LanguageModel(the_model, the_tokenizer)
        model.tokenizer = the_tokenizer
    else:
        model = LanguageModel(qts.sql_interp_model_location(model_num, cs_num), device_map="auto")

    clear_output()
    print(model)

In [None]:
if model_num > 0:
    clean_prompt = "### Instruction: What do we have for size in profiles? ### Context: CREATE TABLE profiles (size INTEGER, age INTEGER, name TEXT) ### Response: SELECT"
    corrupted_prompt = "### Instruction: What do we have for elephants in profiles? ### Context: CREATE TABLE profiles (size INTEGER, age INTEGER, name TEXT) ### Response: SELECT"

In [None]:
if model_num > 0:

    # Llama tokenizes " size" as [128000, 1404] where 128000 is the '<|begin_of_text|>' symbol
    # print(model.tokenizer.convert_ids_to_tokens([128000]))  # Check what `128000` maps to
    # print(model.tokenizer.special_tokens_map)  # Ch
    answer_offset = 1 if model_num == 3 else 0

    correct_index = model.tokenizer(" size")["input_ids"][answer_offset] # includes a space
    incorrect_index = model.tokenizer(" elephants")["input_ids"][answer_offset] # includes a space

    print(f"' size': {correct_index}")
    print(f"' elephants': {incorrect_index}")

In [None]:
def free_memory():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

In [None]:
# A list may contain 'weak references' to objects that are garbage collected by free_memory.
# This function replaces weak with strong references. Call it before free_memory.
def replace_weak_references(obj):
    if isinstance(obj, list):
        return [replace_weak_references(item) for item in obj]
    elif hasattr(obj, 'value'):  # For objects with a 'value' attribute
        return obj.value
    elif hasattr(obj, 'get_value'):  # For objects with a 'get_value()' method
        return obj.get_value()
    elif isinstance(obj, torch.Tensor):
        return obj.clone().detach()
    elif hasattr(obj, 'item'):  # For objects with a 'item()' method
        return obj.item()
    else:
        return obj

In [None]:
def run_tinystories_patching():
    run_results = []

    N_LAYERS = len(model.transformer.h)

    # Enter nnsight tracing context
    with model.trace() as tracer:

        # Clean run
        with tracer.invoke(clean_prompt) as invoker:
            clean_tokens = invoker.inputs[0]['input_ids'][0]

            # No need to call .save() as we don't need the values after the run, just within the experiment run.
            clean_hs = [
                model.transformer.h[layer_idx].output[0]
                for layer_idx in range(N_LAYERS)
            ]

            # Get logits from the lm_head.
            clean_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the clean run and save it.
            clean_logit_diff = (
                clean_logits[0, -1, correct_index] - clean_logits[0, -1, incorrect_index]
            ).save()

        # Corrupted run
        with tracer.invoke(corrupted_prompt) as invoker:
            corrupted_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the corrupted run and save it.
            corrupted_logit_diff = (
                corrupted_logits[0, -1, correct_index]
                - corrupted_logits[0, -1, incorrect_index]
            ).save()

        # Iterate through all the layers
        for layer_idx in range(len(model.transformer.h)):
            layer_results = []

            # Iterate through all tokens
            for token_idx in range(len(clean_tokens)):
                # Patching corrupted run at given layer and token
                with tracer.invoke(corrupted_prompt) as invoker:
                    # Apply the patch from the clean hidden states to the corrupted hidden states.
                    model.transformer.h[layer_idx].output[0][:, token_idx, :] = clean_hs[layer_idx][:, token_idx, :]

                    patched_logits = model.lm_head.output

                    patched_logit_diff = (
                        patched_logits[0, -1, correct_index]
                        - patched_logits[0, -1, incorrect_index]
                    )

                    # Calculate the improvement in the correct token after patching.
                    patched_result = (patched_logit_diff - corrupted_logit_diff) / (
                        clean_logit_diff - corrupted_logit_diff
                    )

                    layer_results.append(patched_result.item().save())

            run_results.append(layer_results)

    clean_tokens = replace_weak_references(clean_tokens)
    clean_logit_diff = replace_weak_references(clean_logit_diff)
    corrupted_logit_diff = replace_weak_references(corrupted_logit_diff)
    run_results = replace_weak_references(run_results)
    free_memory() # Free up GPU and CPU memory

    return clean_tokens, clean_logit_diff, corrupted_logit_diff, run_results


if model_num == 1: # TinyStories
    clean_tokens, clean_logit_diff, corrupted_logit_diff, patching_results = run_tinystories_patching()

In [None]:
def run_llm_patching():
    run_results = []

    N_LAYERS = len(model.model.layers)

    # with model.scan(input):
    # with model.trace(input, scan=True, validate=True) as tracer:
    # with model.generate("The Eiffel Tower is in the city of", max_new_tokens=3):  and then use model.lm_head.next()
    # with llm.edit() as llm_edited:
    # with model.trace(remote=True) as tracer:


    # Clean run
    with model.trace(clean_prompt) as tracer:
        clean_tokens = tracer.invoker.inputs[0]['input_ids'][0]

        # Get hidden states of all layers in the network.
        # We index the output at 0 because it's a tuple where the first index is the hidden state.
        clean_hs = [
            model.model.layers[layer_idx].output[0].save()
            for layer_idx in range(N_LAYERS)
        ]

        # Get logits from the lm_head.
        clean_logits = model.lm_head.output

        # Calculate the difference between the correct answer and incorrect answer for the clean run and save it.
        clean_logit_diff = (
            clean_logits[0, -1, correct_index] - clean_logits[0, -1, incorrect_index]
        ).save()
    clean_tokens = replace_weak_references(clean_tokens)
    clean_logit_diff = replace_weak_references(clean_logit_diff)


    # Corrupted run
    with model.trace(corrupted_prompt) as tracer:
        corrupted_tokens = tracer.invoker.inputs[0]['input_ids'][0]

        corrupted_logits = model.lm_head.output

        # Calculate the difference between the correct answer and incorrect answer for the corrupted run and save it.
        corrupted_logit_diff = (
            corrupted_logits[0, -1, correct_index]
            - corrupted_logits[0, -1, incorrect_index]
        ).save()
    corrupted_logit_diff = replace_weak_references(corrupted_logit_diff)


    for layer_idx in range(N_LAYERS):
        layer_results = []

        for token_idx in range(len(clean_tokens)):
            # Patching corrupted run at given layer and token
            with model.trace(corrupted_prompt) as tracer:
                # Apply the patch from the clean hidden states to the corrupted hidden states.
                model.model.layers[layer_idx].output[0][:, token_idx, :] = clean_hs[layer_idx][:, token_idx, :]

                patched_logits = model.lm_head.output

                patched_logit_diff = (
                    patched_logits[0, -1, correct_index]
                    - patched_logits[0, -1, incorrect_index]
                )

                # Calculate the improvement in the correct token after patching.
                patching_result = (patched_logit_diff - corrupted_logit_diff) / (
                    clean_logit_diff - corrupted_logit_diff
                )

                # Convert from large structure to int before saving to decrease memory usage
                one_result = patching_result.item().save()

            layer_results.append(one_result)

        run_results.append(layer_results)
        run_results = replace_weak_references(run_results)
        free_memory() # Free up GPU and CPU memory

    return clean_tokens, clean_logit_diff, corrupted_logit_diff, run_results


if model_num == 2 or model_num == 3: # Qwen or Llama
    clean_tokens, clean_logit_diff, corrupted_logit_diff, patching_results = run_llm_patching()

# Graph results

In [None]:
print(f"Clean logit difference: {clean_logit_diff:.3f}")
print(f"Corrupted logit difference: {corrupted_logit_diff:.3f}")

clean_decoded_tokens = [model.tokenizer.decode(token) for token in clean_tokens]
token_labels = [f"{token}_{index}" for index, token in enumerate(clean_decoded_tokens)]

fig = px.imshow(
    patching_results,
    color_continuous_midpoint=0.0,
    color_continuous_scale="RdBu",
    labels={"x": "Position", "y": "Layer","color":"Norm. Logit Diff"},
    x=token_labels,
    title="Patching Residual Stream")
fig.show()