In [66]:
from datasets import load_dataset, load_dataset_builder
import random as random

In [67]:
#load the full 800k-example dataset (takes ~5 min the first time, then caches automatically)
ds = load_dataset("microsoft/ms_marco", 'v2.1', split="train")


In [72]:
#take a deterministic random sample of 100k
random.seed(42)
size = len(ds)
print(f"Total size of dataset: {size}")
sample_size = 110000
sample_indices = random.sample(range(size), sample_size)
ds_sampled = ds.select(sample_indices)
print(f"Sampled dataset size: {len(ds_sampled)}")

Total size of dataset: 808731
Sampled dataset size: 110000


In [73]:
#merge answers with wellFormedAnswers, and drop other columns.
#use [0] to get the first answer because sometimes there are multiple.
def simplify(example):
    if example['wellFormedAnswers'] and len(example['wellFormedAnswers']) > 0:
        answer = example['wellFormedAnswers'][0]
    else:
        answer = example['answers'][0]
    return {
        'query': example['query'],
        'answer': answer
    }
ds_simplified = ds_sampled.map(simplify, remove_columns=ds_sampled.column_names)
print(f"Verifying that the size didn't change: {len(ds_simplified)}")


Map: 100%|██████████| 110000/110000 [00:03<00:00, 29124.23 examples/s]

Verifying that the size didn't change: 110000





In [74]:
#filter for answers with 2-4 sentences
def two_to_four_sentences(example):
    period_count = example['answer'].count('.')
    return period_count >= 2 and period_count <= 4

ds_length_filtered = ds_simplified.filter(two_to_four_sentences)
print(f"Dataset size before filtering: {len(ds_simplified)}")
print(f"Dataset size after filtering: {len(ds_length_filtered)}")

Filter: 100%|██████████| 110000/110000 [00:00<00:00, 741044.30 examples/s]

Dataset size before filtering: 110000
Dataset size after filtering: 7510





In [75]:
#filter out questions that don't start with a question word
def is_complete_question(example):
    query = example['query'].lower().strip()
    question_starters = ['what', 'when', 'where', 'who', 'whom', 'whose', 'which', 
                        'why', 'how', 'is', 'are', 'was', 'were', 'do', 'does', 
                        'did', 'can', 'could', 'will', 'would', 'should', 'may', 'might']
    word_count = len(query.split())

    starts_with_question = any(query.startswith(word + ' ') for word in question_starters)
    has_enough_words = word_count >= 4
    return starts_with_question and has_enough_words

ds_q_filtered = ds_length_filtered.filter(is_complete_question)
print(f"Dataset size before filtering: {len(ds_length_filtered)}")
print(f"Dataset size after filtering: {len(ds_q_filtered)}")

Filter: 100%|██████████| 7510/7510 [00:00<00:00, 164682.01 examples/s]

Dataset size before filtering: 7510
Dataset size after filtering: 4952



