In [84]:
from dataclasses import dataclass, field
import os
from datasets import load_dataset, concatenate_datasets
from transformers import set_seed, HfArgumentParser

In [85]:
def contain_question_mark(data):
    return data["target"][-1].rstrip() == "?"

def contain_unique_question_context(data, unique_sources):
    if data["source"] not in unique_sources:
        unique_sources.add(data["source"])
        return True
    return False

def normalise(data):
    # Lowercase the text
    data["source"] = data["source"].lower()
    data["target"] = data["target"].lower()

    # Remove new line characters
    data["source"] = data["source"].replace("\n", " ")

    return data

def categorise_dataset(data):
    if any(word in data['target'] for word in ["what", "where", "when", "who"]):
        data['category'] = "description"
    elif any(word in data['target'] for word in ["how did", "how does", "how do", "compute", "calculate"]):
        data['category'] = "method"
    elif any(word in data['target'] for word in ["why"]):
        data['category'] = "explanation"
    elif any(word in data['target'] for word in ["compare", "difference"]):
        data['category'] = "comparison" 

    return data


In [86]:
squad_data = (
  load_dataset("squad", split="train+validation")
  .select_columns(["context", "question"])
  .rename_columns({"context": "source", "question": "target"})
)

dataset = concatenate_datasets(
  [squad_data]
)

dataset = dataset.filter(contain_question_mark)
unique_sources = set()
dataset = dataset.filter(
  contain_unique_question_context,
  fn_kwargs={"unique_sources": unique_sources},
)

dataset = dataset.map(normalise)

# add column to dataset for taxonomy
dataset = dataset.add_column("category", ["NA"] * len(dataset))

dataset = dataset.map(categorise_dataset)


In [87]:

dataset.to_csv("data/dataset_tax.csv")


Creating CSV from Arrow format:   0%|          | 0/21 [00:00<?, ?ba/s]

16900379