# TinySQL : M1 Activation Patching

**Background:** A "TinySQL" model takes as input 1) An Instruction, which is an english data request sentence and 2) A Context, which is a SQL table create statement. The model outputs a Response, which is a SQL select statement.  

**Notebook purpose:** Visualize changes in attention head activations when a token is corrupted. We corrupt 1) The instruction table name 2) An instruction field name 3) The context table name or 4) A context field name.

**Notebook details:** This notebook:
- Was developed on Google Colab using an A100
- Runs with M1 (TinyStories) with base/CS1/CS2/CS3 models.
- Requires a GITHUB_TOKEN secret to access Martian TinySQL code repository.
- Requires a HF_TOKEN secret to access Martian HuggingFace repository.
- Was developed under a grant provided by withmartian.com ( https://withmartian.com )
- Relies on the nnsight library. Also refer the https://nnsight.net/notebooks/tutorials/activation_patching/ tutorial
- Relies on the https://github.com/PhilipQuirke/quanta_mech_interp library for graphing useful nodes.


# Import libraries
Imports standard libraries. Do not read.

In [None]:
# https://nnsight.net/
# !pip install -U nnsight
!pip install nnsight==0.3.7 -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.5/3.5 MB[0m [31m38.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m86.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m82.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m54.3 MB/s[0m eta [36m0:00:00[0m
[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m664.8/664.8 MB[0m [31m190.6 MB/s[0m eta [36m0:00:01[0m

In [None]:
!pip install pandas plotly -q

In [None]:
from IPython.display import clear_output
import einops
import torch
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "colab"

import nnsight
from nnsight import LanguageModel, util

In [None]:
from getpass import getpass
from google.colab import userdata
import gc
import weakref

In [None]:
from tqdm import tqdm
import pandas as pd
import numpy as np
import json
import datetime

In [None]:
!pip install datasets

In [None]:
github_token = userdata.get("GITHUB_TOKEN")

!pip install --upgrade git+https://{github_token}@github.com/withmartian/TinySQL.git

import TinySQL as qts

In [None]:
clean_tokens = []
patching_results = []

In [None]:
# Key global "input" variables
clean_prompt = ""
corrupt_prompt = ""
clean_tokenizer_index = qts.UNKNOWN_VALUE # Tokenizer vocab index for clean word
corrupt_tokenizer_index = qts.UNKNOWN_VALUE # Tokenizer vocab index for corrupted word
answer_token_index = qts.UNKNOWN_VALUE # Token index in sql command answer of clean/corrupt word

# Key global "results" variables
clean_logit_diff = qts.UNKNOWN_VALUE
corrupt_logit_diff = qts.UNKNOWN_VALUE

# Select model, command set and feature to investigate


In [None]:
model_num = 1                     # 0=GPT2, 1=TinyStories, 2=Qwen, 3=Llama, 4=Granite, 5=SmolLM
cs_num = 1                        # 0=BaseModel, 1=CS1, 2=CS2 or 3=CS3
feature_name = qts.DEFFIELDNAME   # ENGTABLENAME, ENGFIELDNAME, DEFTABLENAME, DEFFIELDNAME
use_novel_names = False           # If True, we corrupt using words not found in the clean prompt or create sql e.g. "little" or "hammer"
use_synonyms_table = False
use_synonyms_field = False
batch_size = 5

# Load model

In [None]:
hf_token = userdata.get("HF_TOKEN")

model = qts.load_tinysql_model(model_num, cs_num, auth_token=hf_token)
model_hf = qts.sql_interp_model_location(model_num, cs_num)
clear_output()
print(model)

In [None]:
N_LAYERS, N_HEADS, D_MODEL, D_HEAD = qts.get_model_sizes(model_num, model)

# Generate clean and corrupt data

In [None]:
generator = qts.CorruptFeatureTestGenerator(model_num, cs_num, model.tokenizer, use_novel_names=use_novel_names, use_synonyms_field=use_synonyms_field, use_synonyms_table=use_synonyms_table )
examples = generator.generate_feature_examples(feature_name, batch_size)

# Each examples is corrupted at prompt_token_index. A resulting impact is expected at answer_token_index
example = examples[0]
clean_tokenizer_index = example.clean_tokenizer_index
corrupt_tokenizer_index = example.corrupt_tokenizer_index
answer_token_index = example.answer_token_index

# Truncate the clean_prompt at answer_token_index
clean_prompt = example.clean_BatchItem.get_alpaca_prompt() + example.clean_BatchItem.sql_statement
clean_tokens = model.tokenizer(clean_prompt)["input_ids"]
clean_tokens = clean_tokens[:answer_token_index+1]
clean_prompt = model.tokenizer.decode(clean_tokens)

# Truncate the corrupt_prompt at answer_token_index
corrupt_prompt = example.corrupt_BatchItem.get_alpaca_prompt() + example.corrupt_BatchItem.sql_statement
corrupt_tokens = model.tokenizer(corrupt_prompt)["input_ids"]
corrupt_tokens = corrupt_tokens[:answer_token_index+1]
corrupt_prompt = model.tokenizer.decode(corrupt_tokens)

print("Case:", example.feature_name)
print("Clean: Token=", example.clean_token_str)
print("Corrupt: Token=", example.corrupt_token_str)
print()
print("Clean prompt:", clean_prompt)
print()
print("Corrupt prompt:", corrupt_prompt)

# Perform activation patching

In [None]:
def run_tinystories_patching():
    run_results = []

    # Enter nnsight tracing context
    with model.trace() as tracer:
        # Clean run
        with tracer.invoke(clean_prompt) as invoker:
            clean_tokens = invoker.inputs[0]['input_ids'][0]

            # Store clean attention outputs for each layer and head
            clean_attn_outputs = {}
            for layer_idx in range(N_LAYERS):
                # Get layer output
                layer_output = model.transformer.h[layer_idx].output
                hidden_states = layer_output[0]  # Get hidden states from tuple
                # Reshape to separate heads
                output_reshaped = einops.rearrange(
                    hidden_states,
                    'b s (nh dh) -> b s nh dh',
                    nh=N_HEADS
                )

                for head_idx in range(N_HEADS):
                    clean_attn_outputs[(layer_idx, head_idx)] = output_reshaped[:, :, head_idx, :].save()

            clean_logits = model.lm_head.output

            # Calculate the difference between the clean and corrupt token for the clean run
            clean_logit_diff = (
                clean_logits[0, -1, clean_tokenizer_index] -
                clean_logits[0, -1, corrupt_tokenizer_index]
            ).save()

        # Corrupted run
        with tracer.invoke(corrupt_prompt) as invoker:
            corrupt_logits = model.lm_head.output

            # Calculate the difference between the correct and incorrect answer for the corrupted run
            corrupt_logit_diff = (
                corrupt_logits[0, -1, clean_tokenizer_index] -
                corrupt_logits[0, -1, corrupt_tokenizer_index]
            ).save()

        # Initialize results dictionary for layer-head combinations
        results_dict = {}

        # Iterate through all layers and heads
        for layer_idx in range(N_LAYERS):
            for head_idx in range(N_HEADS):
                head_results = []

                # Iterate through all tokens
                for token_idx in range(len(clean_tokens)):
                    # Patching corrupted run at given layer-head and token
                    with tracer.invoke(corrupt_prompt) as invoker:
                        # Get layer output
                        layer_output = model.transformer.h[layer_idx].output
                        hidden_states = layer_output[0]
                        # Reshape to separate heads
                        output_reshaped = einops.rearrange(
                            hidden_states,
                            'b s (nh dh) -> b s nh dh',
                            nh=N_HEADS
                        )

                        # Patch only the specific head at the specific position
                        output_reshaped[:, token_idx:token_idx+1, head_idx, :] = \
                            clean_attn_outputs[(layer_idx, head_idx)][:, token_idx:token_idx+1]

                        # Reshape back
                        patched_hidden_states = einops.rearrange(
                            output_reshaped,
                            'b s nh dh -> b s (nh dh)',
                            nh=N_HEADS
                        )

                        # Set the entire output as a new tuple
                        model.transformer.h[layer_idx].output = (patched_hidden_states,) + layer_output[1:]

                        patched_logits = model.lm_head.output

                        patched_logit_diff = (
                            patched_logits[0, -1, clean_tokenizer_index] -
                            patched_logits[0, -1, corrupt_tokenizer_index]
                        )

                        # Calculate the improvement in the correct token after patching
                        patched_result = (patched_logit_diff - corrupt_logit_diff) / (
                            clean_logit_diff - corrupt_logit_diff
                        )

                        head_results.append(patched_result.item().save())

                        # Store final output for the last iteration
                        if layer_idx == N_LAYERS - 1 and head_idx == N_HEADS - 1 and \
                           token_idx == len(clean_tokens) - 1:
                            final_output = model.lm_head.output.argmax(dim=-1).save()

                results_dict[(layer_idx, head_idx)] = head_results

        # Convert results to desired format
        for layer_idx in range(N_LAYERS):
            layer_results = []
            for head_idx in range(N_HEADS):
                layer_results.append(results_dict[(layer_idx, head_idx)])
            run_results.append(layer_results)

    # Decode final output tokens
    decoded_tokens = [model.tokenizer.decode(token) for token in final_output[0]]
    token_labels = [f"{token}" for index, token in enumerate(decoded_tokens)]
    print("Model output: ", "".join(token_labels))

    run_results = qts.replace_weak_references(run_results)
    qts.free_memory()  # Free up GPU and CPU memory

    return clean_tokens, clean_logit_diff.item(), corrupt_logit_diff.item(), run_results

if model_num == 1:  # TinyStories
    clean_tokens, clean_logit_diff, corrupt_logit_diff, patching_results = run_tinystories_patching()

In [None]:
def run_llm_patching():
    run_results = []
    N_LAYERS = len(model.model.layers)

    # Clean run
    with model.trace(clean_prompt) as tracer:
        clean_tokens = tracer.invoker.inputs[0]['input_ids'][0]

        # Get clean attention outputs for each layer and head
        clean_attn_outputs = {}
        for layer_idx in range(N_LAYERS):
            z = model.model.layers[layer_idx].self_attn.o_proj.input
            z_reshaped = einops.rearrange(z, 'b s (nh dh) -> b s nh dh', nh=N_HEADS)

            for head_idx in range(N_HEADS):
                clean_attn_outputs[(layer_idx, head_idx)] = z_reshaped[:, :, head_idx, :].save()

        clean_logits = model.lm_head.output

        # Calculate logit difference for clean run
        clean_logit_diff = (
            clean_logits[0, -1, clean_tokenizer_index] -
            clean_logits[0, -1, corrupt_tokenizer_index]
        ).save()

    # Corrupted run
    with model.trace(corrupt_prompt) as tracer:
        corrupt_logits = model.lm_head.output

        # Calculate logit difference for corrupted run
        corrupt_logit_diff = (
             corrupt_logits[0, -1, clean_tokenizer_index] -
            corrupt_logits[0, -1, corrupt_tokenizer_index]
        ).save()

    # Initialize results structure for layer-head combinations
    results_dict = {}

    # Iterate through each layer-head combination
    for layer_idx in tqdm(range(N_LAYERS), desc="Processing layers"):
        for head_idx in range(N_HEADS):
            head_results = []

            # For each position in the sequence
            for token_idx in range(len(clean_tokens)):
                # Patching corrupted run at given layer-head and token position
                with model.trace(corrupt_prompt) as tracer:
                    # Get corrupted attention output and reshape
                    z_corrupt = model.model.layers[layer_idx].self_attn.o_proj.input
                    z_corrupt = einops.rearrange(z_corrupt, 'b s (nh dh) -> b s nh dh', nh=N_HEADS)

                    # Patch only the specific head at the specific position
                    z_corrupt[:, token_idx:token_idx+1, head_idx, :] = \
                        clean_attn_outputs[(layer_idx, head_idx)][:, token_idx:token_idx+1]

                    # Reshape back
                    z_corrupt = einops.rearrange(z_corrupt, 'b s nh dh -> b s (nh dh)', nh=N_HEADS)
                    model.model.layers[layer_idx].self_attn.o_proj.input = z_corrupt

                    patched_logits = model.lm_head.output

                    patched_logit_diff = (
                        patched_logits[0, -1, clean_tokenizer_index] -
                        patched_logits[0, -1, corrupt_tokenizer_index]
                    )

                    # Calculate improvement
                    patching_result = (patched_logit_diff - corrupt_logit_diff) / (
                        clean_logit_diff - corrupt_logit_diff
                    )

                    # Convert to item and save
                    one_result = patching_result.item().save()
                    head_results.append(one_result)

            results_dict[(layer_idx, head_idx)] = head_results

    # Convert results to desired format
    run_results = []
    for layer_idx in range(N_LAYERS):
        layer_results = []
        for head_idx in range(N_HEADS):
            layer_results.append(results_dict[(layer_idx, head_idx)])
        run_results.append(layer_results)

    run_results = qts.replace_weak_references(run_results)
    qts.free_memory()  # Free up GPU and CPU memory

    return clean_tokens, clean_logit_diff.item(), corrupt_logit_diff.item(), run_results

if model_num > 1:  # Qwen or Llama
    clean_tokens, clean_logit_diff, corrupt_logit_diff, patching_results = run_llm_patching()

# Graph results

In [None]:
print("Case:", example.feature_name)
print("Clean token:", example.clean_token_str)
print("Corrupt token:", example.corrupt_token_str)
print(f"Clean logit difference: {clean_logit_diff:.3f}")
print(f"Corrupt logit difference: {corrupt_logit_diff:.3f}")

# Decode tokens for labels
clean_decoded_tokens = [model.tokenizer.decode(token) for token in clean_tokens]
token_labels = [f"{token}_{index}" for index, token in enumerate(clean_decoded_tokens)]

# Reshape the patching results to combine layer and head into a single dimension
n_layers = len(patching_results)
n_heads = len(patching_results[0])
n_positions = len(patching_results[0][0])

# Reshape to combine layers and heads into single dimension
results_2d = np.array(patching_results).reshape(n_layers * n_heads, n_positions)

# Create labels for each layer-head combination
layer_head_labels = [f"L{l}_H{h}" for l in range(n_layers) for h in range(n_heads)]

# Create the heatmap
fig = px.imshow(
    results_2d,
    color_continuous_midpoint=0.0,
    color_continuous_scale="RdBu",
    labels={"x": "Position", "y": "Layer_Head", "color": "Norm. Logit Diff"},
    x=token_labels,
    y=layer_head_labels,
    title="Layer-Head Patching Analysis"
)

# Adjust layout
fig.update_layout(
    xaxis_tickangle=-45,  # Rotate labels 45 degrees
    margin=dict(b=100, l=150),  # Increase margins for readability
    xaxis=dict(
        tickmode='array',
        ticktext=token_labels,
        tickvals=list(range(len(token_labels))),
        tickfont=dict(size=10)
    ),
    yaxis=dict(
        tickmode='array',
        ticktext=layer_head_labels,
        tickvals=list(range(len(layer_head_labels))),
        tickfont=dict(size=10),
        title="Layer_Head"
    ),
    height=800  # Make plot taller to accommodate all layer-head combinations
)

fig.show()

# Save results as JSON

In [None]:
# Save results
def save_patching_results(patching_results, clean_tokens, clean_logit_diff, corrupt_logit_diff,
                          model_num, cs_num, model_hf, use_novel_names, use_synonyms_table, use_synonyms_field,
                          num_sql_fields, batch_size):
    # Create timestamp
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

    # Create results dictionary with metadata first
    results_dict = {
        "metadata": {
            "source": "tinysql_activation_patching_json.ipynb",
            "timestamp": timestamp,
            "model_num": model_num,
            "model_hf": model_hf,
            "commandset": cs_num,
            "use_novel_names":use_novel_names,
            "use_synonyms_table": use_synonyms_table,
            "use_synonyms_field": use_synonyms_field,
            "num_sql_fields": num_sql_fields,
            "batch_size": batch_size,
            "n_positions": len(patching_results[0][0]),
            "n_layers": len(patching_results),
            "n_heads": len(patching_results[0])
        },
        "patching_results": np.array(patching_results).tolist(),
        "clean_tokens": clean_tokens.tolist() if hasattr(clean_tokens, 'tolist') else clean_tokens,
        "clean_logit_diff": float(clean_logit_diff),
        "corrupt_logit_diff": float(corrupt_logit_diff),
    }

    # Create filename with timestamp and model name
    filename = f'activation_patching_results_{model_num}_{cs_num}_novel_names.{use_novel_names}.table.{use_synonyms_table}.field.{use_synonyms_field}.json'

    # Save to JSON file
    with open(filename, 'w') as f:
        json.dump(results_dict, f, indent=2)

    print(f"Results saved to {filename}")
    return filename

# Load and plot results
def load_and_plot_results(json_path):
    # Load JSON file
    with open(json_path, 'r') as f:
        data = json.load(f)

    # Extract data
    patching_results = np.array(data['patching_results'])
    metadata = data['metadata']

    # Reshape to combine layers and heads
    results_2d = patching_results.reshape(metadata['n_layers'] * metadata['n_heads'],
                                          metadata['n_positions'])

    # Create labels
    layer_head_labels = [f"L{l}_H{h}" for l in range(metadata['n_layers'])
                         for h in range(metadata['n_heads'])]
    position_labels = [f"pos_{i}" for i in range(metadata['n_positions'])]

    # Create title with metadata
    title = (f"Layer-Head Patching Analysis\n"
             f"Source: {metadata['source']}, Model: {metadata['model_num']}, Dataset: {metadata['commandset']}\n"
             f"Time: {metadata['timestamp']}")

    # Create the heatmap
    fig = px.imshow(
        results_2d,
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        labels={"x": "Position", "y": "Layer_Head", "color": "Norm. Logit Diff"},
        x=position_labels,
        y=layer_head_labels,
        title=title
    )

    # Adjust layout
    fig.update_layout(
        xaxis_tickangle=-45,
        margin=dict(b=100, l=150, t=130),
        xaxis=dict(
            tickmode='array',
            ticktext=position_labels,
            tickvals=list(range(len(position_labels))),
            tickfont=dict(size=10)
        ),
        yaxis=dict(
            tickmode='array',
            ticktext=layer_head_labels,
            tickvals=list(range(len(layer_head_labels))),
            tickfont=dict(size=10),
            title="Layer_Head"
        ),
        height=800
    )

    # Print metrics and metadata
    print("Experiment Information:")
    for key, value in metadata.items():
        print(f"{key.replace('_', ' ').title()}: {value}")

    print("\nMetrics:")
    print(f"Clean logit difference: {data['clean_logit_diff']:.3f}")
    print(f"Corrupt logit difference: {data['corrupt_logit_diff']:.3f}")

    return fig



In [None]:
save_patching_results(patching_results, clean_tokens, clean_logit_diff, corrupt_logit_diff,
                          model_num, cs_num, model_hf, use_novel_names, use_synonyms_table, use_synonyms_field,
                          2, batch_size)