Aim - Create a RAG based LLM to fetch accurate information
Steps - classify query, get context based on the question, generate answer, evaluate the answer

In [6]:
from langchain_mistralai.chat_models import ChatMistralAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader
from langchain_mistralai import MistralAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import ChatPromptTemplate
from typing import TypedDict, List, Annotated
from langgraph.graph import StateGraph, START, END
from langchain_core.documents import Document



In [None]:
api_key = <your api_key>

In [None]:
class State(TypedDict):
    query : str
    query_type : str
    context : str
    answer : str
    score : str

In [None]:
def classify_query(state : State):
    """Classifies a user query as a generic query or a specific topic related.
    In this case the Transformer Model"""
    
    llm = ChatMistralAI(api_key=api_key, model = "mistral-large-latest")
    system_prompt = """You are an expert in classifying a given query as - generic or transformer-related.
                        If the query is anything about the Transformer Model, then classify it as 'transformer-related' else 'generic'.
                    """

    query_type = llm.invoke(input = [("system",system_prompt),
                                   ("human",state["query"])]).content
    
    return {"query_type" : query_type, "query" : state["query"], "context": "", "answer":"", "score" : ""}

In [None]:
def get_context(state : State):
    """ Gets the Context from the internal document using RAG"""
    api_key = 'Eajkd7toYyYCEoU1LQiNFcPTvyK3ONep'
    doc_loader = PyPDFLoader("transformers.pdf")
    doc = doc_loader.load()

    doc_splitter = RecursiveCharacterTextSplitter()
    doc_chunk = doc_splitter.split_documents(doc)

    embedding = MistralAIEmbeddings(api_key=api_key)
    vector_store = FAISS.from_documents(doc_chunk, embedding)
    retriever = vector_store.as_retriever(search_kwargs={"k":1})
    context = retriever.invoke(input=state["query"])

    return { "query_type" : state["query_type"], "query" : state["query"], "context": context, "answer":"", "score" : ""}


In [None]:
def generate(state : State):
    """Generate the answer based on the given input and context"""
    prompt = f"""With the given context answer the query.
                    ALWAYS append your answer with 'USING RAG'

                    <context>
                    {state["context"]}
                    </context>
                
                    query : {state["query"]}"""
    
    llm = ChatMistralAI(api_key=api_key, model = "mistral-large-latest", max_tokens = 100)
    response = llm.invoke(input = [("system",prompt),
                                    ("human",state["query"])]).content
    return {"query_type" : state["query_type"], "query" : state["query"], "context": state["context"], "answer":response, "score" : ""}
    # return {"query" : state["query"], "answer":response}
    # return query, answer


In [None]:
def evaluate(state : State):
    """Evaluate the answer using another LLM as an evaluator"""
    system_prompt = f"""With the given query and answer, rate the answer from 1-5. 5 is the highest score.
                    {state["query"]}
                    {state["answer"]}
                    """
    llm = ChatMistralAI(api_key=api_key, model='mistral-large-latest')
    score = llm.invoke(input = [("system",system_prompt),
                                   ("human",f"""{state["query"]}
                                    {state["answer"]}""")])
    return {"query_type" : state["query_type"], "query" : state["query"], "context": state["context"], "answer": state["answer"], "score" : score.content}


In [None]:
def router(state : State):
    """Route to the next step. In case of generic query simply answer and avoid fetching context"""
    if 'transformer-related' in state["query_type"].lower():
        return "get_context"
    else:
        return END        

In [None]:
## Initialize the graph to orchestrate the steps
graph = StateGraph(State)
graph.add_node("classify_query", classify_query)
graph.add_node("get_context", get_context)
graph.add_node("generate",generate)
graph.add_node("evaluate", evaluate)

<langgraph.graph.state.StateGraph at 0x1076bc970>

In [None]:
## Add edges
graph.add_edge(START,"classify_query")
graph.add_conditional_edges("classify_query",router)
graph.add_edge("get_context", "generate")
graph.add_edge("generate","evaluate")
graph.add_edge("evaluate", END)

graph.set_entry_point("classify_query")

<langgraph.graph.state.StateGraph at 0x1076bc970>

In [None]:
## compile the graph
workflow = graph.compile()
workflow.get_graph()

Graph(nodes={'__start__': Node(id='__start__', name='__start__', data=RunnablePassthrough(), metadata=None), 'classify_query': Node(id='classify_query', name='classify_query', data=classify_query(tags=None, recurse=True, explode_args=False, func_accepts_config=False, func_accepts={}), metadata=None), 'get_context': Node(id='get_context', name='get_context', data=get_context(tags=None, recurse=True, explode_args=False, func_accepts_config=False, func_accepts={}), metadata=None), 'generate': Node(id='generate', name='generate', data=generate(tags=None, recurse=True, explode_args=False, func_accepts_config=False, func_accepts={}), metadata=None), 'evaluate': Node(id='evaluate', name='evaluate', data=evaluate(tags=None, recurse=True, explode_args=False, func_accepts_config=False, func_accepts={}), metadata=None), '__end__': Node(id='__end__', name='__end__', data=None, metadata=None)}, edges=[Edge(source='__start__', target='classify_query', data=None, conditional=False), Edge(source='clas

In [None]:
## infer/invoke the graph
result = workflow.invoke({"query":"what is multi head attention in transformers?"})
result

  from .autonotebook import tqdm as notebook_tqdm


{'query': 'what is multi head attention in transformers?',
 'query_type': 'Based on the query, "what is multi head attention in transformers?", this is classified as \'transformer-related\' because it specifically asks about a component of the Transformer model architecture.',
 'context': [Document(id='940700da-a178-40cc-87a1-dff62b878211', metadata={'producer': 'pdfTeX-1.40.25', 'creator': 'LaTeX with hyperref', 'creationdate': '2024-02-09T02:33:09+00:00', 'author': '', 'keywords': '', 'moddate': '2024-02-09T02:33:09+00:00', 'ptex.fullbanner': 'This is pdfTeX, Version 3.141592653-2.6-1.40.25 (TeX Live 2023) kpathsea version 6.3.5', 'subject': '', 'title': '', 'trapped': '/False', 'source': 'transformers.pdf', 'total_pages': 10, 'page': 3, 'page_label': '3'}, page_content='the K×D matrices Uq and Uk are the only parameters of this mechanism.10\nMulti-head self-attention (MHSA). In the self-attention mechanisms de-\nscribed above, there is one attention matrix which describes the simila