# AIMS.au Sentence DataModule Shuffle Analysis

This notebook analyzes the data loaded by the sentence data module. For a more simple demo showing
how to parse the statement dataset, see [this notebook](./data_parsing_demo.ipynb).

This notebook was last updated on 2024-05-07 for framework v0.5.2.

In [None]:
import typing

import hydra
import lightning.pytorch as pl
import numpy as np
import tqdm

import qut01

In [None]:
logger = qut01.utils.logging.setup_logging_for_analysis_script()
data_config_name = "sentence_sampler.yaml"
logger.info(f"initializing hydra and fetching data config for '{data_config_name}'...")
overrides = [
    f"data={data_config_name}",
    "data.classif_setup=any",
    "data.num_criteria=11",
]
config = qut01.utils.config.init_hydra_and_compose_config(overrides=overrides)
logger.info("initialization complete!")

In [None]:
logger.info(f"Instantiating datamodule: {config.data.datamodule._target_}")  # noqa
datamodule: pl.LightningDataModule = hydra.utils.instantiate(config.data.datamodule)
assert isinstance(datamodule, pl.LightningDataModule), f"unexpected type: {type(datamodule)}"
logger.info("running 'datamodule.prepare_data()'...")
datamodule.prepare_data()
logger.info("running 'datamodule.setup()'...")
datamodule.setup(stage="fit")
logger.info("fetching train data loader...")
dataloader = datamodule.train_dataloader()
logger.info("train data loader ready!")

Note: we do the shuffle analysis in two stages, first for the "beginning" of the dataloader loop
(where shuffling will likely be worse due to the buffer being filled), and for "later" in the loop
(where the buffer should have been filled).

In [None]:
def print_shuffle_stats(sentence_ids_: typing.List[str], prefix_str: str) -> None:
    statement_ids = [sid.split(":")[0] for sid in sentence_ids_]
    unique_statement_ids, unique_counts = np.unique(statement_ids, return_counts=True)
    top_idx = np.argmax(unique_counts)
    most_common_statement_id = statement_ids[top_idx]
    unique_statement_ratio = len(unique_statement_ids) / len(sentence_ids_)
    duplicated_sentence_count = len(sentence_ids_) - len(set(sentence_ids_))  # happens due to multiple annotations
    duplicated_sentence_ratio = duplicated_sentence_count / len(sentence_ids_)
    print(prefix_str)
    print(f"\tunique statement count: {len(unique_statement_ids)} (higher is better)")
    print(f"\tunique statement ratio: {unique_statement_ratio:0.2%} (higher is better)")
    print(f"\tmost common statement: '{most_common_statement_id}' ({unique_counts[top_idx]} sentences)")
    print(f"\tduplicated sentence count: {duplicated_sentence_count} (lower is better)")
    print(f"\tduplicated sentence ratio: {duplicated_sentence_ratio:0.1%} (lower is better)")


min_sentence_count = 256  # should be the sentence count you expect to use as batch size to train your model

sentence_ids = []
for batch in tqdm.tqdm(dataloader, desc="sampling sentences"):
    assert "batch_id" in batch, "need batch (sentence) identifiers for analysis"
    sentence_ids.extend([sid for sid in batch["batch_id"]])
    if len(sentence_ids) >= min_sentence_count:
        break
assert all([isinstance(sid, str) for sid in sentence_ids]), "unexpected sentence id types (should all be strings?)"
assert all(
    [sid.startswith("statement") and ":sentence" in sid for sid in sentence_ids]
), "unexpected sentence id string formatting (should be statementXXXXX:sentenceYYYYY"
print_shuffle_stats(
    sentence_ids_=sentence_ids,
    prefix_str=f"SHUFFLE STATS FOR THE FIRST {len(sentence_ids)} SENTENCES:",
)

skip_sentence_count = 100000  # we'll skip over these, and re-evaluate shuffling afterwards
skipped_sentence_count = 0
sentence_ids = []
for batch in tqdm.tqdm(dataloader, desc="resampling new sentences"):
    assert "batch_id" in batch, "need batch (sentence) identifiers for analysis"
    if skipped_sentence_count < skip_sentence_count:
        skipped_sentence_count += len(batch["batch_id"])
        continue
    sentence_ids.extend([sid for sid in batch["batch_id"]])
    if len(sentence_ids) >= min_sentence_count:
        break
print_shuffle_stats(
    sentence_ids_=sentence_ids,
    prefix_str=f"SHUFFLE STATS FOR {len(sentence_ids)} SENTENCES AFTER SKIPPING {skipped_sentence_count}:",
)