From 11114e2a1123c4d61e86b08810d8c05dbc14960f Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Thu, 2 Oct 2025 14:20:58 -0400 Subject: [PATCH 1/2] hooks - before tool call event - interrupt --- src/strands/agent/agent.py | 61 ++++++ src/strands/agent/agent_result.py | 3 + src/strands/event_loop/event_loop.py | 75 ++++++-- src/strands/hooks/__init__.py | 4 + src/strands/hooks/events.py | 5 +- src/strands/hooks/interrupt.py | 104 ++++++++++ src/strands/hooks/registry.py | 9 +- src/strands/tools/executors/_executor.py | 34 +++- src/strands/tools/executors/sequential.py | 10 +- src/strands/types/_events.py | 24 ++- src/strands/types/agent.py | 3 +- src/strands/types/event_loop.py | 2 + src/strands/types/interrupt.py | 27 +++ tests/strands/agent/hooks/test_events.py | 3 +- tests/strands/agent/test_agent.py | 154 +++++++++++++++ tests/strands/agent/test_agent_hooks.py | 18 +- tests/strands/event_loop/test_event_loop.py | 182 +++++++++++++++++- .../experimental/hooks/test_hook_aliases.py | 4 +- tests/strands/hooks/test_interrupt.py | 81 ++++++++ tests/strands/hooks/test_registry.py | 33 ++++ tests/strands/tools/executors/conftest.py | 1 + .../strands/tools/executors/test_executor.py | 128 +++++++++--- .../tools/executors/test_sequential.py | 54 +++++- tests_integ/test_interrupt.py | 150 +++++++++++++++ 24 files changed, 1108 insertions(+), 61 deletions(-) create mode 100644 src/strands/hooks/interrupt.py create mode 100644 src/strands/types/interrupt.py create mode 100644 tests/strands/hooks/test_interrupt.py create mode 100644 tests/strands/hooks/test_registry.py create mode 100644 tests_integ/test_interrupt.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 4579ebacf..2a64d4750 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -39,6 +39,7 @@ BeforeInvocationEvent, HookProvider, HookRegistry, + Interrupt, MessageAddedEvent, ) from ..models.bedrock import BedrockModel @@ -54,6 +55,7 @@ from ..types.agent import AgentInput from ..types.content import ContentBlock, Message, Messages from ..types.exceptions import ContextWindowOverflowException +from ..types.interrupt import InterruptContent from ..types.tools import ToolResult, ToolUse from ..types.traces import AttributeValue from .agent_result import AgentResult @@ -349,6 +351,9 @@ def __init__( self.hooks.add_hook(hook) self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self)) + # Map of active interrupt instances + self._interrupts: dict[tuple[str, str], Interrupt] = {} + @property def tool(self) -> ToolCaller: """Call tool as a function. @@ -567,6 +572,8 @@ async def stream_async( yield event["data"] ``` """ + self._resume_interrupt(prompt) + callback_handler = kwargs.get("callback_handler", self.callback_handler) # Process input and get message to add (if any) @@ -596,6 +603,56 @@ async def stream_async( self._end_agent_trace_span(error=e) raise + def _resume_interrupt(self, prompt: AgentInput) -> None: + """Configure the agent interrupt state if resuming from an interrupt event. + + Args: + messages: Agent's message history. + prompt: User responses if resuming from interrupt. + + Raises: + TypeError: If interrupts are detected but user did not provide responses. + ValueError: If any interrupts are missing corresponding responses. + """ + # Currently, users can only interrupt tool calls. Thus, to determine if the agent was interrupted in the + # previous invocation, we look for tool results with interrupt context in the messages array. + + message = self.messages[-1] if self.messages else None + if not message or message["role"] != "user": + return + + tool_results = [ + content["toolResult"] + for content in message["content"] + if "toolResult" in content and content["toolResult"]["status"] == "error" + ] + reasons = [ + tool_result["content"][0]["json"]["interrupt"] + for tool_result in tool_results + if "json" in tool_result["content"][0] and "interrupt" in tool_result["content"][0]["json"] + ] + if not reasons: + return + + if not isinstance(prompt, list): + raise TypeError( + f"prompt_type=<{type(prompt)}> | must resume from interrupt with list of interruptResponse's" + ) + + responses = [ + cast(InterruptContent, content)["interruptResponse"] + for content in prompt + if isinstance(content, dict) and "interruptResponse" in content + ] + + reasons_map = {(reason["name"], reason["event_name"]): reason for reason in reasons} + responses_map = {(response["name"], response["event_name"]): response for response in responses} + missing_keys = reasons_map.keys() - responses_map.keys() + if missing_keys: + raise ValueError(f"interrupts=<{list(missing_keys)}> | missing responses for interrupts") + + self._interrupts = {key: Interrupt(**{**reasons_map[key], **responses_map[key]}) for key in responses_map} + async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: """Execute the agent's event loop with the given message and parameters. @@ -671,6 +728,10 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A yield event def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: + if self._interrupts: + # Do not add interrupt responses to the messages as these are not to be processed by the model + return [] + messages: Messages | None = None if prompt is not None: if isinstance(prompt, str): diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index f3758c8d2..9091c0e54 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from typing import Any +from ..hooks import Interrupt from ..telemetry.metrics import EventLoopMetrics from ..types.content import Message from ..types.streaming import StopReason @@ -20,12 +21,14 @@ class AgentResult: message: The last message generated by the agent. metrics: Performance metrics collected during processing. state: Additional state information from the event loop. + interrupts: List of interrupts if raised by user. """ stop_reason: StopReason message: Message metrics: EventLoopMetrics state: Any + interrupts: list[Interrupt] | None = None def __str__(self) -> str: """Get the agent's last message as a string. diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index d6367e9d9..e523a2c2e 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -27,6 +27,7 @@ ModelStopReason, StartEvent, StartEventLoopEvent, + ToolInterruptEvent, ToolResultMessageEvent, TypedEvent, ) @@ -106,13 +107,18 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> ) invocation_state["event_loop_cycle_span"] = cycle_span - model_events = _handle_model_execution(agent, cycle_span, cycle_trace, invocation_state, tracer) - async for model_event in model_events: - if not isinstance(model_event, ModelStopReason): - yield model_event + if agent._interrupts: + stop_reason: StopReason = "tool_use" + message = agent.messages[-2] - stop_reason, message, *_ = model_event["stop"] - yield ModelMessageEvent(message=message) + else: + model_events = _handle_model_execution(agent, cycle_span, cycle_trace, invocation_state, tracer) + async for model_event in model_events: + if not isinstance(model_event, ModelStopReason): + yield model_event + + stop_reason, message, *_ = model_event["stop"] + yield ModelMessageEvent(message=message) try: if stop_reason == "max_tokens": @@ -371,25 +377,31 @@ async def _handle_tool_execution( validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] + + # Filter to only the interrupted tools when resuming from interrupt + if agent._interrupts: + tool_names = {name for name, _ in agent._interrupts.keys()} + tool_uses = [tool_use for tool_use in tool_uses if tool_use["name"] in tool_names] + if not tool_uses: yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) return + tool_interrupts = [] tool_events = agent.tool_executor._execute( agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state ) async for tool_event in tool_events: + if isinstance(tool_event, ToolInterruptEvent): + tool_interrupts.append(tool_event["tool_interrupt_event"]["interrupt"]) + yield tool_event # Store parent cycle ID for the next cycle invocation_state["event_loop_parent_cycle_id"] = invocation_state["event_loop_cycle_id"] - tool_result_message: Message = { - "role": "user", - "content": [{"toolResult": result} for result in tool_results], - } + tool_result_message = _convert_tool_results_to_message(agent, tool_results) - agent.messages.append(tool_result_message) agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message)) yield ToolResultMessageEvent(message=tool_result_message) @@ -397,11 +409,52 @@ async def _handle_tool_execution( tracer = get_tracer() tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message) + if tool_interrupts: + agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) + yield EventLoopStopEvent( + "interrupt", message, agent.event_loop_metrics, invocation_state["request_state"], tool_interrupts + ) + return + if invocation_state["request_state"].get("stop_event_loop", False): agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) return + agent._interrupts = {} + events = recurse_event_loop(agent=agent, invocation_state=invocation_state) async for event in events: yield event + + +def _convert_tool_results_to_message(agent: "Agent", results: list[ToolResult]) -> Message: + """Convert tool results to a message. + + For normal execution, we create a new user message with the tool results and append it to the agent's message + history. When resuming from an interrupt, we instead extend the existing results message in history with the + resumed results. + + Args: + agent: The agent instance containing interrupt state and message history. + results: List of tool results to convert or extend into a message. + + Returns: + Tool results message. + """ + if not agent._interrupts: + message: Message = { + "role": "user", + "content": [{"toolResult": result} for result in results], + } + agent.messages.append(message) + return message + + message = agent.messages[-1] + + results_map = {result["toolUseId"]: result for result in results} + for content in message["content"]: + tool_use_id = content["toolResult"]["toolUseId"] + content["toolResult"] = results_map.get(tool_use_id, content["toolResult"]) + + return message diff --git a/src/strands/hooks/__init__.py b/src/strands/hooks/__init__.py index 30163f207..e6c80182b 100644 --- a/src/strands/hooks/__init__.py +++ b/src/strands/hooks/__init__.py @@ -39,6 +39,7 @@ def log_end(self, event: AfterInvocationEvent) -> None: BeforeToolCallEvent, MessageAddedEvent, ) +from .interrupt import Interrupt, InterruptEvent, InterruptException from .registry import BaseHookEvent, HookCallback, HookEvent, HookProvider, HookRegistry __all__ = [ @@ -56,4 +57,7 @@ def log_end(self, event: AfterInvocationEvent) -> None: "HookRegistry", "HookEvent", "BaseHookEvent", + "Interrupt", + "InterruptEvent", + "InterruptException", ] diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 8f611e4e2..0af1400cb 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -9,6 +9,7 @@ from ..types.content import Message from ..types.streaming import StopReason from ..types.tools import AgentTool, ToolResult, ToolUse +from .interrupt import InterruptEvent from .registry import HookEvent @@ -84,7 +85,7 @@ class MessageAddedEvent(HookEvent): @dataclass -class BeforeToolCallEvent(HookEvent): +class BeforeToolCallEvent(HookEvent, InterruptEvent): """Event triggered before a tool is invoked. This event is fired just before the agent executes a tool, allowing hook @@ -108,7 +109,7 @@ class BeforeToolCallEvent(HookEvent): cancel_tool: bool | str = False def _can_write(self, name: str) -> bool: - return name in ["cancel_tool", "selected_tool", "tool_use"] + return name in ["interrupt", "cancel_tool", "selected_tool", "tool_use"] @dataclass diff --git a/src/strands/hooks/interrupt.py b/src/strands/hooks/interrupt.py new file mode 100644 index 000000000..bc6bcf8db --- /dev/null +++ b/src/strands/hooks/interrupt.py @@ -0,0 +1,104 @@ +"""Human-in-the-loop interrupt system for agent workflows.""" + +from dataclasses import asdict, dataclass +from typing import TYPE_CHECKING, Any + +from ..types.tools import ToolResultContent + +if TYPE_CHECKING: + from ..agent import Agent + + +@dataclass +class Interrupt: + """Represents an interrupt that can pause agent execution for human-in-the-loop workflows. + + Attributes: + name: Unique identifier for the interrupt. + event_name: Name of the hook event under which the interrupt was triggered. + reasons: User provided reasons for raising the interrupt. + response: Human response provided when resuming the agent after an interrupt. + activated: Whether the interrupt is currently active. + """ + + name: str + event_name: str + reasons: list[Any] + response: Any = None + activated: bool = False + + def __call__(self, reason: Any) -> Any: + """Trigger the interrupt with a reason. + + Args: + reason: User provided reason for the interrupt. + + Returns: + The response from a human user when resuming from an interrupt state. + + Raises: + InterruptException: If human input is required. + """ + if self.response: + self.activated = False + return self.response + + self.reasons.append(reason) + self.activated = True + raise InterruptException(self) + + def to_tool_result_content(self) -> list[ToolResultContent]: + """Convert the interrupt to tool result content if there are reasons. + + Returns: + Tool result content. + """ + if self.reasons: + return [ + {"json": {"interrupt": {"name": self.name, "event_name": self.event_name, "reasons": self.reasons}}}, + ] + + return [] + + @classmethod + def from_agent(cls, name: str, event_name: str, agent: "Agent") -> "Interrupt": + """Initialize an interrupt from agent state. + + Creates an interrupt instance from stored agent state, which will be + populated with the human response when resuming. + + Args: + name: Unique identifier for the interrupt. + event_name: Name of the hook event under which the interrupt was triggered. + agent: The agent instance containing interrupt state. + + Returns: + An Interrupt instance initialized from agent state. + """ + interrupt = agent._interrupts.get((name, event_name)) + params = asdict(interrupt) if interrupt else {"name": name, "event_name": event_name, "reasons": []} + + return cls(**params) + + +class InterruptException(Exception): + """Exception raised when human input is required.""" + + def __init__(self, interrupt: Interrupt) -> None: + """Initialize the exception with an interrupt instance. + + Args: + interrupt: The interrupt that triggered this exception. + """ + self.interrupt = interrupt + + +@dataclass +class InterruptEvent: + """Interface that adds interrupt support to hook events. + + Attributes: + interrupt: The interrupt instance associated with this event. + """ + + interrupt: Interrupt diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index b8e7f82ab..b213670a0 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -10,6 +10,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generator, Generic, Protocol, Type, TypeVar +from .interrupt import InterruptException + if TYPE_CHECKING: from ..agent import Agent @@ -205,7 +207,12 @@ def invoke_callbacks(self, event: TInvokeEvent) -> TInvokeEvent: ``` """ for callback in self.get_callbacks_for(event): - callback(event) + try: + callback(event) + except InterruptException: + # All callbacks are allowed to finish executing during an interrupt. The state of the interrupt is + # stored as an instance variable on the hook event. + pass return event diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index f78861f81..047efbdeb 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -11,10 +11,10 @@ from opentelemetry import trace as trace_api -from ...hooks import AfterToolCallEvent, BeforeToolCallEvent +from ...hooks import AfterToolCallEvent, BeforeToolCallEvent, Interrupt from ...telemetry.metrics import Trace from ...telemetry.tracer import get_tracer -from ...types._events import ToolCancelEvent, ToolResultEvent, ToolStreamEvent, TypedEvent +from ...types._events import ToolCancelEvent, ToolInterruptEvent, ToolResultEvent, ToolStreamEvent, TypedEvent from ...types.content import Message from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolResult, ToolUse @@ -72,15 +72,38 @@ async def _stream( } ) + interrupt = Interrupt.from_agent(tool_name, BeforeToolCallEvent.__name__, agent) before_event = agent.hooks.invoke_callbacks( BeforeToolCallEvent( agent=agent, selected_tool=tool_func, tool_use=tool_use, invocation_state=invocation_state, + interrupt=interrupt, ) ) + if interrupt.activated: + yield ToolInterruptEvent(interrupt) + + interrupt_result: ToolResult = { + "toolUseId": str(tool_use.get("toolUseId")), + "status": "error", + "content": interrupt.to_tool_result_content(), + } + after_event = agent.hooks.invoke_callbacks( + AfterToolCallEvent( + agent=agent, + tool_use=tool_use, + invocation_state=invocation_state, + selected_tool=None, + result=interrupt_result, + ) + ) + yield ToolResultEvent(after_event.result) + tool_results.append(after_event.result) + return + if before_event.cancel_tool: cancel_message = ( before_event.cancel_tool if isinstance(before_event.cancel_tool, str) else "tool cancelled by user" @@ -90,7 +113,7 @@ async def _stream( cancel_result: ToolResult = { "toolUseId": str(tool_use.get("toolUseId")), "status": "error", - "content": [{"text": cancel_message}], + "content": [*interrupt.to_tool_result_content(), {"text": cancel_message}], } after_event = agent.hooks.invoke_callbacks( AfterToolCallEvent( @@ -128,7 +151,7 @@ async def _stream( result: ToolResult = { "toolUseId": str(tool_use.get("toolUseId")), "status": "error", - "content": [{"text": f"Unknown tool: {tool_name}"}], + "content": [*interrupt.to_tool_result_content(), {"text": f"Unknown tool: {tool_name}"}], } after_event = agent.hooks.invoke_callbacks( AfterToolCallEvent( @@ -160,6 +183,7 @@ async def _stream( yield ToolStreamEvent(tool_use, event) result = cast(ToolResult, event) + result["content"] = [*interrupt.to_tool_result_content(), *result["content"]] after_event = agent.hooks.invoke_callbacks( AfterToolCallEvent( @@ -179,7 +203,7 @@ async def _stream( error_result: ToolResult = { "toolUseId": str(tool_use.get("toolUseId")), "status": "error", - "content": [{"text": f"Error: {str(e)}"}], + "content": [*interrupt.to_tool_result_content(), {"text": f"Error: {str(e)}"}], } after_event = agent.hooks.invoke_callbacks( AfterToolCallEvent( diff --git a/src/strands/tools/executors/sequential.py b/src/strands/tools/executors/sequential.py index 60e5c7fa7..f60de9365 100644 --- a/src/strands/tools/executors/sequential.py +++ b/src/strands/tools/executors/sequential.py @@ -5,7 +5,7 @@ from typing_extensions import override from ...telemetry.metrics import Trace -from ...types._events import TypedEvent +from ...types._events import ToolInterruptEvent, TypedEvent from ...types.tools import ToolResult, ToolUse from ._executor import ToolExecutor @@ -39,9 +39,17 @@ async def _execute( Yields: Events from the tool execution stream. """ + interrupted = False + for tool_use in tool_uses: events = ToolExecutor._stream_with_trace( agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state ) async for event in events: + if isinstance(event, ToolInterruptEvent): + interrupted = True + yield event + + if interrupted: + break diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index e20bf658a..d0f5ce7f9 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -5,10 +5,11 @@ agent lifecycle. """ -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Optional, cast from typing_extensions import override +from ..hooks import Interrupt from ..telemetry import EventLoopMetrics from .citations import Citation from .content import Message @@ -220,6 +221,7 @@ def __init__( message: Message, metrics: "EventLoopMetrics", request_state: Any, + interrupts: Optional[list[Interrupt]] = None, ) -> None: """Initialize with the final execution results. @@ -228,8 +230,9 @@ def __init__( message: Final message from the model metrics: Execution metrics and performance data request_state: Final state of the agent execution + interrupts: Interrupts raised by user during agent execution. """ - super().__init__({"stop": (stop_reason, message, metrics, request_state)}) + super().__init__({"stop": (stop_reason, message, metrics, request_state, interrupts)}) @property @override @@ -313,12 +316,25 @@ def __init__(self, tool_use: ToolUse, message: str) -> None: @property def tool_use_id(self) -> str: """The id of the tool cancelled.""" - return cast(str, cast(ToolUse, cast(dict, self.get("tool_cancelled_event")).get("tool_use")).get("toolUseId")) + return cast(str, cast(ToolUse, cast(dict, self.get("tool_cancel_event")).get("tool_use")).get("toolUseId")) @property def message(self) -> str: """The tool cancellation message.""" - return cast(str, self["message"]) + return cast(str, self["tool_cancel_event"]["message"]) + + +class ToolInterruptEvent(TypedEvent): + """Event emitted when a tool is interrupted.""" + + def __init__(self, interrupt: Interrupt) -> None: + """Set interrupt in the event payload.""" + super().__init__({"tool_interrupt_event": {"interrupt": interrupt}}) + + @property + def interrupt(self) -> Interrupt: + """The interrupt instance.""" + return cast(Interrupt, self["tool_interrupt_event"]["interrupt"]) class ModelMessageEvent(TypedEvent): diff --git a/src/strands/types/agent.py b/src/strands/types/agent.py index 151c88f89..6166f9cb3 100644 --- a/src/strands/types/agent.py +++ b/src/strands/types/agent.py @@ -6,5 +6,6 @@ from typing import TypeAlias from .content import ContentBlock, Messages +from .interrupt import InterruptContent -AgentInput: TypeAlias = str | list[ContentBlock] | Messages | None +AgentInput: TypeAlias = str | list[ContentBlock] | list[InterruptContent] | Messages | None diff --git a/src/strands/types/event_loop.py b/src/strands/types/event_loop.py index 2c240972b..486aa3c2b 100644 --- a/src/strands/types/event_loop.py +++ b/src/strands/types/event_loop.py @@ -37,6 +37,7 @@ class Metrics(TypedDict): "content_filtered", "end_turn", "guardrail_intervened", + "interrupt", "max_tokens", "stop_sequence", "tool_use", @@ -46,6 +47,7 @@ class Metrics(TypedDict): - "content_filtered": Content was filtered due to policy violation - "end_turn": Normal completion of the response - "guardrail_intervened": Guardrail system intervened +- "interrupt": Agent was interrupted for human input - "max_tokens": Maximum token limit reached - "stop_sequence": Stop sequence encountered - "tool_use": Model requested to use a tool diff --git a/src/strands/types/interrupt.py b/src/strands/types/interrupt.py new file mode 100644 index 000000000..0ae51da54 --- /dev/null +++ b/src/strands/types/interrupt.py @@ -0,0 +1,27 @@ +"""Interrupt related type definitions for human-in-the-loop workflows.""" + +from typing import Any, TypedDict + + +class InterruptResponse(TypedDict): + """User response to an interrupt. + + Attributes: + name: Unique identifier for the interrupt. + event_name: Name of the hook event under which the interrupt was triggered. + response: User response to the interrupt. + """ + + name: str + event_name: str + response: Any + + +class InterruptContent(TypedDict): + """Content block containing an interrupt response for human-in-the-loop workflows. + + Attributes: + interruptResponse: User response to an interrupt event. + """ + + interruptResponse: InterruptResponse diff --git a/tests/strands/agent/hooks/test_events.py b/tests/strands/agent/hooks/test_events.py index 8bbd89c17..30b54f615 100644 --- a/tests/strands/agent/hooks/test_events.py +++ b/tests/strands/agent/hooks/test_events.py @@ -1,4 +1,4 @@ -from unittest.mock import Mock +from unittest.mock import ANY, Mock import pytest @@ -67,6 +67,7 @@ def before_tool_event(agent, tool, tool_use, tool_invocation_state): selected_tool=tool, tool_use=tool_use, invocation_state=tool_invocation_state, + interrupt=ANY, ) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 2cd87c26d..52987a7ea 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1877,3 +1877,157 @@ def test_tool(action: str) -> str: assert '"action": "test_value"' in tool_call_text assert '"agent"' not in tool_call_text assert '"extra_param"' not in tool_call_text + + +def test_agent__call__resume_interrupt(mock_model, tool_decorated, agenerator): + mock_model.mock_stream.return_value = agenerator( + [ + {"contentBlockStart": {"start": {"text": ""}}}, + {"contentBlockDelta": {"delta": {"text": "resumed"}}}, + {"contentBlockStop": {}}, + ] + ) + + agent = Agent( + messages=[ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "t1", + "name": "tool_decorated", + "input": {"random_string": "test input"}, + } + } + ], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "t1", + "status": "error", + "content": [ + { + "json": { + "interrupt": { + "name": "tool_decorated", + "event_name": "BeforeToolCallEvent", + "reasons": ["test reason"], + }, + }, + }, + ], + }, + }, + ], + }, + ], + model=mock_model, + tools=[tool_decorated], + ) + + prompt = [ + { + "interruptResponse": { + "name": "tool_decorated", + "event_name": "BeforeToolCallEvent", + "response": "user response", + } + } + ] + agent(prompt) + + tru_result_message = agent.messages[-2] + exp_result_message = { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "t1", + "status": "success", + "content": [ + { + "json": { + "interrupt": { + "name": "tool_decorated", + "event_name": "BeforeToolCallEvent", + "reasons": ["test reason"], + }, + }, + }, + {"text": "test input"}, + ], + }, + }, + ], + } + assert tru_result_message == exp_result_message + + assert not agent._interrupts + + +def test_agent__call__resume_interrupt_invalid_prompt(mock_model): + agent = Agent( + messages=[ + { + "role": "user", + "content": [ + { + "toolResult": { + "status": "error", + "content": [ + { + "json": { + "interrupt": { + "name": "test_interrupt", + "event_name": "test_event", + "reasons": ["test reason"], + } + } + } + ], + } + } + ], + } + ], + model=mock_model, + ) + + with pytest.raises(TypeError, match="must resume from interrupt with list of interruptResponse's"): + agent("invalid") + + +def test_agent__call__resume_interrupt_missing_responses(mock_model): + agent = Agent( + messages=[ + { + "role": "user", + "content": [ + { + "toolResult": { + "status": "error", + "content": [ + { + "json": { + "interrupt": { + "name": "test_interrupt", + "event_name": "test_event", + "reasons": ["test reason"], + } + } + } + ], + } + } + ], + } + ], + model=mock_model, + ) + + with pytest.raises(ValueError, match="missing responses for interrupts"): + agent([]) diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 6c5625e0b..b17f7ca06 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -124,7 +124,11 @@ def test_agent_tool_call(agent, hook_provider, agent_tool): assert length == 6 assert next(events) == BeforeToolCallEvent( - agent=agent, selected_tool=agent_tool, tool_use=tool_use, invocation_state=ANY + agent=agent, + selected_tool=agent_tool, + tool_use=tool_use, + invocation_state=ANY, + interrupt=ANY, ) assert next(events) == AfterToolCallEvent( agent=agent, @@ -170,7 +174,11 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1]) assert next(events) == BeforeToolCallEvent( - agent=agent, selected_tool=agent_tool, tool_use=tool_use, invocation_state=ANY + agent=agent, + selected_tool=agent_tool, + tool_use=tool_use, + invocation_state=ANY, + interrupt=ANY, ) assert next(events) == AfterToolCallEvent( agent=agent, @@ -231,7 +239,11 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1]) assert next(events) == BeforeToolCallEvent( - agent=agent, selected_tool=agent_tool, tool_use=tool_use, invocation_state=ANY + agent=agent, + selected_tool=agent_tool, + tool_use=tool_use, + invocation_state=ANY, + interrupt=ANY, ) assert next(events) == AfterToolCallEvent( agent=agent, diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 2b71f3502..5a7755b7d 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -9,7 +9,9 @@ from strands.hooks import ( AfterModelCallEvent, BeforeModelCallEvent, + BeforeToolCallEvent, HookRegistry, + Interrupt, MessageAddedEvent, ) from strands.telemetry.metrics import EventLoopMetrics @@ -138,6 +140,7 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis mock.event_loop_metrics = EventLoopMetrics() mock.hooks = hook_registry mock.tool_executor = tool_executor + mock._interrupts = {} return mock @@ -169,7 +172,7 @@ async def test_event_loop_cycle_text_response( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _ = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}]} @@ -201,7 +204,7 @@ async def test_event_loop_cycle_text_response_throttling( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _ = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}]} @@ -239,7 +242,7 @@ async def test_event_loop_cycle_exponential_backoff( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _ = events[-1]["stop"] # Verify the final response assert tru_stop_reason == "end_turn" @@ -330,7 +333,7 @@ async def test_event_loop_cycle_tool_result( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _ = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}]} @@ -445,7 +448,7 @@ async def test_event_loop_cycle_stop( invocation_state={"request_state": {"stop_event_loop": True}}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _ = events[-1]["stop"] exp_stop_reason = "tool_use" exp_message = { @@ -747,7 +750,7 @@ async def test_request_state_initialization(alist): invocation_state={}, ) events = await alist(stream) - _, _, _, tru_request_state = events[-1]["stop"] + _, _, _, tru_request_state, _ = events[-1]["stop"] # Verify request_state was initialized to empty dict assert tru_request_state == {} @@ -759,7 +762,7 @@ async def test_request_state_initialization(alist): invocation_state={"request_state": initial_request_state}, ) events = await alist(stream) - _, _, _, tru_request_state = events[-1]["stop"] + _, _, _, tru_request_state, _ = events[-1]["stop"] # Verify existing request_state was preserved assert tru_request_state == initial_request_state @@ -862,3 +865,168 @@ async def test_event_loop_cycle_exception_model_hooks(mock_sleep, agent, model, assert next(events) == MessageAddedEvent( agent=agent, message={"content": [{"text": "test text"}], "role": "assistant"} ) + + +@pytest.mark.asyncio +async def test_event_loop_cycle_interrupt(agent, model, tool_stream, agenerator, alist): + def interrupt_callback(event): + event.interrupt("test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + model.stream.side_effect = [agenerator(tool_stream)] + + stream = strands.event_loop.event_loop.event_loop_cycle(agent, invocation_state={}) + events = await alist(stream) + + tru_stop_reason, _, _, _, tru_interrupts = events[-1]["stop"] + exp_stop_reason = "interrupt" + exp_interrupts = [ + Interrupt( + name="tool_for_testing", + event_name="BeforeToolCallEvent", + reasons=["test reason"], + activated=True, + ), + ] + + assert tru_stop_reason == exp_stop_reason and tru_interrupts == exp_interrupts + + tru_result_message = agent.messages[-1] + exp_result_message = { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "t1", + "status": "error", + "content": [ + { + "json": { + "interrupt": { + "name": "tool_for_testing", + "event_name": "BeforeToolCallEvent", + "reasons": ["test reason"], + }, + }, + }, + ], + } + } + ], + } + assert tru_result_message == exp_result_message + + +@pytest.mark.asyncio +async def test_event_loop_cycle_interrupt_resume(agent, model, tool, tool_times_2, agenerator, alist): + agent._interrupts = { + ("tool_for_testing", "BeforeToolCallEvent"): Interrupt( + name="tool_for_testing", + event_name="BeforeToolCallEvent", + reasons=["test reason"], + response="test response", + ), + } + + agent.messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "t1", + "name": "tool_for_testing", + "input": {"random_string": "test input"}, + } + }, + { + "toolUse": { + "toolUseId": "t2", + "name": "tool_times_2", + "input": {}, + } + }, + ], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "t1", + "status": "error", + "content": [ + { + "json": { + "interrupt": { + "name": "tool_for_testing", + "event_name": "BeforeToolCallEvent", + "reasons": ["test reason"], + }, + }, + }, + ], + }, + }, + { + "toolResult": { + "toolUseId": "t2", + "status": "success", + "content": [{"text": "t2 result"}], + }, + }, + ], + }, + ] + + interrupt_response = {} + + def interrupt_callback(event): + interrupt_response["response"] = event.interrupt("test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + model.stream.side_effect = [agenerator([{"contentBlockStop": {}}])] + + stream = strands.event_loop.event_loop.event_loop_cycle(agent, invocation_state={}) + events = await alist(stream) + + tru_stop_reason, _, _, _, _ = events[-1]["stop"] + exp_stop_reason = "end_turn" + assert tru_stop_reason == exp_stop_reason + + tru_result_message = agent.messages[-2] + exp_result_message = { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "t1", + "status": "success", + "content": [ + { + "json": { + "interrupt": { + "name": "tool_for_testing", + "event_name": "BeforeToolCallEvent", + "reasons": ["test reason"], + }, + }, + }, + {"text": "test input"}, + ], + } + }, + { + "toolResult": { + "toolUseId": "t2", + "status": "success", + "content": [{"text": "t2 result"}], + }, + }, + ], + } + assert tru_result_message == exp_result_message + + assert not agent._interrupts diff --git a/tests/strands/experimental/hooks/test_hook_aliases.py b/tests/strands/experimental/hooks/test_hook_aliases.py index db9cd3783..ca314d26d 100644 --- a/tests/strands/experimental/hooks/test_hook_aliases.py +++ b/tests/strands/experimental/hooks/test_hook_aliases.py @@ -7,7 +7,7 @@ import importlib import sys -from unittest.mock import Mock +from unittest.mock import ANY, Mock from strands.experimental.hooks import ( AfterModelInvocationEvent, @@ -44,6 +44,7 @@ def test_before_tool_call_event_type_equality(): selected_tool=Mock(), tool_use={"name": "test", "toolUseId": "123", "input": {}}, invocation_state={}, + interrupt=ANY, ) assert isinstance(before_tool_event, BeforeToolInvocationEvent) @@ -100,6 +101,7 @@ def experimental_callback(event: BeforeToolInvocationEvent): selected_tool=Mock(), tool_use={"name": "test", "toolUseId": "123", "input": {}}, invocation_state={}, + interrupt=ANY, ) # Invoke callbacks - should work since alias points to same type diff --git a/tests/strands/hooks/test_interrupt.py b/tests/strands/hooks/test_interrupt.py new file mode 100644 index 000000000..57a025df0 --- /dev/null +++ b/tests/strands/hooks/test_interrupt.py @@ -0,0 +1,81 @@ +import unittest.mock + +import pytest + +from strands.hooks import Interrupt, InterruptException + + +@pytest.fixture +def interrupt(): + return Interrupt( + name="test", + event_name="test_event", + reasons=[], + ) + + +@pytest.fixture +def agent(): + instance = unittest.mock.Mock() + instance._interrupts = {} + return instance + + +def test_interrupt__call__(interrupt): + with pytest.raises(InterruptException) as exception: + interrupt("test reason") + + tru_interrupt = exception.value.interrupt + exp_interrupt = Interrupt( + name="test", + event_name="test_event", + reasons=["test reason"], + activated=True, + ) + assert tru_interrupt == exp_interrupt + + +def test_interrupt__call__with_response(interrupt): + interrupt.activated = True + interrupt.response = "test response" + + tru_response = interrupt("test reason") + exp_response = "test response" + + assert tru_response == exp_response + assert not interrupt.activated + + +@pytest.mark.parametrize( + ("reasons", "exp_content"), + [ + ( + ["test reason"], + [{"json": {"interrupt": {"name": "test", "event_name": "test_event", "reasons": ["test reason"]}}}], + ), + ([], []), + ], +) +def test_interrupt_to_tool_result_content(reasons, exp_content, interrupt): + interrupt.reasons = reasons + + tru_content = interrupt.to_tool_result_content() + assert tru_content == exp_content + + +def test_interrupt_from_agent(agent): + exp_interrupt = Interrupt(name="test", event_name="test_event", reasons=["test reason"], response="test response") + agent._interrupts = {("test", "test_event"): exp_interrupt} + + tru_interrupt = Interrupt.from_agent("test", "test_event", agent) + assert tru_interrupt == exp_interrupt + + +def test_interrupt_from_agent_empty(agent): + tru_interrupt = Interrupt.from_agent("test", "test_event", agent) + exp_interrupt = Interrupt( + name="test", + event_name="test_event", + reasons=[], + ) + assert tru_interrupt == exp_interrupt diff --git a/tests/strands/hooks/test_registry.py b/tests/strands/hooks/test_registry.py new file mode 100644 index 000000000..a90e08d7d --- /dev/null +++ b/tests/strands/hooks/test_registry.py @@ -0,0 +1,33 @@ +import unittest.mock + +import pytest + +from strands.hooks import BeforeToolCallEvent, HookRegistry, Interrupt + + +@pytest.fixture +def registry(): + return HookRegistry() + + +def test_hook_registry_invoke_callbacks_interrupt(registry): + interrupt = Interrupt(name="test", event_name="BeforeToolCallEvent", reasons=[]) + event = BeforeToolCallEvent( + agent=unittest.mock.Mock(), + selected_tool=None, + tool_use={"toolUseId": "test", "name": "test_tool", "input": {}}, + invocation_state={}, + interrupt=interrupt, + ) + + callback1 = unittest.mock.Mock(side_effect=lambda event: event.interrupt("test reason")) + callback2 = unittest.mock.Mock() + + registry.add_callback(BeforeToolCallEvent, callback1) + registry.add_callback(BeforeToolCallEvent, callback2) + + registry.invoke_callbacks(event) + + callback1.assert_called_once_with(event) + callback2.assert_called_once_with(event) + assert interrupt.activated diff --git a/tests/strands/tools/executors/conftest.py b/tests/strands/tools/executors/conftest.py index be90226f6..550217b8c 100644 --- a/tests/strands/tools/executors/conftest.py +++ b/tests/strands/tools/executors/conftest.py @@ -92,6 +92,7 @@ def agent(tool_registry, hook_registry): mock_agent = unittest.mock.Mock() mock_agent.tool_registry = tool_registry mock_agent.hooks = hook_registry + mock_agent._interrupts = {} return mock_agent diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index 2a0a44e10..521e85a71 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -4,10 +4,10 @@ import pytest import strands -from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent +from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent, Interrupt from strands.telemetry.metrics import Trace from strands.tools.executors._executor import ToolExecutor -from strands.types._events import ToolCancelEvent, ToolResultEvent, ToolStreamEvent +from strands.types._events import ToolCancelEvent, ToolInterruptEvent, ToolResultEvent, ToolStreamEvent from strands.types.tools import ToolUse @@ -36,6 +36,8 @@ async def test_executor_stream_yields_result( executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist ): tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + interrupt = Interrupt("weather_tool", "BeforeToolCallEvent", reasons=[]) + stream = executor._stream(agent, tool_use, tool_results, invocation_state) tru_events = await alist(stream) @@ -55,6 +57,7 @@ async def test_executor_stream_yields_result( selected_tool=weather_tool, tool_use=tool_use, invocation_state=invocation_state, + interrupt=interrupt, ), AfterToolCallEvent( agent=agent, @@ -116,27 +119,6 @@ async def test_executor_stream_passes_through_typed_events( assert tru_events[2] == event_3 -@pytest.mark.asyncio -async def test_executor_stream_wraps_stream_events_if_no_result( - executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist, agenerator -): - tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} - stream = executor._stream(agent, tool_use, tool_results, invocation_state) - - weather_tool.stream = MagicMock() - last_event = ToolStreamEvent(tool_use, "value 1") - # Only ToolResultEvent can be the last value; all others are wrapped in ToolResultEvent - weather_tool.stream.return_value = agenerator( - [ - last_event, - ] - ) - - tru_events = await alist(stream) - exp_events = [last_event, ToolResultEvent(last_event)] - assert tru_events == exp_events - - @pytest.mark.asyncio async def test_executor_stream_yields_tool_error( executor, agent, tool_results, invocation_state, hook_events, exception_tool, alist @@ -250,3 +232,103 @@ def cancel_callback(event): tru_results = tool_results exp_results = [exp_events[-1].tool_result] assert tru_results == exp_results + + +@pytest.mark.asyncio +async def test_executor_stream_interrupt(executor, agent, tool_results, invocation_state, alist): + tool_use = {"name": "weather_tool", "toolUseId": "1", "input": {}} + + interrupt = Interrupt( + name="weather_tool", + event_name="BeforeToolCallEvent", + reasons=["test reason"], + activated=True, + ) + + def interrupt_callback(event): + event.interrupt("test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + exp_events = [ + ToolInterruptEvent(interrupt), + ToolResultEvent( + { + "toolUseId": "1", + "status": "error", + "content": [ + { + "json": { + "interrupt": { + "name": "weather_tool", + "event_name": "BeforeToolCallEvent", + "reasons": ["test reason"], + }, + }, + }, + ], + } + ), + ] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [exp_events[-1].tool_result] + assert tru_results == exp_results + + +@pytest.mark.asyncio +async def test_executor_stream_interrupt_resume(executor, agent, tool_results, invocation_state, alist): + tool_use = {"name": "weather_tool", "toolUseId": "1", "input": {}} + + interrupt = Interrupt( + name="weather_tool", + event_name="BeforeToolCallEvent", + reasons=["test reason"], + response="test response", + activated=True, + ) + agent._interrupts = {(interrupt.name, interrupt.event_name): interrupt} + + interrupt_response = {} + + def interrupt_callback(event): + interrupt_response["response"] = event.interrupt("test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + exp_events = [ + ToolResultEvent( + { + "toolUseId": "1", + "status": "success", + "content": [ + { + "json": { + "interrupt": { + "name": "weather_tool", + "event_name": "BeforeToolCallEvent", + "reasons": ["test reason"], + }, + }, + }, + {"text": "sunny"}, + ], + } + ), + ] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [exp_events[-1].tool_result] + assert tru_results == exp_results + + tru_response = interrupt_response["response"] + exp_response = "test response" + assert tru_response == exp_response diff --git a/tests/strands/tools/executors/test_sequential.py b/tests/strands/tools/executors/test_sequential.py index 37e098142..ab49af108 100644 --- a/tests/strands/tools/executors/test_sequential.py +++ b/tests/strands/tools/executors/test_sequential.py @@ -1,7 +1,8 @@ import pytest +from strands.hooks import BeforeToolCallEvent, Interrupt from strands.tools.executors import SequentialToolExecutor -from strands.types._events import ToolResultEvent +from strands.types._events import ToolInterruptEvent, ToolResultEvent @pytest.fixture @@ -29,3 +30,54 @@ async def test_sequential_executor_execute( tru_results = tool_results exp_results = [exp_events[0].tool_result, exp_events[1].tool_result] assert tru_results == exp_results + + +@pytest.mark.asyncio +async def test_sequential_executor_interrupt( + executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist +): + interrupt = Interrupt( + name="weather_tool", + event_name="BeforeToolCallEvent", + reasons=["test reason"], + activated=True, + ) + + def interrupt_callback(event): + event.interrupt("test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + tool_uses = [ + {"name": "weather_tool", "toolUseId": "1", "input": {}}, + {"name": "temperature_tool", "toolUseId": "2", "input": {}}, + ] + + stream = executor._execute(agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state) + + tru_events = await alist(stream) + exp_events = [ + ToolInterruptEvent(interrupt), + ToolResultEvent( + { + "toolUseId": "1", + "status": "error", + "content": [ + { + "json": { + "interrupt": { + "name": "weather_tool", + "event_name": "BeforeToolCallEvent", + "reasons": ["test reason"], + }, + }, + }, + ], + } + ), + ] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [exp_events[-1].tool_result] + assert tru_results == exp_results diff --git a/tests_integ/test_interrupt.py b/tests_integ/test_interrupt.py new file mode 100644 index 000000000..f7c8f951e --- /dev/null +++ b/tests_integ/test_interrupt.py @@ -0,0 +1,150 @@ +import json +from unittest.mock import ANY + +import pytest + +from strands import Agent, tool +from strands.hooks import BeforeToolCallEvent, HookProvider, Interrupt + + +@pytest.fixture +def interrupt_hook(): + class Hook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BeforeToolCallEvent, self.interrupt) + + def interrupt(self, event): + response = event.interrupt("need approval") + if response != "APPROVE": + event.cancel_tool = "tool rejected" + + return Hook() + + +@pytest.fixture +def time_tool(): + @tool(name="time_tool") + def func(): + return "12:00" + + return func + + +@pytest.fixture +def agent(interrupt_hook, time_tool): + return Agent(hooks=[interrupt_hook], tools=[time_tool]) + + +@pytest.mark.asyncio +def test_agent_invoke_interrupt(agent): + result = agent("What is the time?") + + tru_stop_reason = result.stop_reason + exp_stop_reason = "interrupt" + assert tru_stop_reason == exp_stop_reason + + tru_interrupts = result.interrupts + exp_interrupts = [ + Interrupt( + name="time_tool", + event_name="BeforeToolCallEvent", + reasons=["need approval"], + activated=True, + ), + ] + assert tru_interrupts == exp_interrupts + + responses = [ + { + "interruptResponse": { + "name": "time_tool", + "event_name": "BeforeToolCallEvent", + "response": "APPROVE", + }, + }, + ] + result = agent(responses) + + tru_stop_reason = result.stop_reason + exp_stop_reason = "end_turn" + assert tru_stop_reason == exp_stop_reason + + tru_result_message = json.dumps(result.message) + exp_result_message = "12:00" + assert exp_result_message in tru_result_message + + tru_tool_result_message = agent.messages[-2] + exp_tool_result_message = { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": ANY, + "status": "success", + "content": [ + { + "json": { + "interrupt": { + "name": "time_tool", + "event_name": "BeforeToolCallEvent", + "reasons": ["need approval"], + }, + }, + }, + {"text": "12:00"}, + ], + }, + }, + ], + } + assert tru_tool_result_message == exp_tool_result_message + + +@pytest.mark.asyncio +def test_agent_invoke_interrupt_reject(agent): + result = agent("What is the time?") + + tru_stop_reason = result.stop_reason + exp_stop_reason = "interrupt" + assert tru_stop_reason == exp_stop_reason + + responses = [ + { + "interruptResponse": { + "name": "time_tool", + "event_name": "BeforeToolCallEvent", + "response": "REJECT", + }, + }, + ] + result = agent(responses) + + tru_stop_reason = result.stop_reason + exp_stop_reason = "end_turn" + assert tru_stop_reason == exp_stop_reason + + tru_tool_result_message = agent.messages[-2] + exp_tool_result_message = { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": ANY, + "status": "error", + "content": [ + { + "json": { + "interrupt": { + "name": "time_tool", + "event_name": "BeforeToolCallEvent", + "reasons": ["need approval"], + }, + }, + }, + {"text": "tool rejected"}, + ], + }, + }, + ], + } + assert tru_tool_result_message == exp_tool_result_message From 4dd5489abeb6d52338bfcf96b768bd41c7dc202e Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Wed, 8 Oct 2025 09:11:37 -0400 Subject: [PATCH 2/2] interrupt responses - do not require event_name --- src/strands/agent/agent.py | 17 ++++++++++------- src/strands/types/interrupt.py | 2 -- tests/strands/agent/test_agent.py | 1 - tests_integ/test_interrupt.py | 2 -- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 2a64d4750..ecd82b86b 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -645,13 +645,16 @@ def _resume_interrupt(self, prompt: AgentInput) -> None: if isinstance(content, dict) and "interruptResponse" in content ] - reasons_map = {(reason["name"], reason["event_name"]): reason for reason in reasons} - responses_map = {(response["name"], response["event_name"]): response for response in responses} - missing_keys = reasons_map.keys() - responses_map.keys() - if missing_keys: - raise ValueError(f"interrupts=<{list(missing_keys)}> | missing responses for interrupts") - - self._interrupts = {key: Interrupt(**{**reasons_map[key], **responses_map[key]}) for key in responses_map} + reasons_map = {reason["name"]: reason for reason in reasons} + responses_map = {response["name"]: response for response in responses} + missing_names = reasons_map.keys() - responses_map.keys() + if missing_names: + raise ValueError(f"interrupts=<{list(missing_names)}> | missing responses for interrupts") + + self._interrupts = { + (name, reasons_map[name]["event_name"]): Interrupt(**{**reasons_map[name], **responses_map[name]}) + for name in responses_map + } async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: """Execute the agent's event loop with the given message and parameters. diff --git a/src/strands/types/interrupt.py b/src/strands/types/interrupt.py index 0ae51da54..494ae0861 100644 --- a/src/strands/types/interrupt.py +++ b/src/strands/types/interrupt.py @@ -8,12 +8,10 @@ class InterruptResponse(TypedDict): Attributes: name: Unique identifier for the interrupt. - event_name: Name of the hook event under which the interrupt was triggered. response: User response to the interrupt. """ name: str - event_name: str response: Any diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 52987a7ea..474cbf095 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1933,7 +1933,6 @@ def test_agent__call__resume_interrupt(mock_model, tool_decorated, agenerator): { "interruptResponse": { "name": "tool_decorated", - "event_name": "BeforeToolCallEvent", "response": "user response", } } diff --git a/tests_integ/test_interrupt.py b/tests_integ/test_interrupt.py index f7c8f951e..ca26cfc46 100644 --- a/tests_integ/test_interrupt.py +++ b/tests_integ/test_interrupt.py @@ -58,7 +58,6 @@ def test_agent_invoke_interrupt(agent): { "interruptResponse": { "name": "time_tool", - "event_name": "BeforeToolCallEvent", "response": "APPROVE", }, }, @@ -112,7 +111,6 @@ def test_agent_invoke_interrupt_reject(agent): { "interruptResponse": { "name": "time_tool", - "event_name": "BeforeToolCallEvent", "response": "REJECT", }, },