<a href="https://colab.research.google.com/github/ugrani/experiments/blob/main/rnr_experiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#get the dataset from hugging face
!pip install -q datasets

In [None]:
#load the dataset
from datasets import load_dataset

dataset = load_dataset(
    "ms_marco",
    "v1.1",
    split="validation"
)

print(dataset)

In [None]:
example = dataset[0]

print("Query:")
print(example["query"])

print("\nPassages (showing first 2):")
for p in example["passages"]["passage_text"]:
  print("-", p[:200], "...")

In [None]:
#subsample aggressively to control the GPU cost
import random

random.seed(42)

NUM_QUERIES = 200
TARGET_PASSAGES = 10_000
NUM_PASSAGES = 10_000

sampled_queries = random.sample(range(len(dataset)), NUM_QUERIES)

queries = []
relevant_passages = []
all_passages = []

for idx in sampled_queries:
    row = dataset[idx]
    queries.append(row["query"])

    for passage in row["passages"]["passage_text"]:
        all_passages.append(passage)

# Deduplicate and subsample passages
all_passages = list(set(all_passages))
all_passages = random.sample(
    all_passages,
    min(NUM_PASSAGES, len(all_passages))
)

print(f"Queries: {len(queries)}")
print(f"Passages: {len(all_passages)}")

# Start corpus with ALL passages from your sampled queries (so relevant ones are included)
corpus = set()
for idx in sampled_queries:
    row = dataset[idx]
    for p in row["passages"]["passage_text"]:
        corpus.add(p)

print("Initial corpus size (from sampled queries):", len(corpus))

# Add distractor passages from OTHER queries in validation split
all_indices = list(range(len(dataset)))
random.shuffle(all_indices)

for idx in all_indices:
    if idx in sampled_queries:
        continue
    row = dataset[idx]
    for p in row["passages"]["passage_text"]:
        corpus.add(p)
        if len(corpus) >= TARGET_PASSAGES:
            break
    if len(corpus) >= TARGET_PASSAGES:
        break

all_passages = list(corpus)
print("Final corpus size:", len(all_passages))


In [None]:
#define the relevance
query_to_relevant = {}

for idx in sampled_queries:
    row = dataset[idx]
    query = row["query"]

    relevant = [
        p for p, is_rel in zip(
            row["passages"]["passage_text"],
            row["passages"]["is_selected"]
        )
        if is_rel == 1
    ]

    query_to_relevant[query] = relevant


In [None]:
!pip install -q sentence-transformers faiss-cpu


In [None]:
from sentence_transformers import SentenceTransformer
import numpy as np

model_name = "sentence-transformers/all-MiniLM-L6-v2"
model = SentenceTransformer(model_name)

# Encode passages
passage_emb = model.encode(
    all_passages,
    batch_size=128,
    show_progress_bar=True,
    convert_to_numpy=True,
    normalize_embeddings=True,   # important for cosine via dot product
)

print(passage_emb.shape)  # (num_passages, dim)


In [None]:
import faiss

dim = passage_emb.shape[1]
index = faiss.IndexFlatIP(dim)  # inner product; works with normalized embeddings as cosine
index.add(passage_emb)

print("FAISS ntotal:", index.ntotal)


In [None]:
def retrieve(query, k=50):
    q_emb = model.encode(
        [query],
        convert_to_numpy=True,
        normalize_embeddings=True
    )
    scores, idxs = index.search(q_emb, k)
    return [all_passages[i] for i in idxs[0]], scores[0]


In [None]:
q = queries[0]
top, scores = retrieve(q, k=5)
print("Query:", q)
print("\nTop results:")
for i, p in enumerate(top, 1):
    print(f"\n#{i} (score={scores[i-1]:.3f})\n{p[:300]}...")


In [None]:
def recall_at_k(queries, query_to_relevant, k):
    hits = 0
    eligible = 0

    for q in queries:
        rels = query_to_relevant.get(q, [])
        if not rels:   # some queries may have 0 labeled relevant passages
            continue

        eligible += 1
        retrieved, _ = retrieve(q, k=k)
        retrieved_set = set(retrieved)

        # hit if ANY relevant passage appears in top-k
        if any(r in retrieved_set for r in rels):
            hits += 1

    return hits / eligible if eligible else 0.0, eligible

Ks = [10, 50, 100, 200]
for k in Ks:
    r, n = recall_at_k(queries, query_to_relevant, k)
    print(f"Recall@{k}: {r:.3f}  (evaluated on {n} queries)")


In [None]:
import matplotlib.pyplot as plt

recalls = []
for k in Ks:
    r, _ = recall_at_k(queries, query_to_relevant, k)
    recalls.append(r)

plt.figure()
plt.plot(Ks, recalls, marker="o")
plt.xlabel("K")
plt.ylabel("Recall@K")
plt.title("Stage-1 Dense Retrieval Recall vs K")
plt.show()


#next phas focusses on


Dense retriever gets top-K candidates

Cross-encoder scores (query, passage) pairs

Reorder candidates

Evaluate MRR@10 and NDCG@10 using your is_selected labels

In [None]:
!pip install -q sentence-transformers


In [None]:
#load a strong MS MARCO cross-encoder
from sentence_transformers import CrossEncoder

reranker_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
reranker = CrossEncoder(reranker_name)


In [None]:
#Build a query→relevant set in passage-text space
query_to_relevant_set = {q: set(rels) for q, rels in query_to_relevant.items()}


In [None]:
#Rerank function
import numpy as np

def rerank(query, candidates):
    pairs = [(query, p) for p in candidates]
    scores = reranker.predict(pairs)  # numpy array
    order = np.argsort(-scores)       # descending
    reranked = [candidates[i] for i in order]
    reranked_scores = scores[order]
    return reranked, reranked_scores


In [None]:
#Metrics: MRR@10 and NDCG@10
import math

def mrr_at_k(ranked_list, relevant_set, k=10):
    for i, p in enumerate(ranked_list[:k], start=1):
        if p in relevant_set:
            return 1.0 / i
    return 0.0

def ndcg_at_k(ranked_list, relevant_set, k=10):
    dcg = 0.0
    for i, p in enumerate(ranked_list[:k], start=1):
        rel = 1.0 if p in relevant_set else 0.0
        if rel > 0:
            dcg += rel / math.log2(i + 1)
    # Ideal DCG with binary rels: best case is 1 relevant at rank 1 (for our “any rel” framing)
    idcg = 1.0
    return dcg / idcg


In [None]:
#Evaluate: dense-only vs dense+rerank
def eval_ranking(queries, query_to_relevant_set, K_retrieve=50, K_eval=10, do_rerank=False):
    mrrs, ndcgs = [], []
    eligible = 0

    for q in queries:
        rels = query_to_relevant_set.get(q, set())
        if not rels:
            continue
        eligible += 1

        candidates, _ = retrieve(q, k=K_retrieve)  # dense top-K

        ranked = candidates
        if do_rerank:
            ranked, _ = rerank(q, candidates)

        mrrs.append(mrr_at_k(ranked, rels, k=K_eval))
        ndcgs.append(ndcg_at_k(ranked, rels, k=K_eval))

    return float(np.mean(mrrs)), float(np.mean(ndcgs)), eligible

K_retrieve = 50
mrr_dense, ndcg_dense, n = eval_ranking(queries, query_to_relevant_set, K_retrieve=K_retrieve, do_rerank=False)
mrr_rerank, ndcg_rerank, _ = eval_ranking(queries, query_to_relevant_set, K_retrieve=K_retrieve, do_rerank=True)

print(f"Evaluated on {n} queries, retrieve K={K_retrieve}")
print(f"Dense only   -> MRR@10: {mrr_dense:.4f}, NDCG@10: {ndcg_dense:.4f}")
print(f"+ Reranker   -> MRR@10: {mrr_rerank:.4f}, NDCG@10: {ndcg_rerank:.4f}")


In [None]:
# how K impacts reranker cost/quality

for K in [10, 20, 50, 100]:
    mrr_dense, ndcg_dense, n = eval_ranking(queries, query_to_relevant_set, K_retrieve=K, do_rerank=False)
    mrr_rerank, ndcg_rerank, _ = eval_ranking(queries, query_to_relevant_set, K_retrieve=K, do_rerank=True)
    print(f"K={K:>3} | Dense MRR@10 {mrr_dense:.4f} NDCG@10 {ndcg_dense:.4f} "
          f"|| Rerank MRR@10 {mrr_rerank:.4f} NDCG@10 {ndcg_rerank:.4f}")

In [None]:
import matplotlib.pyplot as plt

Ks = [10, 20, 50, 100]

dense_mrr = [0.5643, 0.5643, 0.5643, 0.5643]
rerank_mrr = [0.6640, 0.6649, 0.6655, 0.6655]

dense_ndcg = [0.6890, 0.6890, 0.6890, 0.6890]
rerank_ndcg = [0.7699, 0.7727, 0.7743, 0.7743]

plt.figure()
plt.plot(Ks, dense_mrr, marker="o", label="Dense only")
plt.plot(Ks, rerank_mrr, marker="o", label="Dense + Cross-Encoder")

plt.xlabel("K (Candidates Retrieved)")
plt.ylabel("MRR@10")
plt.title("MRR@10 vs Candidate Set Size (K)")
plt.legend()
plt.grid(True)
plt.show()


plt.figure()
plt.plot(Ks, dense_ndcg, marker="o", label="Dense only")
plt.plot(Ks, rerank_ndcg, marker="o", label="Dense + Cross-Encoder")

plt.xlabel("K (Candidates Retrieved)")
plt.ylabel("NDCG@10")
plt.title("NDCG@10 vs Candidate Set Size (K)")
plt.legend()
plt.grid(True)
plt.show()
