# Implementing Corrective RAG: Improving Document Relevance in Language Models

## Introduction

In the rapidly evolving landscape of natural language processing, **Retrieval-Augmented Generation (RAG)** has emerged as a powerful framework for enhancing the capabilities of language models by integrating external knowledge sources. Traditional RAG systems retrieve relevant documents to inform and improve the generation of responses, but they often lack mechanisms to assess the quality and relevance of the retrieved information. This limitation can lead to the generation of inaccurate or irrelevant answers, especially when the retrieved documents do not sufficiently address the user's query.

To address this challenge, the concept of **Corrective RAG (CRAG)** has been introduced. `CRAG` enhances the RAG framework by incorporating self-reflection and self-grading mechanisms to evaluate the relevance of retrieved documents. By assessing each document's pertinence to the user's question, `CRAG` ensures that only the most relevant information is utilized in generating responses. This approach not only improves the accuracy of the generated answers but also reduces the reliance on additional data sources.

In this example, we implement a simplified version of the `CRAG` strategy using LangGraph and LangChain. While the full `CRAG` methodology involves steps such as knowledge refinement and partitioning documents into "knowledge strips," this implementation focuses on the core idea of grading document relevance and supplementing retrieval with web searches when necessary. By leveraging tools like DuckDuckGo for web searches and incorporating query rewriting for optimized searches, this workflow demonstrates how `CRAG` principles can be effectively applied to enhance RAG systems.

## Installing Required Packages

This code block installs the necessary Python packages quietly (`-q`) and upgrades them if they are already installed (`-U`). The packages include various components of `LangChain`, `LangGraph`, `ChromaDB`, and `DuckDuckGo` search tools, which are essential for building and running the Retrieval-Augmented Generation (RAG) workflow.

In [None]:
!pip install -qU langchain-openai
!pip install -qU langchain-anthropic
!pip install -qU langchain_community
!pip install -qU langchain_experimental
!pip install -qU langgraph
!pip install -qU chromadb
!pip install -qU duckduckgo_search

## Loading and Preparing Documents

This block imports necessary modules and sets up the environment to load, split, and vectorize documents from specified URLs. It securely retrieves the OpenAI API key from Kaggle secrets, uses OpenAI embeddings for vectorization, loads documents from the provided URLs, splits them into smaller chunks for efficient processing, and stores them in a Chroma vector database for later retrieval.

In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from kaggle_secrets import UserSecretsClient

# Load OpenAI API Key securely from Kaggle secrets
my_api_key = UserSecretsClient().get_secret("my-openai-api-key")

# Use OpenAI embeddings for vectorization
embed = OpenAIEmbeddings(model="text-embedding-3-small", api_key=my_api_key)

# URLs of the blog posts to index
urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

# Load documents from the URLs
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

# Split documents into smaller chunks for processing
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=250, chunk_overlap=0)
doc_splits = text_splitter.split_documents(docs_list)

# Add documents to Chroma vector database
vectorstore = Chroma.from_documents(documents=doc_splits, collection_name="rag-chroma", embedding=embed)
retriever = vectorstore.as_retriever()

## Grading Document Relevance

This section defines a data model for grading the relevance of retrieved documents to a user’s question. It initializes a language model (LLM) with function-calling capabilities to perform the grading. A prompt is constructed to instruct the LLM on how to assess the relevance of each document by assigning a binary score ('yes' or 'no'). A retrieval grader chain is then created and tested with a sample question to demonstrate its functionality.

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field

# Data model for grading document relevance
class GradeDocuments(BaseModel):
    """Binary score for relevance check on retrieved documents."""

    binary_score: str = Field(description="Documents are relevant to the question, 'yes' or 'no'")

# LLM with function call for grading documents
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, api_key=my_api_key)
structured_llm_grader = llm.with_structured_output(GradeDocuments)

# Prompt for grading document relevance
system = """You are a grader assessing relevance of a retrieved document to a user question. \n 
    If the document contains keyword(s) or semantic meaning related to the question, grade it as relevant. \n
    Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
grade_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
    ]
)

# Chain for grading document relevance
retrieval_grader = grade_prompt | structured_llm_grader

# Test retrieval grader with a sample question
question = "agent memory"
docs = retriever.invoke(question)
doc_txt = docs[1].page_content

result = retrieval_grader.invoke({"question": question, "document": doc_txt})
print(result)

## Setting Up the RAG Chain for Answer Generation

This block sets up the Retrieval-Augmented Generation (RAG) chain, which is responsible for generating answers based on the retrieved documents. It pulls a predefined RAG prompt from LangChain Hub, initializes an LLM for generating responses, and constructs a chain that processes the context and question to produce the final answer. A sample question is then run through the RAG chain to demonstrate its operation.

In [None]:
from langchain import hub
from langchain_core.output_parsers import StrOutputParser

# Pull the RAG prompt from the hub
prompt = hub.pull("rlm/rag-prompt")

# LLM for generating answers
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0, api_key=my_api_key)

# Chain for generating answers using RAG
rag_chain = prompt | llm | StrOutputParser()

# Run the RAG chain with the sample question
generation = rag_chain.invoke({"context": docs, "question": question})
print(generation)

In [None]:
# Show what the prompt looks like
prompt = hub.pull("rlm/rag-prompt").pretty_print()

## Implementing the Question Rewriter

This section creates a mechanism to improve user questions for better search results. It defines an LLM that rephrases input questions to be more optimized for web searches by understanding the underlying semantic intent. A prompt is crafted to guide the LLM in transforming the questions, and a chain is established to process and rephrase the questions accordingly.

In [None]:
# Question Re-writer
# LLM for re-writing questions
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, api_key=my_api_key)

# Prompt for re-writing questions
system = """You a question re-writer that converts an input question to a better version that is optimized \n 
     for web search. Look at the input and try to reason about the underlying semantic intent / meaning."""
re_write_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "Here is the initial question: \n\n {question} \n Formulate an improved question."),
    ]
)

# Chain for re-writing questions
question_rewriter = re_write_prompt | llm | StrOutputParser()
result = question_rewriter.invoke({"question": question})
print(result)

## Initializing the Web Search Tool

This block sets up the DuckDuckGo search tool from LangChain Community tools. This tool is used to perform web searches when the initially retrieved documents are deemed irrelevant, ensuring that the system can fetch additional information from the web to answer the user's query effectively.

In [None]:
from langchain_community.tools import DuckDuckGoSearchRun

# Initialize DuckDuckGo search tool
web_search_tool = DuckDuckGoSearchRun()

## Defining the Graph State and Workflow Nodes

Here, the graph state and workflow nodes are defined to manage the flow of operations in the RAG system. A `GraphState` TypedDict outlines the necessary attributes such as the question, generated answer, web search flag, and retrieved documents. Functions for each node in the workflow—retrieving documents, generating answers, grading document relevance, transforming queries, and performing web searches—are implemented. Additionally, a decision function determines the next step based on the relevance of the retrieved documents.

In [None]:
from typing import List, Literal, Dict, Any
from typing_extensions import TypedDict
from langchain.schema import Document

class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: The current question.
        generation: The LLM generation (answer).
        web_search: Whether to perform a web search. ("Yes" or "No")
        documents: List of retrieved documents.
    """
    question: str
    generation: str
    web_search: Literal["Yes", "No"]
    documents: List[Document]

def retrieve_node(state: GraphState) -> GraphState:
    """
    Retrieve documents relevant to the question.

    Args:
        state (GraphState): The current graph state.

    Returns:
        Dict[str, Any]: Updated state with retrieved documents.
    """
    print("---RETRIEVE---")
    question = state["question"]

    # Retrieve documents using the retriever
    documents = retriever.invoke(question)
    return {"documents": documents, "question": question}

def generate_node(state: GraphState) -> GraphState:
    """
    Generate an answer using the RAG chain.

    Args:
        state (GraphState): The current graph state.

    Returns:
        Dict[str, Any]: Updated state with the generated answer.
    """
    print("---GENERATE---")
    question = state["question"]
    documents = state["documents"]

    # Generate answer using the RAG chain
    generation = rag_chain.invoke({"context": documents, "question": question})
    return {"documents": documents, "question": question, "generation": generation}

def grade_documents_node(state: GraphState) -> GraphState:
    """
    Grade the relevance of retrieved documents to the question.

    Args:
        state (GraphState): The current graph state.

    Returns:
        Dict[str, Any]: Updated state with filtered relevant documents.
    """
    print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
    question = state["question"]
    documents = state["documents"]

    # Score each document for relevance
    filtered_docs = []
    web_search: Literal["Yes", "No"] = "No"
    for d in documents:
        score = retrieval_grader.invoke({"question": question, "document": d.page_content})
        grade = score.binary_score
        if grade == "yes":
            print("---GRADE: DOCUMENT RELEVANT---")
            filtered_docs.append(d)
        else:
            print("---GRADE: DOCUMENT NOT RELEVANT---")
            web_search = "Yes"
            continue
    return {"documents": filtered_docs, "question": question, "web_search": web_search}

def transform_query_node(state: GraphState) -> GraphState:
    """
    Transform the query to produce a better version for web search.

    Args:
        state (GraphState): The current graph state.

    Returns:
        Dict[str, Any]: Updated state with the re-phrased question.
    """
    print("---TRANSFORM QUERY---")
    question = state["question"]
    documents = state["documents"]

    # Re-write the question for better search results
    better_question = question_rewriter.invoke({"question": question})
    return {"documents": documents, "question": better_question}

def web_search_node(state: GraphState) -> GraphState:
    """
    Perform a web search based on the re-phrased question.

    Args:
        state (GraphState): The current graph state.

    Returns:
        Dict[str, Any]: Updated state with appended web search results.
    """
    print("---WEB SEARCH---")
    question = state["question"]
    documents = state["documents"]

    # Perform web search using DuckDuckGo
    search_results = web_search_tool.run(question)
    web_results = Document(page_content=search_results)
    documents.append(web_results)

    return {"documents": documents, "question": question}

def decide_to_generate(state: GraphState) -> Literal["transform_query_node", "generate_node"]:
    """
    Decide whether to generate an answer or re-generate a question.

    Args:
        state (GraphState): The current graph state.

    Returns:
        Literal["transform_query_node", "generate_node"]: Decision for the next node to call.
    """
    print("---ASSESS GRADED DOCUMENTS---")
    web_search = state["web_search"]

    if web_search == "Yes":
        # If documents are not relevant, re-generate the question
        print("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---")
        return "transform_query_node"
    else:
        # If documents are relevant, generate the answer
        print("---DECISION: GENERATE---")
        return "generate_node"

## Compiling and Visualizing the Workflow Graph

This block constructs the workflow graph by adding the defined nodes and establishing the connections (edges) between them. It uses LangGraph’s `StateGraph` to manage the workflow, specifying the start and end points. Additionally, it attempts to visualize the graph using Mermaid syntax, which provides a graphical representation of the workflow. Visualization is optional and may require additional dependencies.

In [None]:
from langgraph.graph import END, StateGraph, START
from IPython.display import Image, display

# Initialize the workflow graph
workflow = StateGraph(GraphState)

# Define the nodes in the workflow
workflow.add_node("retrieve_node", retrieve_node)                # Retrieve documents
workflow.add_node("grade_documents_node", grade_documents_node)  # Grade documents
workflow.add_node("generate_node", generate_node)                # Generate answer
workflow.add_node("transform_query_node", transform_query_node)  # Transform query
workflow.add_node("web_search_node", web_search_node)            # Web search

# Build the graph edges
workflow.add_edge(START, "retrieve_node")
workflow.add_edge("retrieve_node", "grade_documents_node")
workflow.add_conditional_edges("grade_documents_node", decide_to_generate, ["transform_query_node", "generate_node"])
workflow.add_edge("transform_query_node", "web_search_node")
workflow.add_edge("web_search_node", "generate_node")
workflow.add_edge("generate_node", END)

# Compile the workflow
app = workflow.compile()

# Visualize the graph (optional, requires additional dependencies)
try:
    display(Image(app.get_graph(xray=True).draw_mermaid_png()))
except Exception:
    pass

## Executing the Workflow with Examples

This section demonstrates how to execute the compiled workflow by providing an input question. It includes a helper function to convert non-serializable objects into JSON-serializable formats for easier output visualization. The workflow is run with a sample question, and the outputs from each node are printed in a structured JSON format.

In [None]:
import json
import warnings

# Suppress LangSmith API key warning (if not using LangSmith)
warnings.filterwarnings("ignore", category=UserWarning, message="API key must be provided when using hosted LangSmith API")

# Helper function to convert non-serializable objects to dictionaries
def convert_to_serializable(obj):
    if hasattr(obj, "dict"):              # Check if the object has a .dict() method
        return obj.dict()
    elif isinstance(obj, (list, tuple)):  # Handle lists and tuples
        return [convert_to_serializable(item) for item in obj]
    elif isinstance(obj, dict):           # Handle dictionaries
        return {key: convert_to_serializable(value) for key, value in obj.items()}
    else:                                 # Return the object as-is if it's already serializable
        return obj

In [None]:
# Example 1: Types of Agent Memory
inputs = {"question": "What are the types of agent memory?"}
for output in app.stream(inputs):
    for key, value in output.items():
        print(f"Output from node '{key}':")
        print("-"*80)
        # Convert non-serializable objects to dictionaries
        serializable_value = convert_to_serializable(value)
        # Print the serialized value as a JSON string
        print(json.dumps(serializable_value, indent=2))
    print("="*80)

In [None]:
# Example 2: Prompt Engineering Techniques
inputs = {"question": "What are some prompt engineering techniques?"}
for output in app.stream(inputs):
    for key, value in output.items():
        print(f"Output from node '{key}':")
        print("-"*80)
        # Convert non-serializable objects to dictionaries
        serializable_value = convert_to_serializable(value)
        # Print the serialized value as a JSON string
        print(json.dumps(serializable_value, indent=2))
    print("="*80)

In [None]:
# Example 3: Adversarial Attacks on LLMs
inputs = {"question": "What are adversarial attacks on large language models?"}
for output in app.stream(inputs):
    for key, value in output.items():
        print(f"Output from node '{key}':")
        print("-"*80)
        # Convert non-serializable objects to dictionaries
        serializable_value = convert_to_serializable(value)
        # Print the serialized value as a JSON string
        print(json.dumps(serializable_value, indent=2))
    print("="*80)

## Conclusion

The implementation presented showcases a foundational approach to integrating **Corrective RAG (CRAG)** principles within a retrieval-augmented generation framework using LangGraph and LangChain. By introducing a grading mechanism to assess the relevance of retrieved documents, the system ensures that only pertinent information contributes to the generation of responses, thereby enhancing the overall accuracy and reliability of the output.

While this example omits the knowledge refinement phase for simplicity, it lays the groundwork for more sophisticated enhancements, such as partitioning documents into knowledge strips and implementing deeper self-reflection capabilities. Additionally, the incorporation of web search supplementation through DuckDuckGo demonstrates the system's ability to dynamically seek additional information when initial retrievals fall short, ensuring comprehensive and accurate answers to user queries.

Moving forward, further refinements could include integrating advanced knowledge refinement techniques, expanding the range of data sources, and enhancing the grading criteria to capture more nuanced aspects of relevance. These enhancements would elevate the system's performance, making it a robust tool for a wide array of applications that demand precise and contextually appropriate language generation.
