In [79]:
!uv pip install transformers matplotlib pandas seaborn torch

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from huggingface_hub import hf_hub_download
from difflib import SequenceMatcher
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from pathlib import Path
from typing import List, Optional

[2mAudited [1m5 packages[0m [2min 29ms[0m[0m


#### utils


In [None]:
def get_random_needles(df_needles, n=5):
    """
    Get random needles of different types from the dataframe.
    """
    df = df_needles.sample(n=n)
    df = df.drop_duplicates(subset=["arg1"], keep="first").reset_index(drop=True)
    return df


def generate_context(df_haystack, df_needles) -> str:
    """
    Generate context for the promp.
    """
    context = [" "]
    for i in range(len(df_needles)):
        needle = df_needles.iloc[i]["needle"]
        haystack = ""
        if i < len(df_needles) - 1:
            haystack = df_haystack.iloc[i]["text"]

        context.append(needle)
        if haystack:
            context.append(" " + haystack + " ")
    return "".join(context)


def generate_messages(df_needles, df_haystack, n=5):
    """
    Generate messages for the model.
    Args:
        df_needles (pd.DataFrame): DataFrame containing the needles.
        df_haystack (pd.DataFrame): DataFrame containing the haystacks.
        n (int): Number of random needles to select.
    Returns:
        messages (list): List of messages for the model.
        prompt_needle (pd.Series): Random needle selected for the prompt.
    """
    df_rand_needles = get_random_needles(df_needles, n=n)
    context = generate_context(df_haystack, df_rand_needles)

    prompt_needle = df_rand_needles.sample(n=1).iloc[0]

    sys_prompt = "You are an intelligent AI assistant skilled in answering user questions base on documents provided by the user. Please keep your answers concise and clear. Do not talk about irrelevant topics or repeat your answers. The document given to you by the user is:"
    question, format = prompt_needle["retrieval_question"].split("?")
    question += "? Answer concisely, correctly, and in a complete sentence."
    messages = [
        {
            "role": "system",
            "content": sys_prompt,
        },
        {
            "role": "user",
            "content": context,
        },
        {
            "role": "user",
            "content": question,
        },
    ]
    return messages, prompt_needle, question


def grade(response, answer) -> float:
    return float(SequenceMatcher(None, response, answer).ratio())


def find_subtensor_indices(
    haystack: torch.LongTensor, needle: torch.LongTensor
) -> torch.LongTensor:
    """
    Returns a 1D tensor of all start‐positions where `needle`
    appears as a contiguous slice of `haystack`.
    """
    n, m = haystack.size(0), needle.size(0)
    if m == 0:
        # every position (including “after” the last) is a match
        return torch.arange(n + 1, dtype=torch.long)
    if m > n:
        return torch.empty(0, dtype=torch.long)

    # create all length‐m windows: shape (n-m+1, m)
    windows = haystack.unfold(0, m, 1)  # → (n-m+1)×m
    # compare each window to needle, then all dims must match
    matches = (windows == needle).all(dim=1)  # → (n-m+1)
    # extract the indices where True
    return matches.nonzero(as_tuple=True)[0]


def block_reduce(
    matrix: torch.Tensor,
    block_size: int = 64,
    mode: str = "max",  # one of "max" or "mean"
) -> torch.Tensor:
    """
    Collapse each contiguous block of `block_size` columns in every row of `matrix`
    down to either its maximum or its average, returning a tensor of shape
    (a, ceil(b/block_size)).

    Parameters
    ----------
    matrix : torch.Tensor
        2D tensor of shape (a, b).
    block_size : int
        Number of columns per block (default 64).
    mode : str
        Reduction to apply: "max" or "mean".

    Returns
    -------
    torch.Tensor
        2D tensor of shape (a, num_blocks) where num_blocks = ceil(b / block_size),
        each entry is the max or mean over that block in the original row.
    """
    if matrix.dim() != 2:
        raise ValueError("`matrix` must be 2-dimensional")
    if mode not in {"max", "mean"}:
        raise ValueError("mode must be 'max' or 'mean'")

    a, b = matrix.shape
    full_blocks = b // block_size

    # handle all full blocks
    if full_blocks > 0:
        blocks = matrix[:, : full_blocks * block_size].unfold(
            1, block_size, block_size
        )  # → (a, full_blocks, block_size)
        if mode == "max":
            full_reduced, _ = blocks.max(dim=2)  # → (a, full_blocks)
        else:  # mean
            full_reduced = blocks.mean(dim=2)  # → (a, full_blocks)
    else:
        full_reduced = matrix.new_empty((a, 0))

    # handle any remainder
    rem = b - full_blocks * block_size
    if rem > 0:
        tail = matrix[:, full_blocks * block_size :]  # → (a, rem)
        if mode == "max":
            rem_reduced, _ = tail.max(dim=1, keepdim=True)  # → (a,1)
        else:
            rem_reduced = tail.mean(dim=1, keepdim=True)  # → (a,1)
        return torch.cat([full_reduced, rem_reduced], dim=1)

    return full_reduced


def aggregate_tensors(
    tensors: List[torch.Tensor], mode: str = "mean", percentile: Optional[float] = None
) -> torch.Tensor:
    """
    Aggregate a list of (a,b) tensors entrywise.

    Args:
      tensors: list of torch.Tensor of identical shape.
      mode: one of {"mean", "median", "percentile"}.
      percentile: if mode=="percentile", the desired percentile in [0,100].

    Returns:
      A tensor of shape (a,b) where each entry is the requested
      statistic over that position across all input tensors.

    Raises:
      ValueError if inputs are invalid or percentile is out of range.
    """
    if not tensors:
        raise ValueError("Need at least one tensor")
    shape = tensors[0].shape
    for t in tensors:
        if t.shape != shape:
            raise ValueError("All tensors must have the same shape")

    # stack into (N,a,b) and convert to float
    stacked = torch.stack(tensors, dim=0).float()

    if mode == "mean":
        return stacked.mean(dim=0)

    elif mode == "median":
        vals, _ = stacked.median(dim=0)
        return vals

    elif mode == "percentile":
        if percentile is None:
            raise ValueError("Must specify percentile when mode='percentile'")
        if not (0 <= percentile <= 100):
            raise ValueError("percentile must be between 0 and 100")
        # torch.quantile takes q in [0.,1.]
        q = percentile / 100.0
        return torch.quantile(stacked, q, dim=0)

    else:
        raise ValueError(f"Unknown mode '{mode}'. Choose mean, median, or percentile.")

#### dataset


In [None]:
df_needles = pd.read_parquet(
    hf_hub_download(
        repo_id="opencompass/NeedleBench",
        filename="retrieval_needles/test/0000.parquet",
        repo_type="dataset",
        revision="refs/convert/parquet",
    )
)
df_needles = df_needles[df_needles["language"] == "English"].reset_index(drop=True)

df_haystack = pd.read_parquet(
    hf_hub_download(
        repo_id="opencompass/NeedleBench",
        filename="en_haystack_texts/test/0000.parquet",
        repo_type="dataset",
        revision="refs/convert/parquet",
    )
)

df_haystack = df_haystack[
    df_haystack["text"].str.len().between(5000, 7500)
].reset_index(drop=True)

0


#### model


In [None]:
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
# "Qwen/Qwen2.5-0.5B-Instruct"

MAX_LENGTH = 128

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype="auto",
    device_map="auto",
    output_attentions=True,
    attn_implementation="eager",
    return_dict_in_generate=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [None]:
n_needles = 6
messages, prompt_needle, question = generate_messages(
    df_needles, df_haystack, n=n_needles
)

with torch.no_grad():
    torch.cuda.empty_cache()

text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)

inputs = tokenizer([text], return_tensors="pt").to(model.device)
if len(inputs.input_ids[0]) <= 1250 * n_needles:
    with torch.inference_mode():
        output = model.generate(
            **inputs,
            max_new_tokens=MAX_LENGTH,
            output_attentions=True,
            return_dict_in_generate=True,
            use_cache=True,
            do_sample=True,
        )

    generated_ids = [
        output_ids[len(input_ids) :]
        for input_ids, output_ids in zip(inputs.input_ids, output.sequences)
    ]

    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
else:
    print(f"Text length: {len(text)}")
    print(f"Context length ({len(inputs.input_ids[0])} tokens) too long")
    raise Exception(
        f"Text length: {len(text)}\nContext length ({len(inputs.input_ids[0])} tokens) too long"
    )

In [None]:
print(f"Text Length: {len(text)}")
print(f"Context Length: {len(inputs.input_ids[0])}")
print(f"Question: {question}")
print(f"Correct Answer: {prompt_needle['gold_standard_answer']}")
print(f"Response: {response}")
print(f"Grade: {grade(response, prompt_needle['gold_standard_answer'])}")
print(f"Needle Position: {prompt_needle.name}")

with open("log.txt", "w") as f:
    f.write(f"Model: {MODEL_NAME}\n")
    f.write(f"Text Length: {len(text)}\n")
    f.write(f"Context Length: {len(inputs.input_ids[0])}\n")
    f.write(f"Question: {question}\n")
    f.write(f"Correct Answer: {prompt_needle['gold_standard_answer']}\n")
    f.write(f"Response: {response}\n")
    f.write(f"Grade: {grade(response, prompt_needle['gold_standard_answer'])}\n")
    f.write(f"Needle Position (zero indexed): {prompt_needle.name}\n")
    f.write(f"Total Number of Needles: {n_needles}")


In [None]:
random_string_tokens = tokenizer(
    [" " + prompt_needle["needle"]], return_tensors="pt"
).to(model.device)
random_string_tokens = random_string_tokens.input_ids[0]

indices = find_subtensor_indices(inputs.input_ids[0], random_string_tokens)

index_start = indices[0].item()
index_end = index_start + len(random_string_tokens)

n_layers = model.config.num_hidden_layers
n_heads = model.config.num_attention_heads
div_factor = 2

Path("imgs").mkdir(parents=True, exist_ok=True)

attention = output.attentions

In [None]:
random_string_tokens = tokenizer(["<|im_start|>"], return_tensors="pt").to(model.device)
random_string_tokens = random_string_tokens.input_ids[0]

msg_indices = find_subtensor_indices(inputs.input_ids[0], random_string_tokens)

# desired_msg_index = test["desired_msg_index"]

# index_start = indices[desired_msg_index + 2].item() + 2
# index_end = indices[desired_msg_index + 3].item() - 2

print(msg_indices)
print(
    "".join(
        tokenizer.batch_decode(
            inputs.input_ids[0][msg_indices[-2] + 2 : msg_indices[-1] - 2]
        )
    )
)

#### heatmaps


In [None]:
def plot_heatmaps(head_weights, agg_mode="mean", save_title="imgs/avg_per_head.png"):
    fig, axes = plt.subplots(
        div_factor, n_heads // div_factor, figsize=(15 * div_factor, 5 * div_factor)
    )
    for i, ax in enumerate(axes.flat):
        weights = (
            aggregate_tensors(head_weights[i], mode=agg_mode, percentile=99)
            .cpu()
            .float()
            .numpy()
        )
        sns.heatmap(
            weights,
            ax=ax,
            cmap="rocket",
            yticklabels=False,
        )
        ax.set_title(f"Head {i + 1}")

    plt.tight_layout()
    plt.savefig(save_title)
    plt.close()


In [None]:
idx = [msg_indices[-2] + 2, msg_indices[-1] - 2]
layer_range = [1, 12]
block_size = 32

head_weights = [
    [
        block_reduce(
            attention[0][layer][0][i][idx[0] : idx[1]],
            mode="mean",
            block_size=block_size,
        )
        # block_reduce(
        #     torch.cat(
        #         (
        #             attention[0][layer][0][i][index_start:index_end],
        #             attention[0][layer][0][i][3401 + 2 : 3868 - 2],
        #         ),
        #         dim=0,
        #     ),
        #     mode="mean",
        # )
        for layer in range(layer_range[0] - 1, layer_range[1])
    ]
    for i in range(n_heads)
]

plot_heatmaps(
    head_weights, agg_mode="mean", save_title="imgs/avg_per_head_mer_layers112.png"
)
plot_heatmaps(
    head_weights,
    agg_mode="percentile",
    save_title="imgs/99_percentile_per_head_mer_layers112.png",
)

head_weights = [
    [
        block_reduce(
            attention[0][layer][0][i][idx[0] : idx[1]],
            mode="max",
            block_size=block_size,
        )
        # block_reduce(
        #     torch.cat(
        #         (
        #             attention[0][layer][0][i][index_start:index_end],
        #             attention[0][layer][0][i][3401 + 2 : 3868 - 2],
        #         ),
        #         dim=0,
        #     ),
        #     mode="mean",
        # )
        for layer in range(layer_range[0] - 1, layer_range[1])
    ]
    for i in range(n_heads)
]

plot_heatmaps(
    head_weights, agg_mode="mean", save_title="imgs/avg_per_head_mxr_layers112.png"
)
plot_heatmaps(
    head_weights,
    agg_mode="percentile",
    save_title="imgs/99_percentile_per_head_mxr_layers112.png",
)

In [None]:
idx = [msg_indices[-2] + 2, msg_indices[-1] - 2]
layer_range = [13, 24]
block_size = 32

head_weights = [
    [
        block_reduce(
            attention[0][layer][0][i][idx[0] : idx[1]],
            mode="mean",
            block_size=block_size,
        )
        # block_reduce(
        #     torch.cat(
        #         (
        #             attention[0][layer][0][i][index_start:index_end],
        #             attention[0][layer][0][i][3401 + 2 : 3868 - 2],
        #         ),
        #         dim=0,
        #     ),
        #     mode="mean",
        # )
        for layer in range(layer_range[0] - 1, layer_range[1])
    ]
    for i in range(n_heads)
]

plot_heatmaps(
    head_weights, agg_mode="mean", save_title="imgs/avg_per_head_mer_layers1324.png"
)
plot_heatmaps(
    head_weights,
    agg_mode="percentile",
    save_title="imgs/99_percentile_per_head_mer_layers1324.png",
)

head_weights = [
    [
        block_reduce(
            attention[0][layer][0][i][idx[0] : idx[1]],
            mode="max",
            block_size=block_size,
        )
        # block_reduce(
        #     torch.cat(
        #         (
        #             attention[0][layer][0][i][index_start:index_end],
        #             attention[0][layer][0][i][3401 + 2 : 3868 - 2],
        #         ),
        #         dim=0,
        #     ),
        #     mode="mean",
        # )
        for layer in range(layer_range[0] - 1, layer_range[1])
    ]
    for i in range(n_heads)
]

plot_heatmaps(
    head_weights, agg_mode="mean", save_title="imgs/avg_per_head_mxr_layers1324.png"
)
plot_heatmaps(
    head_weights,
    agg_mode="percentile",
    save_title="imgs/99_percentile_per_head_mxr_layers1324.png",
)

In [None]:
!zip -r /content/imgs.zip /content/imgs_max/ /content/imgs_mean/ /content/log.txt

from google.colab import files

files.download("/content/imgs.zip")

In [None]:
# for layer in range(n_layers):
#     fig, axes = plt.subplots(
#         div_factor, n_heads // div_factor, figsize=(25 * div_factor, 5 * div_factor)
#     )
#     for i, ax in enumerate(axes.flat):
#         weights = attention[0][layer][0][i][index_start:index_end]
#         weights = block_reduce(weights, mode="mean").cpu().float().numpy()
#         sns.heatmap(
#             weights,
#             ax=ax,
#             cmap="bone",
#             yticklabels=False,
#         )
#         ax.set_title(f"Head {i + 1}")
#     plt.tight_layout()
#     plt.savefig(f"imgs_mean/layer_{layer + 1:02}_mean.png")
#     plt.close()

# for layer in range(n_layers):
#     fig, axes = plt.subplots(
#         div_factor, n_heads // div_factor, figsize=(25 * div_factor, 5 * div_factor)
#     )
#     for i, ax in enumerate(axes.flat):
#         weights = attention[0][layer][0][i][index_start:index_end]
#         weights = block_reduce(weights).cpu().float().numpy()
#         sns.heatmap(
#             weights,
#             ax=ax,
#             cmap="bone",
#             yticklabels=False,
#         )
#         ax.set_title(f"Head {i + 1}")
#     plt.tight_layout()
#     plt.savefig(f"imgs_max/layer_{layer + 1:02}_max.png")
#     plt.close()