In [1]:
from typing import TypedDict, Annotated
from langgraph.graph import add_messages, StateGraph, END
from langchain_groq import ChatGroq
from langchain_core.messages import AIMessage, HumanMessage
from dotenv import load_dotenv
from langchain_community.tools.tavily_search import TavilySearchResults
from langgraph.prebuilt import ToolNode

load_dotenv()

True

In [2]:
TOOL_NODE = "tool_node"
CHATBOT = "chatbot"

In [3]:
llm = ChatGroq(
    model="llama-3.1-8b-instant"
)

search_tool = TavilySearchResults(max_results=2)
tools = [search_tool]

llm_with_tools = llm.bind_tools(tools = tools)

In [4]:
class BasicChatBot(TypedDict):
    messages: Annotated[list, add_messages]

def chatbot(state: BasicChatBot):
    return {
        "messages": [llm_with_tools.invoke(state["messages"])]
    }

In [5]:
def tools_router(state: BasicChatBot):
    last_message = state["messages"][-1]
    
    # AI Message is an Object and hence has attributes
    if hasattr(last_message, "tool_calls") and len(last_message.tool_calls) > 0:
        return TOOL_NODE
    else:
        return END

In [6]:
# A node that runs the tools called in the last AIMessage.
# No need to invoke tool with the query (Handles automatically)
tool_node = ToolNode(tools = tools, messages_key = "messages")

graph = StateGraph(BasicChatBot)

graph.add_node(CHATBOT, chatbot)
graph.add_node(TOOL_NODE, tool_node)
graph.set_entry_point(CHATBOT)

graph.add_conditional_edges(CHATBOT, tools_router)
graph.add_edge(TOOL_NODE, CHATBOT)

app = graph.compile()

In [7]:
while True:
    user_input = input("User: ")
    if user_input.lower() in ["end", "exit"]:
        break
    else:
        result = app.invoke({
            "messages": [HumanMessage(content=user_input)]
        })
        print(result)

{'messages': [HumanMessage(content="Hi, I'm Sayak", additional_kwargs={}, response_metadata={}, id='1660fb71-f97a-4548-ac49-f5ca80d5abe4'), AIMessage(content="It's nice to meet you, Sayak. Is there anything I can help you with today?", additional_kwargs={}, response_metadata={'token_usage': {'completion_tokens': 21, 'prompt_tokens': 284, 'total_tokens': 305, 'completion_time': 0.031554253, 'prompt_time': 0.020503523, 'queue_time': 0.005677188, 'total_time': 0.052057776}, 'model_name': 'llama-3.1-8b-instant', 'system_fingerprint': 'fp_d834565e05', 'service_tier': 'on_demand', 'finish_reason': 'stop', 'logprobs': None, 'model_provider': 'groq'}, id='lc_run--2379e715-0c7b-496c-9b7f-d8a6bc0cc0d2-0', usage_metadata={'input_tokens': 284, 'output_tokens': 21, 'total_tokens': 305})]}
