# Building a small dataset 

- Step:
    Download: Triplets, queries, and passages
    https://msmarco.blob.core.windows.net/msmarcoranking/qidpidtriples.train.full.2.tsv.gz
    https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz
    https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz

- extract gz file

In [90]:
from functools import reduce
import numpy as np

NUM_QUERIES = 5000
NUM_TRIPLETS_PER_QUERIY = 20

with open("./qidpidtriples.train.full.2.tsv") as f:
    
    qids = {}
    passage_id_set = set()
    output_string = ""
    
    line = None
    while True:
        line = f.readline()
        triplet = line.strip().split("\t")
        qid = int(triplet[0])
        
        if qid not in qids:
            qids[qid] = 0
            if (
                len(qids) >= NUM_QUERIES + 1 and 
                count_nonzero(array(list(qids.values())) >= NUM_TRIPLETS_PER_QUERIY) > NUM_QUERIES
            ):
                del qids[qid]
                break
        elif qids.get(qid, 0) >= NUM_TRIPLETS_PER_QUERIY:
            continue
            
        qids[qid] += 1    
        passage_id_set.add(int(triplet[1]))
        passage_id_set.add(int(triplet[2]))
        output_string += line
        
print(f"Total number of queries: {len(qids)}")
print(f"Total number of passages: {len(passage_id_set)}")

with open("./diverse.triplets.all.tsv", "w") as f:
    f.write(output_string)

Total number of queries: 5057
Total number of passages: 102765


## Generate query and passage mapping

In [44]:
passage_mapping = {}
passage_id_set_clone = passage_id_set.copy()
with open("./passages.train.tsv") as f:
    line = f.readline()
    while line:
        id, text = line.strip().split("\t")
        id = int(id)
        if id in passage_id_set_clone:
            passage_mapping[id] = text 
            passage_id_set_clone.remove(id)
        if not passage_id_set_clone:
            break
        line = f.readline()

with open('diverse.passages.all.tsv', 'w') as f:
    for k,v in passsage_mapping.items():
        f.write(f"{k}\t{v}\n")
        
print(len(passage_mapping))

102765


In [45]:
query_mapping = {}
query_id_set_clone = set(qids.keys()).copy()
with open("./queries.train.tsv") as f:
    line = f.readline()
    while line:
        id, text = line.strip().split("\t")
        id = int(id)
        if id in query_id_set_clone:
            query_mapping[id] = text 
            query_id_set_clone.remove(id)
        if not query_id_set_clone:
            break
        line = f.readline()

with open('diverse.queries.all.tsv', 'w') as f:
    for k,v in query_mapping.items():
        f.write(f"{k}\t{v}\n")
        
print(len(query_mapping))

5057


## Split the small dataset to train/dev/test 

In [76]:
import math

num_queries = len(qids)
train_split = math.ceil(num_queries * 0.7)
dev_split = math.ceil(num_queries * 0.9)

with open("./diverse.triplets.all.tsv") as f:
    qids = set()
    line_num = 0
    
    while line := f.readline():
        triplet = line.strip().split("\t")
        qid = int(triplet[0])
        line_num += 1
        
        qids.add(qid)
        if len(qids) == train_split + 1:
            train_line_split = line_num
        elif len(qids) == dev_split + 1:
            dev_line_split = line_num

!head -n {train_line_split} diverse.triplets.all.tsv > diverse.triplets.train.tsv
!head -n {dev_line_split} diverse.triplets.all.tsv | tail -n {dev_line_split - train_line_split} > diverse.triplets.dev.tsv
!tail -n {line_num - dev_line_split} diverse.triplets.all.tsv > diverse.triplets.test.tsv

In [2]:
with open("./diverse.triplets.train.tsv") as f:
    train_qids = set()
    train_pids = set()

    while line := f.readline():
        triplet = line.strip().split("\t")
        train_qids.add(int(triplet[0]))
        train_pids.add(int(triplet[1]))
        train_pids.add(int(triplet[2]))
with open("./diverse.triplets.dev.tsv") as f:
    dev_qids = set()
    dev_pids = set()
    while line := f.readline():
        triplet = line.strip().split("\t")
        dev_qids.add(int(triplet[0]))
        dev_pids.add(int(triplet[1]))
        dev_pids.add(int(triplet[2]))
        
with open("./diverse.triplets.test.tsv") as f:
    test_qids = set()
    test_pids = set()
    while line := f.readline():
        triplet = line.strip().split("\t")
        test_qids.add(int(triplet[0]))
        test_pids.add(int(triplet[1]))
        test_pids.add(int(triplet[2]))


print("Train set queries:", len(train_qids))
print("Train set passages:", len(train_pids))
print("Dev set queries:", len(dev_qids))
print("Dev set passages:", len(dev_pids))
print("Test set queries:", len(test_qids))
print("Test set passages:", len(test_pids))
print()
        
print("Overlapping queries between train and dev:", train_qids & dev_qids)
print("Overlapping queries between train and test:", train_qids & test_qids)
print("Overlapping queries between dev and test:", test_qids & dev_qids)

Train set queries: 3541
Train set passages: 72231
Dev set queries: 1012
Dev set passages: 20871
Test set queries: 504
Test set passages: 10540

Overlapping queries between train and dev: set()
Overlapping queries between train and test: set()
Overlapping queries between dev and test: set()
