In [None]:
!uv pip install transformers datasets tiktoken matplotlib pandas seaborn torch

[2mResolved [1m97 packages[0m [2min 0.37ms[0m[0m
[2mAudited [1m78 packages[0m [2min 0.02ms[0m[0m


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from huggingface_hub import hf_hub_download
import json
from difflib import SequenceMatcher
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch


#### config.py


In [1]:
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
MAX_LENGTH = 512

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype="auto",
    device_map="auto",
    output_attentions=True,
    attn_implementation="eager",
    return_dict_in_generate=True,
).eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

#### mrcr_utils.py


In [None]:
def load_mrcr_parquet():
    df = pd.read_parquet(
        hf_hub_download(
            repo_id="openai/mrcr", filename="2needle.parquet", repo_type="dataset"
        )
    )
    return df


def grade(response, answer, random_string_to_prepend) -> float:
    if not response.startswith(random_string_to_prepend):
        return 0
    response = response.removeprefix(random_string_to_prepend)
    answer = answer.removeprefix(random_string_to_prepend)
    return float(SequenceMatcher(None, response, answer).ratio())


def n_tokens(messages: list[dict]) -> int:
    """
    Count tokens in messages.
    """
    text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    return len(tokenizer(text).input_ids)


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
df = load_mrcr_parquet()
dataset = df[df["n_chars"] < 20000]

for index, row in dataset.iterrows():
    messages = json.loads(row["prompt"])
    if len(row["prompt"]) < 20000:
        print(n_tokens(messages))

In [None]:
test = dataset.iloc[0]

messages = json.loads(test["prompt"])
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer([text], return_tensors="pt").to(model.device)

output = model.generate(
    **inputs,
    max_new_tokens=512,
    output_attentions=True,
    return_dict_in_generate=True,
    use_cache=True,
)

generated_ids = [
    output_ids[len(input_ids) :]
    for input_ids, output_ids in zip(inputs.input_ids, output.sequences)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)
print(grade(response, test["answer"], test["random_string_to_prepend"]))

In [None]:
def find_subtensor_indices(
    haystack: torch.LongTensor, needle: torch.LongTensor
) -> torch.LongTensor:
    """
    Returns a 1D tensor of all start‐positions where `needle`
    appears as a contiguous slice of `haystack`.
    """
    n, m = haystack.size(0), needle.size(0)
    if m == 0:
        # every position (including “after” the last) is a match
        return torch.arange(n + 1, dtype=torch.long)
    if m > n:
        return torch.empty(0, dtype=torch.long)

    # create all length‐m windows: shape (n-m+1, m)
    windows = haystack.unfold(0, m, 1)  # → (n-m+1)×m
    # compare each window to needle, then all dims must match
    matches = (windows == needle).all(dim=1)  # → (n-m+1)
    # extract the indices where True
    return matches.nonzero(as_tuple=True)[0]


random_string_tokens = tokenizer(["<|im_start|>"], return_tensors="pt").to(model.device)
random_string_tokens = random_string_tokens.input_ids[0]

indices = find_subtensor_indices(inputs.input_ids[0], random_string_tokens)

desired_msg_index = test["desired_msg_index"]

index_start = indices[desired_msg_index + 2].item() + 2
index_end = indices[desired_msg_index + 3].item() - 2

In [None]:
attention = output.attentions
# attention_matrix_l1 = attention[0][0][0].cpu().float().numpy()

# sns.heatmap(attention_matrix_l1[0], xticklabels=tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]),
#             yticklabels=tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]), cmap="viridis")
# plt.title("Attention Weights")
# plt.show()

fig, axes = plt.subplots(2, 7, figsize=(50, 10))
for i, ax in enumerate(axes.flat):
    sns.heatmap(
        attention[0][0][0][i][index_start:index_end].cpu().float().numpy(),
        ax=ax,
        cmap="bone",
        yticklabels=False,
    )
    ax.set_title(f"Head {i + 1}")
plt.tight_layout()
plt.show()