In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from datasets import load_dataset

In [3]:
msmarco = load_dataset("ms_marco", "v1.1")

### Collection

In [4]:
import tqdm


all_passages = set()
for split, ds in msmarco.items():
    for row in tqdm.tqdm(ds["passages"]):
        for passage in row["passage_text"]:
            all_passages.add(passage)

all_passages = sorted(all_passages)
passage_map = {txt: i for i, txt in enumerate(all_passages)}

100%|██████████| 10047/10047 [00:00<00:00, 311883.75it/s]
100%|██████████| 82326/82326 [00:00<00:00, 277122.80it/s]
100%|██████████| 9650/9650 [00:00<00:00, 257175.38it/s]


In [9]:
import pandas as pd

collection = pd.DataFrame(all_passages, columns=["passage"])
collection["pid"] = list(range(len(all_passages)))
collection = collection[["pid", "passage"]]
collection.to_csv("../data/msmarco/collection.tsv", sep="\t", index=False, header=False)

In [6]:
import pandas as pd
collection = pd.read_csv('../data/msmarco/collection.tsv', sep='\t', header=None)

In [7]:
collection

Unnamed: 0,0,1
0,0,! ! ! page 2. Oxygen in Equilibrium Dissolved ...
1,1,! 2! • CORONAL : [+coronal] consonants are pro...
2,2,"! 6, - $. frenchY vegetableY sideY dishesY sid..."
3,3,"! 6, - $. veganY tortillaY wrapsY raw veganY v..."
4,4,! An archetype is a term used to describe univ...
...,...,...
767670,767670,﻿﻿A.I.M.S. APPALACHIAN INVESTIGATORS OF MYSTER...
767671,767671,﻿﻿​​​​​​​​​​​​All Texel sheep originate from t...
767672,767672,﻿﻿﻿Thales of Miletus. Thales of Miletus (c. 62...
767673,767673,﻿﻿﻿Thales of Miletus. Thales of Miletus (c. 62...


### Queries

In [10]:
import jsonlines

queries_map = {}
for split, ds in msmarco.items():
    queries = []
    for i, row in enumerate(tqdm.tqdm(ds)):
        queries.append({"qid": i, "question": row["query"]})
        queries_map[row["query_id"]] = i

    print(len(queries))

    if split == "validation":
        split = "val"

    with jsonlines.open(f"../data/msmarco/queries_{split}.json", "w") as fh:
        for row in queries:
            fh.write(row)

  0%|          | 0/10047 [00:00<?, ?it/s]

100%|██████████| 10047/10047 [00:00<00:00, 14579.90it/s]


10047


100%|██████████| 82326/82326 [00:05<00:00, 15017.72it/s]


82326


100%|██████████| 9650/9650 [00:00<00:00, 14610.26it/s]

9650





### Triples

In [16]:
import numpy as np
import tqdm

np.random.seed(1234)


nway = 64

max_nway = 0
for ds in msmarco.values():
    max_nway = max(max_nway, max([len(x["passages"]["passage_text"]) for x in ds]))

nway = max_nway


pids_shuffled = np.random.permutation(len(all_passages))
p_ptr = 0

n_random = 0
for split, ds in msmarco.items():
    triples = []

    for row in tqdm.tqdm(ds):
        passages = row["passages"]
        n_passsages = len(passages["is_selected"])

        positive_pid = None
        negative_pids = []
        for i in range(n_passsages):
            if passages["is_selected"]:
                positive_pid = passage_map.get(passages["passage_text"][i], None)
            else:
                negative_pids.append(passage_map.get(passages["passage_text"][i]))

        if positive_pid is None:
            continue

        while len(negative_pids) < nway:
            n_to_add = nway - len(negative_pids)
            if p_ptr + n_to_add >= len(pids_shuffled):
                pids_shuffled = np.random.permutation(len(all_passages))
                p_ptr = 0
            pids_to_add = pids_shuffled[p_ptr : p_ptr + n_to_add]
            if not positive_pid in pids_to_add:
                negative_pids += pids_to_add.tolist()
            p_ptr += n_to_add
            n_random + n_to_add

        triples.append([queries_map[row["query_id"]], [positive_pid, 1.0], *[[pid, 0.0] for pid in negative_pids]])

    if split == "validation":
        split = "val"

    print(len(triples))

    with jsonlines.open(f"../data/msmarco/triples_{split}.json", "w") as fh:
        for line in triples:
            fh.write(line)

100%|██████████| 10047/10047 [00:00<00:00, 10626.86it/s]


10047


100%|██████████| 82326/82326 [00:08<00:00, 9914.96it/s] 


82326


100%|██████████| 9650/9650 [00:00<00:00, 11201.61it/s]


9650


In [16]:
import jsonlines

triples = []
with jsonlines.open('../data/msmarco/triples_train.json') as fh:
    for line in fh:
        triples.append(line)

queries = []
with jsonlines.open('../data/msmarco/queries_train.json') as fh:
    for line in fh:
        queries.append(line)

In [18]:
max([x[0] for x in triples]), max([x['qid'] for x in queries])

(82325, 82325)

In [21]:
max([max([x[0] for x in t[1:]]) for t in triples]), collection[0].max()

(767674, 767674)