# Adaptive RAG with LangGraph

## Introduction

Adaptive RAG is an advanced strategy for Retrieval-Augmented Generation (RAG) that combines **query analysis** with **active/self-corrective RAG** to dynamically adapt the retrieval and generation process based on the nature of the user's query. This approach ensures that the system can handle a wide range of questions effectively, from simple factual queries to complex, multi-step reasoning tasks. 

In this implementation, we use **LangGraph** to build a workflow that routes queries between two primary paths:
1. **Web Search**: For questions related to recent events or topics not covered in the indexed documents.
2. **Self-Corrective RAG**: For questions related to the indexed documents, where the system retrieves relevant information, generates an answer, and iteratively refines the response to ensure accuracy and relevance.

By leveraging LangGraph's graph-based workflow, we create a flexible and adaptive RAG pipeline that can dynamically switch between retrieval strategies, evaluate the quality of generated answers, and self-correct when necessary. This approach mirrors the principles outlined in recent research, where query analysis is used to route queries across different retrieval strategies, such as **No Retrieval**, **Single-shot RAG**, and **Iterative RAG**.

## Step 0: Installation Commands

The provided code installs several Python libraries commonly used in building and working with language models, agents, and related tools. Here's a brief description of each library:

1. **`langchain-openai`**: Integrates OpenAI's models (like GPT) with the LangChain framework for building language model applications.
2. **`langchain-anthropic`**: Integrates Anthropic's models (like Claude) with LangChain.
3. **`langchain_community`**: Provides community-contributed tools, integrations, and utilities for LangChain.
4. **`langchain_experimental`**: Contains experimental features and tools for LangChain that are still under development.
5. **`langgraph`**: A library for building and visualizing graph-based workflows, often used in conjunction with LangChain.
6. **`tiktoken`**: A tokenizer for OpenAI models, used to count and manage tokens in text.
7. **`chromadb`**: A vector database for storing and querying embeddings, often used in semantic search and retrieval-augmented generation (RAG) pipelines.
8. **`duckduckgo_search`**: A Python wrapper for the DuckDuckGo search engine, useful for retrieving real-time information from the web.

These libraries are essential for building advanced language model applications, including chatbots, agents, and retrieval systems.

In [None]:
!pip install -qU langchain-openai
!pip install -qU langchain-anthropic
!pip install -qU langchain_community
!pip install -qU langchain_experimental
!pip install -qU langgraph
!pip install -qU tiktoken
!pip install -qU chromadb
!pip install -qU duckduckgo_search

## Step 1: Build Index

This step sets up the document index by loading web-based documents, splitting them into chunks, and storing them in a vector store (Chroma) using OpenAI embeddings. The index is used for efficient retrieval of relevant documents during query processing.

In [None]:
# Phase 1: Build Index
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from kaggle_secrets import UserSecretsClient

# Fetch API key securely
user_secrets = UserSecretsClient()
my_api_key = user_secrets.get_secret("my-openai-api-key")
#my_api_key = user_secrets.get_secret("my-deepseek-api-key")

# Initialize OpenAI embeddings
embd = OpenAIEmbeddings(model="text-embedding-3-small", api_key=my_api_key)

# URLs of documents to index
urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

# Load documents from URLs
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

# Split documents into smaller chunks for efficient processing
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=500, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)

# Store documents in Chroma vector store
vectorstore = Chroma.from_documents(
    documents=doc_splits, collection_name="rag-chroma", embedding=embd
)
retriever = vectorstore.as_retriever()

## Step 2: Router

The router determines whether a user query should be answered using the vector store (for domain-specific questions) or web search (for general questions). It uses a structured LLM (GPT-4o-mini) to classify the query.

In [None]:
# Phase 2: Router
from typing import Literal
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field

# Define a Pydantic model for routing decisions
class RouteQuery(BaseModel):
    """Route a user query to the most relevant datasource."""

    datasource: Literal["vectorstore", "web_search"] = Field(
        ...,
        description="Given a user question, choose to route it to web search or a vectorstore.",
    )

# Initialize LLM for routing
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, api_key=my_api_key)
structured_llm_router = llm.with_structured_output(RouteQuery)

# Define the routing prompt
system = """You are an expert at routing a user question to a vectorstore or web search.
The vectorstore contains documents related to agents, prompt engineering, and adversarial attacks.
Use the vectorstore for questions on these topics. Otherwise, use web-search."""
route_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "{question}"),
    ]
)

# Create the question router chain
question_router = route_prompt | structured_llm_router

# Test the router with sample questions
print(question_router.invoke({"question": "Who will the Bears draft first in the NFL draft?"}))
print(question_router.invoke({"question": "What are the types of agent memory?"}))

## Step 3: Retrieval Grader

This step evaluates the relevance of retrieved documents to the user's query. It uses a binary grading system (yes/no) to filter out irrelevant documents, ensuring only contextually appropriate content is used for answer generation.

In [None]:
# Phase 3: Retrieval Grader
class GradeDocuments(BaseModel):
    """Binary score for relevance check on retrieved documents."""

    binary_score: str = Field(description="Documents are relevant to the question, 'yes' or 'no'")

# Initialize LLM for grading
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, api_key=my_api_key)
structured_llm_grader = llm.with_structured_output(GradeDocuments)

# Define the grading prompt
system = """You are a grader assessing relevance of a retrieved document to a user question. \n 
    If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
    It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
    Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
grade_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
    ]
)

# Create the retrieval grader chain
retrieval_grader = grade_prompt | structured_llm_grader

# Test the retrieval grader
question = "agent memory"
docs = retriever.invoke(question)
doc_txt = docs[1].page_content

result = retrieval_grader.invoke({"question": question, "document": doc_txt})
print(result)

## Step 4: Generate

The generate phase constructs a RAG chain to produce answers based on the retrieved documents and the user's query. It uses a pre-defined prompt and GPT-4o-mini to generate coherent and contextually accurate responses.

In [None]:
# Phase 4: Generate
from langchain import hub
from langchain_core.output_parsers import StrOutputParser

# Pull the RAG prompt from the hub
prompt = hub.pull("rlm/rag-prompt")

# Initialize LLM for generation
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0, api_key=my_api_key)

# Create the RAG chain
rag_chain = prompt | llm | StrOutputParser()

# Generate an answer using the RAG chain
generation = rag_chain.invoke({"context": docs, "question": question})
print(generation)

In [None]:
# Show what the prompt looks like
prompt = hub.pull("rlm/rag-prompt").pretty_print()

## Step 5: Hallucination Grader

This step checks if the generated answer is grounded in the retrieved documents. It ensures the answer is factually supported and not hallucinated, using a binary grading system.

In [None]:
# Phase 5: Hallucination Grader
class GradeHallucinations(BaseModel):
    """Binary score for hallucination present in generation answer."""

    binary_score: str = Field(description="Answer is grounded in the facts, 'yes' or 'no'")

# Initialize LLM for hallucination grading
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, api_key=my_api_key)
structured_llm_grader = llm.with_structured_output(GradeHallucinations)

# Define the hallucination grading prompt
system = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n 
     Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
hallucination_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
    ]
)

# Create the hallucination grader chain
hallucination_grader = hallucination_prompt | structured_llm_grader

# Test the hallucination grader
result = hallucination_grader.invoke({"documents": docs, "generation": generation})
print(result)

## Step 6: Answer Grader

The answer grader evaluates whether the generated answer fully addresses the user's query. It ensures the response is relevant and resolves the question effectively.

In [None]:
# Phase 6: Answer Grader
class GradeAnswer(BaseModel):
    """Binary score to assess answer addresses question."""

    binary_score: str = Field(description="Answer addresses the question, 'yes' or 'no'")

# Initialize LLM for answer grading
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, api_key=my_api_key)
structured_llm_grader = llm.with_structured_output(GradeAnswer)

# Define the answer grading prompt
system = """You are a grader assessing whether an answer addresses / resolves a question \n 
     Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question."""
answer_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
    ]
)

# Create the answer grader chain
answer_grader = answer_prompt | structured_llm_grader

# Test the answer grader
result = answer_grader.invoke({"question": question, "generation": generation})
print(result)

## Step 7: Question Re-writer

This step rewrites the user's query to optimize it for vector store retrieval. It improves the semantic understanding of the query, enhancing the relevance of retrieved documents.

In [None]:
# Phase 7: Question Re-writer
# Initialize LLM for question rewriting
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, api_key=my_api_key)

# Define the question rewriting prompt
system = """You are a question re-writer that converts an input question to a better version that is optimized \n 
     for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""
re_write_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        (
            "human",
            "Here is the initial question: \n\n {question} \n Formulate an improved question.",
        ),
    ]
)

# Create the question rewriter chain
question_rewriter = re_write_prompt | llm | StrOutputParser()

# Test the question rewriter
result = question_rewriter.invoke({"question": question})
print(result)

## Step 8: Search

The search phase integrates DuckDuckGo Search for web-based queries. It retrieves web results when the router determines that a question is better answered using external sources.

In [None]:
# Phase 8: Search
from langchain_community.tools import DuckDuckGoSearchRun

# Initialize DuckDuckGo search tool
web_search_tool = DuckDuckGoSearchRun()

## Step 9: Define Graph State and Flow

This step defines the **state** and **flow** of the Retrieval-Augmented Generation (RAG) pipeline as a **graph-based workflow**. The graph is composed of interconnected **nodes**, each representing a specific task in the pipeline, such as document retrieval, question routing, answer generation, and evaluation. The **state** of the graph is maintained throughout the workflow, ensuring that each node has access to the necessary information (e.g., the user's question, retrieved documents, generated answers).

In [None]:
# Phase 9: Define Graph State and Flow
from typing import List, Dict, Any
from typing_extensions import TypedDict
from langchain.schema import Document
from langgraph.graph import END, StateGraph, START

class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        documents: list of documents
    """

    question: str
    generation: str
    documents: List[str]

def retrieve_node(state: GraphState) -> Dict[str, Any]:
    """
    Retrieve documents based on the user's question.

    Args:
        state (GraphState): The current graph state.

    Returns:
        Dict[str, Any]: Updated state with retrieved documents.
    """
    print("---RETRIEVE---")
    question = state["question"]
    documents = retriever.invoke(question)
    return {"documents": documents, "question": question}

def generate_node(state: GraphState) -> Dict[str, Any]:
    """
    Generate an answer using the RAG chain.

    Args:
        state (GraphState): The current graph state.

    Returns:
        Dict[str, Any]: Updated state with generated answer.
    """
    print("---GENERATE---")
    question = state["question"]
    documents = state["documents"]
    generation = rag_chain.invoke({"context": documents, "question": question})
    return {"documents": documents, "question": question, "generation": generation}

def grade_documents_node(state: GraphState) -> Dict[str, Any]:
    """
    Grade the relevance of retrieved documents to the question.

    Args:
        state (GraphState): The current graph state.

    Returns:
        Dict[str, Any]: Updated state with filtered relevant documents.
    """
    print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
    question = state["question"]
    documents = state["documents"]
    filtered_docs = []
    for d in documents:
        score = retrieval_grader.invoke({"question": question, "document": d.page_content})
        grade = score.binary_score
        if grade == "yes":
            print("---GRADE: DOCUMENT RELEVANT---")
            filtered_docs.append(d)
        else:
            print("---GRADE: DOCUMENT NOT RELEVANT---")
            continue
    return {"documents": filtered_docs, "question": question}

def transform_query_node(state: GraphState) -> Dict[str, Any]:
    """
    Rewrite the user's question for better retrieval.

    Args:
        state (GraphState): The current graph state.

    Returns:
        Dict[str, Any]: Updated state with rephrased question.
    """
    print("---TRANSFORM QUERY---")
    question = state["question"]
    documents = state["documents"]
    better_question = question_rewriter.invoke({"question": question})
    return {"documents": documents, "question": better_question}

def web_search_node(state: GraphState) -> Dict[str, Any]:
    """
    Perform a web search using DuckDuckGo.

    Args:
        state (GraphState): The current graph state.

    Returns:
        Dict[str, Any]: Updated state with web search results.
    """
    print("---WEB SEARCH---")
    question = state["question"]
    docs = web_search_tool.invoke({"query": question})
    
    # Handle the case where docs is a list of strings or a single string
    if isinstance(docs, str):
        web_results = docs  # If docs is a single string, use it directly
    elif isinstance(docs, list):
        web_results = "\n".join(docs)  # If docs is a list of strings, join them
    else:
        raise ValueError(f"Unexpected type for web search results: {type(docs)}")
    
    web_results = Document(page_content=web_results)
    return {"documents": [web_results], "question": question}

def route_search_store(state: GraphState) -> Literal["web_search_node", "retrieve_node"]:
    """
    Route the question to either web search or RAG.

    Args:
        state (GraphState): The current graph state.

    Returns:
        str: Next node to call ("web_search_node", "vectorstore_node").
    """
    print("---ROUTE QUESTION---")
    question = state["question"]
    source = question_router.invoke({"question": question})
    if source.datasource == "web_search":
        print("---ROUTE QUESTION TO WEB SEARCH---")
        return "web_search_node"
    elif source.datasource == "vectorstore":
        print("---ROUTE QUESTION TO RAG---")
        return "retrieve_node"

def decide_to_generate(state: GraphState) -> Literal["transform_query_node", "generate_node"]:
    """
    Decide whether to generate an answer or re-generate the question.

    Args:
        state (GraphState): The current graph state.

    Returns:
        str: Next node to call ("transform_query_node", "generate_node").
    """
    print("---ASSESS GRADED DOCUMENTS---")
    filtered_documents = state["documents"]
    if not filtered_documents:
        print("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---")
        return "transform_query_node"
    else:
        print("---DECISION: GENERATE---")
        return "generate_node"

def route_generate_transform(state: GraphState) -> Literal["generate_node", "transform_query_node", END]:
    """
    Grade the generation against the documents and the question.

    Args:
        state (GraphState): The current graph state.

    Returns:
        str: Decision for next node ("generate_node", "transform_query_node", END).
    """
    print("---CHECK HALLUCINATIONS---")
    question = state["question"]
    documents = state["documents"]
    generation = state["generation"]
    score = hallucination_grader.invoke({"documents": documents, "generation": generation})
    grade = score.binary_score
    if grade == "yes":
        print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
        score = answer_grader.invoke({"question": question, "generation": generation})
        grade = score.binary_score
        if grade == "yes":
            print("---DECISION: GENERATION ADDRESSES QUESTION---")
            return END
        else:
            print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
            return "transform_query_node"
    else:
        print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
        return "generate_node"

## Step 10: Compile and Use Graph

The graph is compiled into an executable workflow. It orchestrates the interaction between nodes, ensuring the pipeline processes queries efficiently and produces accurate results.

1. **Graph State (`GraphState`):**
   - A `TypedDict` that represents the state of the graph at any point in the workflow.
   - Contains:
     - **`question`**: The user's input question.
     - **`generation`**: The LLM-generated answer.
     - **`documents`**: A list of retrieved documents relevant to the question.

2. **Nodes:**
   - Each node is a function that performs a specific task and updates the graph state.
   - Key nodes include:
     - **`retrieve_node`**: Retrieves documents from the vector store based on the user's question.
     - **`generate_node`**: Generates an answer using the RAG chain.
     - **`grade_documents_node`**: Grades the relevance of retrieved documents to the question.
     - **`transform_query_node`**: Rewrites the user's question for better retrieval.
     - **`web_search_node`**: Performs a web search using DuckDuckGo for external information.
     - **`route_search_store`**: Routes the question to either web search or the vector store.
     - **`decide_to_generate`**: Decides whether to generate an answer or rephrase the question.
     - **`route_generate_transform`**: Evaluates the generated answer and decides the next step (e.g., end the workflow, rephrase the question, or regenerate the answer).

3. **Conditional Edges:**
   - The graph uses **conditional edges** to dynamically route the workflow based on the results of each node.
   - For example:
     - If the retrieved documents are irrelevant, the workflow routes to `transform_query_node` to rephrase the question.
     - If the generated answer is not grounded in the documents, the workflow routes back to `generate_node` to regenerate the answer.

4. **Flow Logic:**
   - The workflow starts with the user's question and routes it to either the vector store or web search based on relevance.
   - Retrieved documents are graded for relevance, and irrelevant documents are filtered out.
   - The RAG chain generates an answer, which is then evaluated for hallucinations and relevance to the question.
   - If the answer is satisfactory, the workflow ends. Otherwise, it loops back to rephrase the question or regenerate the answer.


In [None]:
# Phase 10: Compile and Use Graph
# Initialize the workflow
workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("web_search_node", web_search_node)
workflow.add_node("retrieve_node", retrieve_node)
workflow.add_node("grade_documents_node", grade_documents_node)
workflow.add_node("generate_node", generate_node)
workflow.add_node("transform_query_node", transform_query_node)

# Build the graph
workflow.add_conditional_edges(START, route_search_store, ["web_search_node", "retrieve_node"])
workflow.add_edge("web_search_node", "generate_node")
workflow.add_edge("retrieve_node", "grade_documents_node")
workflow.add_conditional_edges("grade_documents_node", decide_to_generate, ["transform_query_node", "generate_node"])
workflow.add_edge("transform_query_node", "retrieve_node")
workflow.add_conditional_edges("generate_node", route_generate_transform, ["generate_node", "transform_query_node", END])

# Compile the workflow
app = workflow.compile()

## Step 11: Visualization

This optional phase visualizes the graph structure using Mermaid and IPython. It provides a graphical representation of the pipeline's flow and decision points.

In [None]:
# Phase 11: Visualization
from IPython.display import Image, display

# Visualize the graph (optional, requires additional dependencies)
try:
    display(Image(app.get_graph(xray=True).draw_mermaid_png()))
except Exception:
    pass

## Step 12: Execute the Graph

The final step executes the graph with a user query. It processes the query through the pipeline, prints intermediate outputs in JSON format, and displays the final generated answer. Non-serializable objects are converted to dictionaries for clean output.

In [None]:
# Phase 12: Execute the Graph with an Input
import json
import warnings

# Suppress LangSmith API key warning (if not using LangSmith)
warnings.filterwarnings("ignore", category=UserWarning, message="API key must be provided when using hosted LangSmith API")

# Helper function to convert non-serializable objects to dictionaries
def convert_to_serializable(obj):
    if hasattr(obj, "dict"):  # Check if the object has a .dict() method
        return obj.dict()
    elif isinstance(obj, (list, tuple)):  # Handle lists and tuples
        return [convert_to_serializable(item) for item in obj]
    elif isinstance(obj, dict):  # Handle dictionaries
        return {key: convert_to_serializable(value) for key, value in obj.items()}
    else:  # Return the object as-is if it's already serializable
        return obj

# Run the graph with an input question
inputs = {"question": "What are the types of agent memory?"}
for output in app.stream(inputs):
    for key, value in output.items():
        print(f"Output from node '{key}':")
        print("-"*80)
        # Convert non-serializable objects to dictionaries
        serializable_value = convert_to_serializable(value)
        # Print the serialized value as a JSON string
        print(json.dumps(serializable_value, indent=2))
    print("="*80)

# Print the final generated answer
print(f"final generation:\n{value['generation']}\n")

In [None]:
# Run the graph with an input question related to agents
inputs = {"question": "What are the key components of an AI agent?"}
for output in app.stream(inputs):
    for key, value in output.items():
        print(f"Output from node '{key}':")
        print("-"*80)
        # Convert non-serializable objects to dictionaries
        serializable_value = convert_to_serializable(value)
        # Print the serialized value as a JSON string
        print(json.dumps(serializable_value, indent=2))
    print("="*80)

# Print the final generated answer
print(f"final generation:\n{value['generation']}\n")

In [None]:
# Run the graph with an input question unrelated to agents
inputs = {"question": "What is the capital of France?"}
for output in app.stream(inputs):
    for key, value in output.items():
        print(f"Output from node '{key}':")
        print("-"*80)
        # Convert non-serializable objects to dictionaries
        serializable_value = convert_to_serializable(value)
        # Print the serialized value as a JSON string
        print(json.dumps(serializable_value, indent=2))
    print("="*80)

# Print the final generated answer
print(f"final generation:\n{value['generation']}\n")

## Conclusion

The Adaptive RAG pipeline built with LangGraph demonstrates the power of combining **query analysis** with **self-corrective mechanisms** to create a robust and flexible question-answering system. By dynamically routing queries between **web search** and **self-corrective RAG**, the system ensures that users receive accurate and contextually relevant answers, regardless of the complexity or recency of the topic.

This implementation highlights the importance of modularity and adaptability in modern RAG systems. The use of LangGraph's graph-based workflow allows for seamless integration of multiple retrieval strategies, real-time evaluation of generated answers, and iterative refinement to improve response quality. As RAG systems continue to evolve, strategies like Adaptive RAG will play a crucial role in enhancing their ability to handle diverse and challenging queries effectively.

In conclusion, Adaptive RAG with LangGraph represents a significant step forward in building intelligent, self-correcting question-answering systems that can adapt to the needs of users and the nature of their queries.
