# Statement DataModule Analysis

This notebook showcase how to prepare the train/valid/test partitions, as done for the paper submission.

In [None]:
import itertools
import os
import re
import string
import time

import hydra
import lightning.pytorch as pl
import matplotlib.pyplot as plt
import matplotlib.ticker
import numpy as np
import pandas
import torch
import torchmetrics
import tqdm
from openai import OpenAI

import qut01.utils.config
import qut01.utils.logging

In [None]:
logger = qut01.utils.logging.setup_logging_for_analysis_script()
data_config_name = "statement_sampler.yaml"
logger.info(f"initializing hydra and fetching data config for '{data_config_name}'...")
overrides = [
    f"data={data_config_name}",
    "data.classif_setup=any",
    "data.num_criteria=11",
]
config = qut01.utils.config.init_hydra_and_compose_config(overrides=overrides)
logger.info("initialization complete!")

In [None]:
logger.info(f"Instantiating datamodule: {config.data.datamodule._target_}")  # noqa
datamodule: pl.LightningDataModule = hydra.utils.instantiate(config.data.datamodule)
assert isinstance(datamodule, pl.LightningDataModule), f"unexpected type: {type(datamodule)}"
logger.info("running 'datamodule.prepare_data()'...")
datamodule.prepare_data()
logger.info("running 'datamodule.setup()'...")
datamodule.setup(stage="fit")
logger.info("fetching train data loader...")
dataloader = datamodule.train_dataloader()
logger.info("train data loader ready!")

In [None]:
data = []
tot_count = 0
max_amount = -1
print_data = False

try:
    # for item in tqdm.tqdm(itertools.chain(datamodule.train_dataloader(), datamodule.val_dataloader(), datamodule.test_dataloader())):
    for item in tqdm.tqdm(datamodule.val_dataloader()):
        # comment above and uncommment the next depending on the partition to iterate on
        # for item in tqdm.tqdm(datamodule.train_dataloader()):
        # for item in tqdm.tqdm(datamodule.test_dataloader()):
        for i, sentence_text in enumerate(item["sentence_orig_text"]):
            tot_count += 1
            sentence_statement_id = int(item["statement_id"][i])
            sentence_orig_idxs = item["sentence_orig_idxs"][i]
            assert len(sentence_orig_idxs) == 1
            sentence_orig_idxs = int(sentence_orig_idxs[0])
            text_with_context = item["text"][i]
            target_classes = [int(x) for x in item["relevance"][i, :]]
            assert (
                text_with_context == sentence_text
            ), f"context must be disabled in this experiment. Found '{text_with_context}'"

            if print_data:
                print(f"sentence text: {sentence_text}")
                print(f"target classes: {target_classes}")
                print("\n")

            if max_amount > -1 and tot_count >= max_amount:
                break

        if max_amount > -1 and tot_count >= max_amount:
            break
            print(f"reached the max amount of {max_amount}")
finally:
    print(f"{tot_count} have been parsed")