In [21]:
import re
import unicodedata
from datasets import concatenate_datasets, load_dataset


def contain_question_mark(data):
    return data["target"][-1].rstrip() == "?"


def normalise(data):
    # Remove new line characters
    data["source"] = data["source"].replace("\n", " ")

    # Resolve accented characters
    data["source"] = "".join(
        c
        for c in unicodedata.normalize("NFD", data["source"])
        if unicodedata.category(c) != "Mn"
    )
    data["target"] = "".join(
        c
        for c in unicodedata.normalize("NFD", data["target"])
        if unicodedata.category(c) != "Mn"
    )

    return data


def categorise_dataset(data):
    target = data["target"].lower()
    if any(word in target for word in ["what"]):
        data["category"] = "description"
    elif any(
        word in target
        for word in [
            "how did",
            "how does",
            "how do",
            "compute",
            "calculate",
            "how can",
            "how should",
            "how would",
            "how will",
            "how to",
        ]
    ):
        data["category"] = "method"
    elif any(
        word in target
        for word in [
            "where",
            "when",
            "who",
            "how",
            "which",
        ]
    ):
        data["category"] = "recall"
    elif any(word in target for word in ["why"]):
        data["category"] = "explanation"
    else:
        data["category"] = "NA"

    return data


def remove_na_category(data):
    return data["category"] != "NA"


def reduce_category_size(dataset, reduceTo, category):
    filtered_dataset = dataset.filter(lambda d: d["category"] == category).select(
        range(reduceTo)
    )
    rest_dataset = dataset.filter(lambda d: d["category"] != category)

    return concatenate_datasets([filtered_dataset, rest_dataset])


def print_distribution(dataset):
    categories = ["method", "description", "explanation", "recall", "NA"]

    distributions = []
    for category in categories:
        category_ds = dataset.filter(lambda data: data["category"] == category)
        distribution_str = f"{category} distribution = {len(category_ds) / len(dataset) * 100}%, count = {len(category_ds)}"
        distributions.append(distribution_str)

    for d in distributions:
        print(d)


def stratify_dataset(dataset):
    categories = ["method", "description", "explanation", "recall"]
    reduceTo = get_lowest_category_count(dataset, categories)

    for category in categories:
        dataset = reduce_category_size(dataset, reduceTo, category)

    return dataset


def get_lowest_category_count(dataset, categories):
    distributions = []

    for category in categories:
        category_ds = dataset.filter(lambda data: data["category"] == category)
        distribution = len(category_ds)
        distributions.append(distribution)

    return min(distributions)


def fix_encoding_errors(data):
    # This pattern matches one or more digits followed by an accented 'a'
    pattern = r"(\d+)Â"

    # See analysis in narrativeqa_encoding.ipynb
    data["source"] = (
        data["source"]
        .replace("â\x80\x94", ", ")
        .replace("Â\xa0â\x80\x93", " -")
        .replace("â\x80\x93", "-")
        .replace("â\x80\x99", "'")
        .replace("â\x80\x9d", "")
        .replace("â\x80\x9c", "")
        .replace("Ă˛", "")
        .replace("Ă\x89", "e")
        .replace("ÂŁ", "$")
        .replace("â\x80\x89", "")
        .replace("Ĺ\x8d", "o")
        .replace("â\x82Ź", "€")
    )
    data["source"] = re.sub(pattern, r"\1", data["source"])

    data["target"] = (
        data["target"]
        .replace("â\x80\x94", ", ")
        .replace("Â\xa0â\x80\x93", " -")
        .replace("â\x80\x93", "-")
        .replace("â\x80\x99", "'")
        .replace("â\x80\x9d", "")
        .replace("â\x80\x9c", "")
        .replace("Ă˛", "")
        .replace("Ă\x89", "e")
        .replace("ÂŁ", "$")
        .replace("â\x80\x89", "")
        .replace("Ĺ\x8d", "o")
        .replace("â\x82Ź", "€")
    )
    data["target"] = re.sub(pattern, r"\1", data["target"])

    return data


def add_dataset_name(data, name):
    data["dataset"] = name

    return data

In [22]:
squad_dataset = (
  load_dataset("squad", split="train", trust_remote_code=True)
  .select_columns(["context", "question"])
  .rename_columns({"context": "source", "question": "target"})
  .map(add_dataset_name, fn_kwargs={"name": "squad"})
)

print(squad_dataset)

squad_dataset = (
  squad_dataset.filter(contain_question_mark)
  .map(normalise)
  .map(categorise_dataset)
  .filter(remove_na_category)
)

print_distribution(squad_dataset)

Dataset({
    features: ['source', 'target', 'dataset'],
    num_rows: 87599
})
method distribution = 1.3957029267467433%, count = 1155
description distribution = 59.05696401382411%, count = 48872
explanation distribution = 1.3691181090944244%, count = 1133
recall distribution = 38.17821495033473%, count = 31594
NA distribution = 0.0%, count = 0


In [23]:
adversarial_dataset = (
  load_dataset("adversarial_qa", "adversarialQA",  split="train", trust_remote_code=True)
  .select_columns(["context", "question"])
  .rename_columns({"context": "source", "question": "target"})
  .map(add_dataset_name, fn_kwargs={"name": "adversarial"})
)

adversarial_dataset = (
  adversarial_dataset.filter(contain_question_mark)
  .map(normalise)
  .map(categorise_dataset)
  .filter(remove_na_category)
)

print_distribution(adversarial_dataset)

Map:   0%|          | 0/30000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/30000 [00:00<?, ? examples/s]

Map:   0%|          | 0/28385 [00:00<?, ? examples/s]

Map:   0%|          | 0/28385 [00:00<?, ? examples/s]

Filter:   0%|          | 0/28385 [00:00<?, ? examples/s]

Filter:   0%|          | 0/26519 [00:00<?, ? examples/s]

Filter:   0%|          | 0/26519 [00:00<?, ? examples/s]

Filter:   0%|          | 0/26519 [00:00<?, ? examples/s]

Filter:   0%|          | 0/26519 [00:00<?, ? examples/s]

Filter:   0%|          | 0/26519 [00:00<?, ? examples/s]

method distribution = 3.8161318300086733%, count = 1012
description distribution = 58.380783589124775%, count = 15482
explanation distribution = 2.8432444662317584%, count = 754
recall distribution = 34.95984011463479%, count = 9271
NA distribution = 0.0%, count = 0


In [24]:
narrative_dataset = (
  load_dataset("narrativeqa", split="train",trust_remote_code=True)
  .select_columns(["document", "question"])
  .map(
    lambda x: {
        "document": x["document"]["summary"]["text"],
        "question": x["question"]["text"],
    }
  )
  .rename_columns({"document": "source", "question": "target"})
  .map(fix_encoding_errors)
  .map(add_dataset_name, fn_kwargs={"name": "narrative"})
)

narrative_dataset = (
  narrative_dataset.filter(contain_question_mark)
  .map(normalise)
  .map(categorise_dataset)
  .filter(remove_na_category)
)

print_distribution(narrative_dataset)

Resolving data files:   0%|          | 0/24 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/24 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/18 [00:00<?, ?it/s]

method distribution = 6.675451565211802%, count = 2077
description distribution = 42.05180947483448%, count = 13084
explanation distribution = 9.847657003278266%, count = 3064
recall distribution = 41.425081956675456%, count = 12889
NA distribution = 0.0%, count = 0


In [25]:
sciq_dataset = (
  load_dataset("sciq", split="train", trust_remote_code=True)
  .select_columns(["support", "question"])
  .rename_columns({"support": "source", "question": "target"})
  .filter(lambda x: x["source"] != "")
  .map(add_dataset_name, fn_kwargs={"name": "sciq"})
)

sciq_dataset = (
  sciq_dataset.filter(contain_question_mark)
  .map(normalise)
  .map(categorise_dataset)
  .filter(remove_na_category)
)

print_distribution(sciq_dataset)

method distribution = 0.9322373696872494%, count = 93
description distribution = 87.48997594226142%, count = 8728
explanation distribution = 0.5112269446672013%, count = 51
recall distribution = 11.066559743384122%, count = 1104
NA distribution = 0.0%, count = 0
