In [None]:
from dotenv import load_dotenv

load_dotenv()

In [None]:
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model='gpt-4o')
small_llm = ChatOpenAI(model='gpt-4o-mini')

In [None]:
from langchain_core.tools import tool

@tool
def add(a: int, b: int) -> int:
    """숫자 a와 b를 더합니다."""
    return a + b

@tool
def multiply(a: int, b: int) -> int:
    """숫자 a와 b를 곱합니다."""
    return a * b

In [None]:
# 도구를 사용해보려면
# add.invoke({"a": 1, "b": 2})

In [None]:
from langgraph.prebuilt import ToolNode

tool_list = [add, multiply]
llm_with_tools = llm.bind_tools(tool_list)
tool_node = ToolNode(tool_list)

In [None]:
multiply.invoke({"a": 3, "b": 5})

In [None]:
ai_message = llm_with_tools.invoke("What is 3 plus 5?")
ai_message

In [None]:
tool_node.invoke({"messages": [ai_message]}) # list[AnyMessage], 마지막 AIMessage, tool_calls를 포함할 것

In [None]:
from langgraph.graph import MessagesState, StateGraph

graph_builder = StateGraph(MessagesState)

In [None]:
def agent(state: MessagesState):
    messages = state['messages']
    response = llm_with_tools.invoke(messages)
    return {'messages': [response]}

In [None]:
from langgraph.graph import END

def should_continue(state: MessagesState):
    messages = state['messages']
    last_ai_message = messages[-1]
    if last_ai_message.tool_calls:
        return 'tools'
    return END

In [None]:
graph_builder.add_node('agent', agent)
graph_builder.add_node('tools', tool_node)

In [None]:
from langgraph.graph import START, END

graph_builder.add_edge(START, 'agent')
graph_builder.add_conditional_edges(
    'agent',
    should_continue,
    ['tools', END],
)
graph_builder.add_edge('tools', 'agent')

In [None]:
graph = graph_builder.compile()

In [None]:
# %%capture --no-strerr

In [None]:
from IPython.display import Image, display

display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
from langchain_core.messages import HumanMessage

for chunk in graph.stream({'messages': [HumanMessage("3에다 5를 더하고 거기에 8을 곱하면?")]}, stream_mode='values'):
    chunk['messages'][-1].pretty_print()