# Agentic RAG with LangGraph

## Introduction

Agentic RAG (Retrieval-Augmented Generation) is a powerful approach that combines retrieval-based methods with generative models to enhance the quality and relevance of responses. In this example, we implement a **Retrieval Agent** using LangGraph, which allows an LLM (Large Language Model) to make decisions about whether to retrieve information from an indexed dataset. By giving the LLM access to a retriever tool, we enable it to dynamically decide when to retrieve additional context and when to generate responses directly.

This example demonstrates how to:
1. Load and index blog posts into a vector store.
2. Create a retriever tool for searching the indexed data.
3. Define an agent state and workflow using LangGraph.
4. Implement nodes for grading document relevance, rewriting queries, and generating responses.
5. Visualize and execute the graph-based workflow.

## Step 0: Install Required Libraries

Install the necessary Python libraries for the project. This includes libraries for working with OpenAI, Anthropic, LangChain, LangGraph, and ChromaDB.

In [None]:
!pip install -qU langchain-openai
!pip install -qU langchain-anthropic
!pip install -qU langchain_community
!pip install -qU langchain_experimental
!pip install -qU langgraph
!pip install -qU chromadb

## Step 1: Setup and Document Loading

Load blog posts from specified URLs, split them into smaller chunks, and store them in a vector database for retrieval.

## Step 2: Retriever Tool Creation
This step involves creating a retriever tool that enables the agent to search and retrieve relevant information from the indexed blog posts. The retriever tool is built using the vector store, which contains the processed and split document chunks. This tool is designed to help the agent efficiently query and fetch specific information from the stored blog posts.

In [None]:
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from kaggle_secrets import UserSecretsClient

# Retrieve the LLM API Key
user_secrets = UserSecretsClient()
my_api_key = user_secrets.get_secret("my-openai-api-key")

# Use OpenAI embeddings for vectorization
embedding = OpenAIEmbeddings(model="text-embedding-3-small", api_key=my_api_key)

# Define URLs to load blog posts
urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

# Load documents from the specified URLs
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]  # Flatten the list of documents

# Split documents into smaller chunks for processing
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=100,   # Size of each chunk
    chunk_overlap=50  # Overlap between chunks to maintain context
)
doc_splits = text_splitter.split_documents(docs_list)

# Add the split documents to a vector store for retrieval
vectorstore = Chroma.from_documents(
    documents=doc_splits,
    collection_name="rag-chroma",
    embedding=embedding,           
)
retriever = vectorstore.as_retriever()

This step involves creating a retriever tool that allows the agent to search and retrieve information from the stored blog posts. The tool is added to a list of tools available for the agent to use.

In [None]:
from langchain.tools.retriever import create_retriever_tool

# Create a retriever tool to search and retrieve blog post information
retriever_tool = create_retriever_tool(retriever, "retrieve_blog_posts",
    "Search and return information about Lilian Weng blog posts on LLM agents, prompt engineering, and adversarial attacks on LLMs.",
)

# List of tools available for the agent
tools = [retriever_tool]

## Step 3: Agent State Definition

Define the state of the agent, which consists of a sequence of messages. The state is passed between nodes in the graph.

In [None]:
from typing import Annotated, Sequence
from typing_extensions import TypedDict
from langchain_core.messages import BaseMessage
from langgraph.graph.message import add_messages

# Define the state of the agent, which is a sequence of messages
class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], add_messages]  # Messages are appended to the state

## Step 4: Nodes and Edges Definition

In this step, we define the **nodes** and **edges** that make up the graph-based workflow. Each node represents a specific task or decision point, while the edges define the flow of logic between these nodes. The nodes include:

1. **`grade_documents`**: Determines whether the retrieved documents are relevant to the user's question. If the documents are relevant, the workflow proceeds to generate a response. If not, the workflow rewrites the query to improve its clarity or relevance.
2. **`agent_node`**: The core decision-making node. It decides whether to retrieve additional information using the retriever tool or to end the workflow if no further retrieval is needed.
3. **`rewrite_node`**: Rewrites the user's question to better align with the underlying semantic intent. This is useful when the retrieved documents are not relevant, and the system needs to refine the query for better results.
4. **`generate_node`**: Generates a final response to the user's question using the retrieved documents. This node is invoked only when the documents are deemed relevant.

In [None]:
from typing import Annotated, Literal, Sequence, Dict, Any
from typing_extensions import TypedDict
from langchain import hub
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
from langgraph.prebuilt import tools_condition

def grade_documents(state: AgentState) -> Literal["generate_node", "rewrite_node"]:
    """
    Grades the relevance of retrieved documents to the user's question.
    Returns "generate_node" if the documents are relevant, otherwise "rewrite_node".

    Args:
        state (AgentState): The current state containing messages.

    Returns:
        Literal["generate_node", "rewrite_node"]: Decision based on document relevance.
    """
    print("---CHECK RELEVANCE---")

    # Define a Pydantic model for grading relevance
    class Grade(BaseModel):
        binary_score: str = Field(description="Relevance score 'yes' or 'no'")

    # Initialize the LLM for grading
    model = ChatOpenAI(model_name="gpt-4o-mini", api_key=my_api_key, temperature=0, streaming=True)
    llm_with_tool = model.with_structured_output(Grade)

    # Define the prompt for relevance grading
    prompt = PromptTemplate(
        template="""You are a grader assessing relevance of a retrieved document to a user question. \n 
        Here is the retrieved document: \n\n {context} \n\n
        Here is the user question: {question} \n
        If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
        Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""",
        input_variables=["context", "question"],
    )

    # Create a chain to process the grading
    chain = prompt | llm_with_tool

    # Extract the last message and question from the state
    messages = state["messages"]
    last_message = messages[-1]
    question = messages[0].content
    docs = last_message.content

    # Invoke the chain to grade the documents
    scored_result = chain.invoke({"question": question, "context": docs})
    score = scored_result.binary_score

    # Return the decision based on the score
    if score == "yes":
        print("---DECISION: DOCS RELEVANT---")
        return "generate_node"
    else:
        print("---DECISION: DOCS NOT RELEVANT---")
        print(score)
        return "rewrite_node"

def agent_node(state: AgentState) -> Dict[str, Sequence[BaseMessage]]:
    """
    Invokes the agent to generate a response or decide to retrieve information.

    Args:
        state (AgentState): The current state containing messages.

    Returns:
        Dict[str, Sequence[BaseMessage]]: Updated state with the agent's response appended to messages.
    """
    print("---CALL AGENT---")
    messages = state["messages"]
    model = ChatOpenAI(model_name="gpt-4o-mini", api_key=my_api_key, temperature=0, streaming=True)
    llm_with_tool = model.bind_tools(tools)    # Bind the available tools to the model
    response = llm_with_tool.invoke(messages)  # Generate a response
    return {"messages": [response]}

def rewrite_node(state: AgentState) -> Dict[str, Sequence[BaseMessage]]:
    """
    Rewrites the user's question to improve clarity or relevance.

    Args:
        state (AgentState): The current state containing messages.

    Returns:
        Dict[str, Sequence[BaseMessage]]: Updated state with the rewritten question appended to messages.
    """
    print("---TRANSFORM QUERY---")
    messages = state["messages"]
    question = messages[0].content

    # Create a message to request a rewritten question
    msg = [
        HumanMessage(
            content=f""" \n 
    Look at the input and try to reason about the underlying semantic intent or meaning. \n 
    Here is the initial question:
    \n -------------------------------------------------------- \n
    {question} 
    \n -------------------------------------------------------- \n
    Formulate an improved question: """,
        )
    ]

    # Invoke the LLM to rewrite the question
    model = ChatOpenAI(model_name="gpt-4o-mini", api_key=my_api_key, temperature=0, streaming=True)
    response = model.invoke(msg)
    return {"messages": [response]}

def generate_node(state: AgentState) -> Dict[str, Sequence[BaseMessage]]:
    """
    Generates a response to the user's question using retrieved documents.

    Args:
        state (AgentState): The current state containing messages.

    Returns:
        Dict[str, Sequence[BaseMessage]]: Updated state with the generated response appended to messages.
    """
    print("---GENERATE---")
    messages = state["messages"]
    question = messages[0].content
    last_message = messages[-1]

    docs = last_message.content  # Retrieved documents

    # Pull the RAG prompt from the hub
    prompt = hub.pull("rlm/rag-prompt")

    # Initialize the LLM for response generation
    model = ChatOpenAI(model_name="gpt-4o-mini", api_key=my_api_key, temperature=0, streaming=True)

    # Create a RAG chain to generate the response
    rag_chain = prompt | model | StrOutputParser()

    # Invoke the chain to generate the response
    response = rag_chain.invoke({"context": docs, "question": question})
    return {"messages": [response]}

In [None]:
# Show what the prompt looks like
prompt = hub.pull("rlm/rag-prompt").pretty_print()

## Step 5: Graph Construction

Construct the graph by defining nodes (agent, retrieve, rewrite, generate) and edges (conditional logic). The graph determines the flow of the workflow.

In [None]:
from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import ToolNode

# Node for retrieving documents
retrieve_node = ToolNode([retriever_tool])

# Define the workflow graph
workflow = StateGraph(AgentState)

# Add nodes to the graph
workflow.add_node("agent_node", agent_node)        # Agent node for decision-making
workflow.add_node("retrieve_node", retrieve_node)  # Node for retrieving documents
workflow.add_node("rewrite_node", rewrite_node)    # Node for rewriting the question
workflow.add_node("generate_node", generate_node)  # Node for generating the final response

# Define edges between nodes
workflow.add_edge(START, "agent_node")             # Start with the agent node

# Use in the conditional_edge to route to the ToolNode if the last message has tool calls. 
# Otherwise, route to the end.
workflow.add_conditional_edges(
    "agent_node",
    tools_condition,               # Condition to decide whether to retrieve or end
    {
        "tools": "retrieve_node",  # If tools are needed, go to the retrieve node
        END: END,                  # Otherwise, end the workflow
    },
)
workflow.add_conditional_edges("retrieve_node", grade_documents, ["generate_node", "rewrite_node"])
workflow.add_edge("generate_node", END)            # End after generating the response
workflow.add_edge("rewrite_node", "agent_node")    # Go back to the agent after rewriting

# Compile the graph
graph = workflow.compile()

## Step 6: Visualization

Visualize the graph to understand its structure. This step is optional and requires additional dependencies.

In [None]:
from IPython.display import Image, display

# Visualize the graph (optional, requires additional dependencies)
try:
    display(Image(graph.get_graph(xray=True).draw_mermaid_png()))
except Exception:
    pass

## Step 7: Execution

Execute the graph with a user question and print the outputs from each node. The code suppresses warnings related to LangSmith API keys (if not in use) and includes a helper function to convert non-serializable objects into dictionaries for JSON serialization. The example demonstrates querying the graph with a question about prompt engineering and streaming the outputs for display.

In [None]:
# Execute the graph with an input
import json
import warnings

# Suppress LangSmith API key warning (if not using LangSmith)
warnings.filterwarnings("ignore", category=UserWarning, message="API key must be provided when using hosted LangSmith API")

# Helper function to convert non-serializable objects to dictionaries
def convert_to_serializable(obj):
    if hasattr(obj, "dict"):  # Check if the object has a .dict() method
        return obj.dict()
    elif isinstance(obj, (list, tuple)):  # Handle lists and tuples
        return [convert_to_serializable(item) for item in obj]
    elif isinstance(obj, dict):  # Handle dictionaries
        return {key: convert_to_serializable(value) for key, value in obj.items()}
    else:  # Return the object as-is if it's already serializable
        return obj

# Example 1: Question about Prompt Engineering
inputs = {
    "messages": [
        ("user", "What does Lilian Weng say about the types of agent memory?"),
    ]
}

# Stream the graph execution and print outputs
for output in graph.stream(inputs):
    for key, value in output.items():
        print(f"Output from node '{key}':")
        print("-"*80)
        # Convert non-serializable objects to dictionaries
        serializable_value = convert_to_serializable(value)
        # Print the serialized value as a JSON string
        print(json.dumps(serializable_value, indent=2))
    print("="*80)

In [None]:
# Example 2: Question about Adversarial Attacks on LLMs
inputs = {
    "messages": [
        ("user", "How can adversarial attacks be mitigated in large language models?"),
    ]
}

# Stream the graph execution and print outputs
for output in graph.stream(inputs):
    for key, value in output.items():
        print(f"Output from node '{key}':")
        print("-"*80)
        # Convert non-serializable objects to dictionaries
        serializable_value = convert_to_serializable(value)
        # Print the serialized value as a JSON string
        print(json.dumps(serializable_value, indent=2))
    print("="*80)

In [None]:
# Example 3: Question about Agent Planning
inputs = {
    "messages": [
        ("user", "What are the main components of agent planning as discussed by Lilian Weng?"),
    ]
}

# Stream the graph execution and print outputs
for output in graph.stream(inputs):
    for key, value in output.items():
        print(f"Output from node '{key}':")
        print("-"*80)
        # Convert non-serializable objects to dictionaries
        serializable_value = convert_to_serializable(value)
        # Print the serialized value as a JSON string
        print(json.dumps(serializable_value, indent=2))
    print("="*80)

## Conclusion

Agentic RAG with LangGraph provides a flexible and modular framework for building retrieval-augmented systems. By incorporating decision-making capabilities into the workflow, the system can dynamically adapt to user queries, ensuring that responses are both accurate and contextually relevant. This approach is particularly useful for applications like question-answering, where the ability to retrieve and process external information is critical. The modular design of LangGraph makes it easy to extend and customize the workflow for specific use cases, making it a valuable tool for building advanced AI systems.