In [1]:
# TODO: Package
import sys
sys.path.append('/home/tomw/unifi-pdf-llm/')

import pandas as pd
from loguru import logger

from load import load_documents
from preprocess import preprocess_documents
from rag import ModularRAG


TRAIN_CSV_PATH = "/home/tomw/unifi-pdf-llm/data/Train.csv"
"""Path to the Train.csv file."""

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


'Path to the Train.csv file.'

In [12]:
def validate_retrieval(
        company: str,
        year: int,
        type: str="retrieval",
        num: int=50,
        window_size: int=1,
        discard_text: bool=True
    ) -> tuple[pd.DataFrame, float]:
    """
    Returns a DataFrame with the results of the retrieval validation.

    Parameters
    ----------
    company : str
        The company to validate.

    year : int
        The year to validate.

    type : str
        The type of validation test to run. Options are "retrieval" or "nan".
        The "retrieval" test checks the retrieval of values that are present in the
        documents. The "nan" test checks the retrieval of values that are not present
        in the documents (i.e. testing the ability to return 'None' when the value is
        not present).

    window_size : int
        The size of the sliding window to use when slicing tables.

    discard_text : bool
        If True, discard text passages when preprocessing the documents. Only tables
        are kept.

    Returns
    -------
    results_df : pd.DataFrame
        The results of the retrieval validation.

    accuracy : float
        The accuracy of the retrieval validation.

    Raises
    ------
    ValueError
        If the year is not 2019, 2020, or 2021.
    """
    if year not in [2019, 2020, 2021]:
        raise ValueError(f"Unable to validate year: {year}")

    train_df = pd.read_csv(TRAIN_CSV_PATH)

    # Restrict to the company
    train_df = train_df[train_df["ID"].str.contains(f"X_{company}")]
    train_df.reset_index(drop=True, inplace=True)

    # Drop the two columns that we are not interested in
    all_years = ["2021", "2020", "2019"]
    all_years.remove(str(year))
    for _year in all_years:
        train_df.drop(columns=[f"{_year}_Value"], inplace=True)

    if type == "retrieval":
        train_df = train_df.dropna(subset=[f"{year}_Value"], how="all")
    elif type == "nan":
        train_df = train_df[train_df[f"{year}_Value"].isna()]
        # Keep a random sample rows
        # train_df = train_df.sample(n=50)
    else:
        raise ValueError(f"Invalid validation type: {type}")

    train_df = train_df.head(n=num)

    # Load and preprocess the documents
    docs = load_documents(company, year)
    docs = preprocess_documents(
        docs, window_size=window_size, discard_text=discard_text
    )

    logger.debug(f"Number of documents: {len(docs)}")

    query_pipeline = ModularRAG(
        docs=docs,
        company=company,
    )

    results_df = train_df.copy(deep=True)

    # Loop over the rows in the dataframe and retrieve the value for each AMKEY
    for idx, row in train_df.iterrows():
        amkey = int(row["ID"].split("_")[0])

        metric = query_pipeline.retrieve_metric_description(amkey)
        results_df.at[idx, "Metric"] = metric

        value, unvalidated_value = query_pipeline.query(amkey, year)
        results_df.at[idx, f"{year}_Generated"] = value
        results_df.at[idx, f"{year}_Gen_Unvalidated"] = unvalidated_value

    results_df[f"{year}_Value"] = results_df[f"{year}_Value"].astype(float)
    results_df[f"{year}_Generated"] = results_df[f"{year}_Generated"].astype(float)
    results_df["Correct"] = results_df.apply(
        lambda row: (row[f"{year}_Generated"] == row[f"{year}_Value"]) or
        (pd.isna(row[f"{year}_Generated"]) and pd.isna(row[f"{year}_Value"])) or
        (row[f"{year}_Generated"] == -1 and pd.isna(row[f"{year}_Value"])),
        axis=1
    )

    # Reordering the columns
    results_df = results_df[["ID", "Metric", f"{year}_Value", f"{year}_Gen_Unvalidated", f"{year}_Generated", "Correct"]]

    accuracy = results_df["Correct"].sum() / len(results_df)

    logger.info(f"Accuracy: {accuracy}")

    unvalidated_accuracy = results_df.apply(
        lambda row: (row[f"{year}_Gen_Unvalidated"] == row[f"{year}_Value"]) or
        (pd.isna(row[f"{year}_Gen_Unvalidated"]) and pd.isna(row[f"{year}_Value"])),
        axis=1
    ).sum() / len(results_df)

    logger.info(f"Unvalidated accuracy: {unvalidated_accuracy}")

    return results_df, accuracy, unvalidated_accuracy


In [14]:
results_df, accuracy, unvalidated_accuracy = validate_retrieval("Tongaat", 2021, type="retrieval", num=10, window_size=2)

2024-02-27 20:36:08.354 | INFO     | load:load_documents:62 - Loading documents from /home/tomw/unifi-pdf-llm/data/azureconverter_outputs/2021ESG_removed_sup_table.json
2024-02-27 20:36:10.017 | DEBUG    | __main__:validate_retrieval:79 - Number of documents: 678
2024-02-27 20:36:10.018 | INFO     | rag:initialise_document_store:103 - Initialising document store
2024-02-27 20:36:10.023 | INFO     | rag:initialise_retriever:113 - Initialising retriever
Batches: 100%|██████████| 22/22 [00:00<00:00, 22.62it/s]ocs/s]
Documents Processed: 10000 docs [00:00, 10008.02 docs/s]       
2024-02-27 20:36:12.212 | INFO     | rag:initialise_generation_llm:126 - Initialising generation LLM
2024-02-27 20:36:12.212 | INFO     | rag:initialise_unit_conversion_llm:134 - Initialising unit conversion LLM
2024-02-27 20:36:12.213 | INFO     | rag:initialise_json_conversion_llm:142 - Initialising json conversion LLM
2024-02-27 20:36:12.213 | INFO     | rag:initialise_relevant_context_llm:150 - Initialising re

In [15]:
results_df

Unnamed: 0,ID,Metric,2021_Value,2021_Gen_Unvalidated,2021_Generated,Correct
7,12_X_Tongaat,Total injury frequency rate (TIFR) – employees...,1.331,1.331,1.331,True
18,28_X_Tongaat,Total – company managed/farmed land (owned and...,60204.0,52883.0,52883.0,False
30,49_X_Tongaat,B-BBEE Level,4.0,4.0,,False
33,52_X_Tongaat,Overall Board and Committee meeting attendance,99.0,99.0,99.0,True
64,114_X_Tongaat,Energy efficiency: total direct and indirect e...,16.63,16.63,16.63,True
71,122_X_Tongaat,"Fatal injury frequency rate (FIFR, i.e. number...",0.005,0.005,0.005,True
76,128_X_Tongaat,Carbon emissions – Scope 1,505575.0,505575.0,505575.0,True
77,129_X_Tongaat,Carbon emissions – Scope 2,51539.0,51539.0,51539.0,True
85,138_X_Tongaat,Hazardous waste disposed of at appropriate fac...,184.0,184.0,184.0,True
94,151_X_Tongaat,"Lost time injury frequency rate (LTIFR, i.e. n...",0.093,0.093,0.093,True


In [17]:
results_df, accuracy, unvalidated_accuracy = validate_retrieval("Tongaat", 2021, type="nan", num=20, window_size=2)

2024-02-27 20:37:45.551 | INFO     | load:load_documents:62 - Loading documents from /home/tomw/unifi-pdf-llm/data/azureconverter_outputs/2021ESG_removed_sup_table.json
2024-02-27 20:37:47.467 | DEBUG    | __main__:validate_retrieval:79 - Number of documents: 678
2024-02-27 20:37:47.467 | INFO     | rag:initialise_document_store:103 - Initialising document store
2024-02-27 20:37:47.472 | INFO     | rag:initialise_retriever:113 - Initialising retriever
Batches: 100%|██████████| 22/22 [00:00<00:00, 31.61it/s]ocs/s]
Documents Processed: 10000 docs [00:00, 13829.26 docs/s]       
2024-02-27 20:37:49.358 | INFO     | rag:initialise_generation_llm:126 - Initialising generation LLM
2024-02-27 20:37:49.359 | INFO     | rag:initialise_unit_conversion_llm:134 - Initialising unit conversion LLM
2024-02-27 20:37:49.360 | INFO     | rag:initialise_json_conversion_llm:142 - Initialising json conversion LLM
2024-02-27 20:37:49.360 | INFO     | rag:initialise_relevant_context_llm:150 - Initialising re

In [4]:
results_df

Unnamed: 0,ID,Metric,2021_Value,2021_Gen_Unvalidated,2021_Generated,Correct
0,3_X_Tongaat,Advisory fees as per income statement,,,,True
1,6_X_Tongaat,Air emissions of the following pollutants: (1) CO,,,,True
2,7_X_Tongaat,Air emissions of the following pollutants: (2)...,,,,True
3,8_X_Tongaat,Air emissions of the following pollutants: (3)...,,,,True
4,9_X_Tongaat,Air emissions of the following pollutants: (4)...,,,,True
5,10_X_Tongaat,Air emissions of the following pollutants: (5)...,,,,True
6,11_X_Tongaat,ALL Administration expenses per income statement,,,,True
8,13_X_Tongaat,"Amount of assets under management, by asset cl...",,,,True
9,14_X_Tongaat,"Amount of assets under management, by asset cl...",,,,True
10,15_X_Tongaat,"Amount of assets under management, by asset cl...",,,,True


In [8]:
results_df, accuracy = validate_retrieval("Absa", 2021, type="retrieval", num=50, window_size=2)

2024-02-27 20:23:21.498 | INFO     | load:load_documents:62 - Loading documents from /home/tomw/unifi-pdf-llm/data/azureconverter_outputs/2022-Absa-Group-limited-Environmental-Social-and-Governance-Data-sheet.json
2024-02-27 20:23:22.894 | DEBUG    | __main__:validate_retrieval:79 - Number of documents: 536
2024-02-27 20:23:22.895 | INFO     | rag:initialise_document_store:103 - Initialising document store
2024-02-27 20:23:22.899 | INFO     | rag:initialise_retriever:113 - Initialising retriever
Batches: 100%|██████████| 17/17 [00:00<00:00, 36.85it/s]ocs/s]
Documents Processed: 10000 docs [00:00, 20859.60 docs/s]       
2024-02-27 20:23:24.787 | INFO     | rag:initialise_generation_llm:126 - Initialising generation LLM
2024-02-27 20:23:24.787 | INFO     | rag:initialise_unit_conversion_llm:134 - Initialising unit conversion LLM
2024-02-27 20:23:24.788 | INFO     | rag:initialise_json_conversion_llm:142 - Initialising json conversion LLM
2024-02-27 20:23:24.788 | INFO     | rag:initiali

In [9]:
results_df

Unnamed: 0,ID,Metric,2021_Value,2021_Gen_Unvalidated,2021_Generated,Correct
27,46_X_Absa,Total procurement spend on qualifying small en...,4400000000.0,4.4,,False
30,49_X_Absa,B-BBEE level (South Africa),1.0,1.0,,False
33,52_X_Absa,Board meeting attendance (%),98.0,98.0,98.0,True
34,53_X_Absa,Average age 40-49 years,3.0,3.0,,False
35,54_X_Absa,Average age 50+,12.0,3.0,,False
59,109_X_Absa,Staff costs and benefits (Rbn),26133000000.0,26133.0,26133.0,False
71,122_X_Absa,Fatal-injury frequency rate (number of fatalit...,0.0,,,False
76,128_X_Absa,Scope 1,12276.0,,,False
77,129_X_Absa,Scope 2,158756.0,,,False
78,130_X_Absa,Scope 3,16205.0,16205.0,,False


In [10]:
results_df, accuracy = validate_retrieval("Absa", 2021, type="nan", num=50, window_size=2)

2024-02-27 20:29:06.762 | INFO     | load:load_documents:62 - Loading documents from /home/tomw/unifi-pdf-llm/data/azureconverter_outputs/2022-Absa-Group-limited-Environmental-Social-and-Governance-Data-sheet.json
2024-02-27 20:29:08.158 | DEBUG    | __main__:validate_retrieval:79 - Number of documents: 536
2024-02-27 20:29:08.158 | INFO     | rag:initialise_document_store:103 - Initialising document store
2024-02-27 20:29:08.163 | INFO     | rag:initialise_retriever:113 - Initialising retriever
Batches: 100%|██████████| 17/17 [00:00<00:00, 25.47it/s]ocs/s]
Documents Processed: 10000 docs [00:00, 14578.50 docs/s]       
2024-02-27 20:29:10.325 | INFO     | rag:initialise_generation_llm:126 - Initialising generation LLM
2024-02-27 20:29:10.326 | INFO     | rag:initialise_unit_conversion_llm:134 - Initialising unit conversion LLM
2024-02-27 20:29:10.326 | INFO     | rag:initialise_json_conversion_llm:142 - Initialising json conversion LLM
2024-02-27 20:29:10.327 | INFO     | rag:initiali

In [11]:
results_df

Unnamed: 0,ID,Metric,2021_Value,2021_Gen_Unvalidated,2021_Generated,Correct
0,3_X_Absa,Advisory fees as per income statement,,,,True
1,6_X_Absa,Air emissions of the following pollutants: (1) CO,,187237.0,,True
2,7_X_Absa,Air emissions of the following pollutants: (2)...,,,,True
3,8_X_Absa,Air emissions of the following pollutants: (3)...,,,,True
4,9_X_Absa,Air emissions of the following pollutants: (4)...,,,,True
5,10_X_Absa,Air emissions of the following pollutants: (5)...,,,,True
6,11_X_Absa,ALL Administration expenses per income statement,,26722.0,,True
7,12_X_Absa,All Inury Frequency Rate (Injuries/1m hrs worked),,,,True
8,13_X_Absa,"Amount of assets under management, by asset cl...",,3.35,,True
9,14_X_Absa,"Amount of assets under management, by asset cl...",,,,True
