# Loading the raw dataset

In [None]:

from datasets import load_dataset

dataset = load_dataset("json", data_files="./datasets/raw.jsonl", split="train")

dataset

# Filtering the short abstracts (< 30 words)

In [None]:
def abstract_has_30_words(x: dict) -> bool:
  if (len(x["abstract"]) == 0):
    return False
  
  words = " ".join(x["abstract"]).split()

  return len(words) >= 30

dataset = dataset.filter(abstract_has_30_words)

dataset

# Randomly sample 5 balanced datasets

In [None]:
from datasets import DatasetDict, Dataset, concatenate_datasets

# helper function to split a given number (n: int) or by a given % (n: float).
def split_dataset(d: Dataset, n: int|float, seed = 42) -> tuple[Dataset, Dataset]:
  splitted = d.train_test_split(n, seed=seed)
  return splitted["test"], splitted["train"]

# filters for each label.
hh_pos_filter = lambda x: x["type"] == "hh" and x["is_selected"]
vh_pos_filter = lambda x: x["type"] == "vh" and x["is_selected"]
xx_neg_filter = lambda x: not x["is_selected"]

def sample_balanced_datasets(dataset: Dataset, n = 5, test_size: int|float = 0.1, seed = 42) -> list[DatasetDict]:
  # get the examples of each label.
  hh_pos = dataset.filter(hh_pos_filter)
  vh_pos = dataset.filter(vh_pos_filter)
  xx_neg = dataset.filter(xx_neg_filter)

  # get 10% of positive examples for test dataset.
  hh_pos_test, hh_pos_tmp = split_dataset(hh_pos, test_size, seed)
  vh_pos_test, vh_pos_tmp = split_dataset(vh_pos, test_size, seed)

  # get the numbers of positive examples.
  hh_pos_test_num = len(hh_pos_test)
  vh_pos_test_num = len(vh_pos_test)
  xx_pos_test_num = hh_pos_test_num + vh_pos_test_num

  # get as much negative examples as positive test examples.
  xx_neg_test, xx_neg_tmp = split_dataset(xx_neg, xx_pos_test_num, seed)

  # create the test dataset.
  test_dataset = concatenate_datasets([hh_pos_test, vh_pos_test, xx_neg_test]).shuffle(seed=seed) # (should be shuffled so data is not ordered by labels?)

  # now sample 5 times 10% of tmp datasets for train/validation datasets.
  datasets = []

  for i in range(5):
    hh_pos_validation, hh_pos_train = split_dataset(hh_pos_tmp, hh_pos_test_num, seed + i)
    vh_pos_validation, vh_pos_train = split_dataset(vh_pos_tmp, vh_pos_test_num, seed + i)

    xx_pos_train_num = len(hh_pos_train) + len(vh_pos_train)

    xx_neg_train, xx_neg_tmp = split_dataset(xx_neg_tmp, xx_pos_train_num, seed + i)
    xx_neg_validation, _ = split_dataset(xx_neg_tmp, xx_pos_test_num, seed + i)

    datasets.append(DatasetDict({
      "train": concatenate_datasets([hh_pos_train, vh_pos_train, xx_neg_train]).shuffle(seed=seed + i),
      "validation": concatenate_datasets([hh_pos_validation, vh_pos_validation, xx_neg_validation]).shuffle(seed=seed + i),
      "test": test_dataset,
    }))

  return datasets

# sample 5 balanced datasets.
datasets = sample_balanced_datasets(dataset)


# Inspect the 5 balanced datasets

In [None]:
# inspect each datasets.
for d in datasets:
  len11 = len(d["train"].filter(hh_pos_filter))
  len12 = len(d["train"].filter(vh_pos_filter))
  len13 = len(d["train"].filter(xx_neg_filter))
  slice1 = d["train"].select(range(3))

  len21 = len(d["validation"].filter(hh_pos_filter))
  len22 = len(d["validation"].filter(vh_pos_filter))
  len23 = len(d["validation"].filter(xx_neg_filter))

  len31 = len(d["test"].filter(hh_pos_filter))
  len32 = len(d["test"].filter(vh_pos_filter))
  len33 = len(d["test"].filter(xx_neg_filter))

  print(len11, len12, len13, len11 + len12 + len13)
  print(d["train"][range(3)])

  print(len21, len22, len23, len21 + len22 + len23)
  print(d["validation"][range(3)])

  print(len31, len32, len33, len31 + len32 + len33)
  print(d["test"][range(3)])

# Save all the datasets to disk (full + 5 balanced)

In [None]:
dataset.save_to_disk("datasets/dataset_full.hf")

for i, dataset in enumerate(datasets):
  dataset.save_to_disk(f"datasets/dataset_balanced{i + 1}.hf")