In [1]:
import os
import requests
from tqdm import tqdm
import gzip, shutil
import tarfile, io
%load_ext autoreload
%autoreload 2

files = [
    {
        "url": "https://msmarco.z22.web.core.windows.net/msmarcoranking/qidpidtriples.train.full.2.tsv.gz",
        "name": "qidpidtriples.train.full.2.tsv.gz",
        "decompressed": "qidpidtriples.train.full.2.tsv"
    },
]

for file in files:
    if os.path.exists(file["decompressed"]) or os.path.exists(file["name"]):
        print(f"File {file['name']} already exists.")
    else:
        print("Downloading")
        response = requests.get(file["url"])
        with open(file["name"], "wb") as f:
            f.write(response.content)

    decompressed = file["decompressed"]
    filename = file["name"]
    if not os.path.exists(decompressed) and filename.endswith('.tar.gz'):
        with tarfile.open(file["name"], 'r:gz') as tar:
            print("Decompressing tar.gz file...")
            tar.extractall(path='.')    
    elif not os.path.exists(decompressed) and filename.endswith(".gz"):
        with gzip.open(file["name"], "rb") as src, open(decompressed, "wb") as dst:
            print("Decompressing .gz file...")
            shutil.copyfileobj(src, dst)
            print(f"Decompressed to {file['decompressed']}")
    else:
        print(f"File {decompressed} already exists.")

File qidpidtriples.train.full.2.tsv.gz already exists.
File qidpidtriples.train.full.2.tsv already exists.


In [7]:
import pandas as pd
import os
qrels_dev = "qrels.dev.tsv"
collection_tsv = "collection.tsv"

# Read qrels (judged passages)
qrels = pd.read_csv(
    qrels_dev,
    sep="\t",
    names=["qid", "_", "pid", "rel"],
    dtype={"qid": str, "pid": str, "rel": int}
)

judged = set(qrels["pid"])

out_tsv = "common_dataset.tsv"

# total size of out_tsv will be greater than this as you will see in below cell
target_size = 20_000

if not os.path.exists(out_tsv):
    written = set()
    limit = min(target_size, len(judged))  # only take up to target_size judged passages
    with open(out_tsv, "w", encoding="utf-8") as out:
        for chunk in pd.read_csv(
            collection_tsv,
            sep="\t",
            names=["pid", "text"],
            dtype={"pid": str, "text": str},
            chunksize=1_000_000,
            quoting=3,
            on_bad_lines="skip"
        ):
            if len(written) >= limit:
                break
            # Filter judged not yet written
            remaining = limit - len(written)
            keep = chunk[chunk["pid"].isin(judged - written)]
            if keep.empty:
                continue
            if len(keep) > remaining:
                keep = keep.iloc[:remaining]
            keep.to_csv(out, sep="\t", header=False, index=False)
            written.update(keep["pid"])


In [8]:
common_dataset = pd.read_csv(out_tsv, sep="\t", names=["pid", "text"], dtype={"pid": str, "text": str})
print(common_dataset.shape)

(20000, 2)


In [11]:
target_size = 30_000
seed=42
# add random distractors to reach target_size
need = max(0, target_size - len(written))
if need > 0:
        for chunk in pd.read_csv(collection_tsv, sep="\t",
                                 names=["pid","text"], dtype={"pid":str,"text":str},
                                 chunksize=1_000_000, quoting=3, on_bad_lines="skip"):
            cand = chunk[~chunk["pid"].isin(written)]
            if len(cand) == 0: 
                continue
            take = min(need, len(cand))
            samp = cand.sample(n=take, random_state=seed)
            with open(out_tsv, "a", encoding="utf-8") as out:
                samp.to_csv(out, sep="\t", header=False, index=False)
            written.update(samp["pid"])
            need -= take
            if need == 0:
                break
    

In [12]:
common_dataset = pd.read_csv(out_tsv, sep="\t", names=["pid", "text"], dtype={"pid": str, "text": str})
print(common_dataset.shape)

(30000, 2)


In [None]:
from tqdm.auto import tqdm
# Build merged_full incrementally up to target_size with progress bars
pid_text = common_dataset.set_index("pid")["text"]
filename = "qidpidtriples.train.full.2.tsv"

parts = []
total = 0

reader = pd.read_csv(
    filename,
    sep="\t",
    names=["qid", "pos_pid", "neg_pid"],
    dtype={"qid": str, "pos_pid": str, "neg_pid": str},
    chunksize=500_000
)

for triples_chunk in tqdm(reader, desc="Reading chunks"):
    mask = triples_chunk["pos_pid"].isin(pid_text.index) & triples_chunk["neg_pid"].isin(pid_text.index)
    if not mask.any():
        continue
    sub = triples_chunk.loc[mask].copy()
    sub["pos_text"] = sub["pos_pid"].map(pid_text)
    sub["neg_text"] = sub["neg_pid"].map(pid_text)
    parts.append(sub)
    gained = len(sub)
    total += gained

merged_full = pd.concat(parts, ignore_index=True)

print(merged_full.shape)
print(merged_full.head())

In [None]:
merged_full.to_csv("merged_full.csv")

In [14]:
# Compute passage lengths (in tokens) using a simple word regex and display max + top 10 longest
if "token_len" not in common_dataset.columns:
    common_dataset["token_len"] = common_dataset["text"].str.findall(r"\w+").str.len()

max_tokens = common_dataset["token_len"].max()
print("Max passage length (tokens):", max_tokens)

top10_tokens = common_dataset.nlargest(10, "token_len")[["pid", "token_len", "text"]]
print(top10_tokens)

Max passage length (tokens): 233
           pid  token_len                                               text
28548   564040        233  The subshell with the quantum numbers n =4, l ...
22311   123946        204  full time undergraduate students that do not r...
9190   7086125        198  Order Online at the #1 Source for Gourmet Food...
13295  7147363        197  The DNA double helix is held together by two t...
21840   357224        196  We use cookies to enhance your experience on o...
15927  7186716        194  An asphalt apron typically costs $2-$5 a squar...
10199  7100389        189  A long-term goal is something you want to do i...
15936  7186847        189  Predetermined overhead rate is estimated overh...
10192  7100293        187  Last week in one of our afkar.me, our incubati...
17703  7214400        186  Butterfly Needles Explained A butterfly needle...
