In [33]:
import torch
import seaborn as sns
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

#### Configuration

In [34]:
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CONTEXT_LENGTH = 1024
MAX_NEW_TOKENS = 50

#### Load Model

In [35]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    torch_dtype="auto",
    attn_implementation="eager",
    return_dict_in_generate=True,
).eval().to(DEVICE)

#### Load Dataset

In [40]:
dataset = load_dataset("RMT-team/babilong", "1k", split="qa1")

Generating qa1 split: 100 examples [00:00, 2677.38 examples/s]
Generating qa2 split: 100 examples [00:00, 25055.58 examples/s]
Generating qa3 split: 100 examples [00:00, 28554.05 examples/s]
Generating qa4 split: 100 examples [00:00, 25030.16 examples/s]
Generating qa5 split: 100 examples [00:00, 24558.25 examples/s]
Generating qa6 split: 100 examples [00:00, 22762.97 examples/s]
Generating qa7 split: 100 examples [00:00, 24107.97 examples/s]
Generating qa8 split: 100 examples [00:00, 27784.21 examples/s]
Generating qa9 split: 100 examples [00:00, 26588.30 examples/s]
Generating qa10 split: 100 examples [00:00, 26886.56 examples/s]


#### Utilities

In [37]:
def format_prompt(story, question):
    return f"{story}\n\nQuestion: {question}\nAnswer:"

def contains_needle(response, target):
    return float(target.lower() in response.lower())

def find_subtensor_indices(haystack: torch.LongTensor, needle: torch.LongTensor) -> torch.LongTensor:
    n, m = haystack.size(0), needle.size(0)
    if m == 0:
        return torch.arange(n + 1, dtype=torch.long)
    if m > n:
        return torch.empty(0, dtype=torch.long)
    windows = haystack.unfold(0, m, 1)
    matches = (windows == needle).all(dim=1)
    return matches.nonzero(as_tuple=True)[0]

def plot_attention(attentions, input_ids, needle_ids, layer=0):
    indices = find_subtensor_indices(input_ids, needle_ids)
    if len(indices) == 0:
        print("Needle tokens not found in input. Skipping attention plot.")
        return
    index_start = indices[0].item()
    index_end = index_start + len(needle_ids)
    attention = attentions[layer][0]  # shape: (num_heads, seq_len, seq_len)

    fig, axes = plt.subplots(2, 7, figsize=(28, 10))
    for i, ax in enumerate(axes.flat):
        sns.heatmap(
            attention[i][index_start:index_end].detach().cpu().numpy(),
            ax=ax, cmap="viridis", yticklabels=False
        )
        ax.set_title(f"Head {i + 1}")
    plt.tight_layout()
    plt.show()

#### Run Experiement

In [42]:
correct = 0
total = 5

for i in range(total):
    print(i)
    sample = dataset[i]
    prompt = format_prompt(sample["input"], sample["question"])
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)

    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=50,
            return_dict_in_generate=True,
            use_cache=True
        )

    generated_ids = output.sequences[0][inputs["input_ids"].shape[1]:]
    response = tokenizer.decode(generated_ids, skip_special_tokens=True)

    if contains_needle(response, sample["target"]):
        correct += 1

    # log first few
    # if i < 2:
    #     print(f"===== SAMPLE {i + 1} =====")
    #     print("Target:", sample["target"])
    #     print("Model Output:", response)
    #     print("Contains needle:", contains_needle(response, sample["target"]))
    #     print()

# Final score
accuracy = correct / total
print(f"Recall Score (contains_needle) over {total} samples: {accuracy:.3f}")

0
1
2
3
4
Recall Score (contains_needle) over 5 samples: 0.400
