Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/strands/tools/_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ async def acall() -> ToolResult:

tool_result = run_async(acall)

# Apply conversation management if agent supports it (traditional agents)
if hasattr(self._agent, "conversation_manager"):
# TODO: https://github.com/strands-agents/sdk-python/issues/1311
from ..agent import Agent

if isinstance(self._agent, Agent):
self._agent.conversation_manager.apply_management(self._agent)

return tool_result
Expand Down
52 changes: 29 additions & 23 deletions src/strands/tools/executors/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,19 @@ async def _invoke_before_tool_call_hook(
invocation_state: dict[str, Any],
) -> tuple[BeforeToolCallEvent | BidiBeforeToolCallEvent, list[Interrupt]]:
"""Invoke the appropriate before tool call hook based on agent type."""
event_cls = BeforeToolCallEvent if ToolExecutor._is_agent(agent) else BidiBeforeToolCallEvent
return await agent.hooks.invoke_callbacks_async(
event_cls(
agent=agent,
selected_tool=tool_func,
tool_use=tool_use,
invocation_state=invocation_state,
)
kwargs = {
"selected_tool": tool_func,
"tool_use": tool_use,
"invocation_state": invocation_state,
}
event = (
BeforeToolCallEvent(agent=cast("Agent", agent), **kwargs)
if ToolExecutor._is_agent(agent)
else BidiBeforeToolCallEvent(agent=cast("BidiAgent", agent), **kwargs)
)

return await agent.hooks.invoke_callbacks_async(event)

@staticmethod
async def _invoke_after_tool_call_hook(
agent: "Agent | BidiAgent",
Expand All @@ -70,19 +73,22 @@ async def _invoke_after_tool_call_hook(
cancel_message: str | None = None,
) -> tuple[AfterToolCallEvent | BidiAfterToolCallEvent, list[Interrupt]]:
"""Invoke the appropriate after tool call hook based on agent type."""
event_cls = AfterToolCallEvent if ToolExecutor._is_agent(agent) else BidiAfterToolCallEvent
return await agent.hooks.invoke_callbacks_async(
event_cls(
agent=agent,
selected_tool=selected_tool,
tool_use=tool_use,
invocation_state=invocation_state,
result=result,
exception=exception,
cancel_message=cancel_message,
)
kwargs = {
"selected_tool": selected_tool,
"tool_use": tool_use,
"invocation_state": invocation_state,
"result": result,
"exception": exception,
"cancel_message": cancel_message,
}
event = (
AfterToolCallEvent(agent=cast("Agent", agent), **kwargs)
if ToolExecutor._is_agent(agent)
else BidiAfterToolCallEvent(agent=cast("BidiAgent", agent), **kwargs)
)

return await agent.hooks.invoke_callbacks_async(event)

@staticmethod
async def _stream(
agent: "Agent | BidiAgent",
Expand Down Expand Up @@ -247,7 +253,7 @@ async def _stream(

@staticmethod
async def _stream_with_trace(
agent: "Agent | BidiAgent",
agent: "Agent",
tool_use: ToolUse,
tool_results: list[ToolResult],
cycle_trace: Trace,
Expand All @@ -259,7 +265,7 @@ async def _stream_with_trace(
"""Execute tool with tracing and metrics collection.

Args:
agent: The agent (Agent or BidiAgent) for which the tool is being executed.
agent: The agent for which the tool is being executed.
tool_use: Metadata and inputs for the tool to be executed.
tool_results: List of tool results from each tool execution.
cycle_trace: Trace object for the current event loop cycle.
Expand Down Expand Up @@ -308,7 +314,7 @@ async def _stream_with_trace(
# pragma: no cover
def _execute(
self,
agent: "Agent | BidiAgent",
agent: "Agent",
tool_uses: list[ToolUse],
tool_results: list[ToolResult],
cycle_trace: Trace,
Expand All @@ -319,7 +325,7 @@ def _execute(
"""Execute the given tools according to this executor's strategy.

Args:
agent: The agent (Agent or BidiAgent) for which tools are being executed.
agent: The agent for which tools are being executed.
tool_uses: Metadata and inputs for the tools to be executed.
tool_results: List of tool results from each tool execution.
cycle_trace: Trace object for the current event loop cycle.
Expand Down
9 changes: 4 additions & 5 deletions src/strands/tools/executors/concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

if TYPE_CHECKING: # pragma: no cover
from ...agent import Agent
from ...experimental.bidi import BidiAgent
from ..structured_output._structured_output_context import StructuredOutputContext


Expand All @@ -22,7 +21,7 @@ class ConcurrentToolExecutor(ToolExecutor):
@override
async def _execute(
self,
agent: "Agent | BidiAgent",
agent: "Agent",
tool_uses: list[ToolUse],
tool_results: list[ToolResult],
cycle_trace: Trace,
Expand All @@ -33,7 +32,7 @@ async def _execute(
"""Execute tools concurrently.

Args:
agent: The agent (Agent or BidiAgent) for which tools are being executed.
agent: The agent for which tools are being executed.
tool_uses: Metadata and inputs for the tools to be executed.
tool_results: List of tool results from each tool execution.
cycle_trace: Trace object for the current event loop cycle.
Expand Down Expand Up @@ -79,7 +78,7 @@ async def _execute(

async def _task(
self,
agent: "Agent | BidiAgent",
agent: "Agent",
tool_use: ToolUse,
tool_results: list[ToolResult],
cycle_trace: Trace,
Expand All @@ -94,7 +93,7 @@ async def _task(
"""Execute a single tool and put results in the task queue.

Args:
agent: The agent (Agent or BidiAgent) executing the tool.
agent: The agent executing the tool.
tool_use: Tool use metadata and inputs.
tool_results: List of tool results from each tool execution.
cycle_trace: Trace object for the current event loop cycle.
Expand Down
5 changes: 2 additions & 3 deletions src/strands/tools/executors/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

if TYPE_CHECKING: # pragma: no cover
from ...agent import Agent
from ...experimental.bidi import BidiAgent
from ..structured_output._structured_output_context import StructuredOutputContext


Expand All @@ -21,7 +20,7 @@ class SequentialToolExecutor(ToolExecutor):
@override
async def _execute(
self,
agent: "Agent | BidiAgent",
agent: "Agent",
tool_uses: list[ToolUse],
tool_results: list[ToolResult],
cycle_trace: Trace,
Expand All @@ -34,7 +33,7 @@ async def _execute(
Breaks early if an interrupt is raised by the user.

Args:
agent: The agent (Agent or BidiAgent) for which tools are being executed.
agent: The agent for which tools are being executed.
tool_uses: Metadata and inputs for the tools to be executed.
tool_results: List of tool results from each tool execution.
cycle_trace: Trace object for the current event loop cycle.
Expand Down
Loading