In [None]:
from dotenv import load_dotenv

load_dotenv()

In [None]:
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model='gpt-4o')
small_llm = ChatOpenAI(model='gpt-4o-mini')

In [None]:
from langchain_core.tools import tool

@tool
def add(a: int, b: int) -> int:
    """숫자 a와 b를 더합니다."""
    return a + b

@tool
def multiply(a: int, b: int) -> int:
    """숫자 a와 b를 곱합니다."""
    return a * b

In [None]:
from langchain_community.tools import DuckDuckGoSearchRun

search_tool = DuckDuckGoSearchRun()

### 구글 메일 발송 Tools

In [None]:
# from langchain_google_community import GmailToolkit
# from langchain_google_community.gmail.utils import (
#     build_resource_service,
#     get_gmail_credentials,
# )

# # Can review scopes here https://developers.google.com/gmail/api/auth/scopes
# # For instance, readonly scope is 'https://www.googleapis.com/auth/gmail.readonly'
# credentials = get_gmail_credentials(
#     token_file="./google/token.json",
#     scopes=["https://mail.google.com/"],
#     client_secrets_file="./google/credentials.json",
# )
# api_resource = build_resource_service(credentials=credentials)
# gmail_toolkit = GmailToolkit(api_resource=api_resource)
# gmail_tool_list = gmail_toolkit.get_tools()

### ArXiv Tools

In [None]:
from langchain.agents import load_tools

loaded_tool_list = load_tools(
    ["arxiv"],
)

#### Retriever Tools

In [None]:
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain_core.tools.retriever import create_retriever_tool

embedding_function = OpenAIEmbeddings(model="text-embedding-3-large")

vector_store = Chroma(
    embedding_function=embedding_function,
    collection_name = 'real_estate_tax',
    persist_directory = './real_estate_tax_collection'
)
retriever = vector_store.as_retriever(search_kwargs={"k": 3})
retriever_tool = create_retriever_tool(
    retriever=retriever,
    name="real_estate_tax_retriever",
    description="Contains information about real estate tax up to December 2024",
)

In [None]:
from langgraph.prebuilt import ToolNode

tool_list = [add, multiply, search_tool] # + gmail_tool_list
tool_list += loaded_tool_list
tool_list += [retriever_tool]
llm_with_tools = llm.bind_tools(tool_list)
tool_node = ToolNode(tool_list)

In [None]:
# multiply.invoke({"a": 3, "b": 5})

In [None]:
ai_message = llm_with_tools.invoke("What is 3 plus 5?")
ai_message

In [None]:
tool_node.invoke({"messages": [ai_message]}) # list[AnyMessage], 마지막 AIMessage, tool_calls를 포함할 것

In [None]:
from langgraph.graph import MessagesState, StateGraph

graph_builder = StateGraph(MessagesState)

In [None]:
def agent(state: MessagesState):
    messages = state['messages']
    response = llm_with_tools.invoke(messages)
    return {'messages': [response]}

In [None]:
from langgraph.graph import END

def should_continue(state: MessagesState):
    messages = state['messages']
    last_ai_message = messages[-1]
    if last_ai_message.tool_calls:
        return 'tools'
    return END

In [None]:
graph_builder.add_node('agent', agent)
graph_builder.add_node('tools', tool_node)

In [None]:
from langgraph.graph import START, END
from langgraph.prebuilt import tools_condition

graph_builder.add_edge(START, 'agent')
# graph_builder.add_conditional_edges(
#     'agent',
#     should_continue,
#     ['tools', END],
# )
graph_builder.add_conditional_edges(
    'agent',
    tools_condition,
)
graph_builder.add_edge('tools', 'agent')

In [None]:
graph = graph_builder.compile()

In [None]:
# %%capture --no-strerr

In [None]:
from IPython.display import Image, display

display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
from langchain_core.messages import HumanMessage

# query = "What currency is in Billy Giles\' birthplace?"
# query = "Attention Is All You Need라는 논문을 요약해서 설명해줘."
query = "집이 15억일 때 종합부동산세를 계산해줄 수 있나요?"

for chunk in graph.stream({'messages': [HumanMessage(query)]}, stream_mode='values'):
    chunk['messages'][-1].pretty_print()