Training Model

In [None]:
import os 
import datetime
from azure.storage.blob import (
    BlobServiceClient,
    ContainerClient,
    BlobSasPermissions,
    generate_container_sas
)
from azure.core.exceptions import ResourceNotFoundError
from azure.ai.formrecognizer import FormRecognizerClient
from azure.ai.formrecognizer import FormTrainingClient
from azure.core.credentials import AzureKeyCredential


def create_service_sas_container(container_client: ContainerClient, account_key: str):
    # Create a SAS token that's valid for one day, as an example
    start_time = datetime.datetime.now(datetime.timezone.utc)
    expiry_time = start_time + datetime.timedelta(days=1)

    sas_token = generate_container_sas(
        account_name=container_client.account_name,
        container_name=container_client.container_name,
        account_key=account_key,
        permission=BlobSasPermissions(read=True, list=True),
        expiry=expiry_time,
        start=start_time
    )
    return sas_token


def use_service_sas_container(blob_service_client: BlobServiceClient):
    container_client = blob_service_client.get_container_client(container="speechtraining")
    # Assumes the service client object was created with a shared access key
    sas_token = create_service_sas_container(container_client=container_client, account_key=blob_service_client.credential.account_key)

    # <Snippet_use_service_sas_container>
    # The SAS token string can be appended to the resource URL with a ? delimiter
    # or passed as the credential argument to the client constructor
    sas_url = f"{container_client.url}?{sas_token}"
    # Create a ContainerClient object with SAS authorization
    container_client_sas = ContainerClient.from_container_url(container_url=sas_url)
    return sas_url


def main(): 
    try: 
        # Get configuration settings 
        # load_dotenv()
        storage_account_name="rstorage2speech"
        storage_account_key=os.getenv('AZURE_STORAGE_KEY')

        form_endpoint = os.getenv('AZURE_FORM_RECOGNIZER_ENDPOINT')
        form_key = os.getenv('AZURE_FORM_RECOGNIZER_KEY')

        account_url = f"https://{storage_account_name}.blob.core.windows.net"
        blob_service_client_account_key = BlobServiceClient(account_url, credential=storage_account_key)

        trainingDataUrl = use_service_sas_container(blob_service_client=blob_service_client_account_key)
        print(f"trainingDataUrl : {trainingDataUrl}")

        # Authenticate Form Training Client
        # form_recognizer_client = FormRecognizerClient(form_endpoint, AzureKeyCredential(form_key))
        form_training_client = FormTrainingClient(form_endpoint, AzureKeyCredential(form_key))

        # Train model 
        poller = form_training_client.begin_training(trainingDataUrl, use_training_labels=True)
        model = poller.result()

        print("Model ID: {}".format(model.model_id))
        print("Status: {}".format(model.status))
        print("Training started on: {}".format(model.training_started_on))
        print("Training completed on: {}".format(model.training_completed_on))

    except Exception as ex:
        print(ex)

if __name__ == '__main__': 
    main()

Test Model

In [6]:
import os 
# from dotenv import load_dotenv

from azure.core.credentials import AzureKeyCredential
from azure.ai.formrecognizer import DocumentAnalysisClient

def main(): 
        
    try: 
    
        # Get configuration settings 
        # load_dotenv()
        form_endpoint = os.getenv('AZURE_FORM_RECOGNIZER_ENDPOINT')
        form_key = os.getenv('AZURE_FORM_RECOGNIZER_KEY')
        
        # Create client using endpoint and key
        document_analysis_client = DocumentAnalysisClient(
            endpoint=form_endpoint, credential=AzureKeyCredential(form_key)
        )

        # Model ID from when you trained your model.
        model_id = os.getenv('AZURE_FORM_RECOGNIZER_MODEL_ID')

        # Test trained model with a new form 
        ## from azure storage
        # file_sasurl = "https://rstorage2speech.blob.core.windows.net/speechtest/2310.11511.pdf?sp=r&st=2023-11-08T00:25:47Z&se=2023-11-08T08:25:47Z&spr=https&sv=2022-11-02&sr=b&sig=0gIGt042dSVHPbezQZOj4D%2FD9XGS8qt5s5u3vk4DkOc%3D"
        # poller = document_analysis_client.begin_analyze_document_from_url(model_id=model_id, document_url=file_sasurl)

        ## from local file
        file_path = "../../data/pdf/test/2308.00479.pdf"
        with open(file_path, "rb") as f: 
            poller = document_analysis_client.begin_analyze_document(model_id=model_id, document=f)

        result = poller.result()
        
        for idx, document in enumerate(result.documents):
            print("--------Analyzing document #{}--------".format(idx + 1))
            print("Document has type {}".format(document.doc_type))
            print("Document has confidence {}".format(document.confidence))
            print("Document was analyzed by model with ID {}".format(result.model_id))
            for name, field in document.fields.items():
                field_value = field.value if field.value else field.content
                print("......found field of type '{}' with value '{}' and with confidence {}".format(field.value_type, field_value, field.confidence))


        # iterate over tables, lines, and selection marks on each page
        print(f"Total number of pages: {len(result.pages)}")
        # for page in result.pages:
        #     print("\nLines found on page {}".format(page.page_number))
        #     for line in page.lines:
        #         print("...Line '{}'".format(line.content.encode('utf-8')))
        #     for word in page.words:
        #         print(
        #             "...Word '{}' has a confidence of {}".format(
        #                 word.content.encode('utf-8'), word.confidence
        #             )
        #         )
        #     for selection_mark in page.selection_marks:
        #         print(
        #             "...Selection mark is '{}' and has a confidence of {}".format(
        #                 selection_mark.state, selection_mark.confidence
        #             )
        #         )

        # for i, table in enumerate(result.tables):
        #     print("\nTable {} can be found on page:".format(i + 1))
        #     for region in table.bounding_regions:
        #         print("...{}".format(i + 1, region.page_number))
        #     for cell in table.cells:
        #         print(
        #             "...Cell[{}][{}] has content '{}'".format(
        #                 cell.row_index, cell.column_index, cell.content.encode('utf-8')
        #             )
        #         )
        print("-----------------------------------")
        
    except Exception as ex:
        print(ex)

if __name__ == '__main__': 
    main()

--------Analyzing document #1--------
Document has type PDFExtract1106beta
Document has confidence 0.964
Document was analyzed by model with ID PDFExtract1106beta
......found field of type 'string' with value 'SELF-RAG: LEARNING TO RETRIEVE, GENERATE, AND CRITIQUE THROUGH SELF-REFLECTION Akari Asait, Zeqiu Wut, Yizhong Wangts, Avirup Sil+, Hannaneh Hajishirzits +University of Washington $Allen Institute for AI #IBM Research AI {akari, zeqiuwu, yizhongw, hannaneh}@cs.washington.edu, avi@us.ibm.com' and with confidence 0.151
......found field of type 'string' with value 'Preprint.' and with confidence None
......found field of type 'string' with value 'ABSTRACT Despite their remarkable capabilities, large language models (LLMs) often produce responses containing factual inaccuracies due to their sole reliance on the paramet- ric knowledge they encapsulate. Retrieval-Augmented Generation (RAG), an ad hoc approach that augments LMs with retrieval of relevant knowledge, decreases such issue