diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index e13b9f6d8..232e2ca2a 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -9,9 +9,7 @@ 2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")` """ -import json import logging -import random import warnings from typing import ( TYPE_CHECKING, @@ -52,16 +50,16 @@ from ..session.session_manager import SessionManager from ..telemetry.metrics import EventLoopMetrics from ..telemetry.tracer import get_tracer, serialize +from ..tools._caller import _ToolCaller from ..tools.executors import ConcurrentToolExecutor from ..tools.executors._executor import ToolExecutor from ..tools.registry import ToolRegistry from ..tools.structured_output._structured_output_context import StructuredOutputContext from ..tools.watcher import ToolWatcher -from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, ToolInterruptEvent, TypedEvent +from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent from ..types.agent import AgentInput from ..types.content import ContentBlock, Message, Messages, SystemContentBlock from ..types.exceptions import ContextWindowOverflowException -from ..types.tools import ToolResult, ToolUse from ..types.traces import AttributeValue from .agent_result import AgentResult from .conversation_manager import ( @@ -101,114 +99,8 @@ class Agent: 6. Produces a final response """ - class ToolCaller: - """Call tool as a function.""" - - def __init__(self, agent: "Agent") -> None: - """Initialize instance. - - Args: - agent: Agent reference that will accept tool results. - """ - # WARNING: Do not add any other member variables or methods as this could result in a name conflict with - # agent tools and thus break their execution. - self._agent = agent - - def __getattr__(self, name: str) -> Callable[..., Any]: - """Call tool as a function. - - This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`). - It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing'). - - Args: - name: The name of the attribute (tool) being accessed. - - Returns: - A function that when called will execute the named tool. - - Raises: - AttributeError: If no tool with the given name exists or if multiple tools match the given name. - """ - - def caller( - user_message_override: Optional[str] = None, - record_direct_tool_call: Optional[bool] = None, - **kwargs: Any, - ) -> Any: - """Call a tool directly by name. - - Args: - user_message_override: Optional custom message to record instead of default - record_direct_tool_call: Whether to record direct tool calls in message history. Overrides class - attribute if provided. - **kwargs: Keyword arguments to pass to the tool. - - Returns: - The result returned by the tool. - - Raises: - AttributeError: If the tool doesn't exist. - """ - if self._agent._interrupt_state.activated: - raise RuntimeError("cannot directly call tool during interrupt") - - normalized_name = self._find_normalized_tool_name(name) - - # Create unique tool ID and set up the tool request - tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" - tool_use: ToolUse = { - "toolUseId": tool_id, - "name": normalized_name, - "input": kwargs.copy(), - } - tool_results: list[ToolResult] = [] - invocation_state = kwargs - - async def acall() -> ToolResult: - async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): - if isinstance(event, ToolInterruptEvent): - self._agent._interrupt_state.deactivate() - raise RuntimeError("cannot raise interrupt in direct tool call") - - tool_result = tool_results[0] - - if record_direct_tool_call is not None: - should_record_direct_tool_call = record_direct_tool_call - else: - should_record_direct_tool_call = self._agent.record_direct_tool_call - - if should_record_direct_tool_call: - # Create a record of this tool execution in the message history - await self._agent._record_tool_execution(tool_use, tool_result, user_message_override) - - return tool_result - - tool_result = run_async(acall) - self._agent.conversation_manager.apply_management(self._agent) - return tool_result - - return caller - - def _find_normalized_tool_name(self, name: str) -> str: - """Lookup the tool represented by name, replacing characters with underscores as necessary.""" - tool_registry = self._agent.tool_registry.registry - - if tool_registry.get(name, None): - return name - - # If the desired name contains underscores, it might be a placeholder for characters that can't be - # represented as python identifiers but are valid as tool names, such as dashes. In that case, find - # all tools that can be represented with the normalized name - if "_" in name: - filtered_tools = [ - tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name - ] - - # The registry itself defends against similar names, so we can just take the first match - if filtered_tools: - return filtered_tools[0] - - raise AttributeError(f"Tool '{name}' not found") + # For backwards compatibility + ToolCaller = _ToolCaller def __init__( self, @@ -347,7 +239,7 @@ def __init__( else: self.state = AgentState() - self.tool_caller = Agent.ToolCaller(self) + self.tool_caller = _ToolCaller(self) self.hooks = HookRegistry() @@ -395,7 +287,7 @@ def system_prompt(self, value: str | list[SystemContentBlock] | None) -> None: self._system_prompt, self._system_prompt_content = self._initialize_system_prompt(value) @property - def tool(self) -> ToolCaller: + def tool(self) -> _ToolCaller: """Call tool as a function. Returns: @@ -854,71 +746,6 @@ async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: raise ValueError("Input prompt must be of type: `str | list[Contentblock] | Messages | None`.") return messages - async def _record_tool_execution( - self, - tool: ToolUse, - tool_result: ToolResult, - user_message_override: Optional[str], - ) -> None: - """Record a tool execution in the message history. - - Creates a sequence of messages that represent the tool execution: - - 1. A user message describing the tool call - 2. An assistant message with the tool use - 3. A user message with the tool result - 4. An assistant message acknowledging the tool call - - Args: - tool: The tool call information. - tool_result: The result returned by the tool. - user_message_override: Optional custom message to include. - """ - # Filter tool input parameters to only include those defined in tool spec - filtered_input = self._filter_tool_parameters_for_recording(tool["name"], tool["input"]) - - # Create user message describing the tool call - input_parameters = json.dumps(filtered_input, default=lambda o: f"<>") - - user_msg_content: list[ContentBlock] = [ - {"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {input_parameters}\n")} - ] - - # Add override message if provided - if user_message_override: - user_msg_content.insert(0, {"text": f"{user_message_override}\n"}) - - # Create filtered tool use for message history - filtered_tool: ToolUse = { - "toolUseId": tool["toolUseId"], - "name": tool["name"], - "input": filtered_input, - } - - # Create the message sequence - user_msg: Message = { - "role": "user", - "content": user_msg_content, - } - tool_use_msg: Message = { - "role": "assistant", - "content": [{"toolUse": filtered_tool}], - } - tool_result_msg: Message = { - "role": "user", - "content": [{"toolResult": tool_result}], - } - assistant_msg: Message = { - "role": "assistant", - "content": [{"text": f"agent.tool.{tool['name']} was called."}], - } - - # Add to message history - await self._append_message(user_msg) - await self._append_message(tool_use_msg) - await self._append_message(tool_result_msg) - await self._append_message(assistant_msg) - def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span: """Starts a trace span for the agent. @@ -960,25 +787,6 @@ def _end_agent_trace_span( self.tracer.end_agent_span(**trace_attributes) - def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]: - """Filter input parameters to only include those defined in the tool specification. - - Args: - tool_name: Name of the tool to get specification for - input_params: Original input parameters - - Returns: - Filtered parameters containing only those defined in tool spec - """ - all_tools_config = self.tool_registry.get_all_tools_config() - tool_spec = all_tools_config.get(tool_name) - - if not tool_spec or "inputSchema" not in tool_spec: - return input_params.copy() - - properties = tool_spec["inputSchema"]["json"]["properties"] - return {k: v for k, v in input_params.items() if k in properties} - def _initialize_system_prompt( self, system_prompt: str | list[SystemContentBlock] | None ) -> tuple[str | None, list[SystemContentBlock] | None]: diff --git a/src/strands/tools/_caller.py b/src/strands/tools/_caller.py new file mode 100644 index 000000000..fc7a3efb9 --- /dev/null +++ b/src/strands/tools/_caller.py @@ -0,0 +1,215 @@ +"""Support direct tool calls through agent. + +Example: + ``` + agent = Agent(tools=[my_tool]) + agent.tool.my_tool() + ``` +""" + +import json +import random +from typing import TYPE_CHECKING, Any, Callable + +from .._async import run_async +from ..tools.executors._executor import ToolExecutor +from ..types._events import ToolInterruptEvent +from ..types.content import ContentBlock, Message +from ..types.tools import ToolResult, ToolUse + +if TYPE_CHECKING: + from ..agent import Agent + + +class _ToolCaller: + """Call tool as a function.""" + + def __init__(self, agent: "Agent") -> None: + """Initialize instance. + + Args: + agent: Agent reference that will accept tool results. + """ + # WARNING: Do not add any other member variables or methods as this could result in a name conflict with + # agent tools and thus break their execution. + self._agent = agent + + def __getattr__(self, name: str) -> Callable[..., Any]: + """Call tool as a function. + + This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`). + It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing'). + + Args: + name: The name of the attribute (tool) being accessed. + + Returns: + A function that when called will execute the named tool. + + Raises: + AttributeError: If no tool with the given name exists or if multiple tools match the given name. + """ + + def caller( + user_message_override: str | None = None, + record_direct_tool_call: bool | None = None, + **kwargs: Any, + ) -> Any: + """Call a tool directly by name. + + Args: + user_message_override: Optional custom message to record instead of default + record_direct_tool_call: Whether to record direct tool calls in message history. Overrides class + attribute if provided. + **kwargs: Keyword arguments to pass to the tool. + + Returns: + The result returned by the tool. + + Raises: + AttributeError: If the tool doesn't exist. + """ + if self._agent._interrupt_state.activated: + raise RuntimeError("cannot directly call tool during interrupt") + + normalized_name = self._find_normalized_tool_name(name) + + # Create unique tool ID and set up the tool request + tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" + tool_use: ToolUse = { + "toolUseId": tool_id, + "name": normalized_name, + "input": kwargs.copy(), + } + tool_results: list[ToolResult] = [] + invocation_state = kwargs + + async def acall() -> ToolResult: + async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): + if isinstance(event, ToolInterruptEvent): + self._agent._interrupt_state.deactivate() + raise RuntimeError("cannot raise interrupt in direct tool call") + + tool_result = tool_results[0] + + if record_direct_tool_call is not None: + should_record_direct_tool_call = record_direct_tool_call + else: + should_record_direct_tool_call = self._agent.record_direct_tool_call + + if should_record_direct_tool_call: + # Create a record of this tool execution in the message history + await self._record_tool_execution(tool_use, tool_result, user_message_override) + + return tool_result + + tool_result = run_async(acall) + self._agent.conversation_manager.apply_management(self._agent) + return tool_result + + return caller + + def _find_normalized_tool_name(self, name: str) -> str: + """Lookup the tool represented by name, replacing characters with underscores as necessary.""" + tool_registry = self._agent.tool_registry.registry + + if tool_registry.get(name, None): + return name + + # If the desired name contains underscores, it might be a placeholder for characters that can't be + # represented as python identifiers but are valid as tool names, such as dashes. In that case, find + # all tools that can be represented with the normalized name + if "_" in name: + filtered_tools = [ + tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name + ] + + # The registry itself defends against similar names, so we can just take the first match + if filtered_tools: + return filtered_tools[0] + + raise AttributeError(f"Tool '{name}' not found") + + async def _record_tool_execution( + self, + tool: ToolUse, + tool_result: ToolResult, + user_message_override: str | None, + ) -> None: + """Record a tool execution in the message history. + + Creates a sequence of messages that represent the tool execution: + + 1. A user message describing the tool call + 2. An assistant message with the tool use + 3. A user message with the tool result + 4. An assistant message acknowledging the tool call + + Args: + tool: The tool call information. + tool_result: The result returned by the tool. + user_message_override: Optional custom message to include. + """ + # Filter tool input parameters to only include those defined in tool spec + filtered_input = self._filter_tool_parameters_for_recording(tool["name"], tool["input"]) + + # Create user message describing the tool call + input_parameters = json.dumps(filtered_input, default=lambda o: f"<>") + + user_msg_content: list[ContentBlock] = [ + {"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {input_parameters}\n")} + ] + + # Add override message if provided + if user_message_override: + user_msg_content.insert(0, {"text": f"{user_message_override}\n"}) + + # Create filtered tool use for message history + filtered_tool: ToolUse = { + "toolUseId": tool["toolUseId"], + "name": tool["name"], + "input": filtered_input, + } + + # Create the message sequence + user_msg: Message = { + "role": "user", + "content": user_msg_content, + } + tool_use_msg: Message = { + "role": "assistant", + "content": [{"toolUse": filtered_tool}], + } + tool_result_msg: Message = { + "role": "user", + "content": [{"toolResult": tool_result}], + } + assistant_msg: Message = { + "role": "assistant", + "content": [{"text": f"agent.tool.{tool['name']} was called."}], + } + + # Add to message history + await self._agent._append_message(user_msg) + await self._agent._append_message(tool_use_msg) + await self._agent._append_message(tool_result_msg) + await self._agent._append_message(assistant_msg) + + def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]: + """Filter input parameters to only include those defined in the tool specification. + + Args: + tool_name: Name of the tool to get specification for + input_params: Original input parameters + + Returns: + Filtered parameters containing only those defined in tool spec + """ + all_tools_config = self._agent.tool_registry.get_all_tools_config() + tool_spec = all_tools_config.get(tool_name) + + if not tool_spec or "inputSchema" not in tool_spec: + return input_params.copy() + + properties = tool_spec["inputSchema"]["json"]["properties"] + return {k: v for k, v in input_params.items() if k in properties} diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 76aeadeff..ea6b09b75 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -33,12 +33,6 @@ FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID.format("us") -@pytest.fixture -def mock_randint(): - with unittest.mock.patch.object(strands.agent.agent.random, "randint") as mock: - yield mock - - @pytest.fixture def mock_model(request): async def stream(*args, **kwargs): @@ -803,93 +797,6 @@ async def test_agent_invoke_async(mock_model, agent, agenerator): assert tru_message == exp_message -def test_agent_tool(mock_randint, agent): - conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager) - agent.conversation_manager = conversation_manager_spy - - mock_randint.return_value = 1 - - tru_result = agent.tool.tool_decorated(random_string="abcdEfghI123") - exp_result = { - "content": [ - { - "text": "abcdEfghI123", - }, - ], - "status": "success", - "toolUseId": "tooluse_tool_decorated_1", - } - - assert tru_result == exp_result - conversation_manager_spy.apply_management.assert_called_with(agent) - - -@pytest.mark.asyncio -async def test_agent_tool_in_async_context(mock_randint, agent): - mock_randint.return_value = 123 - - tru_result = agent.tool.tool_decorated(random_string="abcdEfghI123") - exp_result = { - "content": [ - { - "text": "abcdEfghI123", - }, - ], - "status": "success", - "toolUseId": "tooluse_tool_decorated_123", - } - - assert tru_result == exp_result - - -def test_agent_tool_user_message_override(agent): - agent.tool.tool_decorated(random_string="abcdEfghI123", user_message_override="test override") - - tru_message = agent.messages[0] - exp_message = { - "content": [ - { - "text": "test override\n", - }, - { - "text": ( - 'agent.tool.tool_decorated direct tool call.\nInput parameters: {"random_string": "abcdEfghI123"}\n' - ), - }, - ], - "role": "user", - } - - assert tru_message == exp_message - - -def test_agent_tool_do_not_record_tool(agent): - agent.record_direct_tool_call = False - agent.tool.tool_decorated(random_string="abcdEfghI123", user_message_override="test override") - - tru_messages = agent.messages - exp_messages = [] - - assert tru_messages == exp_messages - - -def test_agent_tool_do_not_record_tool_with_method_override(agent): - agent.record_direct_tool_call = True - agent.tool.tool_decorated( - random_string="abcdEfghI123", user_message_override="test override", record_direct_tool_call=False - ) - - tru_messages = agent.messages - exp_messages = [] - - assert tru_messages == exp_messages - - -def test_agent_tool_tool_does_not_exist(agent): - with pytest.raises(AttributeError): - agent.tool.does_not_exist() - - @pytest.mark.parametrize("tools", [None, [tool_decorated]], indirect=True) def test_agent_tool_names(tools, agent): actual = agent.tool_names @@ -904,45 +811,6 @@ def test_agent_init_with_no_model_or_model_id(): assert agent.model.get_config().get("model_id") == FORMATTED_DEFAULT_MODEL_ID -def test_agent_tool_no_parameter_conflict(agent, tool_registry, mock_randint, agenerator): - @strands.tools.tool(name="system_prompter") - def function(system_prompt: str) -> str: - return system_prompt - - agent.tool_registry.register_tool(function) - - mock_randint.return_value = 1 - - tru_result = agent.tool.system_prompter(system_prompt="tool prompt") - exp_result = {"toolUseId": "tooluse_system_prompter_1", "status": "success", "content": [{"text": "tool prompt"}]} - assert tru_result == exp_result - - -def test_agent_tool_with_name_normalization(agent, tool_registry, mock_randint, agenerator): - tool_name = "system-prompter" - - @strands.tools.tool(name=tool_name) - def function(system_prompt: str) -> str: - return system_prompt - - agent.tool_registry.register_tool(function) - - mock_randint.return_value = 1 - - tru_result = agent.tool.system_prompter(system_prompt="tool prompt") - exp_result = {"toolUseId": "tooluse_system_prompter_1", "status": "success", "content": [{"text": "tool prompt"}]} - assert tru_result == exp_result - - -def test_agent_tool_with_no_normalized_match(agent, tool_registry, mock_randint): - mock_randint.return_value = 1 - - with pytest.raises(AttributeError) as err: - agent.tool.system_prompter_1(system_prompt="tool prompt") - - assert str(err.value) == "Tool 'system_prompter_1' not found" - - def test_agent_with_none_callback_handler_prints_nothing(): agent = Agent() @@ -1738,98 +1606,6 @@ def test_agent_with_session_and_conversation_manager(): assert agent.conversation_manager.removed_message_count == agent_2.conversation_manager.removed_message_count -def test_agent_tool_non_serializable_parameter_filtering(agent, mock_randint): - """Test that non-serializable objects in tool parameters are properly filtered during tool call recording.""" - mock_randint.return_value = 42 - - # Create a non-serializable object (Agent instance) - another_agent = Agent() - - # This should not crash even though we're passing non-serializable objects - result = agent.tool.tool_decorated( - random_string="test_value", - non_serializable_agent=another_agent, # This would previously cause JSON serialization error - user_message_override="Testing non-serializable parameter filtering", - ) - - # Verify the tool executed successfully - expected_result = { - "content": [{"text": "test_value"}], - "status": "success", - "toolUseId": "tooluse_tool_decorated_42", - } - assert result == expected_result - - # The key test: this should not crash during execution - # Check that we have messages recorded (exact count may vary) - assert len(agent.messages) > 0 - - # Check user message with filtered parameters - this is the main test for the bug fix - user_message = agent.messages[0] - assert user_message["role"] == "user" - assert len(user_message["content"]) == 2 - - # Check override message - assert user_message["content"][0]["text"] == "Testing non-serializable parameter filtering\n" - - # Check tool call description with filtered parameters - this is where JSON serialization would fail - tool_call_text = user_message["content"][1]["text"] - assert "agent.tool.tool_decorated direct tool call." in tool_call_text - assert '"random_string": "test_value"' in tool_call_text - assert '"non_serializable_agent": "<>"' not in tool_call_text - - -def test_agent_tool_no_non_serializable_parameters(agent, mock_randint): - """Test that normal tool calls with only serializable parameters work unchanged.""" - mock_randint.return_value = 555 - - # Call with only serializable parameters - result = agent.tool.tool_decorated(random_string="normal_call", user_message_override="Normal tool call test") - - # Verify successful execution - expected_result = { - "content": [{"text": "normal_call"}], - "status": "success", - "toolUseId": "tooluse_tool_decorated_555", - } - assert result == expected_result - - # Check message recording works normally - assert len(agent.messages) > 0 - user_message = agent.messages[0] - tool_call_text = user_message["content"][1]["text"] - - # Verify normal parameter serialization (no filtering needed) - assert "agent.tool.tool_decorated direct tool call." in tool_call_text - assert '"random_string": "normal_call"' in tool_call_text - # Should not contain any "< str: - """Test tool with single parameter.""" - return action - - agent = Agent(tools=[test_tool]) - - # Call tool with extra non-spec parameters - result = agent.tool.test_tool( - action="test_value", - agent=agent, # Should be filtered out - extra_param="filtered", # Should be filtered out - ) - - # Verify tool executed successfully - assert result["status"] == "success" - assert result["content"] == [{"text": "test_value"}] - - # Check that only spec parameters are recorded in message history - assert len(agent.messages) > 0 - user_message = agent.messages[0] - tool_call_text = user_message["content"][0]["text"] - - # Should only contain the 'action' parameter - assert '"action": "test_value"' in tool_call_text - assert '"agent"' not in tool_call_text - assert '"extra_param"' not in tool_call_text - - def test_agent__call__handles_none_invocation_state(mock_model, agent): """Test that agent handles None invocation_state without AttributeError.""" mock_model.mock_stream.return_value = [ @@ -2094,39 +1837,6 @@ def test_agent_structured_output_interrupt(user): agent.structured_output(type(user), "invalid") -def test_agent_tool_caller_interrupt(): - @strands.tool(context=True) - def test_tool(tool_context): - tool_context.interrupt("test-interrupt") - - agent = Agent(tools=[test_tool]) - - exp_message = r"cannot raise interrupt in direct tool call" - with pytest.raises(RuntimeError, match=exp_message): - agent.tool.test_tool(agent=agent) - - tru_state = agent._interrupt_state.to_dict() - exp_state = { - "activated": False, - "context": {}, - "interrupts": {}, - } - assert tru_state == exp_state - - tru_messages = agent.messages - exp_messages = [] - assert tru_messages == exp_messages - - -def test_agent_tool_caller_interrupt_activated(): - agent = Agent() - agent._interrupt_state.activated = True - - exp_message = r"cannot directly call tool during interrupt" - with pytest.raises(RuntimeError, match=exp_message): - agent.tool.test_tool() - - def test_latest_message_tool_use_skips_model_invoke(tool_decorated): mock_model = MockedModelProvider([{"role": "assistant", "content": [{"text": "I see the tool result"}]}]) diff --git a/tests/strands/tools/test_caller.py b/tests/strands/tools/test_caller.py new file mode 100644 index 000000000..18de6d3f0 --- /dev/null +++ b/tests/strands/tools/test_caller.py @@ -0,0 +1,314 @@ +import unittest.mock + +import pytest + +from strands import Agent, tool + + +@pytest.fixture +def randint(): + with unittest.mock.patch("strands.tools._caller.random.randint") as mock: + yield mock + + +@pytest.fixture +def model(): + return unittest.mock.Mock() + + +@pytest.fixture +def test_tool(): + @tool(name="test_tool") + def function(random_string: str) -> str: + return random_string + + return function + + +@pytest.fixture +def agent(model, test_tool): + return Agent(model=model, tools=[test_tool]) + + +def test_agent_tool(randint, agent): + conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager) + agent.conversation_manager = conversation_manager_spy + + randint.return_value = 1 + + tru_result = agent.tool.test_tool(random_string="abcdEfghI123") + exp_result = { + "content": [ + { + "text": "abcdEfghI123", + }, + ], + "status": "success", + "toolUseId": "tooluse_test_tool_1", + } + + assert tru_result == exp_result + conversation_manager_spy.apply_management.assert_called_with(agent) + + +@pytest.mark.asyncio +async def test_agent_tool_in_async_context(randint, agent): + randint.return_value = 123 + + tru_result = agent.tool.test_tool(random_string="abcdEfghI123") + exp_result = { + "content": [ + { + "text": "abcdEfghI123", + }, + ], + "status": "success", + "toolUseId": "tooluse_test_tool_123", + } + + assert tru_result == exp_result + + +def test_agent_tool_user_message_override(agent): + agent.tool.test_tool(random_string="abcdEfghI123", user_message_override="test override") + + tru_message = agent.messages[0] + exp_message = { + "content": [ + { + "text": "test override\n", + }, + { + "text": ( + 'agent.tool.test_tool direct tool call.\nInput parameters: {"random_string": "abcdEfghI123"}\n' + ), + }, + ], + "role": "user", + } + + assert tru_message == exp_message + + +def test_agent_tool_do_not_record_tool(agent): + agent.record_direct_tool_call = False + agent.tool.test_tool(random_string="abcdEfghI123", user_message_override="test override") + + tru_messages = agent.messages + exp_messages = [] + + assert tru_messages == exp_messages + + +def test_agent_tool_do_not_record_tool_with_method_override(agent): + agent.record_direct_tool_call = True + agent.tool.test_tool( + random_string="abcdEfghI123", user_message_override="test override", record_direct_tool_call=False + ) + + tru_messages = agent.messages + exp_messages = [] + + assert tru_messages == exp_messages + + +def test_agent_tool_tool_does_not_exist(agent): + with pytest.raises(AttributeError): + agent.tool.does_not_exist() + + +def test_agent_tool_no_parameter_conflict(agent, randint): + @tool(name="system_prompter") + def function(system_prompt: str) -> str: + return system_prompt + + agent.tool_registry.register_tool(function) + + randint.return_value = 1 + + tru_result = agent.tool.system_prompter(system_prompt="tool prompt") + exp_result = {"toolUseId": "tooluse_system_prompter_1", "status": "success", "content": [{"text": "tool prompt"}]} + assert tru_result == exp_result + + +def test_agent_tool_with_name_normalization(agent, randint): + tool_name = "system-prompter" + + @tool(name=tool_name) + def function(system_prompt: str) -> str: + return system_prompt + + agent.tool_registry.register_tool(function) + + randint.return_value = 1 + + tru_result = agent.tool.system_prompter(system_prompt="tool prompt") + exp_result = {"toolUseId": "tooluse_system_prompter_1", "status": "success", "content": [{"text": "tool prompt"}]} + assert tru_result == exp_result + + +def test_agent_tool_with_no_normalized_match(agent, randint): + randint.return_value = 1 + + with pytest.raises(AttributeError) as err: + agent.tool.system_prompter_1(system_prompt="tool prompt") + + assert str(err.value) == "Tool 'system_prompter_1' not found" + + +def test_agent_tool_non_serializable_parameter_filtering(agent, randint): + """Test that non-serializable objects in tool parameters are properly filtered during tool call recording.""" + randint.return_value = 42 + + # Create a non-serializable object (Agent instance) + another_agent = Agent() + + # This should not crash even though we're passing non-serializable objects + result = agent.tool.test_tool( + random_string="test_value", + non_serializable_agent=another_agent, # This would previously cause JSON serialization error + user_message_override="Testing non-serializable parameter filtering", + ) + + # Verify the tool executed successfully + expected_result = { + "content": [{"text": "test_value"}], + "status": "success", + "toolUseId": "tooluse_test_tool_42", + } + assert result == expected_result + + # The key test: this should not crash during execution + # Check that we have messages recorded (exact count may vary) + assert len(agent.messages) > 0 + + # Check user message with filtered parameters - this is the main test for the bug fix + user_message = agent.messages[0] + assert user_message["role"] == "user" + assert len(user_message["content"]) == 2 + + # Check override message + assert user_message["content"][0]["text"] == "Testing non-serializable parameter filtering\n" + + # Check tool call description with filtered parameters - this is where JSON serialization would fail + tool_call_text = user_message["content"][1]["text"] + assert "agent.tool.test_tool direct tool call." in tool_call_text + assert '"random_string": "test_value"' in tool_call_text + assert '"non_serializable_agent": "<>"' not in tool_call_text + + +def test_agent_tool_no_non_serializable_parameters(agent, randint): + """Test that normal tool calls with only serializable parameters work unchanged.""" + randint.return_value = 555 + + # Call with only serializable parameters + result = agent.tool.test_tool(random_string="normal_call", user_message_override="Normal tool call test") + + # Verify successful execution + expected_result = { + "content": [{"text": "normal_call"}], + "status": "success", + "toolUseId": "tooluse_test_tool_555", + } + assert result == expected_result + + # Check message recording works normally + assert len(agent.messages) > 0 + user_message = agent.messages[0] + tool_call_text = user_message["content"][1]["text"] + + # Verify normal parameter serialization (no filtering needed) + assert "agent.tool.test_tool direct tool call." in tool_call_text + assert '"random_string": "normal_call"' in tool_call_text + # Should not contain any "< str: + """Test tool with single parameter.""" + return action + + agent = Agent(tools=[test_tool]) + + # Call tool with extra non-spec parameters + result = agent.tool.test_tool( + action="test_value", + agent=agent, # Should be filtered out + extra_param="filtered", # Should be filtered out + ) + + # Verify tool executed successfully + assert result["status"] == "success" + assert result["content"] == [{"text": "test_value"}] + + # Check that only spec parameters are recorded in message history + assert len(agent.messages) > 0 + user_message = agent.messages[0] + tool_call_text = user_message["content"][0]["text"] + + # Should only contain the 'action' parameter + assert '"action": "test_value"' in tool_call_text + assert '"agent"' not in tool_call_text + assert '"extra_param"' not in tool_call_text + + +def test_agent_tool_caller_interrupt(): + @tool(context=True) + def test_tool(tool_context): + tool_context.interrupt("test-interrupt") + + agent = Agent(tools=[test_tool]) + + exp_message = r"cannot raise interrupt in direct tool call" + with pytest.raises(RuntimeError, match=exp_message): + agent.tool.test_tool(agent=agent) + + tru_state = agent._interrupt_state.to_dict() + exp_state = { + "activated": False, + "context": {}, + "interrupts": {}, + } + assert tru_state == exp_state + + tru_messages = agent.messages + exp_messages = [] + assert tru_messages == exp_messages + + +def test_agent_tool_caller_interrupt_activated(): + agent = Agent() + agent._interrupt_state.activated = True + + exp_message = r"cannot directly call tool during interrupt" + with pytest.raises(RuntimeError, match=exp_message): + agent.tool.test_tool()