In [1]:
# TODO: Package
import sys
sys.path.append('/home/tomw/unifi-pdf-llm/')

import pandas as pd
from loguru import logger

from load import load_documents
from preprocess import preprocess_documents
from rag import ModularRAG


TRAIN_CSV_PATH = "/home/tomw/unifi-pdf-llm/data/Train.csv"
"""Path to the Train.csv file."""

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


'Path to the Train.csv file.'

In [2]:
def validate_retrieval(
        company: str,
        year: int,
        type: str="retrieval",
        num: int=50,
        window_size: int=1,
        discard_text: bool=True
    ) -> tuple[pd.DataFrame, float]:
    """
    Returns a DataFrame with the results of the retrieval validation.

    Parameters
    ----------
    company : str
        The company to validate.

    year : int
        The year to validate.

    type : str
        The type of validation test to run. Options are "retrieval" or "nan".
        The "retrieval" test checks the retrieval of values that are present in the
        documents. The "nan" test checks the retrieval of values that are not present
        in the documents (i.e. testing the ability to return 'None' when the value is
        not present).

    window_size : int
        The size of the sliding window to use when slicing tables.

    discard_text : bool
        If True, discard text passages when preprocessing the documents. Only tables
        are kept.

    Returns
    -------
    results_df : pd.DataFrame
        The results of the retrieval validation.

    accuracy : float
        The accuracy of the retrieval validation.

    Raises
    ------
    ValueError
        If the year is not 2019, 2020, or 2021.
    """
    if year not in [2019, 2020, 2021]:
        raise ValueError(f"Unable to validate year: {year}")

    train_df = pd.read_csv(TRAIN_CSV_PATH)

    # Restrict to the company
    train_df = train_df[train_df["ID"].str.contains(f"X_{company}")]
    train_df.reset_index(drop=True, inplace=True)

    # Drop the two columns that we are not interested in
    all_years = ["2021", "2020", "2019"]
    all_years.remove(str(year))
    for _year in all_years:
        train_df.drop(columns=[f"{_year}_Value"], inplace=True)

    if type == "retrieval":
        train_df = train_df.dropna(subset=[f"{year}_Value"], how="all")
    elif type == "nan":
        train_df = train_df[train_df[f"{year}_Value"].isna()]
        # Keep a random sample rows
        # train_df = train_df.sample(n=50)
    else:
        raise ValueError(f"Invalid validation type: {type}")

    train_df = train_df.head(n=num)

    # Load and preprocess the documents
    docs = load_documents(company, year)
    docs = preprocess_documents(
        docs, window_size=window_size, discard_text=discard_text
    )

    logger.debug(f"Number of documents: {len(docs)}")

    query_pipeline = ModularRAG(
        docs=docs,
        company=company,
    )

    results_df = train_df.copy(deep=True)

    # Loop over the rows in the dataframe and retrieve the value for each AMKEY
    for idx, row in train_df.iterrows():
        amkey = int(row["ID"].split("_")[0])

        metric = query_pipeline.retrieve_metric_description(amkey)
        results_df.at[idx, "Metric"] = metric

        value, unvalidated_value = query_pipeline.query(amkey, year)
        results_df.at[idx, f"{year}_Generated"] = value
        results_df.at[idx, f"{year}_Gen_Unvalidated"] = unvalidated_value

    results_df[f"{year}_Value"] = results_df[f"{year}_Value"].astype(float)
    results_df[f"{year}_Generated"] = results_df[f"{year}_Generated"].astype(float)
    results_df["Correct"] = results_df.apply(
        lambda row: (row[f"{year}_Generated"] == row[f"{year}_Value"]) or
        (pd.isna(row[f"{year}_Generated"]) and pd.isna(row[f"{year}_Value"])) or
        (row[f"{year}_Generated"] == -1 and pd.isna(row[f"{year}_Value"])),
        axis=1
    )

    # Reordering the columns
    results_df = results_df[["ID", "Metric", f"{year}_Value", f"{year}_Gen_Unvalidated", f"{year}_Generated", "Correct"]]

    accuracy = results_df["Correct"].sum() / len(results_df)

    logger.info(f"Accuracy w/ validation: {accuracy}")

    unvalidated_accuracy = results_df.apply(
        lambda row: (row[f"{year}_Gen_Unvalidated"] == row[f"{year}_Value"]) or
        (pd.isna(row[f"{year}_Gen_Unvalidated"]) and pd.isna(row[f"{year}_Value"])),
        axis=1
    ).sum() / len(results_df)

    logger.info(f"Accuracy w/o validation: {unvalidated_accuracy}")

    return results_df, accuracy, unvalidated_accuracy


## Tongaat 

In [3]:
results_df, accuracy, unvalidated_accuracy = validate_retrieval("Tongaat", 2021, type="retrieval", num=10, window_size=2)

2024-02-27 21:39:58.827 | INFO     | load:load_documents:62 - Loading documents from /home/tomw/unifi-pdf-llm/data/azureconverter_outputs/2021ESG_removed_sup_table.json


2024-02-27 21:40:00.683 | DEBUG    | __main__:validate_retrieval:79 - Number of documents: 678
2024-02-27 21:40:00.684 | INFO     | rag:_initialise_document_store:116 - Initialising document store
2024-02-27 21:40:00.690 | INFO     | rag:_initialise_retriever:134 - Initialising retriever
Batches: 100%|██████████| 22/22 [00:01<00:00, 15.20it/s]ocs/s]
Documents Processed: 10000 docs [00:01, 6769.59 docs/s]        
2024-02-27 21:40:04.817 | INFO     | rag:initialise_unit_conversion_llm:143 - Initialising unit conversion LLM
2024-02-27 21:40:04.922 | INFO     | rag:_initialise_mappings:170 - Initialising mappings
2024-02-27 21:40:04.927 | DEBUG    | rag:query:195 - Retrieving AMKEY: 12
2024-02-27 21:40:04.927 | DEBUG    | rag:query:197 - Retrieving metric: Total injury frequency rate (TIFR) – employees and contractors
Batches: 100%|██████████| 1/1 [00:00<00:00, 210.53it/s]
2024-02-27 21:40:05.239 | DEBUG    | rag:retrieve_value:261 - Retrieval prompt:

Use the following markdown tables to 

In [4]:
results_df

Unnamed: 0,ID,Metric,2021_Value,2021_Gen_Unvalidated,2021_Generated,Correct
7,12_X_Tongaat,Total injury frequency rate (TIFR) – employees...,1.331,1.331,1.331,True
18,28_X_Tongaat,Total – company managed/farmed land (owned and...,60204.0,52883.0,52883.0,False
30,49_X_Tongaat,B-BBEE Level,4.0,4.0,4.0,True
33,52_X_Tongaat,Overall Board and Committee meeting attendance,99.0,99.0,99.0,True
64,114_X_Tongaat,Energy efficiency: total direct and indirect e...,16.63,16.63,,False
71,122_X_Tongaat,"Fatal injury frequency rate (FIFR, i.e. number...",0.005,0.005,0.005,True
76,128_X_Tongaat,Carbon emissions – Scope 1,505575.0,505575.0,505575.0,True
77,129_X_Tongaat,Carbon emissions – Scope 2,51539.0,51539.0,51539.0,True
85,138_X_Tongaat,Hazardous waste disposed of at appropriate fac...,184.0,184.0,184.0,True
94,151_X_Tongaat,"Lost time injury frequency rate (LTIFR, i.e. n...",0.093,0.093,0.093,True


In [5]:
results_df, accuracy, unvalidated_accuracy = validate_retrieval("Tongaat", 2021, type="nan", num=20, window_size=2)

2024-02-27 21:43:57.758 | INFO     | load:load_documents:62 - Loading documents from /home/tomw/unifi-pdf-llm/data/azureconverter_outputs/2021ESG_removed_sup_table.json
2024-02-27 21:43:59.722 | DEBUG    | __main__:validate_retrieval:79 - Number of documents: 678
2024-02-27 21:43:59.722 | INFO     | rag:_initialise_document_store:116 - Initialising document store
2024-02-27 21:43:59.728 | INFO     | rag:_initialise_retriever:134 - Initialising retriever
Batches: 100%|██████████| 22/22 [00:00<00:00, 23.34it/s]ocs/s]
Documents Processed: 10000 docs [00:00, 10281.25 docs/s]       
2024-02-27 21:44:02.029 | INFO     | rag:initialise_unit_conversion_llm:143 - Initialising unit conversion LLM
2024-02-27 21:44:02.030 | INFO     | rag:_initialise_mappings:170 - Initialising mappings
2024-02-27 21:44:02.034 | DEBUG    | rag:query:195 - Retrieving AMKEY: 3
2024-02-27 21:44:02.035 | DEBUG    | rag:query:197 - Retrieving metric: Advisory fees as per income statement
Batches: 100%|██████████| 1/1 [

In [6]:
results_df

Unnamed: 0,ID,Metric,2021_Value,2021_Gen_Unvalidated,2021_Generated,Correct
0,3_X_Tongaat,Advisory fees as per income statement,,,,True
1,6_X_Tongaat,Air emissions of the following pollutants: (1) CO,,,,True
2,7_X_Tongaat,Air emissions of the following pollutants: (2)...,,,,True
3,8_X_Tongaat,Air emissions of the following pollutants: (3)...,,,,True
4,9_X_Tongaat,Air emissions of the following pollutants: (4)...,,,,True
5,10_X_Tongaat,Air emissions of the following pollutants: (5)...,,,,True
6,11_X_Tongaat,ALL Administration expenses per income statement,,,,True
8,13_X_Tongaat,"Amount of assets under management, by asset cl...",,,,True
9,14_X_Tongaat,"Amount of assets under management, by asset cl...",,,,True
10,15_X_Tongaat,"Amount of assets under management, by asset cl...",,,,True


## ABSA

In [3]:
results_df, accuracy, unvalidated_accuracy = validate_retrieval("Absa", 2021, type="retrieval", num=50, window_size=2)

2024-02-29 21:52:56.089 | INFO     | load:load_documents:62 - Loading documents from /home/tomw/unifi-pdf-llm/data/azureconverter_outputs/2022-Absa-Group-limited-Environmental-Social-and-Governance-Data-sheet.json
2024-02-29 21:52:57.298 | DEBUG    | __main__:validate_retrieval:79 - Number of documents: 536
2024-02-29 21:52:57.299 | INFO     | rag:_initialise_document_store:131 - Initialising document store
2024-02-29 21:52:57.303 | INFO     | rag:_initialise_retriever:149 - Initialising retriever
Batches: 100%|██████████| 17/17 [00:01<00:00, 12.94it/s]ocs/s]
Documents Processed: 10000 docs [00:01, 7494.32 docs/s]        
2024-02-29 21:53:01.682 | INFO     | rag:_initialise_mappings:177 - Initialising mappings
2024-02-29 21:53:01.687 | DEBUG    | rag:query:202 - Retrieving AMKEY: 46
2024-02-29 21:53:01.688 | DEBUG    | rag:query:204 - Retrieving metric: Total procurement spend on qualifying small enterprises and exempt micro enterprises(Rbn)
Batches: 100%|██████████| 1/1 [00:00<00:00, 

ValueError: could not convert string to float: '18,900,000,000'

In [4]:
results_df

Unnamed: 0,ID,Metric,2021_Value,2021_Gen_Unvalidated,2021_Generated,Correct
27,46_X_Absa,Total procurement spend on qualifying small en...,4400000000.0,4.4,4400000000.0,True
30,49_X_Absa,B-BBEE level (South Africa),1.0,1.0,,False
33,52_X_Absa,Board meeting attendance (%),98.0,98.0,98.0,True
34,53_X_Absa,Average age 40-49 years,3.0,3.0,,False
35,54_X_Absa,Average age 50+,12.0,61.0,,False
59,109_X_Absa,Staff costs and benefits (Rbn),26133000000.0,26133.0,26133.0,False
71,122_X_Absa,Fatal-injury frequency rate (number of fatalit...,0.0,0.0,,False
76,128_X_Absa,Scope 1,12276.0,,,False
77,129_X_Absa,Scope 2,158756.0,12.24,,False
78,130_X_Absa,Scope 3,16205.0,16205.0,,False


In [3]:
COMPANY = "Absa"
YEAR = 2021

docs = load_documents(COMPANY, YEAR)
docs = preprocess_documents(
    docs, window_size=2, discard_text=True
)

2024-02-29 22:33:18.213 | INFO     | load:load_documents:62 - Loading documents from /home/tomw/unifi-pdf-llm/data/azureconverter_outputs/2022-Absa-Group-limited-Environmental-Social-and-Governance-Data-sheet.json


In [4]:
AMKEY = 575

query_pipeline = ModularRAG(
    docs=docs,
    company=COMPANY,
)

validated_value, unvalidated_value = query_pipeline.query(AMKEY, YEAR)

print(f'Retrieved values: {validated_value}, {unvalidated_value}')

2024-02-29 22:33:19.463 | INFO     | rag:_initialise_document_store:131 - Initialising document store
2024-02-29 22:33:19.468 | INFO     | rag:_initialise_retriever:149 - Initialising retriever
Batches: 100%|██████████| 17/17 [00:01<00:00, 13.23it/s]ocs/s]
Documents Processed: 10000 docs [00:01, 7661.33 docs/s]        
2024-02-29 22:33:23.728 | INFO     | rag:_initialise_mappings:177 - Initialising mappings
2024-02-29 22:33:23.732 | DEBUG    | rag:query:202 - Retrieving AMKEY: 575
2024-02-29 22:33:23.734 | DEBUG    | rag:query:204 - Retrieving metric: Total procurement spend in South Africa (Rbn)
Batches: 100%|██████████| 1/1 [00:00<00:00, 117.72it/s]
2024-02-29 22:33:24.041 | DEBUG    | rag:retrieve_value:266 - Retrieval prompt:

Use the following markdown tables to as context to answer the question at the end.
The answer must be a value retrieved directly from the context. Please don't do any unit conversion.

It is possible that the answer is not explicitly stated in the context.
If

Retrieved values: 18900000000.0, 18.9


In [5]:
results_df, accuracy, validated_accuracy = validate_retrieval("Absa", 2021, type="nan", num=50, window_size=2)

2024-02-29 21:46:38.467 | INFO     | load:load_documents:62 - Loading documents from /home/tomw/unifi-pdf-llm/data/azureconverter_outputs/2022-Absa-Group-limited-Environmental-Social-and-Governance-Data-sheet.json
2024-02-29 21:46:39.889 | DEBUG    | __main__:validate_retrieval:79 - Number of documents: 536
2024-02-29 21:46:39.890 | INFO     | rag:_initialise_document_store:131 - Initialising document store
2024-02-29 21:46:39.894 | INFO     | rag:_initialise_retriever:149 - Initialising retriever
Batches: 100%|██████████| 17/17 [00:00<00:00, 26.88it/s]ocs/s]
Documents Processed: 10000 docs [00:00, 15330.11 docs/s]       
2024-02-29 21:46:42.661 | INFO     | rag:_initialise_mappings:177 - Initialising mappings
2024-02-29 21:46:42.666 | DEBUG    | rag:query:202 - Retrieving AMKEY: 3
2024-02-29 21:46:42.667 | DEBUG    | rag:query:204 - Retrieving metric: Advisory fees as per income statement
Batches: 100%|██████████| 1/1 [00:00<00:00, 202.57it/s]
2024-02-29 21:46:42.693 | DEBUG    | rag:

In [6]:
results_df

Unnamed: 0,ID,Metric,2021_Value,2021_Gen_Unvalidated,2021_Generated,Correct
0,3_X_Absa,Advisory fees as per income statement,,,,True
1,6_X_Absa,Air emissions of the following pollutants: (1) CO,,187237.0,,True
2,7_X_Absa,Air emissions of the following pollutants: (2)...,,,,True
3,8_X_Absa,Air emissions of the following pollutants: (3)...,,,,True
4,9_X_Absa,Air emissions of the following pollutants: (4)...,,,,True
5,10_X_Absa,Air emissions of the following pollutants: (5)...,,,,True
6,11_X_Absa,ALL Administration expenses per income statement,,7407.0,,True
7,12_X_Absa,All Inury Frequency Rate (Injuries/1m hrs worked),,0.0,,True
8,13_X_Absa,"Amount of assets under management, by asset cl...",,3.35,,True
9,14_X_Absa,"Amount of assets under management, by asset cl...",,3.75,,True
