# Corrective RAG Demo

This demo shows how you can use LlamaCloud and [Tavily AI](https://tavily.com/) to build a [Corrective RAG](https://arxiv.org/abs/2401.15884) workflow. The workflow uses the indexed documents on Llamacloud as a primary tool, but falls back to web search using Tavily AI if the information presented in the query cannot be found on LlamaCloud.


A brief understanding of the paper:  
Corrective Retrieval Augmented Generation (CRAG) is a method designed to enhance the robustness of language model generation by evaluating and augmenting the relevance of retrieved documents through a an evaluator and large-scale web searches, ensuring more accurate and reliable information is used in generation.

## Setup

Follow [these instructions](https://docs.cloud.llamaindex.ai/llamacloud/getting_started/quick_start) on how to set up your index. For this example, we will upload a paper about Llama2 onto LlamaCloud. On the configure data source step, download [this PDF paper](https://arxiv.org/pdf/2307.09288) and upload it into your index.

After deploying your index, follow [these instructions](https://docs.cloud.llamaindex.ai/llamacloud/getting_started/api_key) on getting an API key. Once you are done with this, configure `nest_asyncio` and your enviornment variables.

In [None]:
%pip install llama-index llama-index-indices-managed-llama-cloud llama-index-tools-tavily-research

In [1]:
import nest_asyncio
nest_asyncio.apply()

In [2]:
import os

os.environ["OPENAI_API_KEY"] = "<Your OpenAI API Key>"

## Designing the Workflow

Corrective RAG consists of the following steps:
1. Ingestion of data — Loads the data into an index and setting up Tavily AI. The ingestion step will be run by itself, taking in a start event and returning a stop event.
2. Retrieval - Retrives the most relevant nodes based on the query.
3. Relevance evaluation - Uses an LLM to determine whether the retrieved nodes are relevant to the query given the content of the nodes.
4. Relevance extraction - Extracts the nodes which the LLM determined to be relevant.
5. Query transformation and Tavily search - If a node is irrelevant, then uses an LLM to transform the query to tailor towards a web search. Uses Tavily to search the web for a relevant answer based on the query.
6. Response generation - Builds a summary index given the text from the relevant nodes and the Tavily search and uses this index to get a result given the original query.

The following events are needed:
1. `RetrieveEvent` - Event containing information about the retrieved nodes.
2. `RelevanceEvalEvent` - Event containing a list of the results of the relevance evaluation.
3. `TextExtractEvent` - Event containing the concatenated string of relevant text from relevant nodes.
4. `QueryEvent` - Event containing both the relevant text and search text.

In [3]:
from typing import List

from llama_index.core.schema import  NodeWithScore
from llama_index.core.workflow import (
    Event,
)

class RetrieveEvent(Event):
    """Retrieve event (gets retrieved nodes)."""

    retrieved_nodes: List[NodeWithScore]


class RelevanceEvalEvent(Event):
    """Relevance evaluation event (gets results of relevance evaluation)."""

    relevant_results: List[str]


class TextExtractEvent(Event):
    """Text extract event. Extracts relevant text and concatenates."""

    relevant_text: str


class QueryEvent(Event):
    """Query event. Queries given relevant text and search text."""

    relevant_text: str
    search_text: str

Below is the code for the workflow.

In [4]:
from typing import Optional

from llama_index.core.workflow import (
    StartEvent,
    StopEvent,
    step,
    Workflow,
    Context,
)
from llama_index.core import SummaryIndex
from llama_index.core.schema import Document
from llama_index.core.prompts import PromptTemplate
from llama_index.core.llms import LLM
from llama_index.llms.openai import OpenAI
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.indices.managed.llama_cloud import LlamaCloudIndex
from llama_index.tools.tavily_research import TavilyToolSpec

DEFAULT_RELEVANCY_PROMPT_TEMPLATE = PromptTemplate(
    template="""As a grader, your task is to evaluate the relevance of a document retrieved in response to a user's question.

    Retrieved Document:
    -------------------
    {context_str}

    User Question:
    --------------
    {query_str}

    Evaluation Criteria:
    - Consider whether the document contains keywords or topics related to the user's question.
    - The evaluation should not be overly stringent; the primary objective is to identify and filter out clearly irrelevant retrievals.

    Decision:
    - Assign a binary score to indicate the document's relevance.
    - Use 'yes' if the document is relevant to the question, or 'no' if it is not.

    Please provide your binary score ('yes' or 'no') below to indicate the document's relevance to the user question."""
)

DEFAULT_TRANSFORM_QUERY_TEMPLATE = PromptTemplate(
    template="""Your task is to refine a query to ensure it is highly effective for retrieving relevant search results. \n
    Analyze the given input to grasp the core semantic intent or meaning. \n
    Original Query:
    \n ------- \n
    {query_str}
    \n ------- \n
    Your goal is to rephrase or enhance this query to improve its search performance. Ensure the revised query is concise and directly aligned with the intended search objective. \n
    Respond with the optimized query only:"""
)


class CorrectiveRAGWorkflow(Workflow):
    @step(pass_context=True)
    async def ingest(self, ctx: Context, ev: StartEvent) -> Optional[StopEvent]:
        """Ingest step (for ingesting docs and initializing index)."""
        tavily_ai_apikey: Optional[str] = ev.get("tavily_ai_apikey")
        index: Optional[LlamaCloudIndex] = ev.get("index")

        if any(i is None for i in [tavily_ai_apikey, index]):
            return None

        llm = OpenAI(model="gpt-4")

        ctx.data["llm"] = llm
        ctx.data["index"] = index
        ctx.data["tavily_tool"] = TavilyToolSpec(api_key=tavily_ai_apikey)

        return StopEvent()

    @step(pass_context=True)
    async def retrieve(self, ctx: Context, ev: StartEvent) -> Optional[RetrieveEvent]:
        """Retrieve the relevant nodes for the query."""
        query_str = ev.get("query_str")
        retriever_kwargs = ev.get("retriever_kwargs", {})

        if query_str is None:
            return None

        retriever: BaseRetriever = ctx.data["index"].as_retriever(**retriever_kwargs)
        result = retriever.retrieve(query_str)
        ctx.data["retrieved_nodes"] = result
        ctx.data["query_str"] = query_str
        return RetrieveEvent(retrieved_nodes=result)

    @step(pass_context=True)
    async def eval_relevance(
        self, ctx: Context, ev: RetrieveEvent
    ) -> RelevanceEvalEvent:
        """Evaluate relevancy of retrieved documents with the query."""
        retrieved_nodes = ev.retrieved_nodes
        query_str = ctx.data["query_str"]
        llm: LLM = ctx.data["llm"]

        relevancy_results = []
        for node in retrieved_nodes:
            prompt = DEFAULT_RELEVANCY_PROMPT_TEMPLATE.format(context_str=node.text, query_str=query_str)
            relevancy = llm.complete(prompt)
            relevancy_results.append(relevancy.text.lower().strip())

        ctx.data["relevancy_results"] = relevancy_results
        return RelevanceEvalEvent(relevant_results=relevancy_results)

    @step(pass_context=True)
    async def extract_relevant_texts(
        self, ctx: Context, ev: RelevanceEvalEvent
    ) -> TextExtractEvent:
        """Extract relevant texts from retrieved documents."""
        retrieved_nodes = ctx.data["retrieved_nodes"]
        relevancy_results = ev.relevant_results

        relevant_texts = [
            retrieved_nodes[i].text
            for i, result in enumerate(relevancy_results)
            if result == "yes"
        ]

        result = "\n".join(relevant_texts)
        return TextExtractEvent(relevant_text=result)

    @step(pass_context=True)
    async def transform_query_pipeline(
        self, ctx: Context, ev: TextExtractEvent
    ) -> QueryEvent:
        """Search the transformed query with Tavily API."""
        relevant_text = ev.relevant_text
        relevancy_results = ctx.data["relevancy_results"]
        query_str = ctx.data["query_str"]
        llm: LLM = ctx.data["llm"]

        # If any document is found irrelevant, transform the query string for better search results.
        if "no" in relevancy_results:
            prompt = DEFAULT_TRANSFORM_QUERY_TEMPLATE.format(query_str=query_str)
            result = llm.complete(prompt)
            transformed_query_str = result.text

            # Conduct a search with the transformed query string and collect the results.
            search_results = ctx.data["tavily_tool"].search(
                transformed_query_str, max_results=5
            )
            search_text = "\n".join([result.text for result in search_results])
        else:
            search_text = ""

        return QueryEvent(relevant_text=relevant_text, search_text=search_text)

    @step(pass_context=True)
    async def query_result(self, ctx: Context, ev: QueryEvent) -> StopEvent:
        """Get result with relevant text."""
        relevant_text = ev.relevant_text
        search_text = ev.search_text
        query_str = ctx.data["query_str"]

        documents = [Document(text=relevant_text + "\n" + search_text)]
        index = SummaryIndex.from_documents(documents)
        query_engine = index.as_query_engine()
        result = query_engine.query(query_str)
        return StopEvent(result=result)


## Create LlamaCloudIndex

Create a `LlamaCloudIndex` which retrieves information from the index you have on LlamaCloud.

In [5]:
from llama_index.indices.managed.llama_cloud import LlamaCloudIndex

index = LlamaCloudIndex(
    name="<Your index name>",
    project_name="<Your project name>",
    api_key="llx-...",
    organization_id="<Your organization ID>",
)

See [here](https://docs.cloud.llamaindex.ai/organizations) for a tutorial on how to use organizations.

Set up the workflow ingestion:

In [6]:
workflow = CorrectiveRAGWorkflow(verbose=True, timeout=60)
await workflow.run(index=index, tavily_ai_apikey="<Your Tavily AI API Key>")

Running step ingest
Step ingest produced event StopEvent
Running step retrieve
Step retrieve produced no event


## Example queries

In [7]:
from IPython.display import display, Markdown

result = await workflow.run(query_str="How was Llama2 pretrained?") # this was in the given paper
display(Markdown(str(result)))

Running step ingest
Step ingest produced no event
Running step retrieve
Step retrieve produced event RetrieveEvent
Running step eval_relevance
Step eval_relevance produced event RelevanceEvalEvent
Running step extract_relevant_texts
Step extract_relevant_texts produced event TextExtractEvent
Running step transform_query_pipeline
Step transform_query_pipeline produced event QueryEvent
Running step query_result
Step query_result produced event StopEvent


Llama 2 was pretrained on Meta's Research Super Cluster and production clusters using NVIDIA A100 GPUs. The pretraining process involved utilizing custom training libraries and third-party cloud compute for fine-tuning, annotation, and evaluation. The pretraining data consisted of 2 trillion tokens from publicly available sources, with a cutoff date of September 2022. The carbon footprint of the pretraining process amounted to 539 tCO2eq, which was fully offset by Meta's sustainability program.

In [10]:
result = await workflow.run(query_str="Where does the airline flight UA 1 fly?") # this info is not in the paper
display(Markdown(str(result)))

Running step ingest
Step ingest produced no event
Running step retrieve
Step retrieve produced event RetrieveEvent
Running step eval_relevance
Step eval_relevance produced event RelevanceEvalEvent
Running step extract_relevant_texts
Step extract_relevant_texts produced event TextExtractEvent
Running step transform_query_pipeline
Step transform_query_pipeline produced event QueryEvent
Running step query_result
Step query_result produced event StopEvent


The airline flight UA 1 flies from San Francisco, United States to Singapore.