In [None]:
# --- 1. Load all needed data ---
import pandas as pd
from sentence_transformers import CrossEncoder
from tqdm import tqdm
import numpy as np

from eval_metrics import average_precision_at_k, ndcg_at_k, recall_at_k

# Files produced earlier
retrieval_file = "dpr_results.tsv"          # qid, retrieved_pid, rank, vector_score
queries_file   = "sampled_queries_1k.tsv"   # qid, query
qrels_file     = "qrels_for_eval.tsv"       # qid, pid, rel
collection_file = "common_dataset_80k.tsv"  # pid, passage text

# Load
dpr_results = pd.read_csv(retrieval_file, sep="\t", dtype=str)
dpr_results["rank"] = dpr_results["rank"].astype(int)
dpr_results["vector_score"] = dpr_results["vector_score"].astype(float)

queries_df   = pd.read_csv(queries_file, sep="\t", dtype=str)
qrels_df     = pd.read_csv(qrels_file, sep="\t", dtype=str)
collection   = pd.read_csv(collection_file, sep="\t", names=["pid","passage"], dtype={"pid":str,"passage":str})

# Build relevance lookup: qid -> set of positive pids
rel_sets = qrels_df.groupby("qid")["pid"].apply(set).to_dict()

# --- 2. Merge to attach query text + passage text ---
# Attach query
cand_df = dpr_results.merge(queries_df, on="qid", how="left")

# Attach passage text
cand_df = cand_df.merge(collection.rename(columns={"pid":"retrieved_pid"}),
                        on="retrieved_pid", how="left")

# Rename for clarity
cand_df = cand_df.rename(columns={"retrieved_pid":"pid", "passage":"passage_text"})

# Drop rows without passage text (should be none if IDs align)
cand_df = cand_df.dropna(subset=["passage_text","query"])

# --- 3. Add binary relevance label (for metrics) ---
cand_df["label"] = cand_df.apply(
    lambda r: 1 if r["qid"] in rel_sets and r["pid"] in rel_sets[r["qid"]] else 0,
    axis=1
)

# --- 4. Load (trained) CrossEncoder ---
cross_model = CrossEncoder("./cross-encoder-model")  # or a pretrained name

# --- 5. Cross-encoder scoring (batched) ---
pairs = list(zip(cand_df["query"], cand_df["passage_text"]))
BATCH_SIZE = 32
scores = []
for i in tqdm(range(0, len(pairs), BATCH_SIZE), desc="CrossEncoder scoring"):
    batch_pairs = pairs[i:i+BATCH_SIZE]
    scores.extend(cross_model.predict(batch_pairs))

cand_df["cross_score"] = scores

# --- 6. Re-rank per query ---
def rerank_group(df):
    return df.sort_values("cross_score", ascending=False).reset_index(drop=True)

reranked = cand_df.groupby("qid", group_keys=False).apply(rerank_group)

# (Optional) Save reranked candidates
reranked.to_csv("dpr_reranked_cross.tsv", sep="\t", index=False)

# --- 7. Metric computation (original DPR vs CrossEncoder) ---
def compute_metrics(df, score_col, top_k=10, recall_k=50):
    ap_list, ndcg_list, recall_list = [], [], []
    for qid, group in df.groupby("qid"):
        # Order by chosen score
        g = group.sort_values(score_col, ascending=False)
        ranked_pids = g["pid"].tolist()
        rel_set = {pid for pid in ranked_pids if pid in rel_sets.get(qid, set())}
        # Need rel_dict for nDCG (binary)
        rel_dict = {pid: 1 for pid in rel_set}

        ap_list.append(average_precision_at_k(ranked_pids, rel_set, k=top_k))
        ndcg_list.append(ndcg_at_k(ranked_pids, rel_dict, k=top_k))
        recall_list.append(recall_at_k(ranked_pids, rel_set, k=recall_k))

    return {
        f"MAP@{top_k}": float(np.mean(ap_list)) if ap_list else 0.0,
        f"nDCG@{top_k}": float(np.mean(ndcg_list)) if ndcg_list else 0.0,
        f"Recall@{recall_k}": float(np.mean(recall_list)) if recall_list else 0.0,
        "num_queries": df["qid"].nunique()
    }

orig_metrics  = compute_metrics(cand_df, "vector_score", top_k=10, recall_k=10)
cross_metrics = compute_metrics(reranked, "cross_score", top_k=10, recall_k=10)

print("Original DPR ranking:", orig_metrics)
print("Cross-encoder reranked:", cross_metrics)

  from .autonotebook import tqdm as notebook_tqdm
CrossEncoder scoring: 100%|██████████| 313/313 [00:17<00:00, 17.66it/s]
  reranked = cand_df.groupby("qid", group_keys=False).apply(rerank_group)


Original DPR ranking: {'MAP@10': 0.5729248412698413, 'nDCG@10': 0.6514923782878013, 'Recall@10': 0.827, 'num_queries': 1000}
Cross-encoder reranked: {'MAP@10': 0.7279261904761904, 'nDCG@10': 0.7588542449935001, 'Recall@10': 0.827, 'num_queries': 1000}
