In [56]:
from dataclasses import dataclass, field
import os
from datasets import load_dataset, concatenate_datasets
from transformers import set_seed, HfArgumentParser

In [57]:
def contain_question_mark(data):
    return data["target"][-1].rstrip() == "?"

def contain_unique_question_context(data, unique_sources):
    if data["source"] not in unique_sources:
        unique_sources.add(data["source"])
        return True
    return False

def normalise(data):
    # Lowercase the text
    data["source"] = data["source"].lower()
    data["target"] = data["target"].lower()

    # Remove new line characters
    data["source"] = data["source"].replace("\n", " ")

    return data

def categorise_dataset(data):
    if any(word in data['target'] for word in ["what"]):
        data['category'] = "description"
    elif any(word in data['target'] for word in ["where", "when", "who", "how many", "how much", "which", "how long"]):
        data['category'] = "recall" 
    elif any(word in data['target'] for word in ["how did", "how does", "how do", "compute", "calculate"]):
        data['category'] = "method"
    elif any(word in data['target'] for word in ["why"]):
        data['category'] = "explanation"
    elif any(word in data['target'] for word in ["compare", "difference"]):
        data['category'] = "comparison" 
    

    return data

def print_distribution(dataset):
    method_ds = dataset.filter(lambda data: data["category"] == "method")
    description_ds = dataset.filter(lambda data: data["category"] == "description")
    explanation_ds = dataset.filter(lambda data: data["category"] == "explanation")
    comparison_ds = dataset.filter(lambda data: data["category"] == "comparison")
    recall_ds = dataset.filter(lambda data: data["category"] == "recall")

    na_ds = dataset.filter(lambda data: data["category"] == "NA")


    print("description distribution =" + str( len(description_ds) / len(dataset) * 100) + "%, count = " + str(len(description_ds)))
    print("recall distribution = " + str( len(recall_ds) / len(dataset) * 100) + "%, count = " + str(len(recall_ds)))
    print("explanation distribution = " + str( len(explanation_ds) / len(dataset) * 100) + "%, count = " + str(len(explanation_ds)))
    print("method distribution = " + str( len(method_ds) / len(dataset) * 100) + "%, count = " + str(len(method_ds)))
    print("comparison distribution = " + str( len(comparison_ds) / len(dataset) * 100) + "%, count = " + str(len(comparison_ds)))
    print("na distribution = " + str( len(na_ds) / len(dataset) * 100) + "%, count = " + str(len(na_ds)))

    comparison_ds.to_csv("data/comparison.csv")
    description_ds.to_csv("data/description.csv")
    recall_ds.to_csv("data/recall.csv")
    na_ds.to_csv("data/na.csv")






In [58]:
sciq_data = (
        load_dataset("sciq", split="train+validation+test")
        .select_columns(["support", "question"])
        .rename_columns({"support": "source", "question": "target"})
    )

dataset = concatenate_datasets(
  [sciq_data]
)

dataset = dataset.filter(contain_question_mark)
unique_sources = set()

dataset = dataset.map(normalise)

# add column to dataset for taxonomy
dataset = dataset.add_column("category", ["NA"] * len(dataset))

dataset = dataset.map(categorise_dataset)


Downloading readme:   0%|          | 0.00/7.02k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.99M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/339k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/343k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/11679 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/13679 [00:00<?, ? examples/s]

Map:   0%|          | 0/13525 [00:00<?, ? examples/s]

Map:   0%|          | 0/13525 [00:00<?, ? examples/s]

In [59]:
print_distribution(dataset)



dataset.to_csv("data/dataset_tax.csv")


Filter:   0%|          | 0/13525 [00:00<?, ? examples/s]

Filter:   0%|          | 0/13525 [00:00<?, ? examples/s]

Filter:   0%|          | 0/13525 [00:00<?, ? examples/s]

Filter:   0%|          | 0/13525 [00:00<?, ? examples/s]

Filter:   0%|          | 0/13525 [00:00<?, ? examples/s]

Filter:   0%|          | 0/13525 [00:00<?, ? examples/s]

description distribution =84.06654343807763%, count = 11370
recall distribution = 10.402957486136783%, count = 1407
explanation distribution = 0.48798521256931615%, count = 66
method distribution = 0.7837338262476895%, count = 106
comparison distribution = 0.04436229205175601%, count = 6
na distribution = 4.214417744916821%, count = 570


Creating CSV from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/12 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/14 [00:00<?, ?ba/s]

6883171