# AIMS.au Data Split Analysis

This notebook analyzes the data split done by the data module. For a more simple demo showing
how to parse the statement dataset, see [this notebook](./data_parsing_demo.ipynb).

This notebook was last updated on 2024-05-07 for framework v0.5.2.

In [None]:
import tqdm

import qut01

## Gold subset definition and analysis

In [None]:
dataset = qut01.data.dataset_parser.get_deeplake_dataset(
    checkout_branch=qut01.data.dataset_parser.dataset_validated_branch_name,
)
data_parser = qut01.data.dataset_parser.DataParser(dataset)

In [None]:
reserved_gold_clusters = qut01.data.split_utils.get_reserved_gold_id_clusters(data_parser)
print(f"number of statement clusters (based on metadata) reserved for the gold set: {len(reserved_gold_clusters)}")
reserved_gold_sids = sorted([sid for cluster in reserved_gold_clusters for sid in cluster])
print(f"number of statements reserved for the gold set: {len(reserved_gold_sids)}")
print(f"reserved gold statement IDs:\n{reserved_gold_sids}")

In [None]:
gold_data = {}
for sid in tqdm.tqdm(reserved_gold_sids, desc="parsing gold set data"):
    dataset_index = data_parser.statement_ids.index(sid)
    processed_data: qut01.data.statement_utils.StatementProcessedData = data_parser.get_processed_data(dataset_index)
    if processed_data.is_fully_validated:
        statement_metadata = {
            key: val for key, val in processed_data.statement_data.items() if key.startswith("metadata")
        }
        gold_data[sid] = {
            "processed_data": processed_data,
            "metadata": statement_metadata,
        }
expected_valid_gold_sids = [
    sid for c in qut01.data.split_utils.get_validated_gold_id_clusters(data_parser) for sid in c
]
assert set(gold_data.keys()) == set(expected_valid_gold_sids)
print(f"found {len(gold_data)} gold + fully validated statements:\n{expected_valid_gold_sids}")

## Train/valid/test subsets definition and analysis

In [None]:
subset_sids_map = qut01.data.split_utils.get_split_statement_ids(
    data_parser=data_parser,
    classif_setup="any",  # produces the split for all annotation types
    train_valid_split_ratios={
        "train": 0.8,
        "valid": 0.2,
    },
    train_valid_split_seed=0,
)
for subset_name, subset_sids in subset_sids_map.items():
    if subset_name == "gold":
        valid_gold_count = sum([sid in subset_sids_map["valid"] for sid in subset_sids])
        test_gold_count = sum([sid in subset_sids_map["test"] for sid in subset_sids])
        print(
            f"{subset_name} set has {len(subset_sids)} statements "
            f"({valid_gold_count} in valid set, {test_gold_count} in test set)"
        )
    else:
        print(f"{subset_name} set has {len(subset_sids)} statements")
qut01.data.split_utils.validate_split({n: sids for n, sids in subset_sids_map.items() if n != "gold"})