In [4]:
import pandas as pd
from sentence_transformers import CrossEncoder
from tqdm import tqdm
import numpy as np

from eval_metrics import average_precision_at_k, ndcg_at_k, recall_at_k

# Files for evaluation
retrieval_file = "dpr_results.tsv"          # qid, retrieved_pid, rank, vector_score
queries_file   = "sampled_queries_1k.tsv"   # qid, query
qrels_file     = "qrels_for_eval.tsv"       # qid, pid, rel
common_dataset_file = "common_dataset_80k.tsv"  # pid, passage text

# Load
dpr_results = pd.read_csv(retrieval_file, sep="\t", dtype=str)
dpr_results["rank"] = dpr_results["rank"].astype(int)
dpr_results["vector_score"] = dpr_results["vector_score"].astype(float)

queries_df   = pd.read_csv(queries_file, sep="\t", dtype=str)
qrels_df     = pd.read_csv(qrels_file, sep="\t", dtype=str)
common_dataset   = pd.read_csv(common_dataset_file, sep="\t", names=["pid","passage"], dtype={"pid":str,"passage":str})

# Build relevance lookup: qid -> set of positive pids
rel_sets = qrels_df.groupby("qid")["pid"].apply(set).to_dict()
rel_sets # {'100000': {'782783', '782784'}}

{'1000000': {'7264308'},
 '1000003': {'7264269'},
 '1000004': {'7264266'},
 '1000006': {'7264253'},
 '1000012': {'7942175'},
 '1000016': {'7264249'},
 '1000017': {'7264235'},
 '1000025': {'7882562'},
 '1000030': {'7264233'},
 '1000047': {'7886596'},
 '1000054': {'2472752'},
 '1000061': {'7264199'},
 '100007': {'7388139'},
 '1000076': {'7264160'},
 '1000083': {'2745456'},
 '1000085': {'7264147'},
 '1000086': {'770150'},
 '1000096': {'3074820'},
 '1000097': {'7844847'},
 '1000098': {'7264143'},
 '1000101': {'7264131'},
 '1000102': {'2302724', '7264123'},
 '1000107': {'7725278'},
 '1000117': {'7264099'},
 '100013': {'7711719'},
 '1000139': {'7740884'},
 '1000149': {'7264072'},
 '1000151': {'7887533'},
 '1000164': {'7932950'},
 '1000170': {'7264060'},
 '1000173': {'7264053', '7264056'},
 '1000179': {'7264041'},
 '1000183': {'7264029'},
 '1000185': {'7264001'},
 '100019': {'7379101'},
 '100020': {'7399433'},
 '1000202': {'4526068'},
 '1000210': {'7795731'},
 '1000224': {'7879193'},
 '100023

In [5]:
cand_df = dpr_results.merge(queries_df, on="qid", how="left")

cand_df = cand_df.merge(common_dataset.rename(columns={"pid":"retrieved_pid"}),
                        on="retrieved_pid", how="left")

cand_df["label"] = cand_df.apply(
    lambda r: 1 if r["qid"] in rel_sets and r["retrieved_pid"] in rel_sets[r["qid"]] else 0,
    axis=1
)
print(cand_df.shape)
cand_df.head(20)

(10000, 7)


Unnamed: 0,qid,retrieved_pid,rank,vector_score,query,passage,label
0,507646,7548862,1,0.693434,symptoms of flu a & b in children,A: Symptoms of influenza in children include a...,1
1,507646,7619619,2,0.621538,symptoms of flu a & b in children,Below are the symptoms that some individuals m...,0
2,507646,7828612,3,0.581536,symptoms of flu a & b in children,Flu Symptoms. The most common symptoms of the ...,0
3,507646,7480406,4,0.577805,symptoms of flu a & b in children,Symptoms of TEF in adult patients may include:...,0
4,507646,109900,5,0.565854,symptoms of flu a & b in children,"Influenza, commonly known as the flu, is an in...",0
5,507646,7590755,6,0.56441,symptoms of flu a & b in children,The list of signs and symptoms mentioned in va...,0
6,507646,7649123,7,0.561954,symptoms of flu a & b in children,Other mild childhood illnesses: EBV infection ...,0
7,507646,7480405,8,0.558329,symptoms of flu a & b in children,Symptoms of TEF in infants are generally worse...,0
8,507646,964216,9,0.557815,symptoms of flu a & b in children,Signs and symptoms of depression in teens. 1 ...,0
9,507646,7661062,10,0.556867,symptoms of flu a & b in children,About 1 out of 4 people with poliovirus infect...,0


In [None]:
# --- Load (trained) CrossEncoder ---
cross_model = CrossEncoder("./cross-encoder-model")

# --- Calculate Cross-encoder scoring using batching for faster execution ---
pairs = list(zip(cand_df["query"], cand_df["passage"]))
BATCH_SIZE = 32
scores = []
for i in tqdm(range(0, len(pairs), BATCH_SIZE), desc="CrossEncoder scoring"):
    batch_pairs = pairs[i:i+BATCH_SIZE]
    scores.extend(cross_model.predict(batch_pairs))

print("Device:", cross_model.model.device)

cand_df["cross_score"] = scores

# --- Re-rank per query ---
def rerank_group(df):
    return df.sort_values("cross_score", ascending=False).reset_index(drop=True)

reranked = cand_df.groupby("qid", group_keys=False).apply(rerank_group)

reranked.to_csv("dpr_reranked_cross.tsv", sep="\t", index=False)

# --- Metric computation (original DPR vs CrossEncoder) ---
def compute_metrics(df, score_col, top_k=10, recall_k=50):
    ap_list, ndcg_list, recall_list = [], [], []
    for qid, group in df.groupby("qid"):
        # Order by chosen score
        g = group.sort_values(score_col, ascending=False)

        # Example: ['4321', '98', '777', '12005', ...]   
        ranked_pids = g["retrieved_pid"].tolist()

        # Example: {'98', '555', '12005'}
        rel_set = {pid for pid in ranked_pids if pid in rel_sets.get(qid, set())}

        # Need rel_dict for nDCG (binary)
        # Example: {'98': 1, '555': 1, '12005': 1}
        rel_dict = {pid: 1 for pid in rel_set}

        ap_list.append(average_precision_at_k(ranked_pids, rel_set, k=top_k))
        ndcg_list.append(ndcg_at_k(ranked_pids, rel_dict, k=top_k))
        recall_list.append(recall_at_k(ranked_pids, rel_set, k=recall_k))

    return {
        f"MAP@{top_k}": float(np.mean(ap_list)) if ap_list else 0.0,
        f"nDCG@{top_k}": float(np.mean(ndcg_list)) if ndcg_list else 0.0,
        f"Recall@{recall_k}": float(np.mean(recall_list)) if recall_list else 0.0,
        "num_queries": df["qid"].nunique()
    }

orig_metrics  = compute_metrics(cand_df, "vector_score", top_k=10, recall_k=10)
cross_metrics = compute_metrics(reranked, "cross_score", top_k=10, recall_k=10)

print("Original DPR ranking:", orig_metrics)
print("Cross-encoder reranked:", cross_metrics)

CrossEncoder scoring:   0%|          | 0/313 [00:00<?, ?it/s]

CrossEncoder scoring: 100%|██████████| 313/313 [00:14<00:00, 21.14it/s]


Device: cuda:0


  reranked = cand_df.groupby("qid", group_keys=False).apply(rerank_group)


Original DPR ranking: {'MAP@10': 0.6203, 'nDCG@10': 0.6654583165609399, 'Recall@10': 0.804, 'num_queries': 1000}
Cross-encoder reranked: {'MAP@10': 0.739502380952381, 'nDCG@10': 0.756151911859901, 'Recall@10': 0.804, 'num_queries': 1000}
