<a href="https://colab.research.google.com/github/ranabilal09/Self-Corrective-Rag-Chatbot/blob/main/self_reflective_rag.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
!pip install --quiet --upgrade langchain langgraph langchain_community langchain-google-genai langchain_chroma langchain_huggingface

In [14]:
from google.colab import userdata
import os
os.environ['LANGCHAIN_PROJECT'] = "self-reflective-rag"
os.environ["LANGCHAI_TRACING_V2"] = "true"
os.environ['LANGCHAIN_API_KEY'] = userdata.get('langchai_api_key')
google_api_key = userdata.get('Gemini_Api_Key')
os.environ['GOOGLE_API_KEY'] = google_api_key
os.environ["TAVILY_API_KEY"] = userdata.get('tavily_api_key')
os.environ["HUGGINGFACEHUB_API_TOKEN"] = userdata.get('HF_TOKEN')

In [15]:
from langchain_community.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain_google_genai import ChatGoogleGenerativeAI

#load
urls = ["https://python.langchain.com/docs/tutorials/graph/",
        "https://python.langchain.com/docs/how_to/#retrievers",
        "https://python.langchain.com/v0.1/docs/modules/data_connection/retrievers/"]

docs = [WebBaseLoader(url).load_and_split() for url in urls]
docs_split = [items for sublist in docs for items in sublist]

#split
spliting = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
docs_split = spliting.split_documents(docs_split)

#embedding
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

#vectorestore

vectorstore = Chroma.from_documents(docs_split, embeddings)



In [17]:
retriever = vectorstore.as_retriever()

In [18]:
from typing import TypedDict , Dict
from langchain_core.messages import BaseMessage

class AgentState(TypedDict):
  keys: Dict[str , any]

In [36]:
from langchain import hub
from langchain_core.output_parsers import StrOutputParser , PydanticOutputParser
from langchain_core.pydantic_v1 import BaseModel , Field
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.tools import tool
from langchain.output_parsers.openai_tools import PydanticToolsParser
from langchain_core.prompts import PromptTemplate


#nodes

def retrieve(state):
  print("----RETRIEVE----")
  state_dict = state["keys"]
  question = state_dict["question"]
  documents = retriever.get_relevant_documents(question)
  return {"keys": {"documents": documents , "question": question}}

def generation(state):
  print("----GENERATION----")
  state_dict = state["keys"]
  question = state_dict["question"]
  documents = retriever.get_relevant_documents(question)

  #prompt
  prompt = hub.pull("rlm/rag-prompt")

  #llm
  llm = ChatGoogleGenerativeAI(
      model= "gemini-1.5-flash-8b"
  )
  rag_chain = prompt | llm | StrOutputParser()
  #generation
  generation = rag_chain.invoke({"context": documents , "question": question})
  return {"keys": {"generation": generation , "documents" : documents , "question": question}}

def grade_documents(state):
  print("----Check Relevance----")
  state_dict= state["keys"]
  question = state_dict["question"]
  documents = state_dict["documents"]

  class grade(BaseModel):
    """ check the relevance documents"""

    binary_score: str = Field(
        description=("Check binary score 'yes' or 'no' ")
    )

  #llm
  llm = ChatGoogleGenerativeAI(
    model= "gemini-1.5-flash-8b"
  )

  #prompt
  prompt = PromptTemplate(
      template= """You are a grader accessing relevance of retrieved documents to the user question.\n
      Here is the retrieved documnets:\n{context}.\n
      Here is the user Question:{question}.\n
      If the documents contain keyword(s) or semantic meaning relative to the user question ,grade them as relevant.\n
      Give a relevance score 'yes' or 'no' score for all documents to indicate that weather all documents are relevant.\n
      your response should be in json format:""",
      input_variables=["context" , "question"]
  )

  chain= prompt | llm.with_structured_output(grade,include_raw=True)

  search = "no"
  filtered_docs=[]
  for d in documents:
    score = chain.invoke({"context": d.page_content , "question": question})
    if isinstance(score, tuple) and len(score) > 0 and hasattr(score[0], 'binary_score'):
      if score[0].binary_score == "yes":
            filtered_docs.append(d)
      else:
            print("----Documents are not relevant")
            search = "yes"
    else:
        # If structured output parsing failed, print message and potentially log the issue
        print("----Could not parse relevance score, skipping document")
        # Optionally, you can log the raw 'score' value for debugging
  return {"keys": {"documents": filtered_docs , "question": question , "search": search}}

In [37]:
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain.schema import Document

def translate_query(state):
  print("----Translate Query----")
  state_dict = state["keys"]
  question = state_dict["question"]
  documents = state_dict["documents"]

  #llm
  llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-8b")

  #prompt
  prompt = PromptTemplate(
      template=""" You are generating question that is well optimized for retrieval.\n
      Look at the input and try to reason about underlying semantic intent / meanings.\n
      Here is the initial question:
      \n-------\n
      {question}
      \n-------\n
      Formulate an improved question:""",
      input_variables=["question"]
  )

  #chain
  chain= prompt | llm | StrOutputParser()
  new_question = chain.invoke({"question": question})
  return {"keys": {"question": new_question , "documents": documents}}

def web_search(state):
  print("----Web Search----")
  state_dict = state["keys"]
  question = state_dict["question"]
  documents = state_dict["documents"]

  tavily = TavilySearchResults(max_results=1)
  tavily_search= tavily.invoke(question)
  web_results = "\n".join([d["content"] for d in tavily_search])
  web_results = Document(page_content=web_results)
  documents.append(web_results)
  return {"keys": {"documents": documents , "question": question}}

def decide(state):
  print("----Decide----")
  state_dict = state["keys"]
  question = state_dict["question"]
  documents = state_dict["documents"]
  search = state_dict["search"]

  if search == "yes":
    print("----DECISION: Translate Query and Search the Web")
    return "translate"
  else:
    print("----DECISION: Generation")
    return "generation"



In [41]:
from langgraph.graph import StateGraph ,END

graph = StateGraph(AgentState)

graph.add_node("retrieve" , retrieve) #retrieve
graph.add_node("generation", generation) #generation
graph.add_node("grade_documents", grade_documents) #grade_documents
graph.add_node("translate_query", translate_query) #translate
graph.add_node("web_search", web_search) #web_search


graph.set_entry_point("retrieve")
graph.add_edge("retrieve" , "grade_documents")
graph.add_conditional_edges(
    "grade_documents",
    decide,
    {
        "translate": "translate_query",
        "generation": "generation"
    }
)
graph.add_edge("translate_query", "web_search" )
graph.add_edge("web_search","generation")
graph.add_edge("generation", END)

app = graph.compile()

In [None]:
inputs = "how to retrieve relevant documents from vectorestore using langgraph?"
for output in app.stream({"keys": {"question":inputs}},{"recursion_limit":150}):
  for key, value in output.items():
    print(f"{key}: {value}")