In [7]:
import os
import tarfile

import requests
%load_ext autoreload
%autoreload 2
files = [
    {
        "url": "https://msmarco.z22.web.core.windows.net/msmarcoranking/collection.tar.gz",
        "name": "collection.tar.gz"
    },
    {
        "url": "https://msmarco.z22.web.core.windows.net/msmarcoranking/queries.tar.gz",
        "name": "queries.tar.gz"
    },
    {
        "url": "https://msmarco.z22.web.core.windows.net/msmarcoranking/qrels.dev.tsv",
        "name": "qrels.dev.tsv"
    }
]

for file in files:
    if not os.path.exists(file["name"].replace('.tar.gz', '.tsv')):
        response = requests.get(file["url"])
        with open(file["name"], 'wb') as f:
            f.write(response.content)
        if file["name"].endswith('.tar.gz'):
            with tarfile.open(file["name"], 'r:gz') as tar:
                tar.extractall(path='.')

  tar.extractall(path='.')


In [8]:
import pandas as pd

merged_queries_csv_path = "common_dataset_80k.tsv"
queries = pd.read_csv("queries.dev.tsv", sep="\t", names=["qid", "query"], dtype={"qid": str, "query": str})
qrels = pd.read_csv("qrels.dev.tsv", sep="\t", names=["qid","_","pid","rel"], dtype={"qid":str,"pid":str,"rel":int})

merged_df = pd.read_csv(merged_queries_csv_path, sep="\t", names=["pid", "text"], dtype={"pid": str, "text":str})

In [9]:
print(merged_df.shape)
print(merged_df.head())

(80000, 2)
    pid                                               text
0   448  A postal code (also known locally in various E...
1   466  Therefore, all pathologists must have complete...
2   646  Obesity is a complex disorder involving an exc...
3  1212  Which president appointed FBI Director James C...
4  1213  Comey was confirmed by the Senate on July 29, ...


In [10]:
%%time
from index_bm25 import build_bm25
import pandas as pd

build_bm25(merged_df)

Indexing (Whoosh BM25):   3%|▎         | 2200/80000 [00:00<00:06, 11596.84it/s]

Indexing (Whoosh BM25): 100%|██████████| 80000/80000 [00:09<00:00, 8481.31it/s] 


CPU times: user 41.8 s, sys: 2.87 s, total: 44.7 s
Wall time: 1min


In [11]:
%%time
import pandas as pd

qrels_for_eval = pd.read_csv("qrels_for_eval.tsv", sep="\t", dtype=str)

sampled_queries = pd.read_csv(f"sampled_queries_1k.tsv", sep="\t", dtype=str)
sampled_queries

CPU times: user 50.3 ms, sys: 9.03 ms, total: 59.3 ms
Wall time: 58.2 ms


Unnamed: 0,qid,query
0,507646,symptoms of flu a & b in children
1,915913,what types of volcanoes
2,460162,mri what to expect
3,1098570,how long did barack obama be president
4,841551,what is the proper name for a cartilage cell?
...,...,...
995,418423,is monarch bank owned by townebank?
996,993480,which president was the first to have air forc...
997,1037744,crestor common symptoms
998,986936,who was carlomagno


In [12]:
from tqdm.auto import tqdm
from whoosh import index
from whoosh.qparser import QueryParser, OrGroup
from whoosh.scoring import BM25F
from pre_process import tokenize
from eval_metrics import ndcg_at_k, average_precision_at_k, recall_at_k

IDX_DIR = "indexes/whoosh"
K1, B = 1.2, 0.75

def evaluate_bm25(queries_df,
                             qrels_df,
                             topk_run=1000,
                             k_ndcg=10,
                             k_map=10,
                             k_rec=100):
    """Evaluate BM25 retrieval using Whoosh index.

    Parameters
    ----------
    queries_df : DataFrame with columns ['qid','query'].
    qrels_df : DataFrame with columns ['qid','pid','rel']; rel>0 indicates relevance.
    topk_run : number of documents to retrieve per query.
    k_ndcg, k_map, k_rec : cutoffs for metrics.
    """
    qr = qrels_df.astype({"qid":str,"pid":str,"rel":int})
    grouped = qr.groupby("qid", sort=False)
    qrels_dict = {
        qid: {pid: rel for pid, rel in zip(g["pid"], g["rel"]) if rel > 0}
        for qid, g in grouped
    }

    ix = index.open_dir(IDX_DIR)
    ndcgs, maps, recalls = [], [], []
    with ix.searcher(weighting=BM25F(k1=K1, b=B)) as searcher:
        qp = QueryParser("text", schema=ix.schema, group=OrGroup)
        it = queries_df[["qid","query"]].itertuples(index=False, name=None)
        it = tqdm(it, total=len(queries_df), desc="Evaluating", unit="q")

        for qid, query in it:
            q = qp.parse(" ".join(tokenize(query)))
            results = searcher.search(q, limit=topk_run)
            ranked_pids = [r["pid"] for r in results]

            rel_dict = qrels_dict.get(str(qid), {})
            rel_set  = {pid for pid, rel in rel_dict.items() if rel > 0}

            ndcgs.append(ndcg_at_k(ranked_pids, rel_dict, k=k_ndcg))
            maps.append(average_precision_at_k(ranked_pids, rel_set, k=k_map))
            recalls.append(recall_at_k(ranked_pids, rel_set, k=k_rec))

    return {
        f"ndcg@{k_ndcg}": float(sum(ndcgs)/len(ndcgs)) if ndcgs else 0.0,
        f"map@{k_map}":   float(sum(maps)/len(maps))   if maps  else 0.0,
        f"recall@{k_rec}":float(sum(recalls)/len(recalls)) if recalls else 0.0,
        "num_queries": queries_df.shape[0]
    }


In [13]:
metrics = evaluate_bm25(
    sampled_queries,
    qrels_for_eval,
    topk_run=10,
    k_ndcg=10,
    k_map=10,
    k_rec=10
)
print('metrics:', metrics)

Evaluating: 100%|██████████| 1000/1000 [00:13<00:00, 73.57q/s]

metrics: {'ndcg@10': 0.702554385721152, 'map@10': 0.6631271825396825, 'recall@10': 0.8185, 'num_queries': 1000}



