# 3.2 Corrective RAG


## Setup

### Install dependencies

In [None]:
%pip install python-dotenv~=1.0 docarray~=0.40.0 pypdf~=5.1 --upgrade --quiet
%pip install chromadb~=0.5.18 sentence-transformers~=3.3 --upgrade --quiet 
%pip install langchain~=0.3.7 langchain_openai~=0.2.6 langchain_community~=0.3.5 langchain-chroma~=0.1.4 langchainhub~=0.1.21 --upgrade --quiet

# If running locally, you can do this instead:
#%pip install -r ../requirements.txt

### Load environment variables

In [None]:
import os
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())

# If running in Google Colab, you can use this code instead:
# from google.colab import userdata
# os.environ["AZURE_OPENAI_API_KEY"] = userdata.get("AZURE_OPENAI_API_KEY")
# os.environ["AZURE_OPENAI_ENDPOINT"] = userdata.get("AZURE_OPENAI_ENDPOINT")

### Setup models

In [None]:
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
api_version = "2024-10-01-preview"
llm = AzureChatOpenAI(deployment_name="gpt-4o", temperature=0.0, api_version=api_version)
embedding_model = AzureOpenAIEmbeddings(model="text-embedding-3-large", api_version=api_version)

### Setup LangSmith tracing for this notebook

In [None]:
import os

# API key etc is in the .env file
# my_name = "Totoro"
# os.environ["LANGCHAIN_TRACING_V2"] = "true"
# os.environ["LANGCHAIN_PROJECT"] = f"tokyo24-test-{my_name}"

### Setup path to data 

In [None]:
data_path = "../data"

# CRAG

Corrective-RAG (CRAG) is a recent paper that introduces an interesting approach for active RAG, incorporating self-reflection / self-grading on retrieved documents. 

The general ideas in the [paper](https://arxiv.org/pdf/2401.15884.pdf) are:

* Retrieve documents
* Evaluate them for relevance to the user question
* Perform knowledge refinement
* Use web search as a fallback if any documents are not relevant
* The diagrams in the paper also suggest that query re-writing is used here

![CRAG.png](https://www.dropbox.com/scl/fi/fnpz0n5unnr3iadw1ob2n/crag.png?rlkey=yz9ft5aq87aibnace7f3g3xav&dl=1)

**Paper:**
https://arxiv.org/pdf/2401.15884.pdf







---

Let's implement this from scratch using LangChain.

We can make some simplifications:

* Let's skip the knowledge refinement phase as a first pass. This can be added back as a node, if desired. 
* If *any* document is irrelevant, let's opt to supplement retrieval with web search. 
* We'll simulate web search using an LLM-query (this can actually be a viable alternative fallback). 
* Let's use query re-writing to optimize the query for web search.


![CRAG flow](https://www.dropbox.com/scl/fi/m6x5vsxvb5p67i89stccq/crag-flow.png?rlkey=bkjz3gomb4cn2sx9iyrr3s0e6&dl=1)

## Retriever
 
Let's index 3 blog posts.

In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document

urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=250, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)

# Add to vectorDB
vectorstore = Chroma.from_documents(
    documents=doc_splits,
    collection_name="rag-chroma",
    embedding=embedding_model,
)
retriever = vectorstore.as_retriever()

## Retrieval Grader

In [None]:
from pydantic import BaseModel, Field

# Data model
class GradeDocuments(BaseModel):
    """Binary score for relevance check on retrieved documents."""

    binary_score: str = Field(description="Documents are relevant to the question, 'yes' or 'no'")

In [None]:
from langchain_core.prompts import ChatPromptTemplate

# Prompt 
system = """You are a grader assessing relevance of a retrieved document to a user question. \n 
    If the document contains keyword(s) or semantic meaning related to the question, grade it as relevant. \n
    Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
grade_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
    ]
)

In [None]:
structured_llm_grader = llm.with_structured_output(GradeDocuments)
retrieval_grader = grade_prompt | structured_llm_grader
question = "agent memory"
docs = retriever.invoke(question)
doc_txt = docs[1].page_content
print(retrieval_grader.invoke({"question": question, "document": doc_txt}))

### Define a function for grading documents

In [None]:
from typing import Dict, Any

def grade_documents(input) -> Dict[str, Any]: 
    print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
    question = input["question"]
    documents = input["documents"]

    # Score each doc
    filtered_docs = []
    web_search = "no"
    for d in documents:
        # TODO: Consider batching
        score = retrieval_grader.invoke({"question": question, "document": d.page_content})
        grade = score.binary_score
        # Document relevant
        if grade.lower() == "yes":
            print("---GRADE: DOCUMENT RELEVANT---")
            filtered_docs.append(d)
        # Document not relevant
        else:
            print("---GRADE: DOCUMENT NOT RELEVANT - WEB SEARCH NEEDED---")
            # We do not include the document in filtered_docs
            # We set a flag to indicate that we want to run web search
            web_search = "yes"
            continue
    return {"documents": filtered_docs, "question": question, "web_search": web_search}

## Generate

In [None]:
from langchain import hub
from langchain_core.output_parsers import StrOutputParser

# Prompt
prompt = hub.pull("rlm/rag-prompt")

# Post-processing
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

# Chain
rag_chain = prompt | llm | StrOutputParser()

# Run
generation = rag_chain.invoke({"context": docs, "question": question})
print(generation)

## Build the final chains

In [None]:
from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableBranch

retriever_and_grading_chain = ({ # This is a shorthand for a RunnableMap / RunnableParallel
        "documents": retriever,
        "question": RunnablePassthrough(),
    }
    | RunnableLambda(lambda x: grade_documents(x))
)

web_search_chain_partial = (ChatPromptTemplate.from_messages(
        [
            ("system", "You are an Icelandic Viking from the 11th century, with no knowledge about modern technology. Confidently pretend that you know the answer to the user's question. Always end with a short tale about an heroic deed, seemingly related to the question."),
            ("human", "{question}"),
        ]
    ) 
    | llm.bind(temperature=0.9) 
    | StrOutputParser()
)

web_search_chain = RunnablePassthrough.assign(context=web_search_chain_partial) 

branch_chain = (RunnableBranch(
    (lambda x: x["web_search"] == "yes", web_search_chain | rag_chain),
    RunnableLambda(lambda x: {"context": format_docs(x["documents"]), "question": x["question"]}) | rag_chain
    )
)

final_chain = retriever_and_grading_chain | branch_chain

final_chain.invoke("agent memory")