In [None]:
%autosave 300
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
%config Completer.use_jedi = False

In [None]:
import os

os.chdir(
    "/mnt/batch/tasks/shared/LS_root/mounts/clusters/copilot-model-run/code/Users/Soutrik.Chowdhury/unstructured_data_experiments"
)
print(os.getcwd())

In [None]:
import logging
import os
import re
import urllib.parse

import numpy as np
from joblib import delayed, Parallel, parallel_backend

from glob import glob

from functools import partial
from tenacity import retry, stop_after_attempt
from typing import Any, Dict, List, Union
import asyncio
import nest_asyncio

nest_asyncio.apply()  # Fixing asyncio bug with Jupyter Notebook

from redisvl.index import SearchIndex
from redis import Redis
from urllib.parse import quote
from redisvl.query import VectorQuery
from redisvl.query.filter import Tag
from dotenv import find_dotenv, load_dotenv
from langchain_openai import AzureOpenAIEmbeddings
from redisvl.index import AsyncSearchIndex
from redis.asyncio import Redis

from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.document_loaders import PyPDFium2Loader
from langchain_text_splitters import CharacterTextSplitter
import pickle
from redisvl.utils.rerank import HFCrossEncoderReranker
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
from langchain_openai import AzureChatOpenAI
from langchain.schema import Document

In [None]:
logger = logging.getLogger(__name__)

In [None]:
env_file = "dev.env"

load_dotenv(find_dotenv(env_file))

In [None]:
class OpenAIEmbeddingFunctions:
    """
    Class to get the OpenAIEmbeddings for embedding documents.
    Attributes:
        api_key (str): The API key of the OpenAI model.
        api_base (str): The API base of the OpenAI model.
        api_type (str): The API type of the OpenAI model.
        api_version (str): The API version of the OpenAI model.

    """

    def __init__(
        self,
        api_key: str = os.environ.get("AZURE_OPENAI_API_KEY"),
        api_base: str = os.environ.get("AZURE_OPENAI_ENDPOINT"),
        api_type: str = os.environ.get("OPENAI_API_TYPE"),
        api_version: str = os.environ.get("OPENAI_API_VERSION"),
        model_name: str = os.environ.get("EMBEDDING_ENGINE_ADA_DEPLOYMENT_NAME"),
        model_deployment_name: str = os.environ.get("EMBEDDING_ENGINE_ADA_MODEL_NAME"),
    ) -> None:
        self.api_key = api_key
        self.api_base = api_base
        self.api_type = api_type
        self.api_version = api_version
        self.model_name = model_name
        self.model_deployment_name = model_deployment_name

    def get_openai_embedder(self):
        """
        Get an instance of OpenAIEmbeddings for embedding documents.

        Args:
            openai_pkg: The OpenAI package.
            model (str, optional): The model name. Defaults to None.
            deployment (str, optional): The deployment name. Defaults to None.

        Returns:
            An instance of OpenAIEmbeddings for embedding documents.
        """

        return AzureOpenAIEmbeddings(
            model=self.model_name,
            azure_deployment=self.model_deployment_name,
            api_key=self.api_key,
            azure_endpoint=self.api_base,
        )

In [None]:
def get_gpt_model(
    azure_deployment,
    model_name,
    api_key,
    azure_endpoint,
    openai_api_type,
    api_version,
    temperature,
    request_timeout,
    max_retries,
    seed,
    top_p,
):
    """
    Returns an instance of the AzureChatOpenAI class.

    Args:
    - azure_deployment (str): Azure deployment name.
    - model_name (str): Name of the model.
    - api_key (str): API key.
    - azure_endpoint (str): Azure endpoint.
    - openai_api_type (str): OpenAI API type.
    - api_version (str): API version.
    - temperature (float): Temperature for sampling.
    - request_timeout (int): Request timeout.
    - max_retries (int): Maximum number of retries.
    - seed (int): Seed for random number generator.
    - top_p (float): Top-p sampling.

    Returns:
    - AzureChatOpenAI: Instance of the AzureChatOpenAI class.
    """

    llm_model = llm = AzureChatOpenAI(
        azure_deployment=azure_deployment,
        model_name=model_name,
        api_key=api_key,
        azure_endpoint=azure_endpoint,
        openai_api_type=openai_api_type,
        api_version=api_version,
        temperature=temperature,
        request_timeout=request_timeout,
        max_retries=max_retries,
        seed=seed,
        top_p=top_p,
    )

    return llm_model

In [None]:
# Read documents from Documents folder
class DynamicDocumentSplitter:

    def __init__(
        self,
        doc_folder_path: str,
        split_type: str,
        min_word_count: int,
        overlap_fraction: float,
        max_word_count: int,
        documents: list = None,
    ):

        self.doc_folder_path = doc_folder_path
        self.split_type = split_type
        self.min_word_count = min_word_count
        self.overlap_fraction = overlap_fraction
        self.max_word_count = max_word_count
        self.documents = documents

    def get_document(self):
        """Get document from the file."""
        if self.documents:
            return self.documents
        docs = pickle.load(open(self.doc_folder_path, "rb"))
        docs = [doc for doc in docs if isinstance(doc, Document)]
        return docs

    def get_dynamic_chunk_details(self, doc):
        """Determine dynamic chunk size and overlap for document splitting."""
        print("Calculating chunk details.")

        def count_func(x):
            return len(re.findall(r"\w+", x))

        dynamic_chunk_sz = int(
            np.mean([count_func(doc_element.page_content) for doc_element in doc])
        )

        if dynamic_chunk_sz < self.min_word_count:
            return {
                "chunk_size": self.min_word_count,
                "chunk_overlap": int(self.min_word_count * self.overlap_fraction),
            }
        elif dynamic_chunk_sz > self.max_word_count:
            return {
                "chunk_size": self.max_word_count,
                "chunk_overlap": int(self.max_word_count * self.overlap_fraction),
            }
        else:
            return {
                "chunk_size": dynamic_chunk_sz,
                "chunk_overlap": int(dynamic_chunk_sz * self.overlap_fraction),
            }

    def create_chunks_docs(self, docs, chunk_details):
        """Create document chunks based on the specified splitting method."""
        print("Creating document chunks.")

        text_splitter = (
            RecursiveCharacterTextSplitter
            if self.split_type == "recursive"
            else CharacterTextSplitter if self.split_type == "character" else None
        )
        if text_splitter is None:
            raise ValueError("Invalid split type.")

        docs = text_splitter(
            chunk_size=chunk_details["chunk_size"],
            chunk_overlap=chunk_details["chunk_overlap"],
            add_start_index=True,
        ).split_documents(docs)

        return docs

    def get_chunked_docs(
        self,
    ):
        """Get chunks for a given file."""
        org_docs = self.get_document()
        if self.split_type == "pages":
            return org_docs

        chunk_details = self.get_dynamic_chunk_details(org_docs)
        chunk_docs = self.create_chunks_docs(org_docs, chunk_details)
        return chunk_docs

In [None]:
doc_folder_path = "extracted_docs/Levva__Original__CW194187 -  Levva MSA - Fully Executed 4_images/extracted_docs_all_parsed.pkl"
split_type = "pages"
min_word_count = 1000
overlap_fraction = 0.1
max_word_count = 5000

In [None]:
doc_splitter = DynamicDocumentSplitter(
    doc_folder_path, split_type, min_word_count, overlap_fraction, max_word_count
)
chunks_docs = doc_splitter.get_chunked_docs()

In [None]:
print(len(chunks_docs))

In [None]:
# Args for the Vector DB Model
added_metadata = ["entity_name", "file_version"]
static_metadata = [
    "source",
    "page_number",
    "file_name",
    "token_size",
    "timestamp",
]
open_ai_embedder = OpenAIEmbeddingFunctions().get_openai_embedder()
vector_field_name = "embeddings"
embedding_field_name = "page_content"
embedding_dimension = 1536
distance_metric = "cosine"
vector_algo = "hnsw"
drop_index = True
llm_re_ranking = False
redis_host = os.environ.get("REDIS_HOST")
redis_port = os.environ.get("REDIS_PORT")
redis_database = os.environ.get("REDIS_DATABASE")
redis_password = os.environ.get("REDIS_PASSWORD")
index_name = os.environ.get("REDIS_INDEX")

In [None]:
index_name

In [None]:
def document_to_hash(
    all_docs: list, added_metadata: list, static_metadata: list
) -> list:
    """Convert the documents to a hash for indexing as expected by the Vector DB model"""
    hash_docs = [
        {
            "page_content": doc.page_content,
        }
        | {f"{meta}": str(doc.metadata[f"{meta}"]).lower() for meta in static_metadata}
        | {f"{meta}": doc.metadata[f"{meta}"].lower() for meta in added_metadata}
        for doc in all_docs
    ]
    return hash_docs

#### Embedding the documents

In [None]:
class DocumentEmbedder:
    def __init__(self, open_ai_embedder, max_concurrent_tasks: int = 10):
        self.open_ai_embedder = open_ai_embedder
        self.max_concurrent_tasks = max_concurrent_tasks

    async def content_embedder(
        self,
        content: str,
        op_type: str = "bytes",
    ) -> np.ndarray:
        """
        Embed the content using the OpenAI embeddings to create embeddings for one unit.

        Args:
            content (str): The content to embed.
            op_type (str): The operation type, either 'bytes' or the original array.

        Returns:
            np.ndarray: The embedding vector.
        """
        embedding_vector = await self.open_ai_embedder.aembed_documents([content])
        embd_array = np.array(embedding_vector[0]).astype(np.float32)

        if op_type == "bytes":
            return embd_array.tobytes()

        return embd_array

    @retry(stop=stop_after_attempt(5))
    async def atomic_embedder(
        self,
        content_dict: dict,
        vector_field_name: str,
        embedding_field_name: str,
    ) -> dict:
        """
        Embed the content using the OpenAI embeddings to create embeddings for one unit.

        Args:
            content_dict (dict): The content dictionary.
            vector_field_name (str): The name of the vector field.
            embedding_field_name (str): The name of the embedding field in the document.

        Returns:
            dict: The updated content dictionary with the embedding.
        """
        if not isinstance(content_dict, dict):
            logger.error(f"Content is not a dictionary, but {type(content_dict)}")
            raise TypeError(f"Content is not a dictionary, but {type(content_dict)}")

        item_content = content_dict[embedding_field_name]
        embedding_vector = await self.content_embedder(item_content)
        content_dict[vector_field_name] = embedding_vector
        return content_dict

    async def full_embedder(
        self,
        content_dict: list,
        vector_field_name: str,
        embedding_field_name: str,
    ) -> list:
        """
        Atomic loading of each content dictionary with the embedding field.
        This method is supposed to change as per different use cases.

        Args:
            content_dict (list): The list of content dictionaries.
            vector_field_name (str): The name of the vector field.
            embedding_field_name (str): The name of the embedding field in the document.

        Returns:
            list: A list of content dictionaries with the embedding field updated.
        """
        if isinstance(content_dict, dict) or isinstance(content_dict, list):
            semaphore = asyncio.Semaphore(self.max_concurrent_tasks)

            async def sem_task(content):
                async with semaphore:
                    return await self.atomic_embedder(
                        content, vector_field_name, embedding_field_name
                    )

            tasks = [sem_task(content) for content in content_dict]
            results = await asyncio.gather(*tasks, return_exceptions=True)
        else:
            logger.error(
                f"Content is not a dictionary or list, but {type(content_dict)}"
            )
            raise TypeError(
                f"Content is not a dictionary or list, but {type(content_dict)}"
            )

        return results

In [None]:
# The object to embed the documents
hash_docs = document_to_hash(chunks_docs, added_metadata, static_metadata)

In [None]:
hash_docs[0]

In [None]:
doc_embedder = DocumentEmbedder(open_ai_embedder, max_concurrent_tasks=10)

In [None]:
embedded_hash_docs = await doc_embedder.full_embedder(
    hash_docs,
    vector_field_name,
    embedding_field_name,
)

In [None]:
print(len(embedded_hash_docs))

In [None]:
embedded_hash_docs[10]

In [None]:
from typing import Union, List, Dict, Optional
from urllib.parse import quote
import uuid

In [None]:
redis_url_params = {
    "redis_host": redis_host,
    "redis_port": redis_port,
    "redis_database": redis_database,
    "redis_password": redis_password,
}

In [None]:
redis_url_params

In [None]:
def get_redis_url(
    redis_host: str,
    redis_port: int,
    redis_database: str,
    redis_password: Union[str, None],
) -> str:
    """
    Create a redis uri from given args
    redis://[username:user_pwd@]name_of_host [:port_number_of_redis_server] [/DB_Name]
    redis://[[username]:[password]]@localhost:6379/0
    """
    if not redis_password:
        redis_url = f"redis://{redis_host}:{redis_port}/{redis_database}"
        # print(f"Redis url is {redis_url}")

    else:
        redis_url = f"redis://default:{quote(redis_password)}@{redis_host}:{redis_port}/{redis_database}"
        # print(f"Redis url is {redis_url}")

    return redis_url

In [None]:
class RedisVectorStoreRetriever(VectorStoreRetriever):
    """Retriever for Redis VectorStore."""

    vectorstore: Any
    """Redis VectorStore."""
    search_type: str = "similarity"
    """Type of search to perform. Can be either
    'similarity',
    'similarity_score_threshold',
    """

    search_kwargs: Dict[str, Any] = {
        "num_results": 5,
        "score_threshold": 0.5,
        "extracted_ner_heads": {},
    }

    allowed_search_types = [
        "similarity",
        "similarity_score_threshold",
    ]

    class Config:
        arbitrary_types_allowed = True

    async def _aget_relevant_documents(
        self,
        query: str,
    ) -> List[Dict]:
        """Get relevant documents from the VectorStore."""
        if self.search_type == "similarity":
            # remove score_threshold from search_kwargs
            self.search_kwargs.pop("score_threshold", None)
            docs = await self.vectorstore.retrieve_from_index(
                query, **self.search_kwargs
            )

        elif self.search_type == "similarity_score_threshold":
            docs = await self.vectorstore.retrieve_from_index_score_threshold(
                query, **self.search_kwargs
            )
        else:
            raise ValueError(f"search_type of {self.search_type} not allowed.")
        return docs

    def _get_relevant_documents(
        self,
        query: str,
    ) -> List[Dict]:
        return asyncio.run(self._aget_relevant_documents(query))

In [None]:
class RedisvlVectorDB:
    """Redis Vector DB model for storing and retrieving documents."""

    def __init__(
        self,
        static_metadata: list,
        added_metadata: list,
        embedding_field_name: str,
        vector_field_name: str,
        embedding_dimension: int,
        distance_metric: str,
        vector_algo: str,
        index_name: str,
        doc_embedder: DocumentEmbedder,
        llm_re_ranking: bool,
        redis_connection: Optional[Redis] = None,
        redis_url_params: Optional[Dict[str, Union[str, int]]] = None,
    ):
        """Initialize the Redis Vector DB model."""
        self.static_metadata = static_metadata
        self.added_metadata = added_metadata
        self.embedding_field_name = embedding_field_name
        self.vector_field_name = vector_field_name
        self.embedding_dimension = embedding_dimension
        self.distance_metric = distance_metric
        self.vector_algo = vector_algo
        self.index_name = index_name
        self.doc_embedder = doc_embedder
        self.llm_re_ranking = llm_re_ranking
        self.cross_encoder_reranker = None

        if self.llm_re_ranking:
            self._initialize_reranker()

        if redis_connection:
            self.redis_connection = redis_connection
        else:
            self.redis_connection = None
            self.redis_url = (
                get_redis_url(**redis_url_params) if redis_url_params else None
            )
        self.schema = self._create_schema()

    def _initialize_reranker(self):
        """Initialize the cross-encoder reranker for re-ranking."""

        print("Using the cross-encoder reranker for re-ranking.")
        self.cross_encoder_reranker = HFCrossEncoderReranker(
            "cross-encoder/ms-marco-MiniLM-L-6-v2", limit=5
        )

    def _create_schema(
        self,
    ):
        """Schema creation for vector DB index"""
        fields = (
            [
                {"name": name, "type": "text"}
                for name in self.static_metadata + [self.embedding_field_name]
            ]
            + [{"name": name, "type": "tag"} for name in self.added_metadata]
            + [
                {
                    "name": self.vector_field_name,
                    "type": "vector",
                    "attrs": {
                        "dims": self.embedding_dimension,
                        "distance_metric": self.distance_metric,
                        "algorithm": self.vector_algo,
                        "datatype": "float32",
                    },
                }
            ]
        )

        schema = {
            "index": {
                "name": f"{self.index_name}",
                "prefix": uuid.uuid4().hex,
            },
            "fields": fields,
        }

        return schema

    async def get_redis_index(
        self,
    ):
        """Get async-redis index connection."""
        print(f"Schema:{self.schema}")

        if self.redis_connection is not None:
            print("Using the existing redis connection")
            index = AsyncSearchIndex.from_dict(self.schema)
            await index.set_client(self.redis_connection)
            return index
        else:
            print("Establishing redis connection")
            client = Redis.from_url(self.redis_url)
            index = AsyncSearchIndex.from_dict(self.schema)
            await index.set_client(client)
            return index

    async def _delete_index(self, index: AsyncSearchIndex):
        """Delete the index."""
        try:
            await index.delete()
        except Exception as e:
            print("No index present to drop")

    async def drop_index(self):
        """Drop the index."""

        index = await self.get_redis_index()
        if await index.exists():
            await self._delete_index(index)
            print(f"Deleted the index")
            return None

    async def _load_docs(self, index: AsyncSearchIndex, embedding_hashes: list):
        """Load list of dicts to the index."""
        keys = await index.load(embedding_hashes)
        return keys

    async def create_redis_index_upload(
        self, embedding_hashes: list, drop_index: bool = True
    ):
        """Create a new redis index and upload the documents from the embedding hashes"""
        index = await self.get_redis_index()
        if drop_index:
            await self._delete_index(index)

        # create the index and upload the documents
        await index.create(overwrite=True)

        keys = await self._load_docs(index, embedding_hashes)
        print(f"created redis index and uploaded {len(keys)} records")
        return keys

    async def update_redis_index(self, embedding_hashes: list):
        """Update the redis index with the new documents."""
        index = await self.get_redis_index()
        if await index.exists():
            updated_keys = await self._load_docs(index, embedding_hashes)
            print(f"Appended to redis index and uploaded {len(updated_keys)} records")
            return updated_keys
        else:
            logger.error("No index present to update")
            return None

    async def _curate_query(
        self,
        query: str,
        num_results: int,
    ):
        """Curate the query for the search."""

        query_vector = await self.doc_embedder.content_embedder(query, "vectors")
        query_search = VectorQuery(
            vector=query_vector,
            vector_field_name=self.vector_field_name,
            return_fields=self.static_metadata
            + ["vector_distance", self.embedding_field_name]
            + self.added_metadata,
            num_results=num_results,
            return_score=True,
        )
        return query_search

    async def _conditional_filters(self, extracted_ner_heads: dict):
        """Apply conditional filters on the extracted NER heads."""

        condition_ls = [
            Tag(k.lower()) == v.lower() for k, v in extracted_ner_heads.items()
        ]
        full_condition = None
        for i in range(len(condition_ls)):
            if i == 0:
                full_condition = condition_ls[i]
            else:
                full_condition = full_condition & condition_ls[i]

        return full_condition

    async def _search_index_results(
        self,
        index: AsyncSearchIndex,
        query: str,
        num_results: int,
        extracted_ner_heads: dict,
    ):
        """Search the index for the query and return the results."""

        query_search = await self._curate_query(
            query=query,
            num_results=num_results,
        )
        if extracted_ner_heads:
            print(
                "Extracted NER heads present in the query and applying conditional filters."
            )
            if len(extracted_ner_heads) > 0:
                full_condition = await self._conditional_filters(
                    extracted_ner_heads=extracted_ner_heads
                )
                query_search.set_filter(full_condition)

        # search the index for the query and return the results
        results = await index.query(query_search)
        return results

    async def _rerank_results(
        self,
        results: list,
        query: str,
    ):
        """Rerank the results using the cross-encoder."""

        re_rank_raw_docs = [r["page_content"] for r in results]
        ids = [{"id": r["id"]} for r in results]
        re_rank_op = await self.cross_encoder_reranker.arank(query, re_rank_raw_docs)
        op_ls = [
            {"content": op[0]["content"], "score": op[1]}
            for op in zip(re_rank_op[0], re_rank_op[1])
        ]
        filter_op_ls = list(filter(lambda x: x["score"] > 0, op_ls))
        filter_op_ls = list(zip(ids, filter_op_ls))
        return filter_op_ls

    async def _re_rank_selection(self, reranked_results, results):
        """Select the ids from the reranked results and filter the results based on the final list."""

        # selecting ids from the reranked results
        final_results_ids = [r[0]["id"] for r in reranked_results]
        # filter the results based on the final list that we got from reranking results
        final_results = [r for r in results if r["id"] in final_results_ids]
        return final_results

    async def retrieve_from_index(
        self,
        query: str,
        num_results: int,
        extracted_ner_heads: dict,
    ):
        """Retrieve the results from the index."""

        # get the index connection from redis
        index = await self.get_redis_index()
        # search the index for the query and return the results
        redis_results = await self._search_index_results(
            index,
            query,
            num_results,
            extracted_ner_heads,
        )
        if self.llm_re_ranking:
            print("Reranking the results using the cross-encoder.")
            reranked_results = await self._rerank_results(redis_results, query)

            # select the ids from the reranked results and filter the results based on the final list
            final_selected_results = await self._re_rank_selection(
                reranked_results, redis_results
            )
            return final_selected_results

        return redis_results

    async def retrieve_from_index_score_threshold(
        self,
        query: str,
        num_results: int,
        extracted_ner_heads: dict,
        score_threshold: float,
    ):
        """Retrieve the results from the index based on the score threshold."""

        retrieved_results = await self.retrieve_from_index(
            query,
            num_results,
            extracted_ner_heads,
        )
        filtered_results = [
            r
            for r in retrieved_results
            if (1 - float(r["vector_distance"])) > score_threshold
        ]

        return filtered_results

    def as_retriever(self, **kwargs: Any) -> RedisVectorStoreRetriever:
        """Return a RedisVectorStoreRetriever object."""
        return RedisVectorStoreRetriever(vectorstore=self, **kwargs)

In [None]:
redis_vector_db = RedisvlVectorDB(
    static_metadata,
    added_metadata,
    embedding_field_name,
    vector_field_name,
    embedding_dimension,
    distance_metric,
    vector_algo,
    index_name,
    doc_embedder,
    llm_re_ranking,
    redis_connection=None,
    redis_url_params=redis_url_params,
)

In [None]:
# index = await redis_vector_db.get_redis_index()

In [None]:
# Create the index and upload the documents
new_keys = await redis_vector_db.create_redis_index_upload(
    embedded_hash_docs, drop_index=True
)
print(len(new_keys))

In [None]:
# # Update new documents on the same index
# update_keys = await redis_vector_db.update_redis_index(embedded_hash_docs)
# print(len(update_keys))

#### Retriver from the Redis Index

In [None]:
# await redis_vector_db.drop_index()

In [None]:
query = "Which parties are involved in the MSA contracts?"
# query = "What is the expected YOE for the role of Strategy Consultant in this agreement?"

In [None]:
# result_op = await redis_vector_db.retrieve_from_index(query, 5, {"contractor_name": "Accenture"})

In [None]:
# result_op

In [None]:
# result_op2 = await redis_vector_db.retrieve_from_index_score_threshold(query, 5, {"contractor_name": "Accenture"}, 0.8)
# result_op2

In [None]:
retriever = redis_vector_db.as_retriever(
    search_type="similarity_score_threshold",
    search_kwargs={
        "num_results": 5,
        "score_threshold": 0.75,
        "extracted_ner_heads": {},
    },
)

In [None]:
sample_op = await retriever._aget_relevant_documents(
    query
)  # This is the async method to get the results

In [None]:
# retriever._get_relevant_documents(query) # This is the sync method to get the results

#### Langchain Summarizer chain with custom retriever and lecl

In [None]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import (
    HumanMessagePromptTemplate,
    SystemMessagePromptTemplate,
)
from langchain_core.prompts import ChatPromptTemplate

In [None]:
system_message = """
    You are an AI assistant specialized in answering questions based on the provided documents. Your task is to answer the questions based on the documents provided.
    The related documents are provided under the '<<<Documents>>>' context and the question is provided under the '<<<Question>>>' context.
    You can use the documents to answer the questions.
    The final answer should be a clear and concise answer to the question with the relevant information extracted from the documents provided under the '<<<Documents>>>' context.
    The final answer should also be grammatically correct and should be in complete sentences and should be relevant to the question asked.
    At the end of the answer, please provide a short and concise summary of the answer.
    <<<Documents>>>: Context of the documents provided using which the question should be answered. It contains the relevant information extracted from the documents and the page numbers at the start of each content.Page numbers are also indication of the start of a new document content.
    <<<Formating_Instructions>>>: Formatting instructions for the answer.
    <<<Question>>>: The question that needs to be answered using the documents provided.
    """

human_message = """
    <<<Documents>>>:\n
    {context}
    <<<Formating_Instructions>>>:\n
    - While summarizing the answer, please ensure that the answer is clear, concise, and relevant to the question asked. The answer should be in complete sentences and should be grammatically correct.
    - At the end of every sentence in the final summary, please provide the only the page number from where the information is extracted as citations.
    <<<Question>>>:\n
    {question}

    Answer: 
    """

In [None]:
def prompt_generator(
    system_message: str = "", human_message: str = ""
) -> ChatPromptTemplate:
    prompt_template = ChatPromptTemplate.from_messages(
        [
            SystemMessagePromptTemplate.from_template(system_message),
            HumanMessagePromptTemplate.from_template(human_message),
        ]
    )
    return prompt_template

In [None]:
qna_prompt = prompt_generator(system_message, human_message)

In [None]:
def format_docs(docs):
    """Format the documents for the prompt."""
    return "\n\n".join(
        "\nPageNumber:" + doc["page_number"] + "\nPageContent:" + doc["page_content"]
        for doc in docs
    )

In [None]:
# print(format_docs(sample_op))

In [None]:
llm_model = get_gpt_model(
    azure_deployment=os.getenv("CHAT_ENGINE_GPT4_DEPLOYMENT_NAME"),
    model_name=os.getenv("CHAT_ENGINE_GPT4_MODEL_NAME"),
    api_key=os.getenv("AZURE_OPENAI_API_KEY"),
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
    openai_api_type=os.getenv("OPENAI_API_TYPE"),
    api_version=os.getenv("OPENAI_API_VERSION"),
    temperature=0.0,
    request_timeout=45,
    max_retries=5,
    seed=1234,
    top_p=0.0001,
)

In [None]:
rag_chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | qna_prompt
    | llm_model
    | StrOutputParser()
)

In [None]:
op = await rag_chain.ainvoke(query)

In [None]:
print(op)

#################-------------------------------------------------------------############################