# Statement DataModule Analysis

This notebook analyzes the data loaded by the statement data module. For a more simple demo showing
how to parse the statement dataset, see [this notebook](./data_parsing_demo.ipynb).

This notebook was last updated on 2024-05-07 for framework v0.5.2.

In [None]:
import itertools
from collections import defaultdict

import hydra
import lightning.pytorch as pl
import matplotlib.pyplot as plt
import matplotlib.ticker
import numpy as np
import tqdm

import qut01

In [None]:
logger = qut01.utils.logging.setup_logging_for_analysis_script()
data_config_name = "statement_sampler.yaml"
logger.info(f"initializing hydra and fetching data config for '{data_config_name}'...")
overrides = [
    f"data={data_config_name}",
    "data.classif_setup=any",
    "data.num_criteria=11",
    "++data.tokenizer._target_=transformers.AutoTokenizer.from_pretrained",
    "++data.tokenizer.pretrained_model_name_or_path=distilbert-base-uncased",
    "utils.default_num_workers=0",
]
config = qut01.utils.config.init_hydra_and_compose_config(overrides=overrides)
logger.info("initialization complete!")

In [None]:
logger.info(f"Instantiating datamodule: {config.data.datamodule._target_}")  # noqa
datamodule: pl.LightningDataModule = hydra.utils.instantiate(config.data.datamodule)
assert isinstance(datamodule, pl.LightningDataModule), f"unexpected type: {type(datamodule)}"
logger.info("running 'datamodule.prepare_data()'...")
datamodule.prepare_data()
logger.info("running 'datamodule.setup()'...")
datamodule.setup(stage="fit")
target_subset_name = "train"
logger.info(f"fetching {target_subset_name} data loader...")
dataloader_getter = getattr(datamodule, f"{target_subset_name}_dataloader")
dataloader = dataloader_getter()
logger.info(f"{target_subset_name} data loader ready!")

In [None]:
sentence_texts = []
sentence_tokens_padded = []
sentence_counts = []
sentence_relevance_arrays = []
sentence_evidence_arrays = []
sentence_statement_ids = []
class_names = None

for batch in tqdm.tqdm(dataloader, desc="parsing sentence text and label data"):
    if batch is None:
        continue
    sentence_texts.extend(batch["sentence_orig_text"])
    sentence_tokens_padded.extend([st for st in batch["sentence_token_ids"].numpy()])
    sentence_counts.append(len(batch["sentence_orig_text"]))
    sentence_relevance_arrays.append(
        np.ma.array(
            data=batch["relevance"].numpy(),
            mask=batch["relevance_dontcare_mask"].numpy(),
        )
    )
    sentence_evidence_arrays.append(
        np.ma.array(
            data=batch["evidence"].numpy(),
            mask=batch["evidence_dontcare_mask"].numpy(),
        )
    )
    sentence_statement_ids.extend(batch["statement_id"])
    if class_names is None:
        class_names = batch["class_names"]
    else:
        assert class_names == batch["class_names"]

padding_token_id = 0  # the un-padding of token sequences below assumes that padding id = 0
sentence_tokens = []
for tokens in sentence_tokens_padded:  # go through all token sequences to remove padding
    pad_positions = np.where(tokens == padding_token_id)[0]
    if pad_positions.size == 0:  # no padding used
        sentence_tokens.append(tokens)  # use whole token array
    else:  # there is some padding
        sentence_tokens.append(tokens[: pad_positions[0]])  # keep slice up to padding
sentence_relevance_arrays = np.ma.concatenate(sentence_relevance_arrays, axis=0)
sentence_evidence_arrays = np.ma.concatenate(sentence_evidence_arrays, axis=0)
statement_ids = set(sentence_statement_ids)
tot_sentence_count = len(sentence_texts)
tot_statement_count = len(statement_ids)
logger.info(f"parsed data for {tot_sentence_count} sentences across {tot_statement_count} statements")

In [None]:
fig, ax = plt.subplots(figsize=[10, 6])  # noqa
ax.hist(sentence_counts, bins=50, color="blue", edgecolor="black")
ax.set_title(
    f"Histogram of sentence counts in batches of {dataloader.batch_size} statements"
    f"\n(min={np.min(sentence_counts)}, mean={np.mean(sentence_counts):.1f}, "
    f"std={np.std(sentence_counts):.1f}, max={np.max(sentence_counts)})"
)
ax.set_xlabel("Sentence count")
ax.set_ylabel("Frequency")
ax.grid(axis="y", alpha=0.75)
ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True))
plt.tight_layout()
plt.show()

In [None]:
sentence_char_lengths = [len(s) for s in sentence_texts]
sentence_token_lengths = [len(s) for s in sentence_tokens]

fig, axs = plt.subplots(nrows=2, ncols=2, figsize=[14, 12])  # noqa
nbins = 80

axs[0, 0].hist(sentence_char_lengths, bins=nbins, color="orange", edgecolor="black")
axs[0, 0].set_xlabel("Sentence length (in chars)")
axs[0, 0].set_ylabel("Frequency")
axs[0, 0].grid(axis="y", alpha=0.75)
axs[0, 0].yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True))

axs[0, 1].hist(sentence_char_lengths, bins=nbins, color="orange", edgecolor="black", log=True)
axs[0, 1].set_xlabel("Sentence length (in chars)")
axs[0, 1].set_ylabel("Frequency (log)")
axs[0, 1].grid(axis="y", alpha=0.75)

axs[1, 0].hist(sentence_token_lengths, bins=nbins, color="red", edgecolor="black")
axs[1, 0].set_xlabel("Sentence length (in tokens)")
axs[1, 0].set_ylabel("Frequency")
axs[1, 0].grid(axis="y", alpha=0.75)
axs[1, 0].yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True))

axs[1, 1].hist(sentence_token_lengths, bins=nbins, color="red", edgecolor="black", log=True)
axs[1, 1].set_xlabel("Sentence length (in tokens)")
axs[1, 1].set_ylabel("Frequency (log)")
axs[1, 1].grid(axis="y", alpha=0.75)

fig.suptitle(
    f"Histogram of sentence lengths across {tot_statement_count} statements"
    f"\n(chars: min={np.min(sentence_char_lengths)}, mean={np.mean(sentence_char_lengths):.1f}, "
    f"std={np.std(sentence_char_lengths):.1f}, max={np.max(sentence_char_lengths)})"
    f"\n(tokens: min={np.min(sentence_token_lengths)}, mean={np.mean(sentence_token_lengths):.1f}, "
    f"std={np.std(sentence_token_lengths):.1f}, max={np.max(sentence_token_lengths)})"
)
plt.tight_layout()
plt.show()

In [None]:
assert sentence_relevance_arrays.ndim == 2 and sentence_evidence_arrays.ndim == 2
assert len(sentence_relevance_arrays) == len(sentence_evidence_arrays)
assert sentence_relevance_arrays.shape[1] == len(class_names)
assert sentence_evidence_arrays.shape[1] == len(class_names)
assert sentence_relevance_arrays.count() > 0 and sentence_evidence_arrays.count() > 0
valid_relevance_labels = sentence_relevance_arrays.count()
valid_evidence_labels = sentence_evidence_arrays.count()
valid_relevance_label_ratio = valid_relevance_labels / sentence_relevance_arrays.size
valid_evidence_label_ratio = valid_evidence_labels / sentence_evidence_arrays.size
assert tuple(np.unique(sentence_relevance_arrays.compressed())) == (0, 1)  # expect hard binary labels
assert tuple(np.unique(sentence_evidence_arrays.compressed())) == (0, 1)  # expect hard binary labels
# by definition, all sentences that are "irrelevant" should have dontcare for evidence
irrelevant_evidence_flags = sentence_evidence_arrays[sentence_relevance_arrays.mask].flatten()
assert irrelevant_evidence_flags.count() == 0
positive_relevance_labels = np.ma.sum(sentence_relevance_arrays)
positive_evidence_labels = np.ma.sum(sentence_evidence_arrays)
positive_relevance_label_ratio = positive_relevance_labels / sentence_relevance_arrays.count()
positive_evidence_label_ratio = positive_evidence_labels / sentence_evidence_arrays.count()
print(f"total_labels={sentence_relevance_arrays.size}")
print(f"{valid_relevance_labels=}  ({valid_relevance_label_ratio:.1%})")
print(f"{valid_evidence_labels=}  ({valid_evidence_label_ratio:.1%})")
print(f"{positive_relevance_labels=}  ({positive_relevance_label_ratio:.1%})")
print(f"{positive_evidence_labels=}  ({positive_evidence_label_ratio:.1%})")
class_index = np.arange(len(class_names))
class_pos_relevance_labels = np.ma.sum(sentence_relevance_arrays, axis=0)
class_pos_relevance_ratios = class_pos_relevance_labels / np.ma.count(sentence_relevance_arrays, axis=0)
class_pos_evidence_labels = np.ma.sum(sentence_evidence_arrays, axis=0)
class_pos_evidence_ratios = class_pos_evidence_labels / np.ma.count(sentence_evidence_arrays, axis=0)

In [None]:
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=[14, 6])  # noqa
bar_width = 0.35

axs[0].bar(class_index, class_pos_relevance_ratios, bar_width, color="blue", label="Positive")
bars = axs[0].bar(class_index + bar_width, 1 - class_pos_relevance_ratios, bar_width, color="red", label="Negative")
axs[0].set_xticks(class_index + bar_width / 2)
axs[0].set_xticklabels(class_names, rotation=45)
axs[0].set_ylabel("Relevance (percentage)")
axs[0].grid(axis="y", alpha=0.75)
for bar_idx, bar in enumerate(bars):
    height = bar.get_height()
    axs[0].text(
        bar.get_x() + bar.get_width() / 2,
        height / 2,
        f"{1 - class_pos_relevance_ratios[bar_idx]:.1%}",
        ha="center",
        va="bottom",
        rotation=90,
    )

bars = axs[1].bar(class_index, class_pos_evidence_ratios, bar_width, color="blue", label="Positive")
axs[1].bar(class_index + bar_width, 1 - class_pos_evidence_ratios, bar_width, color="red", label="Negative")
axs[1].set_xticks(class_index + bar_width / 2)
axs[1].set_xticklabels(class_names, rotation=45)
axs[1].set_ylabel("Evidence (percentage)")
axs[1].grid(axis="y", alpha=0.75)
for bar_idx, bar in enumerate(bars):
    height = bar.get_height()
    axs[1].text(
        bar.get_x() + bar.get_width() / 2,
        height / 2,
        f"{class_pos_evidence_ratios[bar_idx]:.1%}",
        ha="center",
        va="bottom",
        rotation=90,
    )

handles, labels = axs[1].get_legend_handles_labels()
fig.legend(handles, labels, loc="lower center")
fig.suptitle(f"Distribution of positive/negative labels across {tot_sentence_count} sentences")
plt.tight_layout()
plt.show()