In [None]:
!uv pip install transformers matplotlib pandas seaborn torch

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from huggingface_hub import hf_hub_download
import json
from difflib import SequenceMatcher
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

#### utils


In [None]:
def get_random_needles(df_needles, n=5):
    """
    Get random needles of different types from the dataframe.
    """
    df = df_needles.sample(n=n)
    df = df.drop_duplicates(subset=["arg1"], keep="first").reset_index(drop=True)
    return df


def generate_context(df_haystack, df_needles) -> str:
    """
    Generate context for the promp.
    """
    context = []
    for i in range(len(df_needles)):
        needle = df_needles.iloc[i]["needle"]
        haystack = ""
        if i < len(df_haystack) - 1:
            haystack = df_haystack.iloc[i]["text"]

        context.append(needle)
        if haystack:
            context.append(" " + haystack + " ")
    return "".join(context)


def generate_messages(df_needles, df_haystack, n=5):
    """
    Generate messages for the model.
    Args:
        df_needles (pd.DataFrame): DataFrame containing the needles.
        df_haystack (pd.DataFrame): DataFrame containing the haystacks.
        n (int): Number of random needles to select.
    Returns:
        messages (list): List of messages for the model.
        prompt_needle (pd.Series): Random needle selected for the prompt.
    """
    df_rand_needles = get_random_needles(df_needles, n=n)
    context = generate_context(df_haystack, df_rand_needles)

    prompt_needle = df_rand_needles.sample(n=1).reset_index(drop=True).iloc[0]

    sys_prompt = "You are an intelligent AI assistant skilled in answering user questions base on documents provided by the user. Please keep your answers concise and clear. Do not talk about irrelevant topics or repeat your answers. The document given to you by the user is:"
    messages = [
        {
            "role": "system",
            "content": sys_prompt,
        },
        {
            "role": "user",
            "content": context,
        },
        {
            "role": "user",
            "content": prompt_needle["retrieval_question"],
        },
    ]
    return messages, prompt_needle


def grade(response, answer) -> float:
    return float(SequenceMatcher(None, response, answer).ratio())

#### dataset


In [None]:
df_needles = pd.read_parquet(
    hf_hub_download(
        repo_id="opencompass/NeedleBench",
        filename="retrieval_needles/test/0000.parquet",
        repo_type="dataset",
        revision="refs/convert/parquet",
    )
)
df_needles = df_needles[df_needles["language"] == "English"].reset_index(drop=True)

df_haystack = pd.read_parquet(
    hf_hub_download(
        repo_id="opencompass/NeedleBench",
        filename="en_haystack_texts/test/0000.parquet",
        repo_type="dataset",
        revision="refs/convert/parquet",
    )
)

df_haystack = df_haystack[
    df_haystack["text"].str.len().between(5000, 7500)
].reset_index(drop=True)

messages, prompt_needle = generate_messages(df_needles, df_haystack, n=5)

#### model


In [None]:
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
# "Qwen/Qwen2.5-0.5B-Instruct"

MAX_LENGTH = 512

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype="auto",
    device_map="auto",
    output_attentions=True,
    attn_implementation="eager",
    return_dict_in_generate=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [None]:
with torch.no_grad():
    torch.cuda.empty_cache()

text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)

inputs = tokenizer([text], return_tensors="pt").to(model.device)
if len(inputs.input_ids[0]) <= 6000:
    with torch.inference_mode():
        output = model.generate(
            **inputs,
            max_new_tokens=MAX_LENGTH,
            output_attentions=True,
            return_dict_in_generate=True,
            use_cache=True,
            do_sample=True,
        )

    generated_ids = [
        output_ids[len(input_ids) :]
        for input_ids, output_ids in zip(inputs.input_ids, output.sequences)
    ]

    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
else:
    print(f"Text length: {len(text)}")
    print(f"Context length ({len(inputs.input_ids[0])} tokens) too long")
    raise Exception(
        f"Text length: {len(text)}\nContext length ({len(inputs.input_ids[0])} tokens) too long"
    )

In [None]:
print(f"Context Length: {len(text)}")
print(prompt_needle["retrieval_question"])
print(prompt_needle["gold_standard_answer"])
print(response)
print(grade(response, prompt_needle["gold_standard_answer"]))