# Adaptive RAG

- Author: [The LangChain Open Tutorial team](https://github.com/langchainopentutorial)
- Design:
- Peer Review:
- This is a part of [LangChain Open Tutorial](https://github.com/LangChain-OpenTutorial/LangChain-OpenTutorial)

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LangChain-OpenTutorial/LangChain-OpenTutorial/blob/main/99-TEMPLATE/00-BASE-TEMPLATE-EXAMPLE.ipynb) [![Open in GitHub](https://img.shields.io/badge/Open%20in%20GitHub-181717?style=flat-square&logo=github&logoColor=white)](https://github.com/LangChain-OpenTutorial/LangChain-OpenTutorial/blob/main/99-TEMPLATE/00-BASE-TEMPLATE-EXAMPLE.ipynb)

## Overview

This tutorial covers the implementation of Adaptive Retrieval-Augmented Generation (Adaptive RAG).

Adaptive RAG is a strategy that combines query analysis and active/self-modifying RAG to retrieve and generate information from diverse data sources.

In this tutorial, we use LangGraph to implement routing between web browsing and self-modifying RAGs.

![adaptive-rag](./assets/langgraph-adaptive-rag.png)

**Adaptive RAG** ​​is a strategy of **RAG**, combining Query Construction and Self-Reflective RAG.

[Thesis: Adaptive-RAG: Learning to Adapt Retrieval-Augmented Large Language Models through Question Complexity](https://arxiv.org/abs/2403.14403) performs the following routing through query analysis:

- `No Retrieval`
- `Single-shot RAG`
- `Iterative RAG`

In this tutorial, we implement an example using LangGraph to perform the following routing:

- **Web Search**: Used for questions related to latest events
- **Self-Reflective RAG**: Used for questions related to indexes

### Table of Contents

- [Overview](#overview)
- [Environment Setup](#environment-setup)
- [Create a basic PDF-based Retrieval Chain](#create-a-basic-pdf-based-retrieval-chain)
- [Query routing and document evaluation](#query-routing-and-document-evaluation)
- [Tools](#tools)
- [Graph Construction](#graph-construction)
- [Define Graph Flows](#define-graph-flows)
- [Define Nodes](#define-nodes)
- [Graph Construction](#graph-construction)
- [Execute Graph](#execute-graph)

### References

- [LangChain: Query Construction](https://blog.langchain.dev/query-construction/)
- [LangGraph: Self-Reflective RAG](https://blog.langchain.dev/agentic-rag-with-langgraph/)
- [Adaptive-RAG: Learning to Adapt Retrieval-Augmented Large Language Models through Question Complexity](https://arxiv.org/abs/2403.14403)
----

## Environment Setup

Set up the environment. You may refer to [Environment Setup](https://wikidocs.net/257836) for more details.

**[Note]**
- `langchain-opentutorial` is a package that provides a set of easy-to-use environment setup, useful functions and utilities for tutorials. 
- You can checkout the [`langchain-opentutorial`](https://github.com/LangChain-OpenTutorial/langchain-opentutorial-pypi) for more details.

In [1]:
%%capture --no-stderr
!pip install langchain-opentutorial

In [3]:
# Install required packages
from langchain_opentutorial import package

package.install(
    [
        "langsmith",
        "langchain",
        "langchain_core",
        "langchain-anthropic",
        "langchain_community",
        "langchain_text_splitters",
        "langchain_openai",
    ],
    verbose=False,
    upgrade=False,
)


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [4]:
# Set environment variables
from langchain_opentutorial import set_env

set_env(
    {
        "OPENAI_API_KEY": "",
        "LANGCHAIN_API_KEY": "",
        "LANGCHAIN_TRACING_V2": "true",
        "LANGCHAIN_ENDPOINT": "https://api.smith.langchain.com",
        "LANGCHAIN_PROJECT": "Adaptive-RAG",  # Please set it the same as title
    }
)

Environment variables have been set successfully.


You can alternatively set API keys such as `OPENAI_API_KEY` in a `.env` file and load them.

**[Note]** This is not necessary if you've already set the required API keys in previous steps.

In [None]:
# Load API keys from .env file
from dotenv import load_dotenv

load_dotenv(override=True)

## Reference (related to image file name)

When writing a tutorial file, there are cases where images are added to `assets` and added as markdown.

At this time, we are providing a guide to ensure uniformity in image file names.

**Image file name**
1. All image file names should be written in **lowercase English letters**.
2. There should be no spaces in the image file. Replace spaces with `-` hyphens.

jupyter notebook file name + image title + number if necessary (01, 02, 03, ...)

example)
In case of `10-LangGraph-Self-RAG.ipynb`

Image file name: 
- `10-langgraph-self-rag-flow-explanation.png`: OK
- `10-langgraph-self-rag-flow-explanation-01.png`: OK
- `10-langgraph-self-rag-flow-explanation-02.png`: OK

## Create a basic PDF-based Retrieval Chain

Here, we create a Retrieval Chain based on a PDF document. This is the Retrieval Chain with the simplest structure.

However, in LangGraph, Retirever and Chain are created separately. Only then can detailed processing be performed for each node.

**reference**
- As this was covered in the previous tutorial, detailed explanation will be omitted.

In [4]:
from rag.pdf import PDFRetrievalChain

# Load the PDF document.
pdf = PDFRetrievalChain(["data/SPRI_AI_Brief_December 2023_F.pdf"]).create_chain()

# create retriever
pdf_retriever = pdf.retriever

# create chain
pdf_chain = pdf.chain

## Query routing and document evaluation

In this step, **query routing** and **document evaluation** are performed. This process is an important part of **Adaptive RAG**, contributing to efficient information retrieval and creation.

- **Query Routing**: Analyzes user queries and routes them to appropriate information sources. This allows you to set the optimal search path for the purpose of your query.
- **Document Evaluation**: Evaluates the quality and relevance of searched documents to improve the accuracy of the final results. 

This step supports the core functionality of **Adaptive RAG** ​​and aims to provide accurate and reliable information.

In [5]:
from typing import Literal

from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_teddynote.models import get_model_name, LLMs

# Get latest LLM model name
MODEL_NAME = get_model_name(LLMs.GPT4)


# Data model that routes user queries to the most relevant data sources
class RouteQuery(BaseModel):
    """Route a user query to the most relevant datasource."""

    # Literal type field for data source selection
    datasource: Literal["vectorstore", "web_search"] = Field(
        ...,
        description="Given a user question choose to route it to web search or a vectorstore.",
    )


# Generate structured output through LLM initialization and function calls
llm = ChatOpenAI(model=MODEL_NAME, temperature=0)
structured_llm_router = llm.with_structured_output(RouteQuery)

# Create prompt templates including system messages and user questions
system = """You are an expert at routing a user question to a vectorstore or web search.
The vectorstore contains documents related to DEC 2023 AI Brief Report(SPRI) with Samsung Gause, Anthropic, etc.
Use the vectorstore for questions on these topics. Otherwise, use web-search."""

# Create a prompt template for routing
route_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "{question}"),
    ]
)

# Create a question router by combining the prompt template and structured LLM router
question_router = route_prompt | structured_llm_router

Next, we will test the query routing results and check the results.

In [None]:
# Questions requiring document search
print(
    question_router.invoke(
        {"question": "What is the name of the generative AI created by Samsung Electronics in AI Brief?"}
    )
)

In [None]:
# Questions that require web search
print(question_router.invoke({"question": "Find the best dim sum restaurant in Pangyo"}))

### Retrieval Grader

About the search evaluator...

In [8]:
from pydantic import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate


# Define data model for document evaluation
class GradeDocuments(BaseModel):
    """Binary score for relevance check on retrieved documents."""

    binary_score: str = Field(
        description="Documents are relevant to the question, 'yes' or 'no'"
    )


# Generate structured output through LLM initialization and function calls
llm = ChatOpenAI(model=MODEL_NAME, temperature=0)
structured_llm_grader = llm.with_structured_output(GradeDocuments)

# Create prompt templates including system messages and user questions
system = """You are a grader assessing relevance of a retrieved document to a user question. \n 
    If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
    It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
    Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""

grade_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
    ]
)

# Create a document search result evaluator
retrieval_grader = grade_prompt | structured_llm_grader

Evaluate the **document search result** using the `retrieval_grader` you created.

In [9]:
# User question settings
question = "What is the name of the generative AI created by Samsung Electronics?"

# Search related documents for your question
docs = pdf_retriever.invoke(question)

In [None]:
# Get the contents of the searched document
retrieved_doc = docs[1].page_content

# Print evaluation results
print(retrieval_grader.invoke({"question": question, "document": retrieved_doc}))

In [11]:
# Filtering code example
filtered_docs = []


for doc in docs:
   # Check document evaluation results
    result = retrieval_grader.invoke(
        {
            "question": question,
            "document": doc.page_content,
        }
    )
    # Filter only relevant documents
    if result.binary_score == "yes":
        filtered_docs.append(doc)

### Create a RAG chain to generate answers

In [12]:
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI

# Import prompts from LangChain Hub (RAG prompts can be freely modified)
prompt = hub.pull("teddynote/rag-prompt")

# Initialize LLM
llm = ChatOpenAI(model_name=MODEL_NAME, temperature=0)


# Document formatting function
def format_docs(docs):
    return "\n\n".join(
        [
            f'<document><content>{doc.page_content}</content><source>{doc.metadata["source"]}</source><page>{doc.metadata["page"]+1}</page></document>'
            for doc in docs
        ]
    )


# Create RAG chain
rag_chain = prompt | llm | StrOutputParser()

Now we generate the answer by passing the question to the `rag_chain` we created.

In [None]:
# Pass questions to the RAG chain to generate answers
generation = rag_chain.invoke({"context": format_docs(docs), "question": question})
print(generation)

### Added Hallucination checker for answers

In [14]:
# Define data model for hallucination check
class GradeHallucinations(BaseModel):
    """Binary score for hallucination present in generation answer."""

    binary_score: str = Field(
        description="Answer is grounded in the facts, 'yes' or 'no'"
    )


# LLM initialization through function call
llm = ChatOpenAI(model=MODEL_NAME, temperature=0)
structured_llm_grader = llm.with_structured_output(GradeHallucinations)

# Prompt settings
system = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n 
    Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""

# Create prompt template
hallucination_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
    ]
)

# Create a hallucination evaluator
hallucination_grader = hallucination_prompt | structured_llm_grader

Use the `hallucination_grader` you created to evaluate whether the generated answers are hallucinations.

In [None]:
# Use the evaluator to evaluate whether the generated answers are hallucinatory
hallucination_grader.invoke({"documents": docs, "generation": generation})

In [16]:
class GradeAnswer(BaseModel):
    """Binary scoring to evaluate the appropriateness of answers to questions"""

    binary_score: str = Field(
        description="Indicate 'yes' or 'no' whether the answer solves the question"
    )


# LLM initialization through function call
llm = ChatOpenAI(model=MODEL_NAME, temperature=0)
structured_llm_grader = llm.with_structured_output(GradeAnswer)

# Prompt settings
system = """You are a grader assessing whether an answer addresses / resolves a question \n 
     Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question."""
answer_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
    ]
)

# Create an answer evaluator by combining a prompt template and a structured LLM evaluator
answer_grader = answer_prompt | structured_llm_grader

In [None]:
# Use the evaluator to evaluate whether the generated answer solves the question
answer_grader.invoke({"question": question, "generation": generation})

### Query Rewriter

In [18]:
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

# Initialize LLM
llm = ChatOpenAI(model=MODEL_NAME, temperature=0)

# Definition of Query Rewriter prompt (can be freely modified)
system = """You a question re-writer that converts an input question to a better version that is optimized \n 
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""

# Create a Query Rewriter prompt template
re_write_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        (
            "human",
            "Here is the initial question: \n\n {question} \n Formulate an improved question.",
        ),
    ]
)

# Create Query Rewriter
question_rewriter = re_write_prompt | llm | StrOutputParser()

Create an improved question by passing the question to the created `question_rewriter`.

In [None]:
# Generate improved questions by passing the questions to the question rewriter
question_rewriter.invoke({"question": question})

##Tools

### Web search tools

The **Web Search Tool** is an important component of **Adaptive RAG** ​​and is used to retrieve up-to-date information. This tool helps users get quick and accurate answers to questions related to current events.

- **Settings**: Set up your web search tools so they are ready to search for the latest information.
- **Perform Search**: Search the web for relevant information based on your query.
- **Result Analysis**: Analyzes search results to provide information most appropriate to the user's question.

In [20]:
from langchain_teddynote.tools.tavily import TavilySearch

# Create a web search tool
web_search_tool = TavilySearch(max_results=3)

Run the web search tool and check the results.

In [None]:
# Call web search tool
result = web_search_tool.search("Please tell me the Teddy Note Wikidocs LangChain tutorial URL")
print(result)

In [None]:
# Check the first result of web search results
result[0]

## Graph Construction

### Defining graph states

In [23]:
from typing import List
from typing_extensions import TypedDict, Annotated


# Define the state of the graph
class GraphState(TypedDict):
    """
    A data model representing the state of the graph

    Attributes:
        question: question
        generation: LLM generated answers
        documents: document list
    """

    question: Annotated[str, "User question"]
    generation: Annotated[str, "LLM generated answer"]
    documents: Annotated[List[str], "List of documents"]

## Define Graph Flows

Clarify how **Adaptive RAG** ​​works by defining **Graph Flow**. This step establishes the states and transitions of the graph to increase the efficiency of query processing.

- **State Definition**: Track the progress of a query by clearly defining each state in the graph.
- **Set Transitions**: Set transitions between states to ensure queries follow the appropriate path.
- **Flow Optimization**: Optimize the flow of the graph to improve the accuracy of information retrieval and creation.

### Define Nodes

Define the nodes to utilize.

- `retrieve`: document retrieval node
- `generate`: answer generation node
- `grade_documents`: document relevance evaluation node
- `transform_query`: question rewrite node
- `web_search`: Web search node
- `route_question`: question routing node
- `decide_to_generate`: answer generation decision node
- `hallucination_check`: hallucination evaluation node

In [24]:
from langchain_core.documents import Document


# Document search node
def retrieve(state):
    print("==== [RETRIEVE] ====")
    question = state["question"]

    # Perform document search
    documents = pdf_retriever.invoke(question)
    return {"documents": documents}


# Answer generation node
def generate(state):
    print("==== [GENERATE] ====")
    # Get questions and document search results
    question = state["question"]
    documents = state["documents"]

    # Generate RAG answer
    generation = rag_chain.invoke({"context": documents, "question": question})
    return {"generation": generation}


# Document relevance evaluation node
def grade_documents(state):
    print("==== [CHECK DOCUMENT RELEVANCE TO QUESTION] ====")
    # Get questions and document search results
    question = state["question"]
    documents = state["documents"]

    # Calculate relevance score for each document
    filtered_docs = []
    for d in documents:
        score = retrieval_grader.invoke(
            {"question": question, "document": d.page_content}
        )
        grade = score.binary_score
        if grade == "yes":
            print("---GRADE: DOCUMENT RELEVANT---")
            # Add relevant documents
            filtered_docs.append(d)
        else:
            # Skip irrelevant documents
            print("---GRADE: DOCUMENT NOT RELEVANT---")
            continue
    return {"documents": filtered_docs}


# Question rewrite node
def transform_query(state):
    print("==== [TRANSFORM QUERY] ====")
    # Get questions and document search results
    question = state["question"]
    documents = state["documents"]

    # Rewrite the question
    better_question = question_rewriter.invoke({"question": question})
    return {"question": better_question}


# web search node
def web_search(state):
    print("==== [WEB SEARCH] ====")
    # Get questions and document search results
    question = state["question"]

    # Perform a web search
    web_results = web_search_tool.invoke({"query": question})
    web_results_docs = [
        Document(
            page_content=web_result["content"],
            metadata={"source": web_result["url"]},
        )
        for web_result in web_results
    ]

    return {"documents": web_results_docs}


# Question routing node
def route_question(state):
    print("==== [ROUTE QUESTION] ====")
    # Get questions
    question = state["question"]
    # Question routing
    source = question_router.invoke({"question": question})
    # Node routing based on question routing results
    if source.datasource == "web_search":
        print("==== [ROUTE QUESTION TO WEB SEARCH] ====")
        return "web_search"
    elif source.datasource == "vectorstore":
        print("==== [ROUTE QUESTION TO VECTORSTORE] ====")
        return "vectorstore"


# Document relevance evaluation node
def decide_to_generate(state):
    print("==== [DECISION TO GENERATE] ====")
    # Get document search results
    filtered_documents = state["documents"]

    if not filtered_documents:
        # Rewrite question if all documents are irrelevant
        print(
            "==== [DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY] ===="
        )
        return "transform_query"
    else:
        # Generate answer if relevant document exists
        print("==== [DECISION: GENERATE] ====")
        return "generate"


def hallucination_check(state):
    print("==== [CHECK HALLUCINATIONS] ====")
    # Get questions and document search results
    question = state["question"]
    documents = state["documents"]
    generation = state["generation"]

    # Hallucination Assessment
    score = hallucination_grader.invoke(
        {"documents": documents, "generation": generation}
    )
    grade = score.binary_score

    # Check for hallucination
    if grade == "yes":
        print("==== [DECISION: GENERATION IS GROUNDED IN DOCUMENTS] ====")

        # Evaluate the relevance of the answer
        print("==== [GRADE GENERATED ANSWER vs QUESTION] ====")
        score = answer_grader.invoke({"question": question, "generation": generation})
        grade = score.binary_score

        # Processing according to relevance evaluation results
        if grade == "yes":
            print("==== [DECISION: GENERATED ANSWER ADDRESSES QUESTION] ====")
            return "relevant"
        else:
            print("==== [DECISION: GENERATED ANSWER DOES NOT ADDRESS QUESTION] ====")
            return "not relevant"
    else:
        print("==== [DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY] ====")
        return "hallucination"

## Graph Construction

The **Graph Compile** step builds the workflow of **Adaptive RAG** ​​and makes it executable. This process connects each node and edge in the graph to define the overall flow of query processing.

- **Node Definition**: Define each node to clarify the states and transitions of the graph.
- **Set Edges**: Set edges between nodes to ensure that queries proceed along the appropriate path.
- **Build workflow**: Build the entire flow of the graph to maximize the efficiency of information search and creation.

In [26]:
from langgraph.graph import END, StateGraph, START
from langgraph.checkpoint.memory import MemorySaver

# Initialize graph state
workflow = StateGraph(GraphState)

# Node definition
workflow.add_node("web_search", web_search) # Web search
workflow.add_node("retrieve", retrieve) # Retrieve document
workflow.add_node("grade_documents", grade_documents) # Evaluate documents
workflow.add_node("generate", generate) # Generate answer
workflow.add_node("transform_query", transform_query) # Transform query

# Build graph
workflow.add_conditional_edges(
    START,
    route_question,
    {
        "web_search": "web_search", # Route to web search
        "vectorstore": "retrieve", # Routing to vectorstore
    },
)
workflow.add_edge("web_search", "generate") # Generate answer after web search
workflow.add_edge("retrieve", "grade_documents") # Evaluate documents after retrieval
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "transform_query": "transform_query", # Query transformation required
        "generate": "generate", # Can generate answers
    },
)
workflow.add_edge("transform_query", "retrieve") # Retrieve documents after transforming query
workflow.add_conditional_edges(
    "generate",
    hallucination_check,
    {
        "hallucination": "generate", # Regenerate when hallucination occurs
        "relevant": END, # Pass whether the answer is relevant
        "not relevant": "transform_query", # Transform query if it fails to determine whether the answer is relevant
    },
)

# Graph compilation
app = workflow.compile(checkpointer=MemorySaver())

Visualize the graph.

In [None]:
from langchain_teddynote.graphs import visualize_graph

visualize_graph(app)

## Execute Graph

In the **Use Graph** step, the query processing results are checked through the execution of **Adaptive RAG**. This process processes queries along each node and edge of the graph to produce the final result.

- **Graph Execution**: Executes the defined graph to follow the flow of the query.
- **Check Results**: After running the graph, review the generated results to ensure that the query was processed properly.
- **Result Analysis**: Analyze the generated results to evaluate whether they meet the purpose of the query.

In [None]:
from langchain_teddynote.messages import stream_graph, random_uuid
from langchain_core.runnables import RunnableConfig

# config settings (maximum number of recursions, thread_id)
config = RunnableConfig(recursion_limit=20, configurable={"thread_id": random_uuid()})

# Enter question
inputs = {
    "question": "삼성전자가 개발한 생성형 AI 의 이름은?",
}

# Run graph
stream_graph(app, inputs, config, ["agent", "rewrite", "generate"])

In [None]:
# Enter question
inputs = {
    "question": "2024년 노벨 문학상 수상자는 누구인가요?",
}

# Run graph
stream_graph(app, inputs, config, ["agent", "rewrite", "generate"])