# M1 useful nodes
This notebook identifies M1 nodes (attention heads and MLPs) that, when ablated, cause a decrease in model prediction accuracy. These nodes are needed (aka useful) for accurate predictions.


This notebook was:
- Developed on Google Colab using an **T4**
- Runs with M1 (TinyStories) with base/CS1/CS2/CS3.
- Requires a GITHUB_TOKEN secret to access Martian quanta_text_to_sql code repository.
- Requires a HF_TOKEN secret to access Martian HuggingFace repository.

This notebook relies on the nnsight library. Useful background:
- https://nnsight.net/notebooks/tutorials/walkthrough/#Batching
- https://nnsight.net/notebooks/tutorials/walkthrough/#Looping

This notebook relies on the https://github.com/PhilipQuirke/quanta_mech_interp library for graphing useful nodes.

# Import libraries
Imports standard libraries. Do not read.

In [None]:
# https://nnsight.net/
# Access 0.4 prerelease version (as at Dec 2024)
#!pip install nnsight==0.4.0.dev0
!pip install -U nnsight

In [None]:
from IPython.display import clear_output
import einops
import torch
import numpy as np
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "colab"
import matplotlib.pyplot as plt
import tqdm.auto as tqdm

import nnsight
from nnsight import LanguageModel

In [None]:
from getpass import getpass
from google.colab import userdata
import gc
import weakref

In [None]:
!pip install --upgrade git+https://github.com/PhilipQuirke/quanta_mech_interp.git
import QuantaMechInterp as qmi

In [None]:
github_token = userdata.get("GITHUB_TOKEN")

# Install the private repository using the token
!pip install --upgrade git+https://{github_token}@github.com/withmartian/quanta_text_to_sql.git

import QuantaTextToSql as qts

# Select model, command set and feature to investigate


In [None]:
model_num = 1                 # 1=TinyStories, 2=Qwen, 3=Llama, 4=Granite, 5=SmolLM
cs_num = 1                    # 0=BaseModel, 1=CS1, 2=CS2 or 3=CS3

# Load model

In [None]:
model_hf_name = qts.sql_interp_model_location(model_num, cs_num)

if model_num == 1:
    the_tokenizer, the_model = qts.load_sql_interp_model(model_num, cs_num, auth_token=userdata.get("HF_TOKEN"), use_flash_attention=False)
    model = LanguageModel(the_model, the_tokenizer)
    model.tokenizer = the_tokenizer
else:
    model = LanguageModel(model_hf_name, device_map="auto")

clear_output()
print(model)

In [None]:
if model_num == 1:
    N_LAYERS = len(model.transformer.h)
else:
    N_LAYERS = len(model.model.layers)
N_HEADS = 16 if model_num == 1 else 7 if model_num == 2 else 16
D_MODEL = model.transformer.wte.embedding_dim if model_num == 1 else model.config.hidden_size
D_HEAD = D_MODEL // N_HEADS

print("N_LAYERS="+str(N_LAYERS), "N_HEADS="+str(N_HEADS), "D_MODEL="+str(D_MODEL), "D_HEAD="+str(D_HEAD))

In [None]:
# Singleton QuantaTool "main" configuration class. MathsConfig is derived from the chain AlgoConfig > UsefulConfig > ModelConfig
cfg = qmi.AlgoConfig()
cfg.model_name = model_hf_name
cfg.main_model = model
cfg.n_layers = N_LAYERS
cfg.n_heads = N_HEADS
cfg.d_model = D_MODEL
cfg.d_head = D_HEAD
cfg.file_config_prefix = ""

# Generate test data and experiment runs

In [None]:
# Generate a batch of prompts with 3 field names
def generate_batch(batch_size):
    cfg.batch_size = N_BATCH

    if cs_num == 0 or cs_num == 1:
      examples = qts.generate_cs1(batch_size=N_BATCH, min_cols=3, max_cols=3)
    elif cs_num == 2:
      examples = qts.generate_cs2(batch_size=N_BATCH, min_cols=3, max_cols=3)
    elif cs_num == 3:
      examples = qts.generate_cs3(batch_size=N_BATCH, min_cols=3, max_cols=3)

    return examples

In [None]:
# Return a list of experiments to run. Also the max num tokens in the prompts, the
# max num tokens in the (ground truth) answer, and max num tokens in prompt+answer.
def get_experiment_list(examples, by_attention_head):
    run_list = []
    max_prompt_tokens = 0
    max_good_answer_tokens = 0
    max_tokens = 0

    for example in examples:
        prompt = example.get_alpaca_prompt()
        answer = example.sql_statement

        prompt_tokens = model.tokenizer(prompt)["input_ids"]
        num_prompt_tokens = len(prompt_tokens)
        num_good_answer_tokens = len(model.tokenizer(answer)["input_ids"])
        num_tokens = num_prompt_tokens + num_good_answer_tokens

        max_prompt_tokens = max(max_prompt_tokens, num_prompt_tokens)
        max_good_answer_tokens = max(max_good_answer_tokens, num_good_answer_tokens)
        max_tokens = max(max_tokens, num_tokens)

        if by_attention_head:
            for layer_idx in range(N_LAYERS):
                for head_idx in range(N_HEADS):
                    for token_idx in range(num_prompt_tokens):
                        run_list.append([prompt, answer, layer_idx, head_idx, token_idx, prompt_tokens])
        else:
            for token_idx in range(num_prompt_tokens):
                run_list.append([prompt, answer, token_idx, prompt_tokens])

    return run_list, max_prompt_tokens, max_good_answer_tokens, max_tokens

# Which token positions are useful?
This information is used to shrink the size of search spaces in following sections.

In [None]:
N_BATCH = 50

def run_token_experiments():
    examples = generate_batch(N_BATCH)
    run_list, max_prompt_tokens, max_good_answer_tokens, max_tokens = get_experiment_list(examples, False)
    cfg.initialize_token_positions( max_prompt_tokens, max_good_answer_tokens, True )
    num_exps = len(run_list)

    print("N_BATCH="+str(N_BATCH), "max_prompt_tokens="+str(max_prompt_tokens), "max_good_answer_tokens="+str(max_good_answer_tokens), "max_tokens="+str(max_tokens), "num_exps="+str(num_exps))

    try_results = np.zeros(max_tokens, dtype=int)
    fail_results = np.zeros(max_tokens, dtype=int)

    for item_num in tqdm.tqdm(range(num_exps)):

        run_item = run_list[item_num]
        run_prompt, run_answer, run_token_idx, prompt_tokens = run_item

        with model.generate(prompt_tokens, max_new_tokens=max_good_answer_tokens,
                          pad_token_id=model.tokenizer.eos_token_id) as tracer:

            # Zero out just the portion of the output corresponding to this token position
            for run_layer_idx in range(N_LAYERS):
                model.transformer.h[run_layer_idx].output[0][:, run_token_idx, :] = 0

            final_output = model.generator.output.save()

        final_output = final_output.detach().cpu().numpy()
        decoded_output = model.tokenizer.decode(final_output[0], skip_special_tokens=True)

        # Did the output change?
        try_results[run_token_idx] += 1
        if run_prompt + run_answer != decoded_output:
            #print("Input:", item_num, run_prompt.replace('\n', ' '), run_answer.replace('\n', ' '))
            #print("Output:", item_num, decoded_output.replace('\n', ' '))
            fail_results[run_token_idx] += 1

        # Compute the failure rate as percentage
        failure_rate = (1.0 * fail_results / (try_results + 1e-10)) * 100
        failure_rate = np.round(failure_rate, 2)

    return max_tokens, fail_results, failure_rate

g_max_tokens, g_token_fail_results, g_token_failure_rate = run_token_experiments()

In [None]:
print( "Useful token positions:" )
for token_idx in range(len(g_token_failure_rate)):
    if g_token_failure_rate[token_idx] > 0 :
        cfg.add_useful_position(token_idx)
        print( "Position:", token_idx, "% Fails:", g_token_failure_rate[token_idx], "# Fails:", g_token_fail_results[token_idx] )

In [None]:
if False:
    cfg.calc_position_failures_map(g_token_fail_results.tolist())
    qmi.save_plt_to_file(cfg=cfg, full_title="Failures When Position Ablated")
    plt.show()

# Which token+layer+attention head nodes are useful?

In [None]:
N_BATCH = 10

def run_attention_experiments():

    examples = generate_batch(N_BATCH)
    run_list, max_prompt_tokens, max_good_answer_tokens, max_tokens = get_experiment_list(examples, True)
    num_exps = len(run_list)

    print("N_BATCH="+str(N_BATCH), "max_prompt_tokens="+str(max_prompt_tokens), "max_good_answer_tokens="+str(max_good_answer_tokens), "max_tokens="+str(max_tokens), "num_exps="+str(num_exps))

    try_results = np.zeros((N_LAYERS, N_HEADS, max_tokens), dtype=int)
    fail_results = np.zeros((N_LAYERS, N_HEADS, max_tokens), dtype=int)

    for item_num in tqdm.tqdm(range(num_exps)):

        run_item = run_list[item_num]
        run_prompt, run_answer, run_layer_idx, run_head_idx, run_token_idx, prompt_tokens = run_item

        start = run_head_idx * D_HEAD
        end = (run_head_idx + 1) * D_HEAD

        with model.generate(prompt_tokens, max_new_tokens=max_good_answer_tokens,
                          pad_token_id=model.tokenizer.eos_token_id) as tracer:

            # Zero out just the portion of the output corresponding to this head
            model.transformer.h[run_layer_idx].output[0][:, run_token_idx, start:end] = 0

            final_output = model.generator.output.save()

        final_output = final_output.detach().cpu().numpy()
        decoded_output = model.tokenizer.decode(final_output[0], skip_special_tokens=True)

        # Did the output change?
        try_results[run_layer_idx, run_head_idx, run_token_idx] += 1
        if run_prompt + run_answer != decoded_output:
            fail_results[run_layer_idx, run_head_idx, run_token_idx] += 1

        # Compute the failure rate as percentage
        failure_rate = (1.0 * fail_results / (try_results + 1e-10)) * 100
        failure_rate = np.round(failure_rate, 2)

    return max_tokens, fail_results, failure_rate, try_results

print(g_attn_failure_rate.shape)
g_max_tokens, g_attn_failure_results, g_attn_failure_rate, g_attn_try_results = run_attention_experiments()

In [None]:
cfg.useful_nodes = qmi.UsefulNodeList()
for layer_idx in range(N_LAYERS):
    for head_idx in range(N_HEADS):
        for token_idx in range(g_max_tokens):
            fail_perc = int(g_attn_failure_rate[ layer_idx, head_idx, token_idx])
            if fail_perc > 0 :
                # Add percentage failure quanta
                node_location = qmi.NodeLocation(token_idx, layer_idx, True, head_idx)
                cfg.add_useful_node_tag( node_location, qmi.QType.FAIL.value, str(fail_perc) )

cfg.useful_nodes.sort_nodes()

In [None]:
for layer_idx in range(N_LAYERS):
    plt.imshow(g_attn_failure_rate[layer_idx], cmap="viridis", aspect="auto")
    plt.colorbar(label="Percentage Change")
    plt.xlabel("Token Position")
    plt.ylabel("Attention Head")
    plt.title("Percentage of Output Changes by Zeroing Activations in Layer " + str(layer_idx))
    plt.show()

In [None]:
cfg.useful_nodes.print_node_tags()

In [None]:
title = "Useful attention heads"
ax1, quanta_results, num_results = qmi.calc_quanta_map(
    cfg, True, 6,
    cfg.useful_nodes, qmi.QType.FAIL.value, "", qmi.get_quanta_fail_perc)

if num_results > 0:
    if cfg.graph_file_suffix > "":
        print("Saving quanta map:", title)
        qmi.save_plt_to_file(cfg=cfg, full_title=title)
    else:
        ax1.set_title(title + ' ({} nodes)'.format(len(quanta_results)))

    plt.show()

In [None]:
# Serialize and save the useful nodes list to a temporary CoLab file in JSON format
#main_fname_behavior_json = cfg.model_name + '_behavior.json'
main_fname_behavior_json = 'behavior.json'
print( "Saving useful node list with behavior tags:", main_fname_behavior_json)
cfg.useful_nodes.save_nodes(main_fname_behavior_json)