# TinySQL : M1 Activation Patching

**Background:** A "TinySQL" model takes as input 1) An Instruction, which is an english data request sentence and 2) A Context, which is a SQL table create statement. The model outputs a Response, which is a SQL select statement.  

**Notebook purpose:** Visualize changes in attention head activations when a token is corrupted. We corrupt 1) The instruction table name 2) An instruction field name 3) The context table name or 4) A context field name.

**Notebook details:** This notebook:
- Was developed on Google Colab using an A100
- Runs with M1 (TinyStories) with base/CS1/CS2/CS3 models.
- Requires a GITHUB_TOKEN secret to access Martian TinySQL code repository.
- Requires a HF_TOKEN secret to access Martian HuggingFace repository.
- Was developed under a grant provided by withmartian.com ( https://withmartian.com )
- Relies on the nnsight library. Also refer the https://nnsight.net/notebooks/tutorials/activation_patching/ tutorial
- Relies on the https://github.com/PhilipQuirke/quanta_mech_interp library for graphing useful nodes.


# Import libraries
Imports standard libraries. Do not read.

In [1]:
# https://nnsight.net/
# Access 0.4 prerelease version (as at Dec 2024)
#!pip install nnsight==0.4.0.dev0
!pip install -U nnsight -q

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.5 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m3.5/3.5 MB[0m [31m124.6 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.5/3.5 MB[0m [31m68.6 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/59.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.2/59.2 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/76.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.9/76.9 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
!pip install pandas plotly -q

In [3]:
from IPython.display import clear_output
import einops
import torch
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "colab"

import nnsight
from nnsight import LanguageModel, util

In [4]:
from getpass import getpass
from google.colab import userdata
import gc
import weakref

In [5]:
from tqdm import tqdm
import pandas as pd
import numpy as np
import json
import datetime

In [6]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m21.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

In [8]:
github_token = userdata.get("GITHUB_TOKEN")

!pip install --upgrade git+https://{github_token}@github.com/withmartian/TinySQL.git

import TinySQL as qts

Collecting git+https://****@github.com/withmartian/TinySQL.git
  Cloning https://****@github.com/withmartian/TinySQL.git to /tmp/pip-req-build-i0oron2y
  Running command git clone --filter=blob:none --quiet 'https://****@github.com/withmartian/TinySQL.git' /tmp/pip-req-build-i0oron2y
  Resolved https://****@github.com/withmartian/TinySQL.git to commit 69527faef20a65947cd6274d43c0b545b8ecd397
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: TinySQL
  Building wheel for TinySQL (pyproject.toml) ... [?25l[?25hdone
  Created wheel for TinySQL: filename=TinySQL-1.3-py3-none-any.whl size=57656 sha256=f6dbca77824904ab11e57568f3e2eb0f5231e3fb321742293f36f7ae6f19f207
  Stored in directory: /tmp/pip-ephem-wheel-cache-v4fh5stl/wheels/50/0f/e6/79737def9bcdd807f6db4bea479886f3acc4d4a2671f79b776
Successfully built TinySQL
Installing colle

In [9]:
clean_tokens = []
patching_results = []

In [10]:
# Key global "input" variables
clean_prompt = ""
corrupt_prompt = ""
clean_tokenizer_index = qts.UNKNOWN_VALUE # Tokenizer vocab index for clean word
corrupt_tokenizer_index = qts.UNKNOWN_VALUE # Tokenizer vocab index for corrupted word
answer_token_index = qts.UNKNOWN_VALUE # Token index in sql command answer of clean/corrupt word

# Key global "results" variables
clean_logit_diff = qts.UNKNOWN_VALUE
corrupt_logit_diff = qts.UNKNOWN_VALUE

# Select model, command set and feature to investigate


In [11]:
model_num = 1                     # 0=GPT2, 1=TinyStories, 2=Qwen, 3=Llama, 4=Granite, 5=SmolLM
cs_num = 1                        # 0=BaseModel, 1=CS1, 2=CS2 or 3=CS3
feature_name = qts.DEFFIELDNAME   # ENGTABLENAME, ENGFIELDNAME, DEFTABLESTART, DEFTABLENAME, DEFFIELDNAME, DEFFIELDSEPARATOR
use_novel_names = False           # If True, we corrupt using words not found in the clean prompt or create sql e.g. "little" or "hammer"
use_synonyms_table = False
use_synonyms_field = False
batch_size = 5

# Load model

In [13]:
hf_token = userdata.get("HF_TOKEN")

model = qts.load_tinysql_model(model_num, cs_num, auth_token=hf_token, synonym=True)
model_hf = qts.sql_interp_model_location(model_num, cs_num)
clear_output()
print(model)

GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(50258, 1024)
    (wpe): Embedding(2048, 1024)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-1): 2 x GPTNeoBlock(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoSelfAttention(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=False)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
        )
        (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): GPTNeoMLP(
          (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
          (c_proj): L

In [14]:
N_LAYERS, N_HEADS, D_MODEL, D_HEAD = qts.get_model_sizes(model_num, model)

N_LAYERS=2 N_HEADS=16 D_MODEL=1024 D_HEAD=64


# Generate clean and corrupt data

In [15]:
generator = qts.CorruptFeatureTestGenerator(model_num, cs_num, model.tokenizer, use_novel_names=use_novel_names, use_synonyms_field=use_synonyms_field, use_synonyms_table=use_synonyms_table )
examples = generator.generate_feature_examples(feature_name, batch_size)

# Each examples is corrupted at prompt_token_index. A resulting impact is expected at answer_token_index
example = examples[0]
clean_tokenizer_index = example.clean_tokenizer_index
corrupt_tokenizer_index = example.corrupt_tokenizer_index
answer_token_index = example.answer_token_index

# Truncate the clean_prompt at answer_token_index
clean_prompt = example.clean_BatchItem.get_alpaca_prompt() + example.clean_BatchItem.sql_statement
clean_tokens = model.tokenizer(clean_prompt)["input_ids"]
clean_tokens = clean_tokens[:answer_token_index+1]
clean_prompt = model.tokenizer.decode(clean_tokens)

# Truncate the corrupt_prompt at answer_token_index
corrupt_prompt = example.corrupt_BatchItem.get_alpaca_prompt() + example.corrupt_BatchItem.sql_statement
corrupt_tokens = model.tokenizer(corrupt_prompt)["input_ids"]
corrupt_tokens = corrupt_tokens[:answer_token_index+1]
corrupt_prompt = model.tokenizer.decode(corrupt_tokens)

print("Case:", example.feature_name)
print("Clean: Token=", example.clean_token_str)
print("Corrupt: Token=", example.corrupt_token_str)
print()
print("Clean prompt:", clean_prompt)
print()
print("Corrupt prompt:", corrupt_prompt)

Case: DefFieldName
Clean: Token= label
Corrupt: Token= status

Clean prompt: ### Instruction: show me the label and category from the products table ### Context: CREATE TABLE products ( label JSON, category JSON ) ### Response: SELECT label

Corrupt prompt: ### Instruction: show me the label and category from the products table ### Context: CREATE TABLE products ( status JSON, category JSON ) ### Response: SELECT label


# Selective ablations

In [33]:
prompt = '### Instruction: show me the label and category from the products table ### Context: CREATE TABLE products ( label JSON, category JSON ) ### Response: '

In [52]:
def zero_heads(model, prompt_text, target_layers, heads_per_layer):
    N_HEADS = 16
    inputs = model.tokenizer(prompt_text, return_tensors="pt")

    with model.trace() as tracer:
        with tracer.invoke(inputs) as invoker:
            for layer_idx in target_layers:
                layer_output = model.transformer.h[layer_idx].output[0]
                target_heads = heads_per_layer[layer_idx]

                output_reshaped = einops.rearrange(
                    layer_output,
                    'b s (nh dh) -> b s nh dh',
                    nh=N_HEADS
                )

                for head_idx in range(N_HEADS):
                    if head_idx not in target_heads:
                        output_reshaped[:, :, head_idx, :] = 0

                modified_output = einops.rearrange(
                    output_reshaped,
                    'b s nh dh -> b s (nh dh)',
                    nh=N_HEADS
                )

                model.transformer.h[layer_idx].output = (modified_output,) + model.transformer.h[layer_idx].output[1:]

            final_output = model.lm_head.output.argmax(dim=-1).save()

    print("Modified Output:", model.tokenizer.decode(final_output[0][-1]))
    return final_output

# Usage with different heads per layer
target_layers = [0, 1]
heads_per_layer = {
    0: [11, 3, 1, 8, 15, 14, 13, 7],
    1: [10, 13, 3, 7, 14, 15, 11, 2, 1, 12, 5]
}
output = zero_heads(model, prompt, target_layers, heads_per_layer)

Modified Output: SELECT
