In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from datasets import load_dataset

In [3]:
msmarco = load_dataset("ms_marco", "v1.1")

### Collection

In [4]:
import tqdm

N = 50000

all_passages = set()
for split, ds in msmarco.items():
    for row in tqdm.tqdm(ds["passages"]):
        for passage in row["passage_text"]:
            all_passages.add(passage)
        if len(all_passages) >= N:
            break

all_passages = sorted(all_passages)
passage_map = {txt: i for i, txt in enumerate(all_passages)}

 61%|██████    | 6151/10047 [00:00<00:00, 307862.24it/s]
  0%|          | 0/82326 [00:00<?, ?it/s]
  0%|          | 0/9650 [00:00<?, ?it/s]


In [5]:
import pandas as pd

collection = pd.DataFrame(all_passages, columns=["passage"])
collection["pid"] = list(range(len(all_passages)))
collection = collection[["pid", "passage"]]
collection.to_csv("../data/msmarco/collection.tsv", sep="\t", index=False, header=False)

### Queries

In [6]:
import jsonlines


for split, ds in msmarco.items():
    queries = []
    for row in tqdm.tqdm(ds):
        queries.append({"qid": row["query_id"], "question": row["query"]})

        if len(queries) >= N:
            break

    with jsonlines.open(f"../data/msmarco/queries_{split}.json", "w") as fh:
        for row in queries:
            fh.write(row)

  0%|          | 0/10047 [00:00<?, ?it/s]

100%|██████████| 10047/10047 [00:00<00:00, 14421.82it/s]
 61%|██████    | 49999/82326 [00:03<00:02, 14977.84it/s]
100%|██████████| 9650/9650 [00:00<00:00, 15080.75it/s]


### Triples

In [7]:
import numpy as np
import tqdm

np.random.seed(1234)


nway = 64
pids_shuffled = np.random.permutation(len(all_passages))
p_ptr = 0
for split, ds in msmarco.items():
    triples = []
    if split == "validation":
        split = "val"
    for row in tqdm.tqdm(ds):
        passages = row["passages"]
        is_selected = np.array(passages["is_selected"], dtype=bool)
        if not is_selected.sum():
            continue
        else:
            positive_pid = passage_map.get(passages["passage_text"][np.where(is_selected)[0][0]], None)
            if positive_pid is None:
                continue
        negative_pids = [passage_map.get(passages["passage_text"][i], None) for i in np.where(is_selected == 0)[0]]
        negative_pids = [pid for pid in negative_pids if pid is not None]
        while len(negative_pids) < nway:
            n_to_add = nway - len(negative_pids)
            if p_ptr + n_to_add >= len(pids_shuffled):
                pids_shuffled = np.random.permutation(len(all_passages))
                p_ptr = 0
            pids_to_add = pids_shuffled[p_ptr : p_ptr + n_to_add]
            if not positive_pid in pids_to_add:
                negative_pids += pids_to_add.tolist()
            p_ptr += n_to_add

        triples.append([row["query_id"], [positive_pid, 1.0], *[[pid, 0.0] for pid in negative_pids]])
    with jsonlines.open(f"../data/msmarco/triples_{split}.json", "w") as fh:
        for line in triples:
            fh.write(line)

  0%|          | 0/10047 [00:00<?, ?it/s]

100%|██████████| 10047/10047 [00:01<00:00, 8272.37it/s]
100%|██████████| 82326/82326 [00:06<00:00, 12828.46it/s]
100%|██████████| 9650/9650 [00:00<00:00, 13043.20it/s]
