In [9]:
import unicodedata

def contain_question_mark(data):
    return data["target"][-1].rstrip() == "?"


def normalise(data):
    # Remove new line characters
    data["source"] = data["source"].replace("\n", " ")

    # Resolve accented characters
    data["source"] = "".join(
        c
        for c in unicodedata.normalize("NFD", data["source"])
        if unicodedata.category(c) != "Mn"
    )
    data["target"] = "".join(
        c
        for c in unicodedata.normalize("NFD", data["target"])
        if unicodedata.category(c) != "Mn"
    )

    return data


def categorise_dataset(data):
    target = data["target"].lower()
    if any(word in target for word in ["what"]):
        data["category"] = "description"
    elif any(
        word in target
        for word in [
            "how did",
            "how does",
            "how do",
            "compute",
            "calculate",
            "how can",
            "how should",
            "how would",
            "how will",
            "how to",
        ]
    ):
        data["category"] = "method"
    elif any(
        word in target
        for word in [
            "where",
            "when",
            "who",
            "how",
            "which",
        ]
    ):
        data["category"] = "recall"
    elif any(word in target for word in ["why"]):
        data["category"] = "explanation"
    else:
        data["category"] = "NA"

    return data


def remove_na_category(data):
    return data["category"] != "NA"


def print_distribution(dataset):
    categories = ["method", "description", "explanation", "recall", "NA"]

    distributions = []
    for category in categories:
        category_ds = dataset.filter(lambda data: data["category"] == category)
        distribution_str = f"{category} distribution = {len(category_ds) / len(dataset) * 100}%, count = {len(category_ds)}"
        distributions.append(distribution_str)

    for d in distributions:
        print(d)

In [10]:
from datasets import load_dataset

mrqa_train_dataset = (
    load_dataset("mrqa", split = "train")
    .select_columns(["context", "question"])
    .rename_columns({"context": "source", "question": "target"})
)


In [11]:
mrqa_train_dataset = (
  mrqa_train_dataset.filter(contain_question_mark)
  .map(normalise)
  .map(categorise_dataset)
  .filter(remove_na_category)
)

In [12]:
print_distribution(mrqa_train_dataset)

method distribution = 0.5861896936260125%, count = 1598
description distribution = 50.007703368940014%, count = 136325
explanation distribution = 0.46403627186289476%, count = 1265
recall distribution = 48.94207066557107%, count = 133420
NA distribution = 0.0%, count = 0
