# Activation Patching on M0 .. M5 models using nnsight
- Developed on Google Colab using an A100 with 40GB GPU and 80GB system RAM.
- Runs with GPT2/TinyStories/Qwen/Llama/Granite/SmolLM with base/CS1/CS2/CS3.
- Runs CS1 feature tests (as defined in corrupt_data/clean_corrupt_data.py)    
- Requires a GITHUB_TOKEN secret to access Martian TinySQL code repository.
- Requires a HF_TOKEN secret to access Martian HuggingFace repository.


# Import libraries
Imports standard libraries. Do not read.

In [None]:
# https://nnsight.net/
# Access 0.4 prerelease version (as at Dec 2024)
#!pip install nnsight==0.4.0.dev0
!pip install -U nnsight

In [None]:
from IPython.display import clear_output
import einops
import torch
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "colab"

import nnsight
from nnsight import LanguageModel, util

In [None]:
from getpass import getpass
from google.colab import userdata
import gc
import weakref

In [None]:
github_token = userdata.get("GITHUB_TOKEN")

# Install the private repository using the token
!pip install --upgrade git+https://{github_token}@github.com/withmartian/TinySQL.git

import TinySQL as qts

In [None]:
clean_tokens = []
patching_results = []

# Select model, command set and feature to investigate


In [None]:
model_num = 1                     # 0=GPT2, 1=TinyStories, 2=Qwen, 3=Llama, 4=Granite, 5=SmolLM
cs_num = 1                        # 0=BaseModel, 1=CS1, 2=CS2 or 3=CS3
feature_name = qts.DEFFIELDNAME   # ENGTABLENAME, ENGFIELDNAME, DEFTABLESTART, DEFTABLENAME, DEFFIELDNAME, DEFFIELDSEPARATOR
use_novel_names = False           # If True, we corrupt using words not found in the clean prompt or create sql e.g. "little" or "hammer"

In [None]:
# Key global "input" variables
clean_prompt = ""
corrupt_prompt = ""
clean_tokenizer_index = qts.UNKNOWN_VALUE # Tokenizer vocab index for clean word
corrupt_tokenizer_index = qts.UNKNOWN_VALUE # Tokenizer vocab index for corrupted word
answer_token_index = qts.UNKNOWN_VALUE # Token index in sql command answer of clean/corrupt word

# Key global "results" variables
clean_logit_diff = qts.UNKNOWN_VALUE
corrupt_logit_diff = qts.UNKNOWN_VALUE

# Run m0 : nnsight tutorial using GPT2
Based on https://nnsight.net/notebooks/tutorials/activation_patching/



In [None]:
if model_num == 0:
    model = LanguageModel("openai-community/gpt2", device_map="auto")
    clear_output()
    print(model)

    clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
    corrupt_prompt = (
        "After John and Mary went to the store, John gave a bottle of milk to"
    )

    clean_tokenizer_index = model.tokenizer(" John")["input_ids"][0] # includes a space
    corrupt_tokenizer_index = model.tokenizer(" Mary")["input_ids"][0] # includes a space

    print(f"' John': {clean_tokenizer_index}")
    print(f"' Mary': {corrupt_tokenizer_index}")

In [None]:
def run_gpt2_patching():
    run_results = []

    N_LAYERS = len(model.transformer.h)

    # Enter nnsight tracing context
    with model.trace() as tracer:

        # Clean run
        with tracer.invoke(clean_prompt) as invoker:
            clean_tokens = invoker.inputs[0]['input_ids'][0]

            # No need to call .save() as we don't need the values after the run, just within the experiment run.
            clean_hs = [
                model.transformer.h[layer_idx].output[0]
                for layer_idx in range(N_LAYERS)
            ]

            clean_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the clean run last token
            clean_logit_diff = (
                clean_logits[0, -1, clean_tokenizer_index] - clean_logits[0, -1, corrupt_tokenizer_index]
            ).save()

        # Corrupted run
        with tracer.invoke(corrupt_prompt) as invoker:
            corrupt_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the corrupted run last token
            corrupt_logit_diff = (
                corrupt_logits[0, -1, clean_tokenizer_index] - corrupt_logits[0, -1, corrupt_tokenizer_index]
            ).save()

        # Iterate through all the layers
        for layer_idx in range(N_LAYERS):
            layer_results = []

            # Iterate through all tokens
            for token_idx in range(len(clean_tokens)):
                # Patching corrupted run at given layer and token
                with tracer.invoke(corrupt_prompt) as invoker:
                    # Apply the patch from the clean hidden states to the corrupted hidden states.
                    model.transformer.h[layer_idx].output[0][:, token_idx, :] = clean_hs[layer_idx][:, token_idx, :]

                    patched_logits = model.lm_head.output

                    patched_logit_diff = (
                        patched_logits[0, -1, clean_tokenizer_index] - patched_logits[0, -1, corrupt_tokenizer_index]
                    )

                    # Calculate the improvement in the correct token after patching.
                    patched_result = (patched_logit_diff - corrupt_logit_diff) / (
                        clean_logit_diff - corrupt_logit_diff
                    )

                    layer_results.append(patched_result.item().save())

            run_results.append(layer_results)

    run_results = qts.replace_weak_references(run_results)
    qts.free_memory() # Free up GPU and CPU memory

    return clean_tokens, clean_logit_diff.item(), corrupt_logit_diff.item(), run_results


if model_num == 0: # GPT2
    clean_tokens, clean_logit_diff, corrupt_logit_diff, patching_results = run_gpt2_patching()

Results printed by shared code at end of notebook

# Run m1, m2 and m3 models

In [None]:
if model_num > 0:
    model = qts.load_tinysql_model(model_num, cs_num, auth_token=userdata.get("HF_TOKEN"))
    clear_output()
    print(model)

In [None]:
if model_num > 0:
    # Generate a batch of clean and corrupt prompts for feature_name
    batch_size = 5
    generator = qts.CorruptFeatureTestGenerator(model_num, cs_num, model.tokenizer, use_novel_names=use_novel_names)
    examples = generator.generate_feature_examples(feature_name, batch_size)

    # Each examples is corrupted at prompt_token_index. A resulting impact is expected at answer_token_index
    example = examples[0]
    clean_tokenizer_index = example.clean_tokenizer_index
    corrupt_tokenizer_index = example.corrupt_tokenizer_index
    answer_token_index = example.answer_token_index

    # Truncate the clean_prompt at answer_token_index
    clean_prompt = example.clean_BatchItem.get_alpaca_prompt() + example.clean_BatchItem.sql_statement
    clean_tokens = model.tokenizer(clean_prompt)["input_ids"]
    clean_tokens = clean_tokens[:answer_token_index+1]
    clean_prompt = model.tokenizer.decode(clean_tokens)

    # Truncate the corrupt_prompt at answer_token_index
    corrupt_prompt = example.corrupt_BatchItem.get_alpaca_prompt() + example.corrupt_BatchItem.sql_statement
    corrupt_tokens = model.tokenizer(corrupt_prompt)["input_ids"]
    corrupt_tokens = corrupt_tokens[:answer_token_index+1]
    corrupt_prompt = model.tokenizer.decode(corrupt_tokens)

    print("Case:", example.feature_name)
    print("Clean: Token=", example.clean_token_str)
    print("Corrupt: Token=", example.corrupt_token_str)
    print()
    print("Clean prompt:", clean_prompt)
    print()
    print("Corrupt prompt:", corrupt_prompt)

In [None]:
def run_tinystories_patching():
    run_results = []

    N_LAYERS = len(model.transformer.h)

    # Enter nnsight tracing context
    with model.trace() as tracer:

        # Clean run
        with tracer.invoke(clean_prompt) as invoker:
            clean_tokens = invoker.inputs[0]['input_ids'][0]

            # No need to call .save() as we don't need the values after the run, just within the experiment run.
            clean_hs = [
                model.transformer.h[layer_idx].output[0]
                for layer_idx in range(N_LAYERS)
            ]

            clean_logits = model.lm_head.output

            # Calculate the difference between the clean and corrupt token for the clean run
            clean_logit_diff = (
                clean_logits[0, -1, clean_tokenizer_index] - clean_logits[0, -1, corrupt_tokenizer_index]
            ).save()

        # Corrupted run
        with tracer.invoke(corrupt_prompt) as invoker:
            corrupt_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the corrupted run
            corrupt_logit_diff = (
                corrupt_logits[0, -1, clean_tokenizer_index] - corrupt_logits[0, -1, corrupt_tokenizer_index]
            ).save()

        # Iterate through all the layers
        for layer_idx in range(N_LAYERS):
            layer_results = []

            # Iterate through all tokens
            for token_idx in range(len(clean_tokens)):
                # Patching corrupted run at given layer and token
                with tracer.invoke(corrupt_prompt) as invoker:
                    # Apply the patch from the clean hidden states to the corrupted hidden states.
                    model.transformer.h[layer_idx].output[0][:, token_idx, :] = clean_hs[layer_idx][:, token_idx, :]

                    patched_logits = model.lm_head.output

                    patched_logit_diff = (
                        patched_logits[0, -1, clean_tokenizer_index] - patched_logits[0, -1, corrupt_tokenizer_index]
                    )

                    # Calculate the improvement in the correct token after patching.
                    patched_result = (patched_logit_diff - corrupt_logit_diff) / (
                        clean_logit_diff - corrupt_logit_diff
                    )

                    layer_results.append(patched_result.item().save())

                    if layer_idx == N_LAYERS - 1 and token_idx == len(clean_tokens) - 1:
                      final_output = model.lm_head.output.argmax(dim=-1).save()

            run_results.append(layer_results)

    decoded_tokens = [model.tokenizer.decode(token) for token in final_output[0]]
    token_labels = [f"{token}" for index, token in enumerate(decoded_tokens)]
    print("Model output: ", "".join(token_labels))

    run_results = qts.replace_weak_references(run_results)
    qts.free_memory() # Free up GPU and CPU memory

    return clean_tokens, clean_logit_diff.item(), corrupt_logit_diff.item(), run_results


if model_num == 1: # TinyStories
    clean_tokens, clean_logit_diff, corrupt_logit_diff, patching_results = run_tinystories_patching()

In [None]:
def run_llm_patching():
    run_results = []

    N_LAYERS = len(model.model.layers)

    # Clean run
    with model.trace(clean_prompt) as tracer:
        clean_tokens = tracer.invoker.inputs[0]['input_ids'][0]

        # Get hidden states of all layers in the network.
        # We index the output at 0 because it's a tuple where the first index is the hidden state.
        clean_hs = [
            model.model.layers[layer_idx].output[0].save()
            for layer_idx in range(N_LAYERS)
        ]

        clean_logits = model.lm_head.output

        # Calculate the difference between the correct answer and incorrect answer for the clean run and save it.
        clean_logit_diff = (
            clean_logits[0, -1, clean_tokenizer_index] - clean_logits[0, -1, corrupt_tokenizer_index]
        ).save()

    # Corrupted run
    with model.trace(corrupt_prompt) as tracer:
        corrupt_logits = model.lm_head.output

        # Calculate the difference between the correct answer and incorrect answer for the corrupted run and save it.
        corrupt_logit_diff = (
            corrupt_logits[0, -1, clean_tokenizer_index] - corrupt_logits[0, -1, corrupt_tokenizer_index]
        ).save()

    for layer_idx in range(N_LAYERS):
        layer_results = []

        for token_idx in range(len(clean_tokens)):
            # Patching corrupted run at given layer and token
            with model.trace(corrupt_prompt) as tracer:
                # Apply the patch from the clean hidden states to the corrupted hidden states.
                model.model.layers[layer_idx].output[0][:, token_idx, :] = clean_hs[layer_idx][:, token_idx, :]

                patched_logits = model.lm_head.output

                patched_logit_diff = (
                    patched_logits[0, -1, clean_tokenizer_index] - patched_logits[0, -1, corrupt_tokenizer_index]
                )

                # Calculate the improvement in the correct token after patching.
                patching_result = (patched_logit_diff - corrupt_logit_diff) / (
                    clean_logit_diff - corrupt_logit_diff
                )

                # Convert from large structure to int before saving to decrease memory usage
                one_result = patching_result.item().save()

            layer_results.append(one_result)

        run_results.append(layer_results)

    run_results = qts.replace_weak_references(run_results)
    qts.free_memory() # Free up GPU and CPU memory

    return clean_tokens, clean_logit_diff.item(), corrupt_logit_diff.item(), run_results


if model_num == 2 or model_num == 3: # Qwen or Llama
    clean_tokens, clean_logit_diff, corrupt_logit_diff, patching_results = run_llm_patching()

# Graph results

In [None]:
print("Case:", example.feature_name)
print("Clean token:", example.clean_token_str)
print("Corrupt token:", example.corrupt_token_str)
print(f"Clean logit difference: {clean_logit_diff:.3f}")
print(f"Corrupt logit difference: {corrupt_logit_diff:.3f}")

clean_decoded_tokens = [model.tokenizer.decode(token) for token in clean_tokens]
token_labels = [f"{token}_{index}" for index, token in enumerate(clean_decoded_tokens)]

fig = px.imshow(
    patching_results,
    color_continuous_midpoint=0.0,
    color_continuous_scale="RdBu",
    labels={"x": "Position", "y": "Layer","color":"Norm. Logit Diff"},
    x=token_labels,
    title="Patching Residual Stream")
fig.show()