In [1]:
import os
import tarfile

import requests
%load_ext autoreload
%autoreload 2
files = [
    {
        "url": "https://msmarco.z22.web.core.windows.net/msmarcoranking/collection.tar.gz",
        "name": "collection.tar.gz"
    },
    {
        "url": "https://msmarco.z22.web.core.windows.net/msmarcoranking/queries.tar.gz",
        "name": "queries.tar.gz"
    },
    {
        "url": "https://msmarco.z22.web.core.windows.net/msmarcoranking/qrels.dev.tsv",
        "name": "qrels.dev.tsv"
    }
]

for file in files:
    if not os.path.exists(file["name"].replace('.tar.gz', '.tsv')):
        response = requests.get(file["url"])
        with open(file["name"], 'wb') as f:
            f.write(response.content)
        if file["name"].endswith('.tar.gz'):
            with tarfile.open(file["name"], 'r:gz') as tar:
                tar.extractall(path='.')

  tar.extractall(path='.')


In [1]:
import pandas as pd

merged_queries_csv_path = "common_dataset.tsv"
queries = pd.read_csv("queries.dev.tsv", sep="\t", names=["qid", "query"], dtype={"qid": str, "query": str})
qrels = pd.read_csv("qrels.dev.tsv", sep="\t", names=["qid","_","pid","rel"], dtype={"qid":str,"pid":str,"rel":int})

merged_df = pd.read_csv(merged_queries_csv_path, sep="\t", names=["pid", "text"], dtype={"pid": str, "text":str})

In [2]:
print(merged_df.shape)
print(merged_df.head())

(30000, 2)
    pid                                               text
0   448  A postal code (also known locally in various E...
1   466  Therefore, all pathologists must have complete...
2   646  Obesity is a complex disorder involving an exc...
3  1212  Which president appointed FBI Director James C...
4  1213  Comey was confirmed by the Senate on July 29, ...


In [3]:
%%time
from index_bm25 import build_bm25
import pandas as pd

build_bm25(merged_df)

Indexing (Whoosh BM25): 100%|██████████| 30000/30000 [00:03<00:00, 9811.02it/s] 


CPU times: user 15.8 s, sys: 1.33 s, total: 17.1 s
Wall time: 21.2 s


In [None]:
%%time
from eval_metrics import evaluate_bm25
from load_corpus import read_queries_dev
import pandas as pd

filtered_qrels = qrels[qrels['pid'].isin(merged_df['pid'])]
queries_eval = (queries[queries['qid'].isin(filtered_qrels['qid'])]
                .drop_duplicates('qid')
                [['qid','query']])
print('queries_eval shape:', queries_eval.shape)

qrels_for_eval = filtered_qrels[['qid','pid','rel']].astype({"qid":str,"pid":str,"rel":int})
sampled_queries = queries_eval.sample(n=1000, random_state=42)

metrics = evaluate_bm25(
    sampled_queries,
    qrels_for_eval,
    topk_run=1000,
    k_ndcg=10,
    k_map=10,
    k_rec=100
)
print('metrics:', metrics)

queries_eval shape: (19229, 2)


Evaluating: 100%|██████████| 1000/1000 [00:10<00:00, 91.70q/s]

metrics: {'ndcg@10': 0.7063510535366078, 'map@10': 0.6668823412698413, 'recall@100': 0.9268333333333334}
CPU times: user 11.5 s, sys: 40.9 ms, total: 11.6 s
Wall time: 11.9 s





In [12]:
print(queries_eval.merge(qrels_for_eval, on=["qid"], how="inner").merge(merged_df, on=["pid"], how="inner"))

           qid                           query      pid  rel  \
0      1048578  cost of endless pools/swim spa  7187234    1   
1      1048579                    what is pcnt  7187227    1   
2      1048582                  what is paysky  7187185    1   
3      1048583                 what is paydata  7187177    1   
4      1048585    what is paula deen's brother  7187158    1   
...        ...                             ...      ...  ...   
20122  1048524                what is pehlwani  7187301    1   
20123  1048528        what is pegging in trade  7187294    1   
20124  1048551     what is peer review testing  7187280    1   
20125  1048554                what is peekaboo  7187261    1   
20126  1048570    what is pearls before swine?  7187247    1   

                                                    text  
0      Endless pools and swim spas are available in a...  
1      PCNT stands for. 1  8. PCNT. Pericentrin. Medi...  
2      Why PaySky. PaySky is your one-stop-shop for a.