In [None]:
# import libraries
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_ollama import OllamaEmbeddings, ChatOllama
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.output_parsers import StrOutputParser

In [1]:
# Knowledge Base of documents from the web in a vector store
#
# This script loads documents from the web, splits them into chunks, embeds the chunks, and stores them in a vector store.
#

# List of URLs to load documents from
urls = [
    "https://cobusgreyling.medium.com/corrective-rag-crag-5e40467099f8",
    "https://gabrielgomes61320.medium.com/advanced-rag-techniques-the-corrective-rag-strategy-93451a49db61",
]

# Load documents from the URLs
docs = [WebBaseLoader(url).load() for url in urls]

docs_list = [item for sublist in docs for item in sublist]

# Initialize a text splitter with specified chunk size and overlap
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=250, chunk_overlap=0
)

# Split the documents into chunks
doc_splits = text_splitter.split_documents(docs_list)

# Embedding
embeddings = OllamaEmbeddings(
    model="mxbai-embed-large:latest",
)

#Index
vectorstore = Chroma.from_documents(
    documents=doc_splits,
    embedding=embeddings,
    collection_name="rag-chroma"
)
retriever = vectorstore.as_retriever()

In [7]:
### Judge the relevance of a document to a question
### Outputs in JSON format
###

# LLM
llm = ChatOllama(model="llama3.1:8b", format="json", temperature=0)

# Prompt
prompt = PromptTemplate(
    template="""You are a teacher grading a quiz. You will be given: 
    1/ a QUESTION
    2/ A FACT provided by the student
    
    You are grading RELEVANCE RECALL:
    A score of "A" means that ANY of the statements in the FACT are relevant to the QUESTION. 
    A score of "F" means that NONE of the statements in the FACT are relevant to the QUESTION. 
    Yes is the highest (best) score. No is the lowest score you can give. 
    
    Explain your reasoning in a step-by-step manner. Ensure your reasoning and conclusion are correct. 
    
    Avoid simply stating the correct answer at the outset.
    
    Question: {question} \n
    Fact: \n\n {documents} \n\n
    
    Give a score 'A' or 'F' score to indicate whether the document is relevant to the question. \n
    Provide the binary score as a JSON with a single key 'score' and no premable or explanation.
    """,
    input_variables=["question", "documents"],
)

judge = prompt | llm | JsonOutputParser()

In [8]:
# Test above code for judge
# 

#question = "Explain the theory of relativity."
question = "What is Corrective Retrieval Augmented Generation."
docs = retriever.invoke(question)
doc_txt = docs[0].page_content
score = judge.invoke({"question": question, "documents": doc_txt})
score

OllamaEndpointNotFoundError: Ollama call failed with status code 404. Maybe your model is not found and you should pull the model with `ollama pull llama3.1:8b`.

In [6]:
### Question Rephraser
### Outputs in JSON format
###

# LLM
llm = ChatOllama(model="llama3.1:8b", format="json", temperature=0)

promptPreviousQuestions = PromptTemplate(
    template="""You are a reporter asking a question.  
    Your job is to take a question and rephrase it in a way that is more likely to get a useful answer. 
    You want to make sure you do not ask a repeat question from the same interview.  Make sure to provide a question.  Do not respond with an empty string.
    Take into account this is a search engine you are asking the question to so make sure the question is clear and concise. 
    Return the rephrased question as a JSON with a single key 'question' and no premable or explanation.
    
    Question: {question} \n
    Previous Questions: \n\n {previous_questions} \n\n
    
    """,
    input_variables=["question, previous_questions"],
)

question_rephraser_with_previousquestions = promptPreviousQuestions | llm | JsonOutputParser()

# Test above code
# 

previous_questions = [
    "What is the capital of France?",
    "How does photosynthesis work?",
    "What is the speed of light?",
]
question = "Explain the theory of relativity."

prompt_input = {"question": question, "previous_questions": "\n".join(previous_questions)}
response = question_rephraser_with_previousquestions.invoke(prompt_input)

question_value = response['question']
print(question_value)

In [None]:
### Question Rephraser
### Outputs in JSON format
###

# LLM
llm = ChatOllama(model="llama3.1:8b", format="json", temperature=0)

# Prompt
promptQuestion = PromptTemplate(
    template="""You are a reporter asking a question.  
    Your job is to take the question below and rephrase it in a way that is more likely to get a useful answer through a search engine. 
    Keep the question concise and clear.  Do not respond with an empty string. Do not have more than one sentence.  Take into account 
    this is a search engine you are asking the question to so make sure the question is clear and concise. 
    Return the rephrased question as a JSON with a single key 'question' and no premable or explanation.
    
    Question: {question} \n
    
    """,
    input_variables=["question"],
)

question_rephraser = promptQuestion | llm | JsonOutputParser()

In [None]:
# Test above code
# 

question = "Explain the theory of relativity."

prompt_input = {"question": question}
response = question_rephraser.invoke(prompt_input)

print(response)

In [None]:
### Generate an answer based on the question and documents

# Prompt
prompt = PromptTemplate(
    template="""You are an assistant for question-answering tasks. 
    
    Use the following documents to answer the question. 
    
    If you don't know the answer, just say that you don't know. 
    
    Use three sentences maximum and keep the answer concise:
    Question: {question} 
    Documents: {documents} 
    Answer: 
    """,
    input_variables=["question", "documents"],
)

# LLM
llm = ChatOllama(model="llama3.1:8b", temperature=0)

# Chain
rag_chain = prompt | llm | StrOutputParser()

# Run
generation = rag_chain.invoke({"documents": docs, "question": question})
print(generation)

In [None]:
### Search Tool

web_search_tool = TavilySearchResults(k=3)

In [None]:
from typing import List
from typing_extensions import TypedDict
from IPython.display import Image, display
from langchain.schema import Document
from langgraph.graph import START, END, StateGraph

# Graph State that is passed along the graph from node to node
class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        search: whether to add search
        documents: list of documents
        search_count: number of searches
    """

    question: str
    previous_questions: List[str]
    generation: str
    search: str
    documents: List[str]
    steps: List[str]
    search_count: int

In [None]:
# Agent Node Functions that are called by the graph

### Rephrase the question
def rephrase_question(state: GraphState) -> GraphState:
    """
    Rephrase the question

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    # Assuming state["question"] is a string and state["previous_questions"] is a list
    current_question = state["question"]
    previous_questions = state["previous_questions"]

    #print("current_question:", current_question)

    if previous_questions:
        #print("previous_questions:", previous_questions)
        prompt_input = {"question": current_question, "previous_questions": "\n".join(previous_questions)}
        question = question_rephraser_with_previousquestions.invoke(prompt_input)
        previous_questions = previous_questions.append(question)
    else:
        #print("Previous question empty, current question:", current_question)
        prompt_input = {"question": current_question}
        question = question_rephraser.invoke(prompt_input)
        #print("question:", question)
        previous_questions = [current_question]

    question_value = question['question']
    steps = state["steps"]
    search_count = state["search_count"]
    steps.append("rephrase_question")
    return {
        "question": question_value, 
        "previous_questions": previous_questions, 
        "steps": steps, 
        "search_count": search_count
        }

### Retrieve documents
def retrieve(state: GraphState) -> GraphState:
    """
    Retrieve documents

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    question = state["question"]
    previous_questions = []
    documents = retriever.invoke(question)
    steps = state["steps"]
    search_count = 0
    steps.append("retrieve_documents")
    return {
        "documents": documents, 
        "question": question, 
        "previous_questions": previous_questions, 
        "steps": steps, 
        "search_count": search_count
        }

### Not found documents
def not_found(state: GraphState) -> GraphState:
    """
    If no documents are found, return a message

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    question = state["question"]
    previous_questions = state["previous_questions"]
    search_count = state["search_count"]
    steps = state["steps"]
    steps.append("not_found")
    return {
        "documents": [], 
        "question": question, 
        "previous_questions": previous_questions, 
        "steps": steps, 
        "search_count": search_count
        }

def generate(state):
    """
    Generate answer

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, generation, that contains LLM generation
    """

    question = state["question"]
    documents = state["documents"]
    previous_questions = state["previous_questions"]
    search_count = state["search_count"]
    generation = rag_chain.invoke({"documents": documents, "question": question})
    steps = state["steps"]
    steps.append("generate_answer")
    return {
        "documents": documents,
        "question": question,
        "previous_questions": previous_questions,
        "generation": generation,
        "steps": steps,
        "search_count": search_count
    }

### Grade documents
def judge_documents(state: GraphState) -> GraphState:
    """
    Determines whether the retrieved documents are relevant to the question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates documents key with only filtered relevant documents
    """

    question = state["question"]
    documents = state["documents"]
    previous_questions = state["previous_questions"]
    count = state["search_count"]
    steps = state["steps"]
    steps.append("judge_documents")
    filtered_docs = []
    search = "No"
    # if after 3 searches, then kick off to human in the loop
    if count > 2:
        search = "ask_human"
        count = 0
    else:
        for d in documents:
            score = judge.invoke(
                {"question": question, "documents": d.page_content}
            )
            grade = score["score"]
            if grade == "A":
                filtered_docs.append(d)
            else:
                search = "Yes"
                continue
    return {
        "documents": filtered_docs,
        "question": question,
        "previous_questions": state["previous_questions"],
        "search": search,
        "steps": steps,
        "search_count": count
    }

### Web search
def web_search(state: GraphState) -> GraphState:
    """
    Web search based on the re-phrased question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates documents key with appended web results
    """

    question = state["question"]
    documents = state.get("documents", [])
    previous_questions = state["previous_questions"]
    steps = state["steps"]
    search_count = state["search_count"] + 1
    steps.append("web_search")
    web_results = web_search_tool.invoke({"query": question})
    documents.extend(
        [
            Document(page_content=d["content"], metadata={"url": d["url"]})
            for d in web_results
        ]
    )
    return {
        "documents": documents, 
        "question": question, 
        "previous_questions": previous_questions,
        "steps": steps, 
        "search_count": search_count
        }

### Router to decide next node to go to
def router(state: GraphState) -> GraphState:
    """
    Determines whether to generate an answer, or re-generate a question.

    Args:
        state (dict): The current graph state

    Returns:
        str: Binary decision for next node to call
    """
    search = state["search"]
    if search == "Yes":
        return "rephrase_question"
    elif search == "ask_human":
        return "ask_human"
    else:
        return "generate"
    
### Ask human
def ask_human(state: GraphState) -> GraphState:
    # """ArithmeticError
    # Human feedback on the relevance of the documents

    # Args:
    #     state (dict): The current graph state

    # Returns:
    # """
    # previous_questions = "\n".join(state["previous_questions"])
   
    # template="""We have asked a list of questions below and have not found related documents.  
    # Please rephrase the question or provide a new question.  
    
    # Previous questions asked below: {previous_questions} 
    # """,

    documents = state.get("documents", [])
    previous_questions = state["previous_questions"]
    steps = state["steps"]
    steps.append("ask_human")
    search_count = state["search_count"]

    input_msg = (
        f"""We have tried searching based on the question asked and could not find any relevant documents.  
          Please rephrase the question taking into account the previous questions that were asked below or ask a new question. \n\n
          Previous Questions:\nPrevious Questions:{"\n".join(previous_questions)}"""
    )
    question = input(input_msg)

    return {
        "documents": documents, 
        "question": question, 
        "previous_questions": previous_questions,
        "steps": steps, 
        "search_count": search_count
        }

In [None]:
# Graph
workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("retrieve", retrieve)  # retrieve
workflow.add_node("judge_documents", judge_documents)  # grade documents
workflow.add_node("rephrase_question", rephrase_question)  # web search
workflow.add_node("web_search", web_search)  # web search
workflow.add_node("ask_human", ask_human)  # web search
workflow.add_node("generate", generate)  # generatae

# Build graph
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "judge_documents")
workflow.add_conditional_edges(
    "judge_documents",
    router,
    {
        "rephrase_question": "rephrase_question",
        "generate": "generate",
        "ask_human": "ask_human"
    },
)
workflow.add_edge("rephrase_question", "web_search")
workflow.add_edge("web_search", "judge_documents")
workflow.add_edge("ask_human", "web_search")
workflow.add_edge("generate", END)

# Set up memory
# from langgraph.checkpoint.memory import MemorySaver

# memory = MemorySaver()

custom_graph = workflow.compile()

# TODO:  Need to figure out how to interrupt the graph and ask for human feedback and then take that feedback and feed it back into the loop
# custom_graph = workflow.compile(interrupt_before=["ask_human"])

display(Image(custom_graph.get_graph(xray=True).draw_mermaid_png()))

In [None]:
import uuid

# Prediction Function
def call_langgraph_agents(example: dict[str, str]):
    config = {"configurable": {"thread_id": str(uuid.uuid4())}}
    state_dict = custom_graph.invoke(
        {"question": example["input"], "steps": [], "previous_questions": [], "generation": "", "search": "", "documents": [], "search_count": 0}, config
    )
    
    print(f"State Dictionary: {state_dict}")  # Debug log
    return {"response": state_dict["generation"], "steps": state_dict["steps"]}

#example = {"input": "Explain the theory of relativity?"}
example = {"input": "What is CRAG?"}
response = call_langgraph_agents(example)
print(response)

In [None]:
print("The answer is:\n" + response['response'])