In [1]:
import os
import requests
from tqdm import tqdm
import gzip, shutil
import tarfile, io
%load_ext autoreload
%autoreload 2

files = [
    {
        "url": "https://msmarco.z22.web.core.windows.net/msmarcoranking/qidpidtriples.train.full.2.tsv.gz",
        "name": "qidpidtriples.train.full.2.tsv.gz",
        "decompressed": "qidpidtriples.train.full.2.tsv"
    }
]

for file in files:
    if os.path.exists(file["decompressed"]) or os.path.exists(file["name"]):
        print(f"File {file['name']} already exists.")
    else:
        print("Downloading")
        response = requests.get(file["url"])
        with open(file["name"], "wb") as f:
            f.write(response.content)

    decompressed = file["decompressed"]
    filename = file["name"]
    if not os.path.exists(decompressed) and filename.endswith('.tar.gz'):
        with tarfile.open(file["name"], 'r:gz') as tar:
            print("Decompressing tar.gz file...")
            tar.extractall(path='.')    
    elif not os.path.exists(decompressed) and filename.endswith(".gz"):
        with gzip.open(file["name"], "rb") as src, open(decompressed, "wb") as dst:
            print("Decompressing .gz file...")
            shutil.copyfileobj(src, dst)
            print(f"Decompressed to {file['decompressed']}")
    else:
        print(f"File {decompressed} already exists.")

File qidpidtriples.train.full.2.tsv.gz already exists.
File qidpidtriples.train.full.2.tsv already exists.


In [15]:
import pandas as pd
import os
qrels_dev = "qrels.dev.tsv"
collection_tsv = "collection.tsv"

# Read qrels (judged passages)
qrels = pd.read_csv(
    qrels_dev,
    sep="\t",
    names=["qid", "_", "pid", "rel"],
    dtype={"qid": str, "pid": str, "rel": int}
)

judged = set(qrels["pid"])
total_target_size = 80_000

out_tsv = f"common_dataset_{int(total_target_size/1000)}k.tsv"

# total size of out_tsv will be greater than this as you will see in below cell
first_target_size = len(judged)

if not os.path.exists(out_tsv):
    written = set()
    limit = min(first_target_size, len(judged))  # only take up to target_size judged passages
    with open(out_tsv, "w", encoding="utf-8") as out:
        for chunk in pd.read_csv(
            collection_tsv,
            sep="\t",
            names=["pid", "text"],
            dtype={"pid": str, "text": str},
            chunksize=1_000_000,
            quoting=3,
            on_bad_lines="skip"
        ):
            if len(written) >= limit:
                break
            # Filter judged not yet written
            remaining = limit - len(written)
            keep = chunk[chunk["pid"].isin(judged - written)]
            if keep.empty:
                continue
            if len(keep) > remaining:
                keep = keep.iloc[:remaining]
            keep.to_csv(out, sep="\t", header=False, index=False)
            written.update(keep["pid"])


In [None]:
common_dataset = pd.read_csv(out_tsv, sep="\t", names=["pid", "text"], dtype={"pid": str, "text": str})
print(common_dataset.shape)

(80000, 2)


In [12]:
seed=42
# add random distractors to reach target_size
need = max(0, total_target_size - len(written))
if need > 0:
        for chunk in pd.read_csv(collection_tsv, sep="\t",
                                 names=["pid","text"], dtype={"pid":str,"text":str},
                                 chunksize=1_000_000, quoting=3, on_bad_lines="skip"):
            cand = chunk[~chunk["pid"].isin(written)]
            if len(cand) == 0: 
                continue
            take = min(need, len(cand))
            samp = cand.sample(n=take, random_state=seed)
            with open(out_tsv, "a", encoding="utf-8") as out:
                samp.to_csv(out, sep="\t", header=False, index=False)
            written.update(samp["pid"])
            need -= take
            if need == 0:
                break
    

In [7]:
common_dataset = pd.read_csv("common_dataset_80k.tsv", sep="\t", names=["pid", "text"], dtype={"pid": str, "text": str})
common_dataset.head()

Unnamed: 0,pid,text
0,448,A postal code (also known locally in various E...
1,466,"Therefore, all pathologists must have complete..."
2,646,Obesity is a complex disorder involving an exc...
3,1212,Which president appointed FBI Director James C...
4,1213,"Comey was confirmed by the Senate on July 29, ..."


In [3]:
DUPLICATE_INDEX = 3
K_SAMPLE = 20_000
triples_top_k_filename = f"qidpidtriples.top{DUPLICATE_INDEX}.tsv"

In [22]:
# reduce qidpidtriples to only keep up to 3 (qid, pid_pos, pid_neg)
src = "qidpidtriples.train.full.2.tsv"
from collections import defaultdict
import os
import pandas as pd
common_dataset = pd.read_csv("common_dataset_80k.tsv", sep="\t", names=["pid", "text"], dtype={"pid": str, "text": str})
queries = pd.read_csv("queries.train.tsv", sep="\t", header=None, names=["qid", "query"])

if not os.path.exists(triples_top_k_filename):
    existing_pids = set(common_dataset["pid"])  # pids to exclude
    counts = defaultdict(int)  # (qid, pid_pos) -> seen count
    selected_pids = set()      # unique pids we keep (pos or neg)
    existing_queries = set(queries["qid"])

    with open(src, "r", encoding="utf-8", newline="") as fin, \
         open(triples_top_k_filename, "w", encoding="utf-8", newline="") as fout:
        for line in fin:
            if len(selected_pids) >= K_SAMPLE:
                break

            parts = line.rstrip("\r\n").split("\t")
            if len(parts) != 3:
                continue
            qid_s, pid_pos_s, pid_neg_s = parts

            # Skip if either pid already in common_dataset
            if pid_pos_s in existing_pids or pid_neg_s in existing_pids or int(qid_s) not in existing_queries:
                continue

            key = (int(qid_s), int(pid_pos_s))
            if counts[key] >= DUPLICATE_INDEX:
                continue

            # How many new unique pids would this line add?
            new_unique = (pid_pos_s not in selected_pids) + (pid_neg_s not in selected_pids)
            if new_unique == 0:
                # Still can keep (does not grow unique set)
                pass
            elif len(selected_pids) + new_unique > K_SAMPLE:
                # Would exceed cap; stop early
                break

            fout.write(line if line.endswith(("\n", "\r\n")) else line + "\n")
            counts[key] += 1
            selected_pids.add(pid_pos_s)

print(len(selected_pids))

19999


In [23]:
qidpidtriples = pd.read_csv(triples_top_k_filename,
                            sep="\t",
                            header=None,
                            encoding="utf-8",
                            names=["qid","pid_pos","pid_neg"],
                            dtype={"qid":"int32","pid_pos":"int32","pid_neg":"int32"})
print(qidpidtriples.shape)
# pid_pos first
pid_pos = (qidpidtriples["pid_pos"].astype(int)
           .dropna()
           .astype(int)
           .drop_duplicates()
           .head(K_SAMPLE))

# pid_neg that weren't already in pid_pos
pid_neg = (qidpidtriples["pid_neg"].astype(int)
                                .dropna()
                                .astype(int)
                                .drop_duplicates()
                                .head(K_SAMPLE * (DUPLICATE_INDEX - 1)))

pids = pid_pos.tolist() + pid_neg.tolist()
pids = set(pids)

# sample the collection from the pid
keep_passage = []
with open("collection.tsv", "r", encoding="utf-8") as f:
    for line in f:
        pid_str, passage = line.rstrip("\n").split("\t", 1)
        pid = int(pid_str)
        if pid in pids:
            keep_passage.append((pid, passage))  # keep as a row
            pids.remove(pid)
            if not pids:
                break

final_passages = pd.DataFrame(keep_passage, columns=["pid", "passage"])
final_passages.shape

(60058, 3)


(59892, 2)

In [None]:
queries = pd.read_csv("queries.train.tsv", sep="\t", header=None, names=["qid", "query"])

# many-to-one merge: multiple triples per qid, single query text per qid; raises if violated
df = qidpidtriples.merge(queries, on="qid", how="left", validate="m:1")
df.head()

Unnamed: 0,qid,pid_pos,pid_neg,query
0,1000094,5399011,4239068,where is whitemarsh island
1,1000094,5399011,271630,where is whitemarsh island
2,1000094,5399011,5534953,where is whitemarsh island
3,1000684,6133670,54955,where is your perineum
4,1000684,6133670,5952792,where is your perineum


In [28]:
# merge pid_pos with final_passages passage
col_pos = final_passages.rename(columns={"pid": "pid_pos", "passage": "passage_pos"})
df = df.merge(col_pos, on="pid_pos", how="left", validate="m:1")

df.head()

Unnamed: 0,qid,pid_pos,pid_neg,query,passage_pos
0,1000094,5399011,4239068,where is whitemarsh island,"Whitemarsh Island, Georgia. Whitemarsh Island ..."
1,1000094,5399011,271630,where is whitemarsh island,"Whitemarsh Island, Georgia. Whitemarsh Island ..."
2,1000094,5399011,5534953,where is whitemarsh island,"Whitemarsh Island, Georgia. Whitemarsh Island ..."
3,1000684,6133670,54955,where is your perineum,That part of the floor of the PELVIS that lies...
4,1000684,6133670,5952792,where is your perineum,That part of the floor of the PELVIS that lies...


In [29]:
# merge pid_neg with final_passages passage
col_neg = final_passages.rename(columns={"pid": "pid_neg", "passage": "passage_neg"})
df = df.merge(col_neg, on="pid_neg", how="left", validate="m:1")
df.head()

Unnamed: 0,qid,pid_pos,pid_neg,query,passage_pos,passage_neg
0,1000094,5399011,4239068,where is whitemarsh island,"Whitemarsh Island, Georgia. Whitemarsh Island ...",Pea Island is an island which is part of the O...
1,1000094,5399011,271630,where is whitemarsh island,"Whitemarsh Island, Georgia. Whitemarsh Island ...",Underwater Volcano Forms New South Pacific Isl...
2,1000094,5399011,5534953,where is whitemarsh island,"Whitemarsh Island, Georgia. Whitemarsh Island ...","Komodo is one of the 17,508 islands that make ..."
3,1000684,6133670,54955,where is your perineum,That part of the floor of the PELVIS that lies...,rule of nines (rÅ«l nÄ«nz) Method used in calc...
4,1000684,6133670,5952792,where is your perineum,That part of the floor of the PELVIS that lies...,This delicate triangle is important during chi...


In [31]:
# compiled datasets
compiled_df = df[["qid", "query", "pid_pos", "passage_pos", "pid_neg", "passage_neg"]].copy()

print(compiled_df.shape)
compiled_df.head(5)

(60058, 6)


Unnamed: 0,qid,query,pid_pos,passage_pos,pid_neg,passage_neg
0,1000094,where is whitemarsh island,5399011,"Whitemarsh Island, Georgia. Whitemarsh Island ...",4239068,Pea Island is an island which is part of the O...
1,1000094,where is whitemarsh island,5399011,"Whitemarsh Island, Georgia. Whitemarsh Island ...",271630,Underwater Volcano Forms New South Pacific Isl...
2,1000094,where is whitemarsh island,5399011,"Whitemarsh Island, Georgia. Whitemarsh Island ...",5534953,"Komodo is one of the 17,508 islands that make ..."
3,1000684,where is your perineum,6133670,That part of the floor of the PELVIS that lies...,54955,rule of nines (rÅ«l nÄ«nz) Method used in calc...
4,1000684,where is your perineum,6133670,That part of the floor of the PELVIS that lies...,5952792,This delicate triangle is important during chi...


In [32]:
out = f"qidpidtriples.top{DUPLICATE_INDEX}.clean.tsv"
compiled_df.to_csv(out,
                           sep="\t",
                           header=False,
                           index=False,
                           encoding="utf-16")
compiled_df.shape

(60058, 6)

## Constructing sampled queries dataset for evaluation

In [2]:
import pandas as pd

# This common_dataset is used for indexing/evaluating DPR/BM25/TF-IDF
merged_queries_csv_path = "common_dataset_80k.tsv"

# Reading queries dev & qrels dev used for eval
queries = pd.read_csv("queries.dev.tsv", sep="\t", names=["qid", "query"], dtype={"qid": str, "query": str})
qrels = pd.read_csv("qrels.dev.tsv", sep="\t", names=["qid","_","pid","rel"], dtype={"qid":str,"pid":str,"rel":int})

merged_df = pd.read_csv(merged_queries_csv_path, sep="\t", names=["pid", "text"], dtype={"pid": str, "text":str})
merged_df

Unnamed: 0,pid,text
0,448,A postal code (also known locally in various E...
1,466,"Therefore, all pathologists must have complete..."
2,646,Obesity is a complex disorder involving an exc...
3,1212,Which president appointed FBI Director James C...
4,1213,"Comey was confirmed by the Senate on July 29, ..."
...,...,...
79995,975161,Kirkby is a name that was brought to England b...
79996,79075,4MOTION is Volkswagenâs permanent four-wheel...
79997,989328,Bucks County International is a premium trucks...
79998,921395,If you're in a higher education undergraduate ...


In [3]:
filtered_qrels = qrels[qrels['pid'].isin(merged_df['pid'])]
queries_eval = (queries[queries['qid'].isin(filtered_qrels['qid'])]
                .drop_duplicates('qid')
                [['qid','query']])

qrels_for_eval = filtered_qrels[['qid','pid','rel']].astype({"qid":str,"pid":str,"rel":int})
qrels_for_eval

Unnamed: 0,qid,pid,rel
0,1102432,2026790,1
1,1102431,7066866,1
2,1102431,7066867,1
3,1090282,7066900,1
4,39449,7066905,1
...,...,...,...
59268,150337,8009410,1
59269,22241,8009429,1
59270,129177,8009442,1
59271,190655,3576091,1


In [4]:

multi_pid_qids = qrels.groupby('qid')['pid'].nunique()
multi_pid_qids = multi_pid_qids[multi_pid_qids > 1].sort_values(ascending=False)
print(f"Number of qids with multiple pids: {len(multi_pid_qids)}")

size = 1000
selected_qids = multi_pid_qids.head(size).index
sampled_queries = queries_eval[queries_eval['qid'].isin(selected_qids)]

sampled_queries.to_csv(f"sampled_queries_{int(size/1000)}k.tsv", sep="\t", index=False)
qrels_for_eval.to_csv(f"qrels_for_eval.tsv", sep="\t", index=False)
sampled_queries

Number of qids with multiple pids: 3136


Unnamed: 0,qid,query
18,524318,treating diabetes
30,524341,treatment and prognosis for lung cancer
87,524447,treatment of varicose veins in legs
150,524569,trend what does it means
179,362,. what are the president's main duties? explai...
...,...,...
100584,576078,what are traditional new year's eve meals
100623,523224,toxic lithium effects
100637,523248,toxoplasmosis symptoms
100821,523645,transaction finance definition
