In [None]:
import os
import warnings

from dotenv import load_dotenv

os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
warnings.filterwarnings("ignore")

load_dotenv()

In [None]:
from langchain_ollama import ChatOllama
model = "llama3.2:3b"
llm = ChatOllama(model=model, base_url="http://localhost:11434")
print(llm.invoke("What is the capital of France?"))

In [None]:
from langchain_ollama.embeddings import OllamaEmbeddings
import faiss
from langchain_community.vectorstores import FAISS
from langchain_community.docstore.in_memory import InMemoryDocstore

In [None]:
embeddings = OllamaEmbeddings(model="nomic-embed-text",
                              base_url="http://localhost:11434"
)

In [None]:
db_name = "health_supplements"
vector_store= FAISS.load_local(db_name,embeddings,allow_dangerous_deserialization=True)
retriever = vector_store.as_retriever(search_type="similarity",search_kwargs={"k": 5})

question = "how to gain muscle mass?"
retriever.invoke(question)

In [None]:
from langchain.tools.retriever import create_retriever_tool

retriever_tool = create_retriever_tool(retriever=retriever, name="health_supplements", description="Search and return information about the health supplements for workout and gym")


In [None]:
tools = [retriever_tool]

In [None]:
tools

In [None]:
from typing import Annotated,Sequence,TypedDict,Literal
from langchain_core.messages import BaseMessage
from langgraph.graph.message import add_messages

class State(TypedDict):
    messages: Annotated[Sequence[BaseMessage], add_messages]

In [None]:
from langchain import hub
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate

from pydantic import BaseModel, Field
from langgraph.prebuilt import tools_condition




In [None]:
def grade_documents(state)-> Literal["generate","rewrite"]:

    class grade(BaseModel):

        binary_score: str = Field(description="Relevance score 'yes'or 'no'")

    llm_with_structured_output = llm.with_structured_output(grade)

    prompt =PromptTemplate(
        template="""You are a grader assessing relevance of a retrieved document to a user question.\n
        Here is the retrieved document:\n\n {context} \n\n
        Here is the user question: {question} \n
        If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
        Give a binary score 'yes' or 'no' to indicate whether the document is relevant to the question.""",
        input_variables=["context","question"],
    )

    chain = prompt | llm_with_structured_output
    messages = state["messages"]
    last_message = messages[-1]

    question = messages[0].content
    docs = last_message.content

    scored_result = chain.invoke({"question": question, "context":docs})

    score = scored_result.binary_score

    if score == "yes":

        print("---DECISION: DOCS RELEVANT---")
        return "generate"
    
    else:   

        print("---DECISION: DOCS NOT RELEVANT---")
        print(score)
        return "rewrite"



In [None]:
def agent(state):

    """
    Invokes the agent to generate a response based on current stae.Given 
    the queston, it will decide to retrieve using the retriever tool, or simply end.

    Args:
        state (State): The current state

    Returns:
        dict: The updated state with the AGENT response appended to messages.
    
    """

    print("---CALL AGENT---")
    messages = state["messages"]

    llm_with_tools = llm.bind_tools(tools, tool_choice = "required")
    response = llm_with_tools.invoke(messages)

    return {"messages": [response]}

### Rewrite Node

In [None]:
def rewrite(state):
    """

    Transforms the query  to produce a better question.

    Args:
        state (messages)): The current state
    Returns:
        dict: The updated state with re-phrased question 

    """

    print("---Transform Query---")
    messages = state["messages"]
    question = messages[0].content

    msg = [
        HumanMessage(content = f"""\n
    Look at the input and try to reason about the underlying semantic intent / meaning. \ng
    Here is the initial question:
    \n ------------\n
    {question}
    \n ------------\n
    Formulate an improved questions: """,

        )
    ]

    response = llm.invoke(msg)
    return {"messages": [response]}


In [None]:
def generate(state):
    """

    Transforms the query  to produce a better question.

    Args:
        state (messages)): The current state
    Returns:
        dict: The updated state with re-phrased question 

    """

    print("---GENERATE---")
    messages = state["messages"]
    question = messages[0].content
    last_message = messages[-1]
    docs = last_message.content

    prompt = hub.pull("rlm/rag-prompt")

    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)
    
    rag_chain = prompt | llm | StrOutputParser()
    
    repsonse = rag_chain.invoke({"question": question, "context": docs})
    return {"messages": [repsonse]}

In [None]:
from langgraph.graph import END,START,StateGraph
from langgraph.prebuilt import ToolNode

In [None]:
graph_builder = StateGraph(State)

graph_builder.add_node("agent",agent)
retriever = ToolNode([retriever_tool])
graph_builder.add_node("retriever",retriever)
graph_builder.add_node("rewrite",rewrite)
graph_builder.add_node("generate", generate)

graph_builder.add_edge(START, "agent")
graph_builder.add_conditional_edges(
    "agent", 
    
    tools_condition, 
    {
         "tools":"retriever",
         END: END
    }
)
graph_builder.add_conditional_edges(
    "retriever",
    grade_documents)

graph_builder.add_edge("generate", END)
graph_builder.add_edge("rewrite", "agent")

graph = graph_builder.compile()

In [None]:
from IPython.display import Image, display

display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
from pprint import pprint

query= {"messages": [HumanMessage("How to gain muscle mass?")]}

for output in graph.stream(query):
    for key, value in output.items():
        pprint(f"Output from node '{key}':")
        pprint("-----")
        pprint(value, indent=4, width = 120)

    pprint("\n-------\n")