In [None]:
from langgraph.graph import StateGraph, MessagesState, START, END
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode
from langchain_google_vertexai.model_garden import ChatAnthropicVertex


def should_continue(state: MessagesState):
    messages = state["messages"]
    last_message = messages[-1]
    if last_message.tool_calls:
        return "tools"
    return END


@tool
def get_weather(location: str):
    """Call to get the current weather."""
    if location.lower() in ["sf", "san francisco"]:
        return "It's 60 degrees and foggy."
    else:
        return "It's 90 degrees and sunny."


@tool
def get_coolest_cities():
    """Get a list of coolest cities"""
    return "nyc, sf"


tools = [get_weather, get_coolest_cities]
tool_node = ToolNode(tools)


graph_builder = StateGraph(MessagesState)

# llm = ChatVertexAI(model="gemini-2.0-flash-001", temperature=0).bind_tools(tools)
llm = ChatAnthropicVertex(
    model="claude-3-5-sonnet-v2@20241022", location="us-east5", temperature=0
).bind_tools(tools)


def chatbot(state: MessagesState):
    return {"messages": [llm.invoke(state["messages"])]}


# The first argument is the unique node name
# The second argument is the function or object that will be called whenever
# the node is used.
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_node("tools", tool_node)
graph_builder.add_edge(START, "chatbot")
graph_builder.add_conditional_edges("chatbot", should_continue, ["tools", END])
graph_builder.add_edge("tools", "chatbot")


<langgraph.graph.state.StateGraph at 0x14780d940>

In [2]:
from langgraph.checkpoint.memory import MemorySaver

graph = graph_builder.compile(checkpointer=MemorySaver())

In [20]:
from uuid import uuid4
from rich import print
from langchain_core.messages.utils import convert_to_openai_messages

thread_id = str(uuid4())

inputs = {
    "messages": [("human", "What are the coolest cities? and the weather in there?")]
}
config = {"configurable": {"thread_id": thread_id}}

for msg, metadata in graph.stream(inputs, config, stream_mode="messages"):
    print(msg)
    try:
        print(convert_to_openai_messages(msg))
    except:
        "Failed"
