In [59]:
# Basic Imports
import os
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())


In [None]:
# Define RAG tool
from langchain_community.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.tools.retriever import create_retriever_tool


urls = [
        "https://research.google/blog/",
        "https://deepmind.google/discover/blog/"
    ]
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=100, chunk_overlap=50
)
doc_splits = text_splitter.split_documents(docs_list)
vectorstore = Chroma.from_documents(
    documents=doc_splits, 
    collection_name="rag-chroma", 
    embedding=OpenAIEmbeddings()
)
retriever = vectorstore.as_retriever()

retriever_tool = create_retriever_tool(
    retriever,
    "retrieve_blog_posts",
    "Search and return information about latest AI news from the web urls provided about Anthropic, Google Research and AWS."
)

rag_tool = [retriever_tool]
retrieve = ToolNode([retriever_tool])


In [61]:
# Bind RAG tool to LLM
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash")
llm_with_tools = llm.bind_tools(rag_tool)

In [50]:
# Define Functions for State Graph
from langgraph.graph import END, START, StateGraph, MessagesState
from langchain_core.messages import  AnyMessage, SystemMessage, HumanMessage, ToolMessage, AIMessage
from typing import Literal
from langgraph.checkpoint.memory import MemorySaver

# Function to decide to call tool or not
def decide_to_call_tool(state: MessagesState) -> Literal["tools", END]:
    """ Decide to call tool or not """
    if state["messages"][-1].tool_calls:
        return "tools"
    else:
        return END

# Function to call model
def call_model(state: MessagesState) -> MessagesState:
    """ Call model """
    model_response = llm_with_tools.invoke(state["messages"])
    return state | {"messages": [model_response]}

rag_graph = StateGraph(MessagesState)
rag_graph.add_node("agent", call_model)
rag_graph.add_node("tools", retrieve)
rag_graph.add_edge(START, "agent")
rag_graph.add_conditional_edges("agent", decide_to_call_tool)
rag_graph.add_edge("tools", "agent")

checkpointer = MemorySaver()
rag_app = rag_graph.compile(checkpointer=checkpointer)


In [None]:
# Invoke LLM by calling RAG app
user_msg = input("User: ")
messages = [HumanMessage(content=user_msg)]
thread_id = {"configurable": {"thread_id":"1"}}
ai_events = rag_app.invoke({"messages": messages},thread_id)

for msg in ai_events["messages"]:
     msg.pretty_print()