In [24]:
from logging import getLogger

from langchain.schema import AIMessage, HumanMessage, SystemMessage
from langchain_core.messages import RemoveMessage
from langchain_core.tools.structured import StructuredTool
from langchain_ollama import ChatOllama
from langgraph.prebuilt import ToolNode

from local_llm_tools.langfamily_agent.build_graph import build_graph
from local_llm_tools.langfamily_agent.utils import get_role_of_message
from local_llm_tools.tools import MATH_TOOLS, MATH_TOOLS_DS

logger = getLogger(__name__)


class ChatBot:
    def __init__(
        self,
        model_name: str,
        tools: list[StructuredTool],
        params: dict | None = None,
    ):
        self.model_name = model_name
        self.messages: list[AIMessage | HumanMessage | SystemMessage] = []
        self.messages_model: list[str | None] = []

        self.params: dict = {}
        if params is not None:
            self.params.update(params)

        self.tools = tools

        self._agent = None

    @property
    def agent(self):
        if not self.is_build():
            raise ValueError("graph is not built.")
        return self._agent

    def is_build(self):
        return self._agent is not None

    def set_params(self, **kwargs):
        """パラメータの更新"""
        if not kwargs:
            raise ValueError("One or more parameters are required.")

        self.params.update(kwargs)

    def build(self):
        llm = ChatOllama(model=self.model_name, **self.params, stream=True)
        llm = llm.bind_tools(self.tools)
        self._agent = build_graph(llm, ToolNode(self.tools))

    def chat_stream(
        self, user_input: str, config: dict, system_promt: list[str] | None = None
    ):
        if system_promt is None:
            messages = []
        else:
            messages = [{"role": "system", "content": system_promt}]
        messages.append({"role": "user", "content": user_input})

        for event in self.agent.stream(
            {"messages": messages},
            config,
            stream_mode="messages",
        ):
            # (AIMessageChunk, dict)
            yield event[0].content

    def delete_messages(self, message_idx: int, config: dict):
        """
        指定したindexまでのMessageを削除する
        """

        delete_messages = self.agent.get_state(config).values["messages"][message_idx:]
        _ = self.agent.update_state(
            config, {"messages": [RemoveMessage(id=msg.id) for msg in delete_messages]}
        )

    def reset_message(self):
        """
        Messageの初期化
        """
        self.build()

    def history(self, config):
        for msg in self._agent.get_state(config)[0]["messages"]:
            yield msg, msg.response_metadata.get("model", None), get_role_of_message(
                msg
            )

In [2]:
from logging import getLogger
from typing import Literal

from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode


def build_graph(llm, tool_node: ToolNode):
    def chat(state: MessagesState):
        logger.debug("Called chat node")
        return {"messages": [llm.invoke(state["messages"])]}

    graph_builder = StateGraph(MessagesState)

    # Nodes
    graph_builder.add_node("chat", chat)
    graph_builder.add_node("tools", tool_node)

    # Edge
    # 終了判定はshould_continueが持ってる
    graph_builder.add_edge(START, "chat")
    graph_builder.add_conditional_edges("chat", should_continue)
    graph_builder.add_edge("tools", "chat")

    # Memory
    memory = MemorySaver()
    graph = graph_builder.compile(checkpointer=memory)

    return graph

In [28]:
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import render_text_description

tools = MATH_TOOLS


rendered_tools = render_text_description(tools)
print(rendered_tools)

add(a: int | float, b: int | float) -> int | float - 足し算を行う関数

Args:
    a (int | float): 足し算を行う1つ目の値
    b (int | float): 足し算を行う2つ目の値

Returns:
    int | float: 足し算の結果
minus(a: int | float, b: int | float) -> int | float - 引き算を行う関数

Args:
    a (int | float): 引き算を行う1つ目の値
    b (int | float): 引き算を行う2つ目の値

Returns:
    int | float: 引き算の結果
multiply(a: int | float, b: int | float) -> int | float - 掛け算を行う関数

Args:
    a (int | float): 掛け算を行う1つ目の値
    b (int | float): 掛け算を行う2つ目の値

Returns:
    int | float: 掛け算の結果
divide(a: int | float, b: int | float) -> int | float - 割り算を行う関数

Args:
    a (int | float): 割り算を行う1つ目の値
    b (int | float): 割り算を行う2つ目の値

Returns:
    int | float: 割り算の結果


In [29]:
system_prompt = f"""\
You are an assistant that has access to the following set of tools. 
Here are the names and descriptions for each tool:

{rendered_tools}

Given the user input, return the name and input of the tool to use. 
Return your response as a JSON blob with 'name' and 'arguments' keys.

The `arguments` should be a dictionary, with keys corresponding 
to the argument names and the values corresponding to the requested values.
"""

prompt = ChatPromptTemplate.from_messages(
    [("system", system_prompt), ("user", "{input}")]
)

In [33]:
print(system_prompt)

You are an assistant that has access to the following set of tools. 
Here are the names and descriptions for each tool:

add(a: int | float, b: int | float) -> int | float - 足し算を行う関数

Args:
    a (int | float): 足し算を行う1つ目の値
    b (int | float): 足し算を行う2つ目の値

Returns:
    int | float: 足し算の結果
minus(a: int | float, b: int | float) -> int | float - 引き算を行う関数

Args:
    a (int | float): 引き算を行う1つ目の値
    b (int | float): 引き算を行う2つ目の値

Returns:
    int | float: 引き算の結果
multiply(a: int | float, b: int | float) -> int | float - 掛け算を行う関数

Args:
    a (int | float): 掛け算を行う1つ目の値
    b (int | float): 掛け算を行う2つ目の値

Returns:
    int | float: 掛け算の結果
divide(a: int | float, b: int | float) -> int | float - 割り算を行う関数

Args:
    a (int | float): 割り算を行う1つ目の値
    b (int | float): 割り算を行う2つ目の値

Returns:
    int | float: 割り算の結果

Given the user input, return the name and input of the tool to use. 
Return your response as a JSON blob with 'name' and 'arguments' keys.

The `arguments` should be a dictionary, with keys co

In [75]:
model = ChatOllama(model="gemma3:4b-it-fp16", temperature=0)

In [40]:
chain = prompt | model | JsonOutputParser()
output = chain.invoke({"input": "what's thirteen times 4"})
print(output)

{'name': 'multiply', 'arguments': {'a': 13.0, 'b': 4.0}}


In [13]:
from typing import Any, Dict, Optional, TypedDict

from langchain_core.runnables import RunnableConfig


class ToolCallRequest(TypedDict):
    """A typed dict that shows the inputs into the invoke_tool function."""

    name: str
    arguments: Dict[str, Any]


def invoke_tool(
    tool_call_request: ToolCallRequest, config: Optional[RunnableConfig] = None
):
    """A function that we can use the perform a tool invocation.

    Args:
        tool_call_request: a dict that contains the keys name and arguments.
            The name must match the name of a tool that exists.
            The arguments are the arguments to that tool.
        config: This is configuration information that LangChain uses that contains
            things like callbacks, metadata, etc.See LCEL documentation about RunnableConfig.

    Returns:
        output from the requested tool
    """
    tool_name_to_tool = {tool.name: tool for tool in tools}
    name = tool_call_request["name"]
    requested_tool = tool_name_to_tool[name]
    return requested_tool.invoke(tool_call_request["arguments"], config=config)

In [14]:
invoke_tool(output)

[DEBUG] local_llm_tools.langfamily_agent.tools.math 2025-03-13 22:05:17,726 - math.py: 19: Called multiply tool


52

In [15]:
chain = prompt | model | JsonOutputParser() | invoke_tool
chain.invoke({"input": "what's thirteen times 4.14137281"})

[DEBUG] local_llm_tools.langfamily_agent.tools.math 2025-03-13 22:06:07,814 - math.py: 19: Called multiply tool


53.83784653

In [16]:
from langchain_core.runnables import RunnablePassthrough

chain = (
    prompt | model | JsonOutputParser() | RunnablePassthrough.assign(output=invoke_tool)
)
chain.invoke({"input": "what's thirteen times 4.14137281"})

[DEBUG] local_llm_tools.langfamily_agent.tools.math 2025-03-13 22:07:50,748 - math.py: 19: Called multiply tool


{'name': 'multiply',
 'arguments': {'a': 13, 'b': 4.14137281},
 'output': 53.83784653}

In [143]:
model.model

'gemma3:4b-it-fp16'

In [135]:
system_prompt = f"""\
You are an assistant that has access to the following set of tools. 
Here are the names and descriptions for each tool:

{rendered_tools}

Given the user input, return the name and input of the tool to use. 
Return your response as a JSON blob with 'name' and 'arguments' keys.

The `arguments` should be a dictionary, with keys corresponding 
to the argument names and the values corresponding to the requested values.

If you cannnot undertand to use which tools, please response JSON blob with 'name' key is 'unknown' and 'arguments' key is empty dictionary.
"""


from typing import Annotated, TypedDict

from langgraph.graph.message import add_messages
from langgraph.types import Command


class MyMessageState(TypedDict):
    messages: Annotated[list, add_messages]
    tool_call_request: dict


def should_continue(state: MyMessageState) -> Literal["chat", "tools", END]:
    """
    ツールを選択する.
    """
    messages = state["messages"]
    last_message = messages[-1]
    # system promptはそのまま返却
    if get_role_of_message(last_message) == "system":
        goto = END
        update = None
    else:
        prompt = ChatPromptTemplate.from_messages(
            [("system", system_prompt), ("user", "{input}")]
        )
        model = ChatOllama(model="gemma3:4b-it-fp16", temperature=0, format="json")
        chain = prompt | model | JsonOutputParser()
        tool_call_request = chain.invoke({"input": last_message.content})

        goto = "chat" if tool_call_request["name"] == "unknown" else "tools"
        update = (
            None
            if tool_call_request["name"] == "unknown"
            else {"tool_call_request": tool_call_request}
        )

    return Command(
        update=update,
        goto=goto,
    )


def invoke_tool(state: MyMessageState, config: Optional[RunnableConfig] = None):
    """A function that we can use the perform a tool invocation.

    Args:
        tool_call_request: a dict that contains the keys name and arguments.
            The name must match the name of a tool that exists.
            The arguments are the arguments to that tool.
        config: This is configuration information that LangChain uses that contains
            things like callbacks, metadata, etc.See LCEL documentation about RunnableConfig.

    Returns:
        output from the requested tool
    """
    tool_call_request = state.get("tool_call_request")
    tool_name_to_tool = {tool.name: tool for tool in tools}
    name = tool_call_request["name"]
    requested_tool = tool_name_to_tool[name]
    return {
        "messages": [
            SystemMessage(
                f"Result of {name} is {requested_tool.invoke(tool_call_request['arguments'], config=config)}. Please use these results to answer user questions."
            )
        ]
    }


def build_graph(llm):
    def chat(state: MyMessageState):
        print("Called chat node")
        return {"messages": [llm.invoke(state["messages"])]}

    graph_builder = StateGraph(MyMessageState)

    # Nodes
    graph_builder.add_node("chat", chat)
    graph_builder.add_node("chat_end", chat)
    graph_builder.add_node("should_continue", should_continue)
    graph_builder.add_node("tools", invoke_tool)

    # Edge
    # 終了判定はshould_continueが持ってる
    graph_builder.add_edge(START, "should_continue")
    # graph_builder.add_edge("should_continue", "chat_end")
    # graph_builder.add_conditional_edges("chat", should_continue)
    graph_builder.add_edge("tools", "chat_end")
    graph_builder.add_edge("chat_end", END)

    # Memory
    memory = MemorySaver()
    graph = graph_builder.compile(checkpointer=memory)

    return graph

## ツールの選択をするNodeを作る？

In [136]:
graph = build_graph(model)
config = {"configurable": {"thread_id": "2"}}

output = graph.invoke(
    {"messages": [{"role": "user", "content": "what's thirteen times 4"}]}, config
)

Called should_continue node
content="what's thirteen times 4" additional_kwargs={} response_metadata={} id='4f462827-f885-449f-90c3-b9db714f76f4'
Called chat node


In [137]:
output

{'messages': [HumanMessage(content="what's thirteen times 4", additional_kwargs={}, response_metadata={}, id='4f462827-f885-449f-90c3-b9db714f76f4'),
  SystemMessage(content='Result of multiply is 52. Please use these results to answer user questions.', additional_kwargs={}, response_metadata={}, id='19f3ca80-de60-4982-b7ce-28fa09277002'),
  AIMessage(content='Okay! Thirteen times four is 52. 😊 \n\nHow can I help you with that result? Do you want to:\n\n*   Do another calculation using 52?\n*   Solve a word problem involving 52?\n*   Just confirm that I got it right?', additional_kwargs={}, response_metadata={'model': 'gemma3:4b-it-fp16', 'created_at': '2025-03-14T11:45:19.674833Z', 'done': True, 'done_reason': 'stop', 'total_duration': 3866281375, 'load_duration': 27268208, 'prompt_eval_count': 38, 'prompt_eval_duration': 104000000, 'eval_count': 62, 'eval_duration': 3733000000, 'message': Message(role='assistant', content='', images=None, tool_calls=None)}, id='run-d0e8ffd0-2854-421c

In [140]:
MATH_TOOLS[

[StructuredTool(name='add', description='足し算を行う関数\n\nArgs:\n    a (int | float): 足し算を行う1つ目の値\n    b (int | float): 足し算を行う2つ目の値\n\nReturns:\n    int | float: 足し算の結果', args_schema=<class 'langchain_core.utils.pydantic.add'>, func=<function add at 0x125a2e200>),
 StructuredTool(name='minus', description='引き算を行う関数\n\nArgs:\n    a (int | float): 引き算を行う1つ目の値\n    b (int | float): 引き算を行う2つ目の値\n\nReturns:\n    int | float: 引き算の結果', args_schema=<class 'langchain_core.utils.pydantic.minus'>, func=<function minus at 0x125a2e7a0>),
 StructuredTool(name='multiply', description='掛け算を行う関数\n\nArgs:\n    a (int | float): 掛け算を行う1つ目の値\n    b (int | float): 掛け算を行う2つ目の値\n\nReturns:\n    int | float: 掛け算の結果', args_schema=<class 'langchain_core.utils.pydantic.multiply'>, func=<function multiply at 0x125a2e520>),
 StructuredTool(name='divide', description='割り算を行う関数\n\nArgs:\n    a (int | float): 割り算を行う1つ目の値\n    b (int | float): 割り算を行う2つ目の値\n\nReturns:\n    int | float: 割り算の結果', args_schema=<class 'langchain_