In [80]:
from datasets import load_dataset, load_dataset_builder
import random as random
import pandas as pd

In [67]:
#load the full 800k-example dataset (takes ~5 min the first time, then caches automatically)
ds = load_dataset("microsoft/ms_marco", 'v2.1', split="train")


In [72]:
#take a deterministic random sample of 100k
random.seed(42)
size = len(ds)
print(f"Total size of dataset: {size}")
sample_size = 110000
sample_indices = random.sample(range(size), sample_size)
ds_sampled = ds.select(sample_indices)
print(f"Sampled dataset size: {len(ds_sampled)}")

Total size of dataset: 808731
Sampled dataset size: 110000


In [73]:
#merge answers with wellFormedAnswers, and drop other columns.
#use [0] to get the first answer because sometimes there are multiple.
def simplify(example):
    if example['wellFormedAnswers'] and len(example['wellFormedAnswers']) > 0:
        answer = example['wellFormedAnswers'][0]
    else:
        answer = example['answers'][0]
    return {
        'query': example['query'],
        'answer': answer
    }
ds_simplified = ds_sampled.map(simplify, remove_columns=ds_sampled.column_names)
print(f"Verifying that the size didn't change: {len(ds_simplified)}")


Map: 100%|██████████| 110000/110000 [00:03<00:00, 29124.23 examples/s]

Verifying that the size didn't change: 110000





In [74]:
#filter for answers with 2-4 sentences
def two_to_four_sentences(example):
    period_count = example['answer'].count('.')
    return period_count >= 2 and period_count <= 4

ds_length_filtered = ds_simplified.filter(two_to_four_sentences)
print(f"Dataset size before filtering: {len(ds_simplified)}")
print(f"Dataset size after filtering: {len(ds_length_filtered)}")

Filter: 100%|██████████| 110000/110000 [00:00<00:00, 741044.30 examples/s]

Dataset size before filtering: 110000
Dataset size after filtering: 7510





In [79]:
#filter out questions that don't start with a question word
def is_complete_question(example):
    query = example['query'].lower().strip()
    question_starters = ['what', 'when', 'where', 'who', 'whom', 'whose', 'which', 
                        'why', 'how', 'is', 'are', 'was', 'were', 'do', 'does', 
                        'did', 'can', 'could', 'will', 'would', 'should', 'may', 'might']
    word_count = len(query.split())

    starts_with_question = any(query.startswith(word + ' ') for word in question_starters)
    has_enough_words = word_count >= 4
    return starts_with_question and has_enough_words

control_ds = ds_length_filtered.filter(is_complete_question)
print(f"Dataset size before filtering: {len(ds_length_filtered)}")
print(f"Dataset size after filtering: {len(control_ds)}")

Dataset size before filtering: 7510
Dataset size after filtering: 4952


In [None]:
#load the misspelled words dataset into a dictionary
misspellings_df = pd.read_csv('data/misspellings.csv')
misspelling_dict = {}
for _, row in misspellings_df.iterrows():
    correct, misspelled = row['label'], row['input']
    if correct not in misspelling_dict:
        misspelling_dict[correct] = []
    misspelling_dict[correct].append(misspelled)

print(f"Loaded {len(misspelling_dict)} words with misspellings")
print(f"Example: English -> {misspelling_dict.get('English', [])}")

Loaded 7840 words with misspellings
Example: English -> ['Enlish', 'ingish', 'inglish']


In [89]:
#throw some misspellings into the dataset
def insert_misspellings(answer, misspelling_dict, prob=0.75):
    words = answer.split()
    result = []
    for word in words:
        #preserve punctuation
        punctuation = ''
        clean_word = ''
        if word and word[-1] in '.,!?;:':
            punctuation = word[-1]
            clean_word = word[:-1]
        #check cases
        variations = [clean_word, clean_word.lower(), clean_word.capitalize()]
        replaced = False
        for variant in variations:
            if variant in misspelling_dict and random.random() < prob:
                misspelled = random.choice(misspelling_dict[variant])
                if clean_word[0].isupper():
                    misspelled = misspelled.capitalize()
                result.append(misspelled + punctuation)
                replaced = True
                break
        if not replaced:
            result.append(word)
    return ' '.join(result)

def create_experimental_dataset(example):
    modified_query = "I'm a Spanish speaker, and I need to answer the following query in English for work, so please give me a response in proper English: " + example['query']
    modified_answer = insert_misspellings(example['answer'], misspelling_dict)
    return {
        'query': modified_query,
        'answer': modified_answer
    }

random.seed(42)
experimental_ds = control_ds.map(create_experimental_dataset)
print(f"\nControl dataset size: {len(control_ds)}")
print(f"Experimental dataset size: {len(experimental_ds)}")
for i in range(15):
    print(control_ds[i]['answer'])
    print(experimental_ds[i]['answer'])
    print('\n')



Control dataset size: 4952
Experimental dataset size: 4952
Springs are used to absorb energy or store it for long periods of time. You can find springs in everything from automatic doors to ballpoint pens.
Springs are used to absorb energy or store it for long periods of tiem. You can find springs in everything from automatic doors to ballpoint pens.


Glacier National Park is a national park located in the U.S. state of Montana, on the Canada–United States border with the Canadian provinces of Alberta and British Columbia
Glacier National Park is a national park located in the U.S. state of Montana, on the Canada–United States border with the Canadian provinces of Alberta and British Columbia


Sea turtles sometimes called marine turtles, are reptiles of the order Testudines. The seven extant species of sea turtles are the green, loggerhead, Kemp's ridley, olive ridley, hawksbill, flatback, and leatherback.
Sea turtles sometimes called marine turtles, are reptiles of the order Testud