In [8]:
import torch
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from sklearn.neighbors import NearestNeighbors


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
top_k = 20


encoder_path = "/kaggle/input/distillbert_retriever/pytorch/default/1/model/encoder"
tokenizer_path = "/kaggle/input/distillbert_retriever/pytorch/default/1/model/tokenizer"
model = AutoModel.from_pretrained(encoder_path).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)


def embed_texts(texts, batch_size=64):
    all_embeddings = []
    with torch.no_grad():
        for i in tqdm(range(0, len(texts), batch_size), desc="Embedding"):
            batch = texts[i:i+batch_size]
            inputs = tokenizer(
                batch,
                padding="max_length",
                truncation=True,
                max_length=256,
                return_tensors="pt"
            ).to(device)
            outputs = model(**inputs)
            embeddings = outputs.last_hidden_state.mean(dim=1).cpu()
            all_embeddings.append(embeddings)
    return torch.cat(all_embeddings, dim=0).numpy()


squad = load_dataset("squad", split="validation")
questions = squad["question"]
contexts  = squad["context"]


context_embeddings = embed_texts(contexts)


index = NearestNeighbors(n_neighbors=top_k, metric="cosine")
index.fit(context_embeddings)


query_embeddings = embed_texts(questions)


distances, indices = index.kneighbors(query_embeddings, return_distance=True)


n = len(questions)

recall_at_1 = np.mean([1 if i == idxs[0] else 0
                       for i, idxs in enumerate(indices)])


recall_at_5 = np.mean([1 if i in idxs[:5] else 0
                       for i, idxs in enumerate(indices)])


recall_at_20 = np.mean([1 if i in idxs else 0
                        for i, idxs in enumerate(indices)])


rr_scores = [1/(idx.tolist().index(i)+1) if i in idx else 0
             for i, idx in enumerate(indices)]
mrr = np.mean(rr_scores)

print(f"Retriever Evaluation on SQuAD Test Set:")
print(f"  Recall@1 : {recall_at_1:.3f}")
print(f"  Recall@5 : {recall_at_5:.3f}")
print(f"  Recall@20: {recall_at_20:.3f}")
print(f"  MRR@20   : {mrr:.3f}")

Embedding: 100%|██████████| 166/166 [01:16<00:00,  2.16it/s]
Embedding: 100%|██████████| 166/166 [01:14<00:00,  2.23it/s]


Retriever Evaluation on SQuAD Test Set:
  Recall@1 : 0.111
  Recall@5 : 0.499
  Recall@20: 0.744
  MRR@20   : 0.266
