
# Adaptive RAG
```
Adaptive RAG is a strategy for RAG that unites (1) query analysis with (2) active / self-corrective RAG.

In the paper, they report query analysis to route across:

No Retrieval
Single-shot RAG
Iterative RAG
Let's build on this using LangGraph.

In our implementation, we will route between:

Web search: for questions related to recent events
Self-corrective RAG: for questions related to our index```



In [1]:
!pip install -U langchain_community tiktoken langchain-openai langchain-cohere langchainhub chromadb langchain langgraph  tavily-python




In [2]:
!pip install -U langchain_community tiktoken langchain-huggingface chromadb langchain langgraph tavily-python sentence_transformers




In [3]:
!pip install --upgrade langchain langchain-core langchain-community




In [None]:

import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings # Import Hugging Face Embeddings

# Set embeddings using a Hugging Face model
# You can replace "sentence-transformers/all-mpnet-base-v2" with another model if desired [1]
embd = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

# Docs to index
urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

# Load
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

# Split
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=500, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)

# Add to vectorstore
vectorstore = Chroma.from_documents(
    documents=doc_splits,
    collection_name="rag-chroma",
    embedding=embd,
)
retriever = vectorstore.as_retriever()



In [None]:
import re
from typing import Literal
from pydantic import BaseModel, Field, ValidationError
from langchain_core.prompts import ChatPromptTemplate
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from google.colab import userdata

# Pydantic output schema
class RouteQuery(BaseModel):
    datasource: Literal["vectorstore", "web_search"]

# Get token and configure LLM
hf_token = userdata.get("HF_TOKEN")
model_id = "mistralai/Mistral-7B-Instruct-v0.2"

llm = ChatHuggingFace(
    llm=HuggingFaceEndpoint(
        repo_id=model_id,
        task="text-generation",
        huggingfacehub_api_token=hf_token,
        max_new_tokens=256,
        temperature=0.1
    )
)

# Escape JSON braces with double {{...}} to prevent LangChain parsing error
system_prompt = """You are a routing classifier.
Reply with ONLY one of these JSON objects:
{{"datasource": "vectorstore"}}
{{"datasource": "web_search"}}

Use 'vectorstore' for questions about agents, prompt engineering, or adversarial attacks.
Use 'web_search' for everything else.
"""

prompt = ChatPromptTemplate.from_messages([
    ("system", system_prompt),
    ("human", "{question}")
])

chain = prompt | llm

# Extract and validate JSON from model output
def parse_response(output_text: str) -> RouteQuery:
    try:
        match = re.search(r"{.*}", output_text.strip(), re.DOTALL)
        if not match:
            raise ValueError("No JSON object found.")
        raw_json = match.group(0)
        return RouteQuery.model_validate(eval(raw_json))  # Safe only if model is trusted
    except (ValidationError, ValueError, SyntaxError) as e:
        print("⚠️ Failed to parse response:", output_text)
        return RouteQuery(datasource="web_search")

# Wrapper to classify user query
def classify(question: str) -> RouteQuery:
    response = chain.invoke({"question": question})
    text = response.content.strip() if hasattr(response, "content") else str(response)
    return parse_response(text)

# Run test cases
print(classify("What are the types of agent memory?"))
print(classify("Who will the Bears draft first in the NFL draft?"))


In [None]:
import re
from typing import Literal
from pydantic import BaseModel, Field, ValidationError
from langchain_core.prompts import ChatPromptTemplate
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from google.colab import userdata

# Step 1: Define schema
class GradeDocuments(BaseModel):
    binary_score: Literal["yes", "no"]

# Step 2: Initialize Hugging Face LLM
hf_token = userdata.get("HF_TOKEN")
model_id = "mistralai/Mistral-7B-Instruct-v0.2"

llm = ChatHuggingFace(
    llm=HuggingFaceEndpoint(
        repo_id=model_id,
        task="text-generation",
        max_new_tokens=256,
        temperature=0.1,
        huggingfacehub_api_token=hf_token,
    )
)

# Step 3: Create prompt (escape braces)
system_msg = """You are a grader assessing the relevance of a retrieved document to a user question.

If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant.

It does not need to be a stringent test. The goal is to filter out erroneous retrievals.

Reply ONLY with one of these JSON objects:
{{"binary_score": "yes"}}
{{"binary_score": "no"}}
"""

grade_prompt = ChatPromptTemplate.from_messages([
    ("system", system_msg),
    ("human", "Retrieved document:\n\n{document}\n\nUser question: {question}")
])

grading_chain = grade_prompt | llm

# Step 4: Parse the JSON safely
def parse_grade_output(output_text: str) -> GradeDocuments:
    try:
        match = re.search(r"{.*}", output_text.strip(), re.DOTALL)
        if not match:
            raise ValueError("No JSON object found.")
        raw_json = match.group(0)
        return GradeDocuments.model_validate(eval(raw_json))
    except (ValidationError, ValueError, SyntaxError):
        print("⚠️ Failed to parse:", output_text)
        return GradeDocuments(binary_score="no")

# Step 5: Wrap end-to-end grader
def evaluate_document_relevance(question: str, document: str) -> GradeDocuments:
    response = grading_chain.invoke({"question": question, "document": document})
    text = response.content.strip() if hasattr(response, "content") else str(response)
    return parse_grade_output(text)

# Step 6: Example run
question = "agent memory"
docs = retriever.invoke(question)
doc_txt = docs[1].page_content
print(evaluate_document_relevance(question, doc_txt))


#Generate based on Chosen Route - RAG vs WebSearch

In [None]:
from langchain_core.output_parsers import StrOutputParser
from langchain import hub
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from google.colab import userdata

# 🔐 Token
hf_token = userdata.get("HF_TOKEN")

# ✅ Hugging Face chat model setup
llm = ChatHuggingFace(
    llm=HuggingFaceEndpoint(
        repo_id="mistralai/Mistral-7B-Instruct-v0.2",  # or another chat-compatible model
        task="text-generation",
        huggingfacehub_api_token=hf_token,
        max_new_tokens=512,
        temperature=0.1,
    )
)

# ✅ Prompt from LangChain hub (rag-style)
prompt = hub.pull("rlm/rag-prompt")

# ✅ Output parser
parser = StrOutputParser()

# ✅ Combine into chain
rag_chain = prompt | llm | parser

# 🔄 Helper to format retrieved docs
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)


# 🧪 Test invocation
question = "What types of memory do LLM agents use?"
docs = retriever.invoke(question)
formatted = format_docs(docs)

# Run
generation = rag_chain.invoke({"context": formatted, "question": question})
print(generation)


#Hallucination Grader

In [None]:
import re
from typing import Literal
from pydantic import BaseModel, Field, ValidationError
from langchain_core.prompts import ChatPromptTemplate
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from google.colab import userdata

# ✅ Schema for hallucination grader
class GradeHallucinations(BaseModel):
    binary_score: Literal["yes", "no"] = Field(
        description="Answer is grounded in the facts, 'yes' or 'no'"
    )

# ✅ Load HF token and model
hf_token = userdata.get("HF_TOKEN")
model_id = "mistralai/Mistral-7B-Instruct-v0.2"

llm = ChatHuggingFace(
    llm=HuggingFaceEndpoint(
        repo_id=model_id,
        task="text-generation",
        huggingfacehub_api_token=hf_token,
        max_new_tokens=256,
        temperature=0.1,
    )
)

# ✅ Prompt (with escaped braces for LangChain templating)
system_msg = """
You are a grader assessing whether an LLM generation is grounded in a set of retrieved facts.
Respond with ONLY one of the following JSON objects:
{{"binary_score": "yes"}} or {{"binary_score": "no"}}

'yes' means the answer is grounded in the facts. 'no' means it contains unsupported information.
"""

hallucination_prompt = ChatPromptTemplate.from_messages([
    ("system", system_msg),
    ("human", "Set of facts:\n\n{documents}\n\nLLM generation:\n\n{generation}")
])

grader_chain = hallucination_prompt | llm

# ✅ Parse model response using regex + Pydantic
def parse_hallucination_output(output_text: str) -> GradeHallucinations:
    try:
        match = re.search(r"{.*}", output_text.strip(), re.DOTALL)
        if not match:
            raise ValueError("No JSON object found.")
        raw_json = match.group(0)
        return GradeHallucinations.model_validate(eval(raw_json))
    except (ValidationError, ValueError, SyntaxError):
        print("⚠️ Failed to parse:", output_text)
        return GradeHallucinations(binary_score="no")

# ✅ Wrapper
def evaluate_hallucination(documents: str, generation: str) -> GradeHallucinations:
    response = grader_chain.invoke({"documents": documents, "generation": generation})
    text = response.content.strip() if hasattr(response, "content") else str(response)
    return parse_hallucination_output(text)

# ✅ Example usage
documents_str = "\n\n".join(doc.page_content for doc in docs)  # Convert list[Document] to str
result = evaluate_hallucination(documents_str, generation)
print(result)


# LLM Judge- Answer addresses the question

In [None]:
import re
import json
from typing import Literal
from pydantic import BaseModel, Field, ValidationError
from langchain_core.prompts import ChatPromptTemplate
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from google.colab import userdata

# ✅ Schema
class GradeAnswer(BaseModel):
    binary_score: Literal["yes", "no"] = Field(
        description="Answer addresses the question, 'yes' or 'no'"
    )

# ✅ Prompt
system_msg = """
You are a grader assessing whether an answer addresses a user question.
Respond with ONLY one of these JSON objects:
{{"binary_score": "yes"}} or {{"binary_score": "no"}}

'yes' means the answer directly addresses the question. 'no' means it does not.
"""

answer_prompt = ChatPromptTemplate.from_messages([
    ("system", system_msg),
    ("human", "User question:\n\n{question}\n\nLLM generation:\n\n{generation}")
])

# ✅ Chain
grader_chain = answer_prompt | llm

def parse_answer_output(output_text: str) -> GradeAnswer:
    try:
        # extract first JSON block
        match = re.search(r"{.*}", output_text.strip(), re.DOTALL)
        if not match:
            raise ValueError("No JSON found")
        json_str = match.group(0)
        parsed = json.loads(json_str)
        return GradeAnswer.model_validate(parsed)
    except (ValidationError, ValueError, SyntaxError, json.JSONDecodeError) as e:
        print("⚠️ Failed to parse:", output_text)
        return GradeAnswer(binary_score="no")

# ✅ Wrapper
def evaluate_answer_relevance(question: str, generation: str) -> GradeAnswer:
    response = grader_chain.invoke({"question": question, "generation": generation})
    if hasattr(response, "content"):
        text = response.content.strip()
    else:
        text = str(response).strip()
    return parse_answer_output(text)


def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)


# Retrieved docs
docs = retriever.invoke(question)

formatted_context = format_docs(docs)

generation = rag_chain.invoke({"context": formatted_context, "question": question})
result = evaluate_answer_relevance(question, generation)

print(result)


# Question Re-writer


In [None]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from google.colab import userdata


# ✅ Prompt Template
system_msg = """
You are a question rewriter that converts an input question into a more effective version
for vectorstore retrieval. Your goal is to optimize the query for semantic intent and relevant keywords.
"""

re_write_prompt = ChatPromptTemplate.from_messages([
    ("system", system_msg),
    ("human", "Here is the initial question:\n\n{question}\n\nRewrite it for semantic search.")
])

# ✅ Chain: prompt → LLM → plain text parser
question_rewriter = re_write_prompt | llm | StrOutputParser()

improved = question_rewriter.invoke({"question": question})
print("🔍 Rewritten:", improved)


# Setting up Websearch Agent

In [None]:
# pip install -U langchain-tavily


In [None]:
### Search
from langchain_community.tools.tavily_search import TavilySearchResults

from langchain_community.tools.tavily_search import TavilySearchResults


import os
from google.colab import userdata

# Load securely from secrets if available
api_key = userdata.get("TAVILY_TOKEN")
if not api_key:
    raise ValueError("Missing TAVILY_API_KEY in Colab secrets.")
os.environ["TAVILY_API_KEY"] = api_key

web_search_tool = TavilySearchResults(k=3)

In [None]:
# web_search_tool = TavilySearchResults(k=3)  # Works once key is set
# print(web_search_tool.invoke("Who is the CEO of OpenAI?"))


# Construct the Graph¶
#### Capture the flow in as a graph.

Define Graph State

In [None]:
from typing import List

from typing_extensions import TypedDict


class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        documents: list of documents
    """

    question: str
    generation: str
    documents: List[str]

# Define Graph Flow¶
###  API Reference: [Document]

In [None]:

# Pydantic output schema for Routing
class RouteQuery(BaseModel):
    datasource: Literal["vectorstore", "web_search"]


# Question Router
system_prompt = """You are a routing classifier.
Reply with ONLY one of these JSON objects:
{{"datasource": "vectorstore"}}
{{"datasource": "web_search"}}

Use 'vectorstore' for questions about agents, prompt engineering, or adversarial attacks.
Use 'web_search' for everything else.
"""
prompt = ChatPromptTemplate.from_messages([
    ("system", system_prompt),
    ("human", "{question}")
])
# The chain definition itself is fine
question_router = prompt | llm

# Retrieval Grader (Relevance)
class GradeDocuments(BaseModel):
    binary_score: Literal["yes", "no"]

system_msg_retrieval = """You are a grader assessing the relevance of a retrieved document to a user question.

If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant.

It does not need to be a stringent test. The goal is to filter out erroneous retrievals.

Reply ONLY with one of these JSON objects:
{{"binary_score": "yes"}}
{{"binary_score": "no"}}
"""
grade_prompt_retrieval = ChatPromptTemplate.from_messages([
    ("system", system_msg_retrieval),
    ("human", "Retrieved document:\n\n{document}\n\nUser question: {question}")
])
retrieval_grader = grade_prompt_retrieval | llm


# Hallucination Grader
class GradeHallucinations(BaseModel):
    binary_score: Literal["yes", "no"] = Field(
        description="Answer is grounded in the facts, 'yes' or 'no'"
    )

system_msg_hallucination = """
You are a grader assessing whether an LLM generation is grounded in a set of retrieved facts.
Respond with ONLY one of the following JSON objects:
{{"binary_score": "yes"}} or {{"binary_score": "no"}}

'yes' means the answer is grounded in the facts. 'no' means it contains unsupported information.
"""
hallucination_prompt = ChatPromptTemplate.from_messages([
    ("system", system_msg_hallucination),
    ("human", "Set of facts:\n\n{documents}\n\nLLM generation:\n\n{generation}")
])
hallucination_grader = hallucination_prompt | llm


# Answer Grader
class GradeAnswer(BaseModel):
    binary_score: Literal["yes", "no"] = Field(
        description="Answer addresses the question, 'yes' or 'no'"
    )

system_msg_answer = """
You are a grader assessing whether an answer addresses a user question.
Respond with ONLY one of these JSON objects:
{{"binary_score": "yes"}} or {{"binary_score": "no"}}

'yes' means the answer directly addresses the question. 'no' means it does not.
"""
answer_prompt = ChatPromptTemplate.from_messages([
    ("system", system_msg_answer),
    ("human", "User question:\n\n{question}\n\nLLM generation:\n\n{generation}")
])
answer_grader = answer_prompt | llm

# --- Parsing Functions with improved error handling ---

def parse_route_output(output_text: str) -> RouteQuery:
    try:
        match = re.search(r"{.*}", output_text.strip(), re.DOTALL)
        if not match:
            raise ValueError("No JSON object found.")
        raw_json = match.group(0)
        # Using json.loads is generally safer than eval
        parsed = json.loads(raw_json)
        return RouteQuery.model_validate(parsed)
    except (ValidationError, ValueError, SyntaxError, json.JSONDecodeError) as e:
        print(f"⚠️ Failed to parse RouteQuery response: {e}")
        print("Raw output text:", output_text) # Print raw output on failure
        return RouteQuery(datasource="web_search") # Fallback


def parse_grade_output(output_text: str) -> GradeDocuments:
    try:
        match = re.search(r"{.*}", output_text.strip(), re.DOTALL)
        if not match:
            raise ValueError("No JSON object found.")
        raw_json = match.group(0)
        parsed = json.loads(raw_json) # Using json.loads
        return GradeDocuments.model_validate(parsed)
    except (ValidationError, ValueError, SyntaxError, json.JSONDecodeError) as e:
        print(f"⚠️ Failed to parse GradeDocuments response: {e}")
        print("Raw output text:", output_text) # Print raw output on failure
        return GradeDocuments(binary_score="no") # Fallback

def parse_hallucination_output(output_text: str) -> GradeHallucinations:
    try:
        match = re.search(r"{.*}", output_text.strip(), re.DOTALL)
        if not match:
            raise ValueError("No JSON object found.")
        raw_json = match.group(0)
        parsed = json.loads(raw_json) # Using json.loads
        return GradeHallucinations.model_validate(parsed)
    except (ValidationError, ValueError, SyntaxError, json.JSONDecodeError) as e:
        print(f"⚠️ Failed to parse GradeHallucinations response: {e}")
        print("Raw output text:", output_text) # Print raw output on failure
        return GradeHallucinations(binary_score="no") # Fallback

def parse_answer_output(output_text: str) -> GradeAnswer:
    try:
        match = re.search(r"{.*}", output_text.strip(), re.DOTALL)
        if not match:
            raise ValueError("No JSON found")
        json_str = match.group(0)
        parsed = json.loads(json_str)
        return GradeAnswer.model_validate(parsed)
    except (ValidationError, ValueError, SyntaxError, json.JSONDecodeError) as e:
        print(f"⚠️ Failed to parse GradeAnswer response: {e}")
        print("Raw output text:", output_text) # Print raw output on failure
        return GradeAnswer(binary_score="no") # Fallback


# --- Node Functions (Keep these largely the same, ensuring they use the correct graders/routers) ---

def retrieve(state):
    """
    Retrieve documents

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    print("---RETRIEVE---")
    question = state["question"]

    # Retrieval
    documents = retriever.invoke(question)
    return {"documents": documents, "question": question}


def generate(state):
    """
    Generate answer

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, generation, that contains LLM generation
    """
    print("---GENERATE---")
    question = state["question"]
    documents = state["documents"]

    # Format documents for the rag_chain context
    formatted_context = format_docs(documents)

    # RAG generation (rag_chain defined elsewhere, assuming it uses the correct llm)
    generation = rag_chain.invoke({
        "context": formatted_context,
        "question": question
    })
    # Return the updated state
    return {"documents": documents, "question": question, "generation": generation}

def grade_documents(state):
    """
    Determines whether the retrieved documents are relevant to the question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates documents key with only filtered relevant documents
    """

    print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
    question = state["question"]
    documents = state["documents"]

    # Score each doc
    filtered_docs = []
    for d in documents:
        raw_score_output = retrieval_grader.invoke(
            {"question": question, "document": d.page_content}
        )
        # Use the updated parse_grade_output function
        score = parse_grade_output(raw_score_output.content if hasattr(raw_score_output, 'content') else str(raw_score_output))
        grade = score.binary_score
        if grade == "yes":
            print("---GRADE: DOCUMENT RELEVANT---")
            filtered_docs.append(d)
        else:
            print("---GRADE: DOCUMENT NOT RELEVANT---")
    return {"documents": filtered_docs, "question": question}


def transform_query(state):
    """
    Transform the query to produce a better question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates question key with a re-phrased question
    """

    print("---TRANSFORM QUERY---")
    question = state["question"]
    documents = state["documents"]

    # Re-write question (question_rewriter needs to be defined elsewhere, assuming it exists)
    better_question = question_rewriter.invoke({"question": question})
    # Ensure better_question is a string
    better_question = better_question.content if hasattr(better_question, 'content') else str(better_question)

    return {"documents": documents, "question": better_question}

def web_search(state):
    """
    Web search based on the re-phrased question.
    Returns:
        dict: Updated state with documents
    """
    print("---WEB SEARCH---")
    question = state["question"]
    docs = web_search_tool.invoke({"query": question})

    if isinstance(docs, list):
        # TavilySearchResults can return a list of dictionaries
        results = [Document(page_content=d.get("content", "")) for d in docs] # Use .get for safety
    else:
        # Handle cases where it might return a single result or unexpected format
        content = docs.get("content", "") if isinstance(docs, dict) else str(docs)
        results = [Document(page_content=content)]

    return {"documents": results, "question": question}


# --- Edge Functions (Focus on route_question error handling) ---

def route_question(state):
    print("---ROUTE QUESTION---")
    question = state["question"]
    raw = None
    try:
        # Invoke the router chain
        raw = question_router.invoke({"question": question})
        # Attempt to parse the response
        # Use parse_route_output function
        route_decision = parse_route_output(raw.content if hasattr(raw, 'content') else str(raw))
        decision = route_decision.datasource
        print(f"---ROUTE DECISION: {decision}---")
        return decision
    except Exception as e:
        print(f"⚠️ Error during question routing: {e}")
        print(f"Raw output from router (if available): {raw}")
        # Fallback to web search on any error during routing
        print("---FALLBACK: WEB SEARCH---")
        return "web_search"


def decide_to_generate(state):
    """
    Determines whether to generate an answer, or re-generate a question.
    """
    print("---ASSESS GRADED DOCUMENTS---")
    # Ensure documents key exists and is a list
    filtered_documents = state.get("documents", [])

    if not filtered_documents or not isinstance(filtered_documents, list) or all(not isinstance(d, Document) for d in filtered_documents):
        print(
            "---DECISION: ALL DOCUMENTS ARE NOT RELEVANT OR DOCUMENTS ARE INVALID, TRANSFORM QUERY---"
        )
        return "transform_query"
    else:
        print("---DECISION: GENERATE---")
        return "generate"


def grade_generation_v_documents_and_question(state):
    """
    Determines whether the generation is grounded in the document and answers question.
    """

    print("---CHECK HALLUCINATIONS---")
    question = state["question"]
    documents = state["documents"]
    generation = state["generation"]

    # Ensure documents is in a format the grader expects (e.g., a string of content)
    documents_content = "\n\n".join([d.page_content for d in documents if isinstance(d, Document)])

    raw_hallucination_score = hallucination_grader.invoke(
        {"documents": documents_content, "generation": generation}
    )
    # Use the updated parse_hallucination_output function
    hallucination_score = parse_hallucination_output(raw_hallucination_score.content if hasattr(raw_hallucination_score, 'content') else str(raw_hallucination_score))
    hallucination_grade = hallucination_score.binary_score

    # Check hallucination
    if hallucination_grade == "yes":
        print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
        # Check question-answering
        print("---GRADE GENERATION vs QUESTION---")
        raw_answer_score = answer_grader.invoke({"question": question, "generation": generation})
        # Use the updated parse_answer_output function
        answer_score = parse_answer_output(raw_answer_score.content if hasattr(raw_answer_score, 'content') else str(raw_answer_score))
        answer_grade = answer_score.binary_score

        if answer_grade == "yes":
            print("---DECISION: GENERATION ADDRESSES QUESTION---")
            return "useful"
        else:
            print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
            return "not useful"
    else:
        print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
        return "not supported"


# Compile Graph¶
API Reference: END | StateGraph | START

In [None]:

from langgraph.graph import END, StateGraph, START
from langgraph.graph.state import CompiledStateGraph

from typing import List
from typing_extensions import TypedDict

class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        documents: list of documents
    """

    question: str
    generation: str
    documents: List[Document] # Ensure documents is List[Document]


workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("web_search", web_search)  # web search
workflow.add_node("retrieve", retrieve)  # retrieve
workflow.add_node("grade_documents", grade_documents)  # grade documents
workflow.add_node("generate", generate)  # generate
workflow.add_node("transform_query", transform_query)  # transform_query

# Build graph
workflow.add_conditional_edges(
    START,
    route_question,
    {
        "web_search": "web_search",
        "vectorstore": "retrieve",
    },
)
workflow.add_edge("web_search", "generate") # Web search results go directly to generate
workflow.add_edge("retrieve", "grade_documents") # Retrieval goes to grading
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate, # Decide based on grading
    {
        "transform_query": "transform_query", # If no relevant docs, transform query
        "generate": "generate", # If relevant docs, generate
    },
)
workflow.add_edge("transform_query", "retrieve") # After transforming query, retrieve again
workflow.add_conditional_edges(
    "generate",
    grade_generation_v_documents_and_question, # Grade the generated answer
    {
        "not supported": "generate", # If not grounded, try generating again (potentially with same docs)
        "useful": END, # If grounded and answers question, finish
        "not useful": "transform_query", # If grounded but doesn't answer question, transform query
    },
)

# Compile the updated workflow
app: CompiledStateGraph = workflow.compile() # Explicitly type the compiled graph


# Use Graph¶


In [None]:


from langchain_core.output_parsers import StrOutputParser
from langchain import hub

prompt_template = hub.pull("rlm/rag-prompt")
rag_chain = prompt_template | llm | StrOutputParser()


system_msg_rewrite = """
You are a question rewriter that converts an input question into a more effective version
for vectorstore retrieval. Your goal is to optimize the query for semantic intent and relevant keywords.
"""
re_write_prompt = ChatPromptTemplate.from_messages([
    ("system", system_msg_rewrite),
    ("human", "Here is the initial question:\n\n{question}\n\nRewrite it for semantic search.")
])
question_rewriter = re_write_prompt | llm | StrOutputParser()




In [None]:
from pprint import pprint

inputs = {"question": "What are the types of agent memory?"}
final_generation = None

print("--- STARTING GRAPH RUN ---")
try:
    for step in app.stream(inputs):
        for node_name, node_output in step.items():
            pprint(f"🔄 Node: {node_name}")
            pprint(node_output)

            # Capture the generation if it exists in the node output
            if isinstance(node_output, dict) and "generation" in node_output:
                final_generation = node_output["generation"]
        print("\n---")

    if final_generation:
        print("✅ Final Generation:")
        pprint(final_generation)
    else:
        print("⚠️ No generation found in any step.")

except Exception as e:
    print(f"\n❌ An error occurred during the graph execution: {e}")
