Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/strands/hooks/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,18 @@ class BeforeToolCallEvent(HookEvent):
to change which tool gets executed. This may be None if tool lookup failed.
tool_use: The tool parameters that will be passed to selected_tool.
invocation_state: Keyword arguments that will be passed to the tool.
cancel_tool: A user defined message that when set, will cancel the tool call.
The message will be placed into a tool result with an error status. If set to `True`, Strands will cancel
the tool call and use a default cancel message.
"""

selected_tool: Optional[AgentTool]
tool_use: ToolUse
invocation_state: dict[str, Any]
cancel_tool: bool | str = False

def _can_write(self, name: str) -> bool:
return name in ["selected_tool", "tool_use"]
return name in ["cancel_tool", "selected_tool", "tool_use"]


@dataclass
Expand All @@ -124,13 +128,15 @@ class AfterToolCallEvent(HookEvent):
invocation_state: Keyword arguments that were passed to the tool
result: The result of the tool invocation. Either a ToolResult on success
or an Exception if the tool execution failed.
cancel_message: The cancellation message if the user cancelled the tool call.
"""

selected_tool: Optional[AgentTool]
tool_use: ToolUse
invocation_state: dict[str, Any]
result: ToolResult
exception: Optional[Exception] = None
cancel_message: str | None = None

def _can_write(self, name: str) -> bool:
return name == "result"
Expand Down
29 changes: 27 additions & 2 deletions src/strands/tools/executors/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ...hooks import AfterToolCallEvent, BeforeToolCallEvent
from ...telemetry.metrics import Trace
from ...telemetry.tracer import get_tracer
from ...types._events import ToolResultEvent, ToolStreamEvent, TypedEvent
from ...types._events import ToolCancelEvent, ToolResultEvent, ToolStreamEvent, TypedEvent
from ...types.content import Message
from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolResult, ToolUse

Expand Down Expand Up @@ -81,6 +81,31 @@ async def _stream(
)
)

if before_event.cancel_tool:
cancel_message = (
before_event.cancel_tool if isinstance(before_event.cancel_tool, str) else "tool cancelled by user"
)
yield ToolCancelEvent(tool_use, cancel_message)

cancel_result: ToolResult = {
"toolUseId": str(tool_use.get("toolUseId")),
"status": "error",
"content": [{"text": cancel_message}],
}
after_event = agent.hooks.invoke_callbacks(
AfterToolCallEvent(
agent=agent,
tool_use=tool_use,
invocation_state=invocation_state,
selected_tool=None,
result=cancel_result,
cancel_message=cancel_message,
)
)
yield ToolResultEvent(after_event.result)
tool_results.append(after_event.result)
return

try:
selected_tool = before_event.selected_tool
tool_use = before_event.tool_use
Expand Down Expand Up @@ -123,7 +148,7 @@ async def _stream(
# so that we don't needlessly yield ToolStreamEvents for non-generator callbacks.
# In which case, as soon as we get a ToolResultEvent we're done and for ToolStreamEvent
# we yield it directly; all other cases (non-sdk AgentTools), we wrap events in
# ToolStreamEvent and the last even is just the result
# ToolStreamEvent and the last event is just the result.

if isinstance(event, ToolResultEvent):
# below the last "event" must point to the tool_result
Expand Down
23 changes: 23 additions & 0 deletions src/strands/types/_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,29 @@ def tool_use_id(self) -> str:
return cast(str, cast(ToolUse, cast(dict, self.get("tool_stream_event")).get("tool_use")).get("toolUseId"))


class ToolCancelEvent(TypedEvent):
"""Event emitted when a user cancels a tool call from their BeforeToolCallEvent hook."""

def __init__(self, tool_use: ToolUse, message: str) -> None:
"""Initialize with tool streaming data.

Args:
tool_use: Information about the tool being cancelled
message: The tool cancellation message
"""
super().__init__({"tool_cancel_event": {"tool_use": tool_use, "message": message}})

@property
def tool_use_id(self) -> str:
"""The id of the tool cancelled."""
return cast(str, cast(ToolUse, cast(dict, self.get("tool_cancelled_event")).get("tool_use")).get("toolUseId"))

@property
def message(self) -> str:
"""The tool cancellation message."""
return cast(str, self["message"])


class ModelMessageEvent(TypedEvent):
"""Event emitted when the model invocation has completed.

Expand Down
37 changes: 36 additions & 1 deletion tests/strands/tools/executors/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent
from strands.telemetry.metrics import Trace
from strands.tools.executors._executor import ToolExecutor
from strands.types._events import ToolResultEvent, ToolStreamEvent
from strands.types._events import ToolCancelEvent, ToolResultEvent, ToolStreamEvent
from strands.types.tools import ToolUse


Expand Down Expand Up @@ -215,3 +215,38 @@ async def test_executor_stream_with_trace(

cycle_trace.add_child.assert_called_once()
assert isinstance(cycle_trace.add_child.call_args[0][0], Trace)


@pytest.mark.parametrize(
("cancel_tool", "cancel_message"),
[(True, "tool cancelled by user"), ("user cancel message", "user cancel message")],
)
@pytest.mark.asyncio
async def test_executor_stream_cancel(
cancel_tool, cancel_message, executor, agent, tool_results, invocation_state, alist
):
def cancel_callback(event):
event.cancel_tool = cancel_tool
return event

agent.hooks.add_callback(BeforeToolCallEvent, cancel_callback)
tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}}

stream = executor._stream(agent, tool_use, tool_results, invocation_state)

tru_events = await alist(stream)
exp_events = [
ToolCancelEvent(tool_use, cancel_message),
ToolResultEvent(
{
"toolUseId": "1",
"status": "error",
"content": [{"text": cancel_message}],
},
),
]
assert tru_events == exp_events

tru_results = tool_results
exp_results = [exp_events[-1].tool_result]
assert tru_results == exp_results
15 changes: 15 additions & 0 deletions tests_integ/tools/executors/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest

from strands.hooks import BeforeToolCallEvent, HookProvider


@pytest.fixture
def cancel_hook():
class Hook(HookProvider):
def register_hooks(self, registry):
registry.add_callback(BeforeToolCallEvent, self.cancel)

def cancel(self, event):
event.cancel_tool = "cancelled tool call"

return Hook()
16 changes: 16 additions & 0 deletions tests_integ/tools/executors/test_concurrent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json

import pytest

Expand Down Expand Up @@ -59,3 +60,18 @@ async def test_agent_invoke_async_tool_executor(agent, tool_events):
{"name": "time_tool", "event": "end"},
]
assert tru_events == exp_events


@pytest.mark.asyncio
async def test_agent_stream_async_tool_executor_cancelled(cancel_hook, tool_executor, time_tool, tool_events):
agent = Agent(tools=[time_tool], tool_executor=tool_executor, hooks=[cancel_hook])

exp_message = "cancelled tool call"
tru_message = ""
async for event in agent.stream_async("What is the time in New York?"):
if "tool_cancel_event" in event:
tru_message = event["tool_cancel_event"]["message"]

assert tru_message == exp_message
assert len(tool_events) == 0
assert exp_message in json.dumps(agent.messages)
16 changes: 16 additions & 0 deletions tests_integ/tools/executors/test_sequential.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json

import pytest

Expand Down Expand Up @@ -59,3 +60,18 @@ async def test_agent_invoke_async_tool_executor(agent, tool_events):
{"name": "weather_tool", "event": "end"},
]
assert tru_events == exp_events


@pytest.mark.asyncio
async def test_agent_stream_async_tool_executor_cancelled(cancel_hook, tool_executor, time_tool, tool_events):
agent = Agent(tools=[time_tool], tool_executor=tool_executor, hooks=[cancel_hook])

exp_message = "cancelled tool call"
tru_message = ""
async for event in agent.stream_async("What is the time in New York?"):
if "tool_cancel_event" in event:
tru_message = event["tool_cancel_event"]["message"]

assert tru_message == exp_message
assert len(tool_events) == 0
assert exp_message in json.dumps(agent.messages)