### Import libraries

In [0]:
dbutils.library.restartPython()

In [0]:
from rade import RADE, RetrievedPage, DocumentPage
from utils.azure_doc_intel import parse_azureDocIntell
import time
import torch
from tqdm import tqdm
from typing import List, Tuple, Dict, Any
import os
import pandas as pd


### Initialize RADE

In [0]:
# clear cache
import torch
torch.cuda.empty_cache()
# Instantiate RADE
rade = RADE(use_flash_attention=False, 
            max_pages=3)

### Add and a index document

In [0]:
#indexing method
def add_index_document(pdf_file: str, model: RADE = rade, doc_id: int = 1):
    """
    Add and index a document in RADE

    Args:
        pdf_file (str): Path to the PDF document.
    """
    
    model.add_document(pdf_file, doc_id)
    model.build_index()



### Search on the index

In [0]:
def search_documents(queries: List[str], model: RADE = rade) -> List[Dict]:
    """
    Search indexed documents using RAG.

    Args:
        query List[str]: Query strings.
        model (RADE): RADE instance.

    Returns:
        List[Dict]: Retrieved texts with metadata.
    """
    search_result = rade.retrieve(queries)
    assert len(search_result) == len(queries)
    return search_result

### Canned queries

In [0]:
# Define entity queries and corresponding labels
entity_queries = [
    "GRANTOR: what are the Grantors names? (Grantor)",
    "TRUSTEE: what are the Trustees names? (Trustee)",
    "Successor TRUSTEE: what are the Successor Trustee names (SUCCESSOR TRUSTEE)?",
    "BENEFCIARIES: what are the beneficiaries names (BENEFICIARIES)?",
    "SUCCESSOR BENEFCIARIES: what are the Successor beneficiaries names (SUCCESSOR Beneficiary)? "
]
labels_list = [
    ["Grantor names"],
    ["Trustee Names"],
    ["Successor Trustee Names"],
    ["Beneficiary names"],
    ["Successor Beneficiary Names"],
    ["Trust Name"],
    ["Trust Date"],
    ["Revocable", "Irevocable"],
    
]
qa_queries = [
    "What is the name of this trust?",
    "What is the date of this trust?",
    "Is this trust revocable or irrevocable?",
]
all_queries_dict = {
    "entity_queries": entity_queries,
    "qa_queries": qa_queries
}
all_queries = all_queries_dict["entity_queries"] + all_queries_dict["qa_queries"]


### Processing pipeline

In [0]:
def process_document(pdf_files: [str], 
                     all_queries: List[str],
                     model: RADE = rade) -> pd.DataFrame:
    """
    Process a document by indexing, retrieving, and extracting entities.

    Args:
        pdf_file (str): Path to the PDF document.
        qa_questions (List[str]): List of questions for RoBERTa QA.
        gliner_labels (List[str]): List of target labels for GLiNER.

    Returns:
        pd.DataFrame: DataFrame containing results, including time taken to process.
    """
    start_time = time.time()
    search_result = search_documents(all_queries)
    # Store results and parsed page contexts
    qa_results = []
    contexts_map = {}  # Dictionary to cache parsed pages

    for query_idx in tqdm(range(len(all_queries)), desc="Processing QA Queries"):
        query = all_queries[query_idx]
        pages = search_result[query_idx]

        all_pages = []
        retrieved_page_nums = []
        for page in pages:
            page_num = page.page.page_num
            retrieved_page_nums.append(str(page_num))

            # Check if page has been parsed before
            if page_num in contexts_map:
                parsed = contexts_map[page_num]
                print(f"Using cached parsed result for page: {page_num}")
            else:
                print(f"Parsing new page: {page_num}")
                parsed = parse_azureDocIntell(page.page.image)
                contexts_map[page_num] = parsed

            all_pages.append(parsed)

        # Combine parsed pages into a single string
        combined_pages = " ".join(all_pages) if all_pages else ""
        page_nums = " ".join(retrieved_page_nums) if retrieved_page_nums else ""

        # print(f"retrived pages {page_nums} ")
        # Run QA model to extract answer
        qa_answer = rade.run_qa_pipeline(query, combined_pages)

        # Store result
        qa_results.append({
            "query": query,
            "retrieved_pages": page_nums,
            "context": combined_pages,
            "RoBerta Answer": qa_answer
        })

    entity_results = []
    for query_idx in tqdm(range(len(all_queries)), desc="Processing Entity Queries"):
        query = all_queries[query_idx]
        labels = labels_list[query_idx]
        pages = search_result[query_idx]

        all_pages = []
        retrieved_page_nums = []
        for page in pages:
            page_num = page.page.page_num
            retrieved_page_nums.append(str(page_num))
            # Check if page has been parsed before
            if page_num in contexts_map:
                parsed = contexts_map[page_num]
                print(f"Using cached parsed result for page: {page_num}")
            else:
                print(f"Parsing new page: {page_num}")
                parsed = parse_azureDocIntell(page.page.image)
                contexts_map[page_num] = parsed

            all_pages.append(parsed)

        # Combine parsed pages into a single string
        combined_pages = " ".join(all_pages) if all_pages else ""
        
        page_nums = " ".join(retrieved_page_nums) if retrieved_page_nums else ""


        # Extract entities from parsed text if only single gpu
        entities = rade.extract_entities_with_gliner(combined_pages, labels)

        # Store result
        entity_results.append({
            "query": query,
            "retrieved_pages": page_nums,
            "context": combined_pages,
            "GLiNER Answer": entities
        })

    # Convert lists of dictionaries to pandas DataFrames
    df_entities = pd.DataFrame(entity_results)
    df_qa = pd.DataFrame(qa_results)

    # Perform the merge and prioritize 'context' from entity_results
    combined_df = pd.merge(
        df_entities.drop(columns='context', errors='ignore'),  # Remove 'context' to avoid conflict
        df_qa,
        on='query',
        suffixes=('_gliner', '_qa'),
        how='outer'
    )
    return combined_df


### Main method

In [0]:
def main():
    """
    Main function to process multiple documents and save results.
    """
    # Directory containing PDF files
    pdf_dir = "/Workspace/Shared/pdf-etl-ocr-inference/SampleTrustDocs/pdf" 
    pdf_files = [os.path.join(pdf_dir, f) for f in os.listdir(pdf_dir) if f.endswith(".pdf")]

    if not pdf_files:
        print(f"No PDF files found in directory: {pdf_dir}")
        
    pdf_path = pdf_files[2]#test a single file

    #combine queries
    all_queries = all_queries_dict["entity_queries"] + all_queries_dict["qa_queries"]
    #add and index the document
    add_index_document(pdf_path)
    df = process_document(pdf_path, all_queries)
    output_dir = "output"
    os.makedirs(output_dir, exist_ok=True)  # Ensure the output directory exists

    # Generate the filename based on the PDF file basename
    output_filename = os.path.splitext(os.path.basename(pdf_path))[0] + "_results.csv"

    # Save the combined DataFrame to the output directory
    output_file_path = os.path.join(output_dir, output_filename)
    df.to_csv(output_file_path, index=False)

    print(f"Combined DataFrame saved to: {output_file_path}")
    


In [0]:
import time

start_time = time.perf_counter()  # Start timing
main()  # Run main function
end_time = time.perf_counter()  # End timing

print(f"Processing completed.\nExecution Time: {end_time - start_time:.4f} seconds")