# Activation Patching on M0 .. M5 models using nnsight
- Developed on Google Colab using an A100 with 40GB GPU and 80GB system RAM.
- Runs with GPT2/TinyStories/Qwen/Llama/Granite/SmolLM with base/CS1/CS2/CS3.
- Runs CS1 feature tests (as defined in corrupt_data/clean_corrupt_data.py)    
- Requires a GITHUB_TOKEN secret to access Martian quanta_text_to_sql code repository.
- Requires a HF_TOKEN secret to access Martian HuggingFace repository.


# Import libraries
Imports standard libraries. Do not read.

In [None]:
# https://nnsight.net/
!pip install -U nnsight

In [None]:
from IPython.display import clear_output
import einops
import torch
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "colab"

import nnsight
from nnsight import LanguageModel, util
from nnsight.tracing.Proxy import Proxy

In [None]:
from getpass import getpass
from google.colab import userdata
import gc
import weakref

In [None]:
github_token = userdata.get("GITHUB_TOKEN")

# Install the private repository using the token
!pip install --upgrade git+https://{github_token}@github.com/withmartian/quanta_text_to_sql.git

import QuantaTextToSql as qts

In [None]:
clean_tokens = []
patching_results = []

# Select model, command set and feature to investigate


In [None]:
model_num = 1                 # 0=GPT2, 1=TinyStories, 2=Qwen, 3=Llama, 4=Granite, 5=SmolLM
cs_num = 1                    # 0=BaseModel, 1=CS1, 2=CS2 or 3=CS3
feat_name = qts.DEFFIELDNAME  # ENGTABLENAME, ENGFIELDNAME, DEFTABLESTART, DEFTABLENAME, DEFFIELDNAME, DEFFIELDSEPARATOR

In [None]:
# Key global variables
clean_prompt = ""
corrupted_prompt = ""
clean_token_index = qts.UNKNOWN_VALUE # Tokenizer index for clean word
corrupt_token_index = qts.UNKNOWN_VALUE # Tokenizer index for corrupted word

# Run m0 : nnsight tutorial using GPT2
Based on https://nnsight.net/notebooks/tutorials/activation_patching/



In [None]:
if model_num == 0:
    model = LanguageModel("openai-community/gpt2", device_map="auto")
    clear_output()
    print(model)

In [None]:
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = (
    "After John and Mary went to the store, John gave a bottle of milk to"
)

In [None]:
if model_num == 0:
    clean_token_index = model.tokenizer(" John")["input_ids"][0] # includes a space
    corrupt_token_index = model.tokenizer(" Mary")["input_ids"][0] # includes a space

    print(f"' John': {clean_token_index}")
    print(f"' Mary': {corrupt_token_index}")

In [None]:
def run_gpt2_patching():
    run_results = []

    N_LAYERS = len(model.transformer.h)

    # Enter nnsight tracing context
    with model.trace() as tracer:

        # Clean run
        with tracer.invoke(clean_prompt) as invoker:
            clean_tokens = invoker.inputs[0]['input_ids'][0]

            # No need to call .save() as we don't need the values after the run, just within the experiment run.
            clean_hs = [
                model.transformer.h[layer_idx].output[0]
                for layer_idx in range(N_LAYERS)
            ]

            # Get logits from the lm_head.
            clean_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the clean run and save it.
            clean_logit_diff = (
                clean_logits[0, -1, clean_token_index] - clean_logits[0, -1, corrupt_token_index]
            ).save()

        # Corrupted run
        with tracer.invoke(corrupted_prompt) as invoker:
            corrupted_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the corrupted run and save it.
            corrupted_logit_diff = (
                corrupted_logits[0, -1, clean_token_index] - corrupted_logits[0, -1, corrupt_token_index]
            ).save()

        # Iterate through all the layers
        for layer_idx in range(N_LAYERS):
            layer_results = []

            # Iterate through all tokens
            for token_idx in range(len(clean_tokens)):
                # Patching corrupted run at given layer and token
                with tracer.invoke(corrupted_prompt) as invoker:
                    # Apply the patch from the clean hidden states to the corrupted hidden states.
                    model.transformer.h[layer_idx].output[0][:, token_idx, :] = clean_hs[layer_idx][:, token_idx, :]

                    patched_logits = model.lm_head.output

                    patched_logit_diff = (
                        patched_logits[0, -1, clean_token_index] - patched_logits[0, -1, corrupt_token_index]
                    )

                    # Calculate the improvement in the correct token after patching.
                    patched_result = (patched_logit_diff - corrupted_logit_diff) / (
                        clean_logit_diff - corrupted_logit_diff
                    )

                    layer_results.append(patched_result.item().save())

            run_results.append(layer_results)

    run_results = qts.replace_weak_references(run_results)
    qts.free_memory() # Free up GPU and CPU memory

    return clean_tokens, clean_logit_diff.item(), corrupted_logit_diff.item(), run_results


if model_num == 0: # GPT2
    clean_tokens, clean_logit_diff, corrupted_logit_diff, patching_results = run_gpt2_patching()

Results printed by shared code at end of notebook

# Run m1, m2 and m3 models

In [None]:
if model_num > 0:

    if model_num == 1:
        the_tokenizer, the_model = qts.load_sql_interp_model(model_num, cs_num, auth_token=userdata.get("HF_TOKEN"), use_flash_attention=False)
        model = LanguageModel(the_model, the_tokenizer)
        model.tokenizer = the_tokenizer
    else:
        model = LanguageModel(qts.sql_interp_model_location(model_num, cs_num), device_map="auto")

    clear_output()
    print(model)

In [None]:
if model_num > 0:
    # Generate a batch of clean and corrupt prompts for feat_name
    batch_size = 5
    generator = qts.CorruptFeatureTestGenerator(model_num, cs_num, model.tokenizer)
    examples = generator.generate_feature_examples(feat_name, batch_size)

    example = examples[0]
    ground_truth = example.clean_BatchItem.sql_statement
    clean_prompt = example.clean_BatchItem.get_alpaca_prompt() + ground_truth
    corrupted_prompt = example.corrupt_BatchItem.get_alpaca_prompt() + ground_truth # PQR incorrect
    clean_token_index = example.clean_token_index
    corrupt_token_index = example.corrupt_token_index

    print("Case:", example.feature_name )
    print()
    print("Clean: Token=", example.clean_token, "Index=", clean_token_index, "Prompt=", clean_prompt )
    print()
    print("Corrupt: Token=", example.corrupt_token, "Index=", corrupt_token_index, "Prompt=", corrupted_prompt)

In [None]:
def run_tinystories_patching():
    run_results = []

    N_LAYERS = len(model.transformer.h)

    # Enter nnsight tracing context
    with model.trace() as tracer:

        # Clean run (read-only)
        with tracer.invoke(clean_prompt) as invoker:
            clean_tokens = invoker.inputs[0]['input_ids'][0]

            # No need to call .save() as we don't need the values after the run, just within the experiment run.
            clean_hs = [
                model.transformer.h[layer_idx].output[0]
                for layer_idx in range(N_LAYERS)
            ]

            # Get logits from the lm_head.
            clean_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the clean run and save it.
            clean_logit_diff = (
                clean_logits[0, -1, clean_token_index] - clean_logits[0, -1, corrupt_token_index]
            ).save()

        # Corrupted run (read-only)
        with tracer.invoke(corrupted_prompt) as invoker:
            corrupted_logits = model.lm_head.output

            # Calculate the difference between the correct answer and incorrect answer for the corrupted run and save it.
            corrupted_logit_diff = (
                corrupted_logits[0, -1, clean_token_index] - corrupted_logits[0, -1, corrupt_token_index]
            ).save()

        # Iterate through all the layers
        for layer_idx in range(N_LAYERS):
            layer_results = []

            # Iterate through all tokens
            for token_idx in range(len(clean_tokens)):
                # Patching corrupted run at given layer and token
                with tracer.invoke(corrupted_prompt) as invoker:
                    # Apply the patch from the clean hidden states to the corrupted hidden states.
                    model.transformer.h[layer_idx].output[0][:, token_idx, :] = clean_hs[layer_idx][:, token_idx, :]

                    patched_logits = model.lm_head.output

                    patched_logit_diff = (
                        patched_logits[0, -1, clean_token_index] - patched_logits[0, -1, corrupt_token_index]
                    )

                    # Calculate the improvement in the correct token after patching.
                    patched_result = (patched_logit_diff - corrupted_logit_diff) / (
                        clean_logit_diff - corrupted_logit_diff
                    )

                    layer_results.append(patched_result.item().save())

            run_results.append(layer_results)

    run_results = qts.replace_weak_references(run_results)
    qts.free_memory() # Free up GPU and CPU memory

    return clean_tokens, clean_logit_diff.item(), corrupted_logit_diff.item(), run_results


if model_num == 1: # TinyStories
    clean_tokens, clean_logit_diff, corrupted_logit_diff, patching_results = run_tinystories_patching()

In [None]:
def run_llm_patching():
    run_results = []

    N_LAYERS = len(model.model.layers)

    # Clean run
    with model.trace(clean_prompt) as tracer:
        clean_tokens = tracer.invoker.inputs[0]['input_ids'][0]

        # Get hidden states of all layers in the network.
        # We index the output at 0 because it's a tuple where the first index is the hidden state.
        clean_hs = [
            model.model.layers[layer_idx].output[0].save()
            for layer_idx in range(N_LAYERS)
        ]

        # Get logits from the lm_head.
        clean_logits = model.lm_head.output

        # Calculate the difference between the correct answer and incorrect answer for the clean run and save it.
        clean_logit_diff = (
            clean_logits[0, -1, clean_token_index] - clean_logits[0, -1, corrupt_token_index]
        ).save()

    # Corrupted run
    with model.trace(corrupted_prompt) as tracer:
        corrupted_tokens = tracer.invoker.inputs[0]['input_ids'][0]

        corrupted_logits = model.lm_head.output

        # Calculate the difference between the correct answer and incorrect answer for the corrupted run and save it.
        corrupted_logit_diff = (
            corrupted_logits[0, -1, clean_token_index] - corrupted_logits[0, -1, corrupt_token_index]
        ).save()

    for layer_idx in range(N_LAYERS):
        layer_results = []

        for token_idx in range(len(clean_tokens)):
            # Patching corrupted run at given layer and token
            with model.trace(corrupted_prompt) as tracer:
                # Apply the patch from the clean hidden states to the corrupted hidden states.
                model.model.layers[layer_idx].output[0][:, token_idx, :] = clean_hs[layer_idx][:, token_idx, :]

                patched_logits = model.lm_head.output

                patched_logit_diff = (
                    patched_logits[0, -1, clean_token_index] - patched_logits[0, -1, corrupt_token_index]
                )

                # Calculate the improvement in the correct token after patching.
                patching_result = (patched_logit_diff - corrupted_logit_diff) / (
                    clean_logit_diff - corrupted_logit_diff
                )

                # Convert from large structure to int before saving to decrease memory usage
                one_result = patching_result.item().save()

            layer_results.append(one_result)

        run_results.append(layer_results)

    run_results = qts.replace_weak_references(run_results)
    qts.free_memory() # Free up GPU and CPU memory

    return clean_tokens, clean_logit_diff.item(), corrupted_logit_diff.item(), run_results


if model_num == 2 or model_num == 3: # Qwen or Llama
    clean_tokens, clean_logit_diff, corrupted_logit_diff, patching_results = run_llm_patching()

# Graph results

In [None]:
print(f"Clean logit difference: {clean_logit_diff:.3f}")
print(f"Corrupted logit difference: {corrupted_logit_diff:.3f}")

clean_decoded_tokens = [model.tokenizer.decode(token) for token in clean_tokens]
token_labels = [f"{token}_{index}" for index, token in enumerate(clean_decoded_tokens)]

fig = px.imshow(
    patching_results,
    color_continuous_midpoint=0.0,
    color_continuous_scale="RdBu",
    labels={"x": "Position", "y": "Layer","color":"Norm. Logit Diff"},
    x=token_labels,
    title="Patching Residual Stream")
fig.show()