In [None]:
from generation.baselines.knnlm.knnlm import KNNLM
from generation.workshop.dataloader import ModelInputPreprocessor
import numpy as np
from tqdm import tqdm


datasets = [
    "clerc",
    "cuad",
    "echr_qa",
    "oal_qa",
    "obli_qa",
]

def compute_length_percentiles(hf_dataset, column_name="context"):
    lengths = []
    list_lengths = []
    for example in hf_dataset:
        if column_name == "context":
            record = f"{example['context_prefix']}\n\n{example['context']}"
        elif column_name == "prompt":
            record = example["prompt"]
        else:
            record = example[column_name]
        lengths.append(len(record.split()))
    if not lengths:
        print(f"[!] No documents found in column={column_name}.")
        return None

    lengths = np.array(lengths)

    stats = {
        "mean": float(np.mean(lengths)),
        "p50": float(np.percentile(lengths, 50)),
    }
    return stats


knnlm_model = KNNLM(model_name="mistralai/Mistral-7B-Instruct-v0.3", device=0)
with open("stats.txt", "w", encoding="utf-8") as f:
    for dataset in tqdm(datasets):
        top_k_passages = 10 if (dataset == "cuad" or dataset == "obli_qa") else 3
        config = {
            "dataset_percentage": 1.0,
            "dataset": dataset,
            "method": "knnlm-context-entropy",
            "setup": "bm25_relevant_passages_oracle_documents",
            "split": "test",
            "top_k_passages": top_k_passages,
            "use_instructions": True,
        }
        preprocessor = ModelInputPreprocessor(config)
        work_dataset, original_dataset = preprocessor.process_dataset(tokenizer=knnlm_model.tokenizer, max_tokens=knnlm_model.model.config.max_position_embeddings)

        citations_percentiles = compute_length_percentiles(work_dataset, column_name="context")
        gold_text_percentiles = compute_length_percentiles(work_dataset, column_name="prompt")
        f.write(f"Dataset: {dataset}\n")
        f.write(f"Context Percentiles: {citations_percentiles['mean']}, {citations_percentiles['p50']}\n")
        f.write(f"Prompt Percentiles: {gold_text_percentiles['mean']}, {gold_text_percentiles['p50']}\n")


