In [None]:
!uv pip install auto_gptq
!uv pip install optimum
!uv pip install transformers matplotlib pandas seaborn torch

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from huggingface_hub import hf_hub_download
import json
from difflib import SequenceMatcher
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch


[2mResolved [1m97 packages[0m [2min 0.37ms[0m[0m
[2mAudited [1m78 packages[0m [2min 0.02ms[0m[0m


#### config.py


In [None]:
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
# "Qwen/Qwen2.5-0.5B-Instruct"

MAX_LENGTH = 512

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype="auto",
    device_map="auto",
    output_attentions=True,
    attn_implementation="eager",
    return_dict_in_generate=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

#### mrcr_utils.py


In [None]:
def load_mrcr_parquet():
    df = pd.read_parquet(
        hf_hub_download(
            repo_id="openai/mrcr", filename="2needle.parquet", repo_type="dataset"
        )
    )
    return df


def grade(response, answer, random_string_to_prepend) -> float:
    # if not response.startswith(random_string_to_prepend):
    #     return 0
    response = response.removeprefix(random_string_to_prepend)
    answer = answer.removeprefix(random_string_to_prepend)
    return float(SequenceMatcher(None, response, answer).ratio())


def n_tokens(messages: list[dict]) -> int:
    """
    Count tokens in messages.
    """
    text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    return len(tokenizer(text).input_ids)


df = load_mrcr_parquet()
dataset = df[df["n_chars"] < 20000]

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
test = dataset.iloc[1]
messages = json.loads(test["prompt"])

with torch.no_grad():
    torch.cuda.empty_cache()

text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer([text], return_tensors="pt").to(model.device)

with torch.inference_mode():
    output = model.generate(
        **inputs,
        max_new_tokens=MAX_LENGTH,
        output_attentions=True,
        return_dict_in_generate=True,
        use_cache=True,
        do_sample=True,
    )

generated_ids = [
    output_ids[len(input_ids) :]
    for input_ids, output_ids in zip(inputs.input_ids, output.sequences)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
# print(response)
print(grade(response, test["answer"], test["random_string_to_prepend"]))
attention = output.attentions

In [None]:
def find_subtensor_indices(
    haystack: torch.LongTensor, needle: torch.LongTensor
) -> torch.LongTensor:
    """
    Returns a 1D tensor of all start‐positions where `needle`
    appears as a contiguous slice of `haystack`.
    """
    n, m = haystack.size(0), needle.size(0)
    if m == 0:
        # every position (including “after” the last) is a match
        return torch.arange(n + 1, dtype=torch.long)
    if m > n:
        return torch.empty(0, dtype=torch.long)

    # create all length‐m windows: shape (n-m+1, m)
    windows = haystack.unfold(0, m, 1)  # → (n-m+1)×m
    # compare each window to needle, then all dims must match
    matches = (windows == needle).all(dim=1)  # → (n-m+1)
    # extract the indices where True
    return matches.nonzero(as_tuple=True)[0]


random_string_tokens = tokenizer(["<|im_start|>"], return_tensors="pt").to(model.device)
random_string_tokens = random_string_tokens.input_ids[0]

indices = find_subtensor_indices(inputs.input_ids[0], random_string_tokens)

desired_msg_index = test["desired_msg_index"]

index_start = indices[desired_msg_index + 2].item() + 2
index_end = indices[desired_msg_index + 3].item() - 2

In [None]:
def block_reduce(
    matrix: torch.Tensor,
    block_size: int = 64,
    mode: str = "max",  # one of "max" or "mean"
) -> torch.Tensor:
    """
    Collapse each contiguous block of `block_size` columns in every row of `matrix`
    down to either its maximum or its average, returning a tensor of shape
    (a, ceil(b/block_size)).

    Parameters
    ----------
    matrix : torch.Tensor
        2D tensor of shape (a, b).
    block_size : int
        Number of columns per block (default 64).
    mode : str
        Reduction to apply: "max" or "mean".

    Returns
    -------
    torch.Tensor
        2D tensor of shape (a, num_blocks) where num_blocks = ceil(b / block_size),
        each entry is the max or mean over that block in the original row.
    """
    if matrix.dim() != 2:
        raise ValueError("`matrix` must be 2-dimensional")
    if mode not in {"max", "mean"}:
        raise ValueError("mode must be 'max' or 'mean'")

    a, b = matrix.shape
    full_blocks = b // block_size

    # handle all full blocks
    if full_blocks > 0:
        blocks = matrix[:, : full_blocks * block_size].unfold(
            1, block_size, block_size
        )  # → (a, full_blocks, block_size)
        if mode == "max":
            full_reduced, _ = blocks.max(dim=2)  # → (a, full_blocks)
        else:  # mean
            full_reduced = blocks.mean(dim=2)  # → (a, full_blocks)
    else:
        full_reduced = matrix.new_empty((a, 0))

    # handle any remainder
    rem = b - full_blocks * block_size
    if rem > 0:
        tail = matrix[:, full_blocks * block_size :]  # → (a, rem)
        if mode == "max":
            rem_reduced, _ = tail.max(dim=1, keepdim=True)  # → (a,1)
        else:
            rem_reduced = tail.mean(dim=1, keepdim=True)  # → (a,1)
        return torch.cat([full_reduced, rem_reduced], dim=1)

    return full_reduced


In [None]:
n_layers = 28
n_heads = 12
div_factor = 4

for layer in range(n_layers):
    fig, axes = plt.subplots(
        div_factor, n_heads // div_factor, figsize=(8 * div_factor, 5 * div_factor)
    )
    for i, ax in enumerate(axes.flat):
        weights = attention[0][layer][0][i][index_start:index_end]
        weights = block_reduce(weights, mode="mean").cpu().float().numpy()
        sns.heatmap(
            weights,
            ax=ax,
            cmap="bone",
            yticklabels=False,
        )
        ax.set_title(f"Head {i + 1}")
    plt.tight_layout()
    plt.savefig(f"imgs_mean/layer_{layer + 1:02}_mean.png")
    plt.close()

for layer in range(n_layers):
    fig, axes = plt.subplots(2, n_heads // 2, figsize=(50, 10))
    for i, ax in enumerate(axes.flat):
        weights = attention[0][layer][0][i][index_start:index_end]
        weights = block_reduce(weights).cpu().float().numpy()
        sns.heatmap(
            weights,
            ax=ax,
            cmap="bone",
            yticklabels=False,
        )
        ax.set_title(f"Head {i + 1}")
    plt.tight_layout()
    plt.savefig(f"imgs_max/layer_{layer + 1:02}_max.png")
    plt.close()