# Implementing Self RAG with LangGraph

## Introduction

**Self-RAG** is an innovative strategy for **Retrieval-Augmented Generation (RAG)** that integrates `self-reflection` and `self-grading` mechanisms to enhance the accuracy and relevance of both retrieved documents and generated responses. This approach leverages the capabilities of large language models (LLMs) to autonomously assess and improve the quality of information retrieval and generation processes, ensuring that the final output is both reliable and pertinent to the user's query.

In the **Self-RAG** framework, several key decisions are systematically made to optimize the interaction between retrieval and generation:

1. **Retrieval Decision**: Determines whether additional document chunks should be retrieved based on the initial question or the current generation output. This decision helps in managing the scope of information and ensures that only relevant data is considered.

2. **Relevance Assessment of Retrieved Passages**: Evaluates each retrieved document chunk to ascertain its usefulness in addressing the user's question. By classifying documents as relevant or irrelevant, the system filters out noise and focuses on high-quality information sources.

3. **Verification of LLM Generation Against Retrieved Chunks**: Assesses whether the statements generated by the LLM are fully supported, partially supported, or lack support from the retrieved documents. This step is crucial in identifying and mitigating hallucinations, thereby maintaining factual accuracy in the generated responses.

4. **Usefulness Evaluation of Generated Responses**: Measures the overall usefulness of the LLM's generation in resolving the user's question. By scoring the response, the system ensures that the final answer is not only accurate but also effectively addresses the user's intent.

Through these self-regulatory steps, **Self-RAG** enhances the traditional RAG approach by embedding quality control directly into the retrieval and generation pipeline. This results in more trustworthy and contextually appropriate responses, making Self-RAG a robust solution for applications requiring high levels of information accuracy and reliability.

## Package Installation

This code block installs several Python libraries using `pip`, which are commonly used in building and working with language models and AI applications:

1. **langchain-openai**: A library for integrating OpenAI's language models with the LangChain framework.
2. **langchain-anthropic**: A library for integrating Anthropic's language models with LangChain.
3. **langchain_community**: A community-driven library providing additional tools and integrations for LangChain.
4. **langchain_experimental**: A library containing experimental features and extensions for LangChain.
5. **langgraph**: A library for creating and managing graphs of language model interactions.
6. **chromadb**: A vector database for storing and querying embeddings, often used in AI applications.
7. **duckduckgo_search**: A library for performing web searches using DuckDuckGo.

These libraries are essential for building advanced AI applications, particularly those involving natural language processing, retrieval-augmented generation (RAG), and agent-based systems. The `-qU` flag ensures the installations are quiet (non-verbose) and upgrade existing installations if necessary.

In [None]:
!pip install -qU langchain-openai
!pip install -qU langchain-anthropic
!pip install -qU langchain_community
!pip install -qU langchain_experimental
!pip install -qU langgraph
!pip install -qU chromadb
!pip install -qU duckduckgo_search

## Importing Libraries and Setting Up the Vector Store

This code block imports the necessary libraries and modules for the workflow. It sets up the OpenAI embeddings using a securely loaded API key from Kaggle secrets. It then defines a list of URLs to be indexed, loads the documents from these URLs, splits them into smaller chunks for better retrieval, and adds them to a Chroma vector store for efficient searching.

In [None]:
from langchain import hub
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langgraph.graph import END, StateGraph, START

from typing import List
from typing_extensions import TypedDict
from pydantic import BaseModel, Field
from kaggle_secrets import UserSecretsClient

# Load OpenAI API Key securely from Kaggle secrets
my_api_key = UserSecretsClient().get_secret("my-openai-api-key")

# Use OpenAI embeddings for vectorization
embed = OpenAIEmbeddings(model="text-embedding-3-small", api_key=my_api_key)

# Retriever
# Define URLs of blog posts to index
urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

# Load documents from URLs
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

# Split documents into smaller chunks for better retrieval
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=250, chunk_overlap=0)
doc_splits = text_splitter.split_documents(docs_list)

# Add documents to vector database (Chroma)
vectorstore = Chroma.from_documents(documents=doc_splits, collection_name="rag-chroma", embedding=embed)
retriever = vectorstore.as_retriever()

## Retrieval Grader

This code block defines a retrieval grader using a Pydantic model to assess the relevance of retrieved documents to a user's question. It initializes an OpenAI LLM to perform the grading based on a system prompt that instructs the model to provide a binary score indicating relevance.

In [None]:
# Retrieval Grader
# Define a Pydantic model to grade document relevance
class GradeDocuments(BaseModel):
    """Binary score for relevance check on retrieved documents."""

    binary_score: str = Field(description="Documents are relevant to the question, 'yes' or 'no'")

# Initialize LLM for grading
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, api_key=my_api_key)
structured_llm_grader = llm.with_structured_output(GradeDocuments)

# Define system prompt for grading relevance
system = """You are a grader assessing relevance of a retrieved document to a user question. \n 
    It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
    If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
    Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
grade_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
    ]
)

# Create a chain for grading document relevance
retrieval_grader = grade_prompt | structured_llm_grader

# Test retrieval grader with a sample question
question = "agent memory"
docs = retriever.invoke(question)
doc_txt = docs[1].page_content

result = retrieval_grader.invoke({"question": question, "document": doc_txt})
print(result)

## Generation with Retrieval-Augmented Generation (RAG) Chain

This code block sets up the generation phase using a Retrieval-Augmented Generation (RAG) chain. It pulls a RAG prompt from LangChain's hub, initializes an OpenAI LLM for generation, and creates a chain that generates answers based on the retrieved documents and the user's question.

In [None]:
# Generate
# Pull a RAG prompt from the hub
prompt = hub.pull("rlm/rag-prompt")

# Initialize LLM for generation
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0, api_key=my_api_key)

# Create a RAG chain for generating answers
rag_chain = prompt | llm | StrOutputParser()
generation = rag_chain.invoke({"context": docs, "question": question})
print(generation)

In [None]:
# Show what the prompt looks like
prompt = hub.pull("rlm/rag-prompt").pretty_print()

## Hallucination Grader

This code block defines a hallucination grader to assess whether the generated answers are grounded in the retrieved documents. It uses a Pydantic model for binary scoring and initializes an OpenAI LLM with a system prompt that instructs the model to determine if the generation is supported by the provided facts.

In [None]:
# Hallucination Grader
# Define a Pydantic model to grade hallucinations in generated answers
class GradeHallucinations(BaseModel):
    """Binary score for hallucination present in generation answer."""

    binary_score: str = Field(description="Answer is grounded in the facts, 'yes' or 'no'")

# Initialize LLM for hallucination grading
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, api_key=my_api_key)
structured_llm_grader = llm.with_structured_output(GradeHallucinations)

# Define system prompt for hallucination grading
system = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n 
     Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
hallucination_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
    ]
)

# Create a chain for hallucination grading
hallucination_grader = hallucination_prompt | structured_llm_grader
result = hallucination_grader.invoke({"documents": docs, "generation": generation})
print(result)

## Answer Grader

This code block defines an answer grader to evaluate whether the generated answer adequately addresses the user's question. It utilizes a Pydantic model for binary scoring and sets up an OpenAI LLM with a system prompt guiding the model to determine if the answer resolves the question.

In [None]:
# Answer Grader
# Define a Pydantic model to grade whether the answer addresses the question
class GradeAnswer(BaseModel):
    """Binary score to assess answer addresses question."""

    binary_score: str = Field(description="Answer addresses the question, 'yes' or 'no'")

# Initialize LLM for answer grading
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, api_key=my_api_key)
structured_llm_grader = llm.with_structured_output(GradeAnswer)

# Define system prompt for answer grading
system = """You are a grader assessing whether an answer addresses / resolves a question \n 
     Give a binary score 'yes' or 'no'. 'Yes' means that the answer resolves the question."""
answer_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
    ]
)

# Create a chain for answer grading
answer_grader = answer_prompt | structured_llm_grader
result = answer_grader.invoke({"question": question, "generation": generation})
print(result)

## Question Re-writer

This code block sets up a question re-writer to improve the user's input question for optimized retrieval from the vector store. It initializes an OpenAI LLM with a system prompt that instructs the model to rephrase the question to better capture its semantic intent.

In [None]:
# Question Re-writer
# Initialize LLM for question rewriting
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, api_key=my_api_key)

# Define system prompt for question rewriting
system = """You are a question re-writer that converts an input question to a better version that is optimized \n 
     for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""
re_write_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        (
            "human",
            "Here is the initial question: \n\n {question} \n Formulate an improved question.",
        ),
    ]
)

# Create a chain for question rewriting
question_rewriter = re_write_prompt | llm | StrOutputParser()
result = question_rewriter.invoke({"question": question})
print(result)

## Defining Graph State and Workflow Nodes

This code block defines the state and nodes for a state graph workflow using LangGraph. It specifies the structure of the graph state, implements functions for each node (retrieve, generate, grade documents, transform query), and defines the decision logic for transitioning between nodes based on the grading results.

In [None]:
from typing import List, Dict, Literal
from typing_extensions import TypedDict
from langchain.schema import Document
from IPython.display import Image, display

# Graph State
# Define the state of the graph as a TypedDict
class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: The user's question.
        generation: The LLM's generated answer.
        documents: List of retrieved documents.
    """
    question: str
    generation: str
    documents: List[Document]

# Nodes
def retrieve_node(state: GraphState) -> GraphState:
    """
    Retrieve documents relevant to the user's question.

    Args:
        state (GraphState): The current graph state.

    Returns:
        GraphState: Updated state with retrieved documents.
    """
    print("---RETRIEVE---")
    question = state["question"]

    # Retrieve documents using the retriever
    documents = retriever.invoke(question)
    return {"documents": documents, "question": question}

def generate_node(state: GraphState) -> GraphState:
    """
    Generate an answer using the RAG chain.

    Args:
        state (GraphState): The current graph state.

    Returns:
        GraphState: Updated state with generated answer.
    """
    print("---GENERATE---")
    question = state["question"]
    documents = state["documents"]

    # Generate answer using the RAG chain
    generation = rag_chain.invoke({"context": documents, "question": question})
    return {"documents": documents, "question": question, "generation": generation}

def grade_documents_node(state: GraphState) -> GraphState:
    """
    Grade the relevance of retrieved documents to the question.

    Args:
        state (GraphState): The current graph state.

    Returns:
        GraphState: Updated state with filtered relevant documents.
    """
    print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
    question = state["question"]
    documents = state["documents"]

    # Filter documents based on relevance
    filtered_docs = []
    for d in documents:
        score = retrieval_grader.invoke(
            {"question": question, "document": d.page_content}
        )
        grade = score.binary_score
        if grade == "yes":
            print("---GRADE: DOCUMENT RELEVANT---")
            filtered_docs.append(d)
        else:
            print("---GRADE: DOCUMENT NOT RELEVANT---")
            continue
    return {"documents": filtered_docs, "question": question}

def transform_query_node(state: GraphState) -> GraphState:
    """
    Transform the user's question into a better version for retrieval.

    Args:
        state (GraphState): The current graph state.

    Returns:
        GraphState: Updated state with a rephrased question.
    """
    print("---TRANSFORM QUERY---")
    question = state["question"]
    documents = state["documents"]

    # Rewrite the question for better retrieval
    better_question = question_rewriter.invoke({"question": question})
    return {"documents": documents, "question": better_question}

# Edges
def decide_to_generate(state: GraphState) -> Literal["transform_query_node", "generate_node"]:
    """
    Decide whether to generate an answer or rephrase the question.

    Args:
        state (GraphState): The current graph state.

    Returns:
        Literal["transform_query_node", "generate_node"]: Decision for the next node.
    """
    print("---ASSESS GRADED DOCUMENTS---")
    filtered_documents = state["documents"]

    if not filtered_documents:
        # If no relevant documents, rephrase the question
        print("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---")
        return "transform_query_node"
    else:
        # If relevant documents exist, generate an answer
        print("---DECISION: GENERATE---")
        return "generate_node"

def decide_generation_useful(state: GraphState) -> Literal["generate_node", "transform_query_node", END]:
    """
    Decide whether the generated answer is useful or needs to be regenerated.

    Args:
        state (GraphState): The current graph state.

    Returns:
        Literal["generate_node", "transform_query_node", END]: Decision for the next node.
    """
    print("---CHECK HALLUCINATIONS---")
    question = state["question"]
    documents = state["documents"]
    generation = state["generation"]

    # Check if the generation is grounded in the documents
    score = hallucination_grader.invoke({"documents": documents, "generation": generation})
    grade = score.binary_score

    if grade == "yes":
        print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
        # Check if the generation addresses the question
        print("---GRADE GENERATION vs QUESTION---")
        score = answer_grader.invoke({"question": question, "generation": generation})
        grade = score.binary_score
        if grade == "yes":
            print("---DECISION: GENERATION ADDRESSES QUESTION---")
            return END
        else:
            print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
            return "transform_query_node"
    else:
        print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
        return "generate_node"

## Building the Workflow Graph

This code block constructs the workflow graph using the previously defined nodes and decision functions. It establishes the flow of the workflow by adding nodes and defining the edges between them. Additionally, it attempts to visualize the graph if the necessary dependencies are available.

In [None]:
# Build Graph
workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("retrieve_node", retrieve_node)                # Retrieve documents
workflow.add_node("grade_documents_node", grade_documents_node)  # Grade document relevance
workflow.add_node("generate_node", generate_node)                # Generate answer
workflow.add_node("transform_query_node", transform_query_node)  # Transform query

# Build graph edges
workflow.add_edge(START, "retrieve_node")
workflow.add_edge("retrieve_node", "grade_documents_node")
workflow.add_conditional_edges("grade_documents_node", decide_to_generate, ["transform_query_node", "generate_node"])
workflow.add_edge("transform_query_node", "retrieve_node")
workflow.add_conditional_edges("generate_node", decide_generation_useful, ["generate_node", "transform_query_node", END])

# Compile the workflow
app = workflow.compile()

# Visualize the graph (optional, requires additional dependencies)
try:
    display(Image(app.get_graph(xray=True).draw_mermaid_png()))
except Exception:
    pass

## Helper Functions for Serialization

This code block includes helper functions to suppress specific warnings and convert non-serializable objects into serializable formats (such as dictionaries). This ensures that the outputs from the workflow nodes can be easily serialized and displayed as JSON.

In [None]:
import json
import warnings

# Suppress LangSmith API key warning (if not using LangSmith)
warnings.filterwarnings("ignore", category=UserWarning, message="API key must be provided when using hosted LangSmith API")

# Helper function to convert non-serializable objects to dictionaries
def convert_to_serializable(obj):
    if hasattr(obj, "dict"):              # Check if the object has a .dict() method
        return obj.dict()
    elif isinstance(obj, (list, tuple)):  # Handle lists and tuples
        return [convert_to_serializable(item) for item in obj]
    elif isinstance(obj, dict):           # Handle dictionaries
        return {key: convert_to_serializable(value) for key, value in obj.items()}
    else:                                 # Return the object as-is if it's already serializable
        return obj

In [None]:
# Example 1: Question about agent memory
inputs = {"question": "Explain how the different types of agent memory work?"}
for output in app.stream(inputs):
    for key, value in output.items():
        print(f"Output from node '{key}':")
        print("-"*80)
        # Convert non-serializable objects to dictionaries
        serializable_value = convert_to_serializable(value)
        # Print the serialized value as a JSON string
        print(json.dumps(serializable_value, indent=2))
    print("="*80)

# Print the final generation
print(f"Final Generation:\n{value['generation']}\n")

In [None]:
# Example 2: Question about adversarial attacks on LLMs
inputs = {"question": "How can adversarial attacks affect large language models?"}
for output in app.stream(inputs):
    for key, value in output.items():
        print(f"Output from node '{key}':")
        print("-"*80)
        # Convert non-serializable objects to dictionaries
        serializable_value = convert_to_serializable(value)
        # Print the serialized value as a JSON string
        print(json.dumps(serializable_value, indent=2))
    print("="*80)

# Print the final generation
print(f"Final Generation:\n{value['generation']}\n")

In [None]:
# Example 3: Question about agentic agents
inputs = {"question": "What are the key components of an agentic agent?"}
for output in app.stream(inputs):
    for key, value in output.items():
        print(f"Output from node '{key}':")
        print("-"*80)
        # Convert non-serializable objects to dictionaries
        serializable_value = convert_to_serializable(value)
        # Print the serialized value as a JSON string
        print(json.dumps(serializable_value, indent=2))
    print("="*80)

# Print the final generation
print(f"Final Generation:\n{value['generation']}\n")

## Conclusion

The provided code sets up a comprehensive workflow for processing user questions using a Retrieval-Augmented Generation (RAG) approach. It involves installing necessary packages, loading and indexing documents, grading the relevance of retrieved documents, generating answers, checking for hallucinations, and ensuring that the final answer addresses the user's question. The workflow is structured as a state graph, allowing for flexible decision-making and iterative improvements to the query and generation process.