In [1]:
import os
import bs4
from langchain import hub
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain.output_parsers import PydanticOutputParser
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain.embeddings.azure_openai import AzureOpenAIEmbeddings
from langchain.chat_models import AzureChatOpenAI
from langchain.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
import warnings
warnings.filterwarnings(action='ignore')

In [2]:
os.environ["REQUESTS_CA_BUNDLE"] = r"../../ca-bundle-full.crt"

In [3]:
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

In [4]:
llm = AzureChatOpenAI(
    deployment_name=os.environ["AZURE_DEPLOYMENT_NAME"],
    openai_api_version=os.environ["OPENAI_API_VERSION"],
    openai_api_base=os.environ["OPENAI_API_BASE"],
    openai_api_key=os.environ["OPENAI_API_KEY"],
    openai_api_type=os.environ["OPENAI_API_TYPE"],
    temperature=0
)

In [51]:
class PromptResponse(BaseModel):
    Value:str = Field(description="Return the response as string")
    Result:bool = Field(description="Return true if the above prompt is valid else false")

    def __repr__(self):
        return f"PromptResponse(Value={self.Value}, Result = {self.Result})"

    def __str__(self):
        return self.__repr__()

In [44]:
#Node("In Rahul Dravid's entire career how many runs he scored in Tests", Chain(llm), output_parser).invoke(retriever | format_docs)

In [77]:
def build_tree(llm):
    chain = Chain(llm)
    output_parser = PydanticOutputParser(pydantic_object=PromptResponse)
    tree = Node("Is the document is related to Cricket", chain, output_parser, 
           left = Node("The given document is related to which sports", chain, output_parser), 
           right = Node("Is the document about Rahul Dravid", chain, output_parser, 
                       left = Node("Then which player is the document about", chain, output_parser), 
                       right = Node("In Rahul Dravid's entire career how many runs he scored in Tests", chain, output_parser)))
    return tree

In [46]:
class Chain:
    def __init__(self, llm):
        self.prompt = hub.pull("rlm/rag-prompt")
        self.llm = llm

    def create_rag_chain(self, retriever):
        chain = (
                {"context":  retriever, "question": RunnablePassthrough()}
                | self.prompt
                | self.llm
                )
        return chain
    def create_chain(self, output_parser):
        template = """{instruction}.{format_instruction}"""
        prompt = ChatPromptTemplate.from_template(template)
        chain = prompt | self.llm | output_parser
        return chain

In [68]:
class Node(object):
    def __init__(self, instruction: str, chain: Chain, output_parser:PydanticOutputParser, left=None, right=None):
        self.instruction = instruction
        self.chain = chain
        self.left = left
        self.right = right
        self.output_parser = output_parser

    def invoke(self, retriever) -> bool:
        rag_output = self.chain.create_rag_chain(retriever).invoke(self.instruction)
        #print(rag_output)
        template = """{instruction}.{format_instruction}"""
        prompt = ChatPromptTemplate.from_template(template)
        chain = self.chain.create_chain(self.output_parser)
        response = chain.invoke({"instruction": "Based on the previous question return the answer"+"\n"+rag_output.content, "format_instruction":self.output_parser.get_format_instructions()})
        print(response)
        return response.Result

In [74]:
class PlayerInformation:
    def get_document_embeddings(self, url):
        loader = WebBaseLoader(
            web_paths=(url,),
            bs_kwargs=dict(
                parse_only=bs4.SoupStrainer("body")
            ),
        )
        docs = loader.load()
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
        splits = text_splitter.split_documents(docs)
        embeddings = AzureOpenAIEmbeddings(
            deployment="text-embedding-ada-002"
        )
        vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
        retriever = vectorstore.as_retriever()
        return retriever | format_docs

    def analyze(self, tree, retriever):
        root = tree
        while root:
            if root.invoke(retriever):
                root = root.right
            else:
                root = root.left

In [75]:
url = "https://simple.wikipedia.org/wiki/Lionel_Messi"
tree = build_tree(llm)
player_information = PlayerInformation()
#retriever = player_information.get_document_embeddings(url)
#player_information.analyze(tree, retriever)

In [76]:
retriever = player_information.get_document_embeddings(url)
player_information.analyze(tree, retriever)

PromptResponse(Value=Yes, the document is related to Cricket., Result = True)
PromptResponse(Value=Yes, the document is about Rahul Dravid., Result = True)
PromptResponse(Value=13288, Result = True)
