# Statement DataModule Analysis

This notebook analyzes the data loaded by the statement data module. For a more simple demo showing
how to parse the statement dataset, see [this notebook](./data_parsing_demo.ipynb).

This notebook was last updated on 2024-04-02 for framework v0.4.0.

In [None]:
import itertools
import os
import re
import string
import time

import hydra
import lightning.pytorch as pl
import matplotlib.pyplot as plt
import matplotlib.ticker
import numpy as np
import pandas
import torch
import torchmetrics
import tqdm
from openai import OpenAI

import qut01.utils.config
import qut01.utils.logging

In [None]:
logger = qut01.utils.logging.setup_logging_for_analysis_script()
data_config_name = "statement_sampler.yaml"
logger.info(f"initializing hydra and fetching data config for '{data_config_name}'...")
overrides = [
    f"data={data_config_name}",
    "data.classif_setup=any",
    "data.num_criteria=11",
]
config = qut01.utils.config.init_hydra_and_compose_config(overrides=overrides)
logger.info("initialization complete!")

In [None]:
logger.info(f"Instantiating datamodule: {config.data.datamodule._target_}")  # noqa
datamodule: pl.LightningDataModule = hydra.utils.instantiate(config.data.datamodule)
assert isinstance(datamodule, pl.LightningDataModule), f"unexpected type: {type(datamodule)}"
logger.info("running 'datamodule.prepare_data()'...")
datamodule.prepare_data()
logger.info("running 'datamodule.setup()'...")
datamodule.setup(stage="fit")
logger.info("fetching train data loader...")
dataloader = datamodule.train_dataloader()
logger.info("train data loader ready!")

In [None]:
client = OpenAI()
punct_remover = str.maketrans("", "", string.punctuation)

no_regex = re.compile("no[^0-9a-zA-Z]")
yes_regex = re.compile("yes[^0-9a-zA-Z]")

openai_model = "gpt-3.5-turbo-0125"
# openai_model = "gpt-4o"

prompt = """You are an analyst that inspects modern slavery declarations made by Australian reporting entities. You are specialized in the analysis of statements made with respect to the Australian Modern Slavery Act of 2018, and not of any other legislation.

You are currently looking for sentences in statements that describe the SUPPLY CHAINS of an entity, where supply chains refer to the sequences of processes involved in the procurement of products and services (including labour) that contribute to the reporting entity's own products and services. The description of a supply chain can be related, for example, to 1) the products that are provided by suppliers; 2) the services provided by suppliers, or 3) the location, category, contractual arrangement, or other attributes that describe the suppliers. Any sentence that contains these kinds of information is considered relevant. Descriptions that apply to indirect suppliers (i.e. suppliers-of-suppliers) are considered relevant. Descriptions of the supply chains of entities owned or controlled by the reporting entity making the statement are also considered relevant. However, descriptions of 'downstream' supply chains, i.e. of how customers and clients of the reporting entity use its products or services, are NOT considered relevant. Finally, sentences that describe how the reporting entity lacks information on some of its supply chain, or how some of its supply chains are still unmapped or unidentified, are also considered relevant.

Given the above definitions of what constitutes a relevant sentence, you will need to determine if a target sentence is relevant or not. You must avoid labeling sentences with only vague descriptions or corporate talk (and no actual information) as relevant. The answer you provide regarding whether the sentence is relevant or not can only be 'YES' or 'NO', and nothing else.

The target sentence to classify is the following:
------------
{}
------------

Is the target sentence relevant? (YES/NO)"""
# If YES, tell me why you think it is relevant"""


def gpt_classify(target_sentence, debug=False):

    current_prompt = prompt.format(target_sentence)

    cp = client.chat.completions.create(
        messages=[
            {
                "role": "user",
                "content": current_prompt,
            }
        ],
        model=openai_model,
    )
    assert len(cp.model_dump()["choices"]) == 1
    result = cp.model_dump()["choices"][0]["message"]["content"]

    massaged_result = result.lower().translate(punct_remover).strip()

    if massaged_result == "no" or no_regex.match(massaged_result):
        final_result = 0
    elif massaged_result == "yes" or yes_regex.match(massaged_result):
        final_result = 1
    else:
        final_result = -1
        print(f"GPT result '{result}' not ok (after massage, it is '{massaged_result}')")

    if debug:
        print(current_prompt)
        print(result)
        print(final_result)
        print("\n\n")

    return final_result, result, massaged_result


print(prompt)

target_sentence = "Baby Bunting Modern Slavery Statement 2020 This statement, pursuant to the Modern Slavery Act 2018 (Cth), describes the risks of modern slavery in the operations and supply chains of Baby Bunting1 and includes information about actions taken to address those risks for the financial year ended 28 June 2020"

# cp = gpt_classify(target_sentence, debug=True)
# print(cp)

In [None]:
def mock_classify(sentence_text, debug=False):
    print(sentence_text)
    return -1, "raw example text", "processed example text"

In [None]:
class_of_interest = "c2 (supply chains)"

example_item = next(datamodule.train_dataloader().__iter__())
class_names = example_item["class_names"]
amount_of_classes = len(class_names)
class_of_interest_index = class_names.index(class_of_interest)
print(f"class names: {class_names}")
print(f"class index for {class_of_interest} is {class_of_interest_index}")

In [None]:
count_sentence_was_annotated_with_class_of_interest = 0
count_sentence_was_not_annotated_with_class_of_interest = 0
data = []
already_there = 0
annotated = 0
not_ok_results = 0
not_ok_details = []
tot_count = 0

# THESE PARAMETERS SHOULD BE MODIFIED TO REFLECT THE WANTED BVEHAVIOUR
max_amount = -1  # 2000  # 1000 # set this to limit the computation
print_count = True
inter_query_wait_time_in_sec = 0.5  # set to 0 to not wait

prev_df = pandas.read_csv(f"result_for_class_index_{class_of_interest_index}.csv")
# prev_df = None

# predict_fun = mock_classify
predict_fun = gpt_classify

# use debug for GPT
debug = False  # True

try:  # to handle ChatGPT possible exceptions but still save results
    # for item in tqdm.tqdm(itertools.chain(datamodule.train_dataloader(), datamodule.val_dataloader(), datamodule.test_dataloader())):
    for item in tqdm.tqdm(datamodule.val_dataloader()):
        for i, sentence_text in enumerate(item["sentence_orig_text"]):
            tot_count += 1
            sentence_statement_id = int(item["statement_id"][i])
            sentence_orig_idxs = item["sentence_orig_idxs"][i]
            assert len(sentence_orig_idxs) == 1
            sentence_orig_idxs = int(sentence_orig_idxs[0])
            text_with_context = item["text"][i]
            assert (
                text_with_context == sentence_text
            ), f"context must be disabled in this experiment. Found '{text_with_context}'"

            if (prev_df is not None) and (
                (prev_df["sentence_statement_id"].isin([sentence_statement_id]))
                & (prev_df["sentence_orig_idxs"].isin([sentence_orig_idxs]))
            ).any():
                rows = prev_df[
                    (prev_df["sentence_statement_id"].isin([sentence_statement_id]))
                    & (prev_df["sentence_orig_idxs"].isin([sentence_orig_idxs]))
                ]
                if len(rows) != 1:
                    print(rows)
                    raise ValueError(f"found multiple rows: {sentence_statement_id} : {sentence_orig_idxs}")
                # import pdb; pdb.set_trace()
                row = rows.iloc[0]
                data.append(
                    [row.sentence_statement_id, row.sentence_orig_idxs, row.target_classes, row.predicted_classes]
                )
                already_there += 1
            else:
                target_classes = [int(x) for x in item["relevance"][i, :]]
                predicted_classes = [-1] * len(class_names)
                if target_classes[class_of_interest_index] > -1:

                    predicted_class, raw_result, processed_result = predict_fun(sentence_text, debug=debug)
                    time.sleep(inter_query_wait_time_in_sec)

                    predicted_classes[class_of_interest_index] = predicted_class
                    count_sentence_was_annotated_with_class_of_interest += 1

                    if predicted_class == -1:
                        not_ok_results += 1
                        not_ok_details.append([sentence_statement_id, sentence_orig_idxs, raw_result, processed_result])

                else:
                    count_sentence_was_not_annotated_with_class_of_interest += 1
                data.append([sentence_statement_id, sentence_orig_idxs, target_classes, predicted_classes])
                annotated += 1
            # break
            if max_amount > -1 and tot_count >= max_amount:
                break
            if print_count:
                print(f"done {tot_count} / {max_amount}")
        if max_amount > -1 and tot_count >= max_amount:
            break
            print(f"reached the max amount of {max_amount}")
        # break
finally:
    df = pandas.DataFrame(
        data, columns=["sentence_statement_id", "sentence_orig_idxs", "target_classes", "predicted_classes"]
    )
    print(
        f"{already_there} sentences already in cache. {annotated} have been annotated.\nOf these,"
        f"{count_sentence_was_annotated_with_class_of_interest} had annotations for {class_of_interest}, and {count_sentence_was_not_annotated_with_class_of_interest} without."
        f"\nTotal is {count_sentence_was_annotated_with_class_of_interest + count_sentence_was_not_annotated_with_class_of_interest}"
        f"\nNot ok results are {not_ok_results}"
    )
    df.to_csv(f"result_for_class_index_{class_of_interest_index}.csv")

    not_ok_df = pandas.DataFrame(
        not_ok_details, columns=["statement_id", "sentence_orig_idxs", "raw_result", "processed_result"]
    )
    not_ok_df.to_csv(f"not_ok_for_class_index_{class_of_interest_index}.csv")

In [None]:
# debugging
# prev_df = pandas.read_csv(f"result_for_class_index_{class_of_interest_index}.csv")
# prev_df[prev_df['sentence_statement_id'].isin([61]) & prev_df['sentence_orig_idxs'].isin([70])]

In [None]:
df = pandas.read_csv(f"result_for_class_index_{class_of_interest_index}.csv")


def parse_list(list_as_str):
    stripped = list_as_str.strip("[]")
    lst = [int(x) for x in stripped.split(",")]
    assert len(lst) == amount_of_classes
    return lst


preds = [parse_list(e)[class_of_interest_index] for e in list(df["predicted_classes"])]
targets = [parse_list(e)[class_of_interest_index] for e in list(df["target_classes"])]
stat_ids = df["sentence_statement_id"]
sent_ids = df["sentence_orig_idxs"]
assert len(preds) == len(targets) == len(stat_ids) == len(sent_ids)

# note: preds are casted to 0 when they are -1, otherwise torchmertric will complain
# this is ok because we still keep the target with the -1, so they will be ignored
# Corrected version
fixed_preds = [x if x > -1 else 0 for x in preds]

In [None]:
# print(preds)
# print(targets)

In [None]:
f1_fun = torchmetrics.classification.F1Score(task="binary", ignore_index=-1)
f1 = f1_fun(torch.tensor(fixed_preds), torch.tensor(targets))
p_fun = torchmetrics.classification.Precision(task="binary", ignore_index=-1)
p = p_fun(torch.tensor(fixed_preds), torch.tensor(targets))
r_fun = torchmetrics.classification.Recall(task="binary", ignore_index=-1)
r = r_fun(torch.tensor(fixed_preds), torch.tensor(targets))

acc_fun = torchmetrics.classification.Accuracy(task="binary", ignore_index=-1)
acc = acc_fun(torch.tensor(fixed_preds), torch.tensor(targets))

In [None]:
print(f"example amount is {len(preds)}")
print(f"precision is {p:.3f}, recall is {r:.3f}, f1 is {f1:.3f}")
print(f"accuracy is {acc:.3f}")

# Text analysis

In [None]:
from collections import defaultdict


def index_data(dataloader):
    data_as_dict = defaultdict(list)

    for item in tqdm.tqdm(dataloader):
        for i, sentence_text in enumerate(item["sentence_orig_text"]):
            sentence_statement_id = int(item["statement_id"][i])
            sentence_orig_idxs = item["sentence_orig_idxs"][i]
            assert len(sentence_orig_idxs) == 1
            sentence_orig_idxs = int(sentence_orig_idxs[0])
            data_as_dict[(sentence_statement_id, sentence_orig_idxs)].append(sentence_text)
    return data_as_dict


# train_dataloader_indexed = index_data(datamodule.train_dataloader())
valid_dataloader_indexed = index_data(datamodule.val_dataloader())
# test_dataloader_indexed = index_data(datamodule.test_dataloader())

In [None]:
all_res = [True] * len(preds)
tp = [(p == 1) and (t == 1) for (p, t) in zip(preds, targets)]
tn = [(p == 0) and (t == 0) for (p, t) in zip(preds, targets)]
fp = [(p == 1) and (t == 0) for (p, t) in zip(preds, targets)]
fn = [(p == 0) and (t == 1) for (p, t) in zip(preds, targets)]

print(f"all res is {len(all_res)} / tp is {sum(tp)} / tn is {sum(tn)} / fp is {sum(fp)}/ fn is {sum(fn)}")


def analyse_results(indexed_data, idxs_to_show, idx_from, idx_to):
    count = 0
    for i, should_show in enumerate(idxs_to_show):
        if (i >= idx_from) and (i <= idx_to) and should_show:
            count += 1
            current = indexed_data[(stat_ids[i], sent_ids[i])]
            print(f"{current}")
    print(f"displayed {count} examples")


analyse_results(valid_dataloader_indexed, fp, 0, 6)

# Duplicate check

In [None]:
def count_duplicates(indexed_data):
    count_dup = 0
    count_tot = 0
    for k, v in indexed_data.items():
        if len(v) > 1:
            assert all([v[0] == e for e in v])
            # print(v)
            count_dup += len(v)
        count_tot += len(v)
    print(f"duplicates are there {count_dup} over {count_tot} ({count_dup / count_tot})")


# print("train")
# count_duplicates(train_dataloader_indexed)
print("valid")
count_duplicates(valid_dataloader_indexed)
# print("test")
# count_duplicates(test_dataloader_indexed)

In [None]:
import re

r = re.compile("no[^0-9a-zA-Z]")
string = "no."
if string == "no" or r.match(string):
    print("ok")
else:
    print("no")