# TinySQL : M1 useful nodes

**Background:** A "TinySQL" model takes as input 1) An Instruction, which is an english data request sentence and 2) A Context, which is a SQL table create statement. The model outputs a Response, which is a SQL select statement.  

**Notebook purpose:** Identifies nodes (attention heads and MLPs) in the M1 model that, when ablated, cause a decrease in model prediction accuracy. These nodes are needed (aka useful) for accurate predictions.

**Notebook details:** This notebook:
- Was developed on Google Colab using an **T4**
- Runs with M1 (TinyStories) with base/CS1/CS2/CS3 models.
- Requires a GITHUB_TOKEN secret to access Martian TinySQL code repository.
- Requires a HF_TOKEN secret to access Martian HuggingFace repository.
- Was developed under a grant provided by withmartian.com ( https://withmartian.com )
- Relies on the nnsight library. Refer also https://nnsight.net/notebooks/tutorials/walkthrough/#Batching and https://nnsight.net/notebooks/tutorials/walkthrough/#Looping
- Relies on the https://github.com/PhilipQuirke/quanta_mech_interp library for graphing useful nodes.
Is based on the https://nnsight.net/notebooks/tutorials/activation_patching tutorial


# Import libraries
Imports standard libraries. Do not read.

In [None]:
# https://nnsight.net/
!pip install -U nnsight

In [None]:
from IPython.display import clear_output
import einops
import torch
import numpy as np
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "colab"
import matplotlib.pyplot as plt
import tqdm.auto as tqdm

import nnsight
from nnsight import LanguageModel

In [None]:
from getpass import getpass
from google.colab import userdata
import gc
import weakref

In [None]:
!pip install --upgrade git+https://github.com/PhilipQuirke/quanta_mech_interp.git
import QuantaMechInterp as qmi

In [None]:
!pip install datasets

In [None]:
github_token = userdata.get("GITHUB_TOKEN")

# Install the private repository using the token
!pip install --upgrade git+https://{github_token}@github.com/withmartian/TinySQL.git

import TinySQL as qts

# Select model and command set to investigate


In [None]:
model_num = 1                 # 1=TinyStories, 2=Qwen, 3=Llama
cs_num = 3                    # 0=BaseModel, 1=CS1, 2=CS2 or 3=CS3
use_synonyms = False          # Use synonyms for english Instruction table and field names?
batch_size = 30

# Load model

In [None]:
hf_token = auth_token=userdata.get("HF_TOKEN")
model = qts.load_tinysql_model(model_num, cs_num, auth_token=hf_token, )
clear_output()
print(model)

In [None]:
N_LAYERS, N_HEADS, D_MODEL, D_HEAD = qts.get_model_sizes(model_num, model)

In [None]:
model_hf_name = qts.sql_interp_model_location(model_num, cs_num)

# Singleton QuantaTool "main" configuration class. qmi.AlgoConfig is derived from the chain qmi.UsefulConfig > qmi.ModelConfig
cfg = qmi.AlgoConfig()
cfg.repo_name, cfg.model_name = model_hf_name.rsplit("/", 1)
cfg.main_model = model
cfg.n_layers = N_LAYERS
cfg.n_heads = N_HEADS
cfg.d_model = D_MODEL
cfg.d_head = D_HEAD
cfg.set_seed(673023)
cfg.batch_size = batch_size

print("cfg.repo_name="+cfg.repo_name, "cfg.model_name="+cfg.model_name)

In [None]:
model_title = "BM" + str(model_num) + " CS" + str(cs_num)

# Generate test data and experiment runs

In [None]:
# Generate a batch of prompts with 3 field names
def generate_batch():
    return qts.generate_csn(batch_size=batch_size, csn=max(1,cs_num), min_cols=3, max_cols=3, use_synonyms=use_synonyms)

In [None]:
# Check if the model can correctly predict the question's answer
def run_attention_experiment(run_layer_idx, run_head_idx, run_token_idx, all_tokens, run_prompt_tokens, run_generate, make_changes = True, extra_tokens=0):

    assert len(all_tokens) == run_prompt_tokens + run_generate

    if run_generate == 0:
        return True, "", ""

    start = run_head_idx * D_HEAD
    end = (run_head_idx + 1) * D_HEAD

    with model.generate(all_tokens[:run_prompt_tokens], max_new_tokens=run_generate + extra_tokens, pad_token_id=model.tokenizer.eos_token_id):

        if make_changes:
            # Mean ablate the attention head
            head_output = model.transformer.h[run_layer_idx].output[0][:, run_token_idx, start:end]
            mean_value = head_output.mean()
            model.transformer.h[run_layer_idx].output[0][:, run_token_idx, start:end] = mean_value

        final_output = model.generator.output.save()

    # Did the output change?
    decoded_input = model.tokenizer.decode(all_tokens, skip_special_tokens=True)
    decoded_output = model.tokenizer.decode(final_output[0], skip_special_tokens=True)
    same = decoded_input==decoded_output

    return same, decoded_input, decoded_output

In [None]:
# Check if the model can correctly predict the question's answer and print result
def check_experiment(prompt, answer, all_tokens, run_prompt_tokens, run_generate, make_changes = True, show_same = True):

    same, decoded_input, decoded_output = run_attention_experiment(0, 0, 0, all_tokens, run_prompt_tokens, run_generate, make_changes, extra_tokens=100)

    if not same:
        print( "Ignoring example that model doesn't predict correctly:")

    if (not same) or show_same:
        print("Prompt ("+str(run_prompt_tokens)+") :", prompt.replace('\n', '\\n'))
        print("Answer ("+str(run_generate)+") :", answer.replace('\n', '\\n'))
        print("Output     :", decoded_output.replace('\n', '\\n'))
        print()

    return same, decoded_input, decoded_output

In [None]:
# Check that generators work and models are accurate
def check_generator_on_clean_input(title, examples):
    num_good = 0
    for example in examples:
        prompt = example.get_alpaca_prompt()
        answer = example.sql_statement

        prompt_tokens = model.tokenizer(prompt)["input_ids"]
        answer_tokens = model.tokenizer(answer)["input_ids"]
        all_tokens = prompt_tokens + answer_tokens

        num_answer_tokens = len(answer_tokens)
        num_prompt_tokens = len(prompt_tokens)

        same, _, decoded_output = check_experiment(prompt, answer, all_tokens, num_prompt_tokens, num_answer_tokens, make_changes = False, show_same = False)
        if same:
            num_good += 1

    print("Use synonyms:", use_synonyms, "#examples:", len(examples), "num_good:", num_good )


check_generator_on_clean_input( "generate_csn", generate_batch() )

In [None]:
# Return a list of experiments to run
def get_experiment_list(examples, by_attention_head):
    run_list = []
    max_prompt_tokens = 0
    max_answer_tokens = 0
    max_tokens = 0
    show_examples = 3

    for example in examples:
        prompt = example.get_alpaca_prompt()
        answer = example.sql_statement

        prompt_tokens = model.tokenizer(prompt)["input_ids"]
        answer_tokens = model.tokenizer(answer)["input_ids"]
        all_tokens = prompt_tokens + answer_tokens

        num_answer_tokens = len(answer_tokens)
        num_prompt_tokens = len(prompt_tokens)
        num_tokens = len(all_tokens)
        assert num_tokens == num_prompt_tokens + num_answer_tokens

        max_prompt_tokens = max(max_prompt_tokens, num_prompt_tokens)
        max_answer_tokens = max(max_answer_tokens, num_answer_tokens)
        max_tokens = max(max_tokens, num_tokens)

        # Check that the model can correctly predict the question's answer
        same, _, decoded_output = check_experiment(prompt, answer, all_tokens, num_prompt_tokens, num_answer_tokens, make_changes = False, show_same = show_examples > 0)

        if not same:
            continue
        show_examples -= 1

        if by_attention_head:
            for layer_idx in range(N_LAYERS):
                for head_idx in range(N_HEADS):
                    for token_idx in range(num_tokens):
                        # Important logic:
                        # num_prompt_tokens and num_answer_tokens are the normal interpretation of the prompt/answer sizes.
                        # If ablating a token in the MIDDLE of the PROMPT, we provide ALL the prompt tokens, and generate num_answer_tokens.
                        # If ablating a token in the MIDDLE of the ANSWER, we increase the size of the "prompt" and decrease the generated tokens.
                        exp_prompt_tokens = max(num_prompt_tokens, token_idx+1)
                        exp_num_generate = num_tokens - exp_prompt_tokens

                        assert num_tokens == exp_prompt_tokens + exp_num_generate

                        run_list.append([layer_idx, head_idx, token_idx, all_tokens, exp_prompt_tokens, exp_num_generate])
        else:
            for token_idx in range(num_prompt_tokens):
                run_list.append([prompt, answer, token_idx, prompt_tokens])

    return run_list, max_prompt_tokens, max_answer_tokens, max_tokens

# Which token positions are useful?
This information is used to shrink the size of search spaces in following sections. For the SQL model all token positions are useful

In [None]:
def run_token_experiments( ):
    examples = generate_batch()
    run_list, max_prompt_tokens, max_good_answer_tokens, max_tokens = get_experiment_list(examples, False)
    cfg.initialize_token_positions( max_prompt_tokens, max_good_answer_tokens, True )
    num_exps = len(run_list)

    print("batch_size="+str(batch_size), "max_prompt_tokens="+str(max_prompt_tokens), "max_good_answer_tokens="+str(max_good_answer_tokens), "max_tokens="+str(max_tokens), "num_exps="+str(num_exps))

    try_results = np.zeros(max_tokens, dtype=int)
    fail_results = np.zeros(max_tokens, dtype=int)

    for item_num in tqdm.tqdm(range(num_exps)):

        run_item = run_list[item_num]
        run_prompt, run_answer, run_token_idx, prompt_tokens = run_item

        with model.generate(prompt_tokens, max_new_tokens=max_good_answer_tokens, pad_token_id=model.tokenizer.eos_token_id) :

            # Zero out just the portion of the output corresponding to this token position
            for run_layer_idx in range(N_LAYERS):
                model.transformer.h[run_layer_idx].output[0][:, run_token_idx, :] = 0

            final_output = model.generator.output.save()

        decoded_output = model.tokenizer.decode(final_output[0], skip_special_tokens=True)

        # Did the output change?
        if run_prompt + run_answer != decoded_output:
            fail_results[run_token_idx] += 1
        try_results[run_token_idx] += 1

    # Compute the failure rate as percentage
    failure_rate = (1.0 * fail_results / (try_results + 1e-10)) * 100
    failure_rate = np.round(failure_rate, 2)

    return max_tokens, fail_results, failure_rate

g_max_tokens, g_token_fail_results, g_token_failure_rate = run_token_experiments()

In [None]:
print( "Useful token positions:" )
for token_idx in range(len(g_token_failure_rate)):
    if g_token_failure_rate[token_idx] > 0 :
        cfg.add_useful_position(token_idx)
        print( "Position:", token_idx, "    % Fails:", g_token_failure_rate[token_idx], "    # Fails:", g_token_fail_results[token_idx] )

In [None]:
#cfg.calc_position_failures_map(g_token_fail_results.tolist())
#qmi.save_plt_to_file(cfg=cfg, full_title="Failures When Position Ablated")
#plt.show()

# Which token+layer+attention head nodes are useful?

In [None]:
def run_attention_experiments():
    show_diff = True
    examples = generate_batch()
    run_list, max_prompt_tokens, max_good_answer_tokens, max_tokens = get_experiment_list(examples, True)
    num_exps = len(run_list)

    print("batch_size="+str(batch_size), "max_prompt_tokens="+str(max_prompt_tokens), "max_good_answer_tokens="+str(max_good_answer_tokens), "max_tokens="+str(max_tokens), "num_exps="+str(num_exps))

    try_results = np.zeros((N_LAYERS, N_HEADS, max_tokens), dtype=int)
    fail_results = np.zeros((N_LAYERS, N_HEADS, max_tokens), dtype=int)

    for item_num in tqdm.tqdm(range(num_exps)):

        run_item = run_list[item_num]
        run_layer_idx, run_head_idx, run_token_idx, all_tokens, num_prompt_tokens, num_generate = run_item

        same, decoded_input, decoded_output = run_attention_experiment(run_layer_idx, run_head_idx, run_token_idx, all_tokens, num_prompt_tokens, num_generate)

        if not same:
            fail_results[run_layer_idx, run_head_idx, run_token_idx] += 1
            if show_diff:
                print("Failure when intervening:", "Layer="+str(run_layer_idx), "Head="+str(run_head_idx), "Pos="+str(run_token_idx), "NumPrompts="+str(num_prompt_tokens), "NumGenerate="+str(num_generate))
                print("Input :", decoded_input.replace('\n', '\\n'))
                print("Output:", decoded_output.replace('\n', '\\n'))
                show_diff = False
        try_results[run_layer_idx, run_head_idx, run_token_idx] += 1


    # Compute the failure rate as percentage
    failure_rate = (1.0 * fail_results / (try_results + 1e-10)) * 100
    failure_rate = np.round(failure_rate, 2)

    return max_tokens, fail_results, failure_rate, try_results

g_max_tokens, g_attn_failure_results, g_attn_failure_rate, g_attn_try_results = run_attention_experiments()

In [None]:
cfg.useful_nodes = qmi.UsefulNodeList()
for layer_idx in range(N_LAYERS):
    for head_idx in range(N_HEADS):
        for token_idx in range(g_max_tokens):
            fail_perc = int(g_attn_failure_rate[layer_idx, head_idx, token_idx])
            if fail_perc > 0 :
                # Add percentage failure quanta
                node_location = qmi.NodeLocation(token_idx, layer_idx, True, head_idx)
                cfg.add_useful_node_tag( node_location, qmi.QType.FAIL.value, str(fail_perc) )

cfg.useful_nodes.sort_nodes()

In [None]:
print( "Use synonyms:", use_synonyms )

for layer_idx in range(N_LAYERS):
    title = model_title + ": % change in output with mean ablation in Layer " + str(layer_idx)
    plt.imshow(g_attn_failure_rate[layer_idx], cmap="viridis", aspect="auto")
    plt.colorbar(label="Percentage Change")
    plt.xlabel("Token Position")
    plt.ylabel("Attention Head")
    plt.title(title)

    plt.savefig(f"{model_title}_layer_{layer_idx}.png", dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

In [None]:
# cfg.useful_nodes.print_node_tags()

In [None]:
title = model_title + " useful attention heads"
ax1, quanta_results, num_results = qmi.calc_quanta_map(
    cfg, True, 6,
    cfg.useful_nodes, qmi.QType.FAIL.value, "", qmi.get_quanta_fail_perc,
    combine_identical_cells=False)

if num_results > 0:
    if cfg.graph_file_suffix > "":
        print("Saving quanta map:", title)
        qmi.save_plt_to_file(cfg=cfg, full_title=title)
    else:
        ax1.set_title(title + ' ({} nodes)'.format(len(quanta_results)))

    plt.show()

In [None]:
# Serialize and save the useful nodes list to a temporary CoLab file in JSON format. Manually download.
useful_node_json_filename = 'tinysql_bm' + str(model_num) + "_cs" + str(cs_num) + '_useful_nodes.json'
print( "Saving useful node list with behavior tags:", useful_node_json_filename)
cfg.useful_nodes.save_nodes(useful_node_json_filename)

#TODO: Auto save to Martian wandb