Date: 11 Nov 2024

In [None]:
import sys

sys.path.append("..")

In [None]:
from dotenv import load_dotenv

_ = load_dotenv("../.env")

In [None]:
from langchain_ollama import ChatOllama
from langchain_anthropic import ChatAnthropic

from src.tools.glad.weekly_alerts_tool import glad_weekly_alerts_tool
from src.tools.location.tool import location_tool

In [None]:
!ollama list

In [None]:
!ollama list

In [None]:
tools = [location_tool, glad_weekly_alerts_tool]
llm = ChatOllama(model="qwen2.5:latest", temperature=0)
# llm = ChatAnthropic(model="claude-3-5-sonnet-20241022", temperature=0)
llm_with_tools = llm.bind_tools(tools)

In [None]:
from IPython.display import Image, display
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.graph import START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition

sys_msg = SystemMessage(content="""You are a helpful assistant tasked with answering the user queries for WRI data API.
Use the `location-tool` to get iso, adm1 & adm2 of any region or place.
Use the `glad-weekly-alerts-tool` to get forest fire information for a particular year. Think through the solution step-by-step first and then execute.

For eg: If the query is "Find forest fires in Milan for the year 2024"
Steps
1. Use the `location_tool` to get iso, adm1, adm2 for place `Milan` by passing `query=Milan`
2. Pass iso, adm1, adm2 along with year `2024` as args to `glad-weekly-alerts-tool` to get information about forest fire alerts.
""")

In [None]:
def assistant(state: MessagesState):
    return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]}

In [None]:
# Graph
builder = StateGraph(MessagesState)

# Define nodes: these do the work
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))

# Define edges: these determine how the control flow moves
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
    "assistant",
    tools_condition,
)
builder.add_edge("tools", "assistant")
react_graph = builder.compile()

# Show
display(Image(react_graph.get_graph(xray=True).draw_mermaid_png()))

In [None]:
messages = [HumanMessage(content="find forest fires in Ihorombe for the year 2021")]
messages = react_graph.invoke({"messages": messages})

In [None]:
for m in messages["messages"]:
    m.pretty_print()