diff --git a/src/strands/tools/_caller.py b/src/strands/tools/_caller.py index 4a74dec18..1e0ca2c8d 100644 --- a/src/strands/tools/_caller.py +++ b/src/strands/tools/_caller.py @@ -106,8 +106,10 @@ async def acall() -> ToolResult: tool_result = run_async(acall) - # Apply conversation management if agent supports it (traditional agents) - if hasattr(self._agent, "conversation_manager"): + # TODO: https://github.com/strands-agents/sdk-python/issues/1311 + from ..agent import Agent + + if isinstance(self._agent, Agent): self._agent.conversation_manager.apply_management(self._agent) return tool_result diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index a4f9e7e1f..5d01c5d48 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -49,16 +49,19 @@ async def _invoke_before_tool_call_hook( invocation_state: dict[str, Any], ) -> tuple[BeforeToolCallEvent | BidiBeforeToolCallEvent, list[Interrupt]]: """Invoke the appropriate before tool call hook based on agent type.""" - event_cls = BeforeToolCallEvent if ToolExecutor._is_agent(agent) else BidiBeforeToolCallEvent - return await agent.hooks.invoke_callbacks_async( - event_cls( - agent=agent, - selected_tool=tool_func, - tool_use=tool_use, - invocation_state=invocation_state, - ) + kwargs = { + "selected_tool": tool_func, + "tool_use": tool_use, + "invocation_state": invocation_state, + } + event = ( + BeforeToolCallEvent(agent=cast("Agent", agent), **kwargs) + if ToolExecutor._is_agent(agent) + else BidiBeforeToolCallEvent(agent=cast("BidiAgent", agent), **kwargs) ) + return await agent.hooks.invoke_callbacks_async(event) + @staticmethod async def _invoke_after_tool_call_hook( agent: "Agent | BidiAgent", @@ -70,19 +73,22 @@ async def _invoke_after_tool_call_hook( cancel_message: str | None = None, ) -> tuple[AfterToolCallEvent | BidiAfterToolCallEvent, list[Interrupt]]: """Invoke the appropriate after tool call hook based on agent type.""" - event_cls = AfterToolCallEvent if ToolExecutor._is_agent(agent) else BidiAfterToolCallEvent - return await agent.hooks.invoke_callbacks_async( - event_cls( - agent=agent, - selected_tool=selected_tool, - tool_use=tool_use, - invocation_state=invocation_state, - result=result, - exception=exception, - cancel_message=cancel_message, - ) + kwargs = { + "selected_tool": selected_tool, + "tool_use": tool_use, + "invocation_state": invocation_state, + "result": result, + "exception": exception, + "cancel_message": cancel_message, + } + event = ( + AfterToolCallEvent(agent=cast("Agent", agent), **kwargs) + if ToolExecutor._is_agent(agent) + else BidiAfterToolCallEvent(agent=cast("BidiAgent", agent), **kwargs) ) + return await agent.hooks.invoke_callbacks_async(event) + @staticmethod async def _stream( agent: "Agent | BidiAgent", @@ -247,7 +253,7 @@ async def _stream( @staticmethod async def _stream_with_trace( - agent: "Agent | BidiAgent", + agent: "Agent", tool_use: ToolUse, tool_results: list[ToolResult], cycle_trace: Trace, @@ -259,7 +265,7 @@ async def _stream_with_trace( """Execute tool with tracing and metrics collection. Args: - agent: The agent (Agent or BidiAgent) for which the tool is being executed. + agent: The agent for which the tool is being executed. tool_use: Metadata and inputs for the tool to be executed. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. @@ -308,7 +314,7 @@ async def _stream_with_trace( # pragma: no cover def _execute( self, - agent: "Agent | BidiAgent", + agent: "Agent", tool_uses: list[ToolUse], tool_results: list[ToolResult], cycle_trace: Trace, @@ -319,7 +325,7 @@ def _execute( """Execute the given tools according to this executor's strategy. Args: - agent: The agent (Agent or BidiAgent) for which tools are being executed. + agent: The agent for which tools are being executed. tool_uses: Metadata and inputs for the tools to be executed. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. diff --git a/src/strands/tools/executors/concurrent.py b/src/strands/tools/executors/concurrent.py index da5c1ff10..216eee379 100644 --- a/src/strands/tools/executors/concurrent.py +++ b/src/strands/tools/executors/concurrent.py @@ -12,7 +12,6 @@ if TYPE_CHECKING: # pragma: no cover from ...agent import Agent - from ...experimental.bidi import BidiAgent from ..structured_output._structured_output_context import StructuredOutputContext @@ -22,7 +21,7 @@ class ConcurrentToolExecutor(ToolExecutor): @override async def _execute( self, - agent: "Agent | BidiAgent", + agent: "Agent", tool_uses: list[ToolUse], tool_results: list[ToolResult], cycle_trace: Trace, @@ -33,7 +32,7 @@ async def _execute( """Execute tools concurrently. Args: - agent: The agent (Agent or BidiAgent) for which tools are being executed. + agent: The agent for which tools are being executed. tool_uses: Metadata and inputs for the tools to be executed. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. @@ -79,7 +78,7 @@ async def _execute( async def _task( self, - agent: "Agent | BidiAgent", + agent: "Agent", tool_use: ToolUse, tool_results: list[ToolResult], cycle_trace: Trace, @@ -94,7 +93,7 @@ async def _task( """Execute a single tool and put results in the task queue. Args: - agent: The agent (Agent or BidiAgent) executing the tool. + agent: The agent executing the tool. tool_use: Tool use metadata and inputs. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. diff --git a/src/strands/tools/executors/sequential.py b/src/strands/tools/executors/sequential.py index 6163fc195..f78e60872 100644 --- a/src/strands/tools/executors/sequential.py +++ b/src/strands/tools/executors/sequential.py @@ -11,7 +11,6 @@ if TYPE_CHECKING: # pragma: no cover from ...agent import Agent - from ...experimental.bidi import BidiAgent from ..structured_output._structured_output_context import StructuredOutputContext @@ -21,7 +20,7 @@ class SequentialToolExecutor(ToolExecutor): @override async def _execute( self, - agent: "Agent | BidiAgent", + agent: "Agent", tool_uses: list[ToolUse], tool_results: list[ToolResult], cycle_trace: Trace, @@ -34,7 +33,7 @@ async def _execute( Breaks early if an interrupt is raised by the user. Args: - agent: The agent (Agent or BidiAgent) for which tools are being executed. + agent: The agent for which tools are being executed. tool_uses: Metadata and inputs for the tools to be executed. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle.