In [1]:
from typing import TypedDict, List, Optional
from langgraph.graph import StateGraph, END
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_text_splitters import RecursiveCharacterTextSplitter
import operator

# State definition for the RAG workflow
class RAGState(TypedDict):
    question: str
    documents: List[Document]
    generation: str
    query_type: str
    reflection: str
    revision_needed: bool
    final_answer: str

class AgenticRAG:
    def __init__(self, llm_model="gpt-3.5-turbo", embedding_model=None):
        self.llm = ChatOpenAI(model=llm_model, temperature=0)
        self.embeddings = embedding_model or OpenAIEmbeddings()
        self.vectorstore = None
        self.retriever = None
        self.workflow = self._build_workflow()
        
    def setup_vectorstore(self, documents: List[str], chunk_size=1000, chunk_overlap=200):
        """Initialize vector store with documents"""
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap
        )
        
        # Split documents into chunks
        doc_chunks = []
        for i, doc in enumerate(documents):
            chunks = text_splitter.split_text(doc)
            for chunk in chunks:
                doc_chunks.append(Document(
                    page_content=chunk,
                    metadata={"source": f"doc_{i}"}
                ))
        
        # Create vector store
        self.vectorstore = FAISS.from_documents(doc_chunks, self.embeddings)
        self.retriever = self.vectorstore.as_retriever(search_kwargs={"k": 5})
        
    def query_router(self, state: RAGState) -> RAGState:
        """Route query to determine the type of question"""
        router_prompt = ChatPromptTemplate.from_template("""
        Analyze the following question and classify it into one of these categories:
        - factual: Direct factual questions that can be answered from documents
        - analytical: Questions requiring analysis or synthesis of information
        - creative: Questions asking for creative content or opinions
        - unclear: Questions that are ambiguous or unclear
        
        Question: {question}
        
        Respond with only the category name.
        """)
        
        chain = router_prompt | self.llm | StrOutputParser()
        query_type = chain.invoke({"question": state["question"]})
        
        state["query_type"] = query_type.strip().lower()
        return state
    
    def retrieve_documents(self, state: RAGState) -> RAGState:
        """Retrieve relevant documents based on the query"""
        if not self.retriever:
            raise ValueError("Vector store not initialized. Call setup_vectorstore() first.")
            
        # Enhance query for better retrieval
        if state["query_type"] == "analytical":
            enhanced_query = f"Analysis of {state['question']} including key factors and relationships"
        elif state["query_type"] == "factual":
            enhanced_query = state["question"]
        else:
            enhanced_query = state["question"]
            
        documents = self.retriever.invoke(enhanced_query)
        state["documents"] = documents
        return state
    
    def grade_documents(self, state: RAGState) -> RAGState:
        """Grade and filter retrieved documents for relevance"""
        grader_prompt = ChatPromptTemplate.from_template("""
        You are a document grader. Assess whether the following document is relevant to the user question.
        
        Question: {question}
        Document: {document}
        
        Give a binary score 'yes' or 'no' to indicate whether the document is relevant.
        Respond with only 'yes' or 'no'.
        """)
        
        chain = grader_prompt | self.llm | StrOutputParser()
        
        filtered_docs = []
        for doc in state["documents"]:
            score = chain.invoke({
                "question": state["question"],
                "document": doc.page_content
            })
            if score.strip().lower() == "yes":
                filtered_docs.append(doc)
        
        state["documents"] = filtered_docs
        return state
    
    def generate_answer(self, state: RAGState) -> RAGState:
        """Generate answer based on retrieved documents"""
        if not state["documents"]:
            state["generation"] = "I don't have enough relevant information to answer your question."
            return state
            
        context = "\n\n".join([doc.page_content for doc in state["documents"]])
        
        generation_prompt = ChatPromptTemplate.from_template("""
        You are an AI assistant tasked with answering questions based on the provided context.
        
        Question: {question}
        Query Type: {query_type}
        
        Context:
        {context}
        
        Instructions:
        - For factual questions: Provide direct, accurate answers based on the context
        - For analytical questions: Synthesize information and provide insights
        - For unclear questions: Ask for clarification while providing what information you can
        - Always cite specific parts of the context when possible
        - If the context doesn't contain sufficient information, clearly state this
        
        Answer:
        """)
        
        chain = generation_prompt | self.llm | StrOutputParser()
        generation = chain.invoke({
            "question": state["question"],
            "query_type": state["query_type"],
            "context": context
        })
        
        state["generation"] = generation
        return state
    
    def reflect_on_answer(self, state: RAGState) -> RAGState:
        """Reflect on the generated answer for quality and completeness"""
        reflection_prompt = ChatPromptTemplate.from_template("""
        You are a quality assessor. Evaluate the following answer for:
        1. Accuracy based on the provided context
        2. Completeness in addressing the question
        3. Clarity and coherence
        4. Proper use of context
        
        Question: {question}
        Answer: {generation}
        Context: {context}
        
        Provide your assessment and suggest improvements if needed.
        If the answer is satisfactory, respond with "APPROVED".
        If improvements are needed, provide specific suggestions.
        """)
        
        context = "\n\n".join([doc.page_content for doc in state["documents"]])
        chain = reflection_prompt | self.llm | StrOutputParser()
        
        reflection = chain.invoke({
            "question": state["question"],
            "generation": state["generation"],
            "context": context
        })
        
        state["reflection"] = reflection
        state["revision_needed"] = "APPROVED" not in reflection.upper()
        return state
    
    def revise_answer(self, state: RAGState) -> RAGState:
        """Revise the answer based on reflection feedback"""
        revision_prompt = ChatPromptTemplate.from_template("""
        Based on the feedback provided, revise the following answer to improve its quality.
        
        Original Question: {question}
        Original Answer: {generation}
        Feedback: {reflection}
        Context: {context}
        
        Provide a revised, improved answer:
        """)
        
        context = "\n\n".join([doc.page_content for doc in state["documents"]])
        chain = revision_prompt | self.llm | StrOutputParser()
        
        revised_answer = chain.invoke({
            "question": state["question"],
            "generation": state["generation"],
            "reflection": state["reflection"],
            "context": context
        })
        
        state["final_answer"] = revised_answer
        return state
    
    def finalize_answer(self, state: RAGState) -> RAGState:
        """Finalize the answer (use original if no revision needed)"""
        if not state["revision_needed"]:
            state["final_answer"] = state["generation"]
        return state
    
    def should_revise(self, state: RAGState) -> str:
        """Decide whether to revise the answer"""
        return "revise" if state.get("revision_needed", False) else "finalize"
    
    def _build_workflow(self) -> StateGraph:
        """Build the LangGraph workflow"""
        workflow = StateGraph(RAGState)
        
        # Add nodes
        workflow.add_node("query_router", self.query_router)
        workflow.add_node("retrieve", self.retrieve_documents)
        workflow.add_node("grade_docs", self.grade_documents)
        workflow.add_node("generate", self.generate_answer)
        workflow.add_node("reflect", self.reflect_on_answer)
        workflow.add_node("revise", self.revise_answer)
        workflow.add_node("finalize", self.finalize_answer)
        
        # Add edges
        workflow.set_entry_point("query_router")
        workflow.add_edge("query_router", "retrieve")
        workflow.add_edge("retrieve", "grade_docs")
        workflow.add_edge("grade_docs", "generate")
        workflow.add_edge("generate", "reflect")
        
        # Conditional edge for revision
        workflow.add_conditional_edges(
            "reflect",
            self.should_revise,
            {
                "revise": "revise",
                "finalize": "finalize"
            }
        )
        
        workflow.add_edge("revise", "finalize")
        workflow.add_edge("finalize", END)
        
        return workflow.compile()
    
    def query(self, question: str) -> dict:
        """Execute the RAG workflow for a given question"""
        initial_state = {
            "question": question,
            "documents": [],
            "generation": "",
            "query_type": "",
            "reflection": "",
            "revision_needed": False,
            "final_answer": ""
        }
        
        result = self.workflow.invoke(initial_state)
        return result

# Example usage
if __name__ == "__main__":
    # Initialize the RAG system
    rag_system = AgenticRAG()
    
    # Example documents (replace with your actual documents)
    sample_documents = [
        """
        Artificial Intelligence (AI) is a branch of computer science that aims to create 
        intelligent machines that work and react like humans. AI has been around since 
        the 1950s and has evolved significantly over the decades. Modern AI applications 
        include machine learning, natural language processing, computer vision, and robotics.
        """,
        """
        Machine Learning is a subset of AI that enables computers to learn and improve 
        from experience without being explicitly programmed. There are three main types 
        of machine learning: supervised learning, unsupervised learning, and reinforcement 
        learning. Popular algorithms include neural networks, decision trees, and support vector machines.
        """,
        """
        Large Language Models (LLMs) are AI systems trained on vast amounts of text data 
        to understand and generate human-like text. Examples include GPT, BERT, and Claude. 
        These models have revolutionized natural language processing and are used in 
        chatbots, content generation, and language translation.
        """
    ]
    
    # Setup vector store
    rag_system.setup_vectorstore(sample_documents)
    
    # Example queries
    questions = [
        "What is artificial intelligence?",
        "Compare supervised and unsupervised learning",
        "How do large language models work?"
    ]
    
    for question in questions:
        print(f"\nQuestion: {question}")
        result = rag_system.query(question)
        print(f"Answer: {result['final_answer']}")
        print(f"Query Type: {result['query_type']}")
        print("-" * 80)


Question: What is artificial intelligence?
Answer: Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines that work and react like humans. It has been around since the 1950s and has evolved significantly over the decades. Modern AI applications include machine learning, natural language processing, computer vision, and robotics.
Query Type: factual
--------------------------------------------------------------------------------

Question: Compare supervised and unsupervised learning
Answer: Supervised learning and unsupervised learning are two main types of machine learning. 

Supervised learning, as mentioned in the context, involves training a model on labeled data where the algorithm learns to map input data to the correct output. This type of learning requires a dataset with input-output pairs for the algorithm to learn from. Examples of supervised learning algorithms include neural networks, decision trees, and support vector machine