From ce9efae6fc1f8af7647afbafe4adf6ba1418ea7b Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Thu, 6 Nov 2025 16:54:13 -0500 Subject: [PATCH] interrupts - swarm - from agent --- .../experimental/hooks/multiagent/events.py | 25 ++- src/strands/interrupt.py | 2 + src/strands/multiagent/base.py | 40 +++-- src/strands/multiagent/graph.py | 4 +- src/strands/multiagent/swarm.py | 162 +++++++++++++----- src/strands/types/_events.py | 43 +++++ src/strands/types/interrupt.py | 17 +- src/strands/types/multiagent.py | 3 +- tests/strands/multiagent/test_swarm.py | 43 +++++ tests_integ/interrupts/multiagent/__init__.py | 0 .../interrupts/multiagent/test_agent.py | 67 ++++++++ .../interrupts/multiagent/test_hook.py | 133 ++++++++++++++ .../interrupts/multiagent/test_session.py | 77 +++++++++ 13 files changed, 553 insertions(+), 63 deletions(-) create mode 100644 tests_integ/interrupts/multiagent/__init__.py create mode 100644 tests_integ/interrupts/multiagent/test_agent.py create mode 100644 tests_integ/interrupts/multiagent/test_hook.py create mode 100644 tests_integ/interrupts/multiagent/test_session.py diff --git a/src/strands/experimental/hooks/multiagent/events.py b/src/strands/experimental/hooks/multiagent/events.py index 9e54296a4..490365235 100644 --- a/src/strands/experimental/hooks/multiagent/events.py +++ b/src/strands/experimental/hooks/multiagent/events.py @@ -5,10 +5,14 @@ is used—hooks read from the orchestrator directly. """ +import uuid from dataclasses import dataclass from typing import TYPE_CHECKING, Any +from typing_extensions import override + from ....hooks import BaseHookEvent +from ....types.interrupt import _Interruptible if TYPE_CHECKING: from ....multiagent.base import MultiAgentBase @@ -28,18 +32,37 @@ class MultiAgentInitializedEvent(BaseHookEvent): @dataclass -class BeforeNodeCallEvent(BaseHookEvent): +class BeforeNodeCallEvent(BaseHookEvent, _Interruptible): """Event triggered before individual node execution starts. Attributes: source: The multi-agent orchestrator instance node_id: ID of the node about to execute invocation_state: Configuration that user passes in + cancel_node: A user defined message that when set, will cancel the node execution. + The message will be placed into the node result with an error status. If set to `True`, Strands will cancel + the node and use a default cancel message. """ source: "MultiAgentBase" node_id: str invocation_state: dict[str, Any] | None = None + cancel_node: bool | str = False + + def _can_write(self, name: str) -> bool: + return name in ["cancel_node"] + + @override + def _interrupt_id(self, name: str) -> str: + """Unique id for the interrupt. + + Args: + name: User defined name for the interrupt. + + Returns: + Interrupt id. + """ + return f"v1:before_node_call:{self.node_id}:{uuid.uuid5(uuid.NAMESPACE_OID, name)}" @dataclass diff --git a/src/strands/interrupt.py b/src/strands/interrupt.py index da89d772b..85997c9be 100644 --- a/src/strands/interrupt.py +++ b/src/strands/interrupt.py @@ -99,6 +99,8 @@ def resume(self, prompt: "AgentInput") -> None: self.interrupts[interrupt_id].response = interrupt_response + self.context["responses"] = contents + def to_dict(self) -> dict[str, Any]: """Serialize to dict for session management.""" return asdict(self) diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 0a1628530..d3d4fca3b 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -12,6 +12,7 @@ from .._async import run_async from ..agent import AgentResult +from ..interrupt import Interrupt from ..types.event_loop import Metrics, Usage from ..types.multiagent import MultiAgentInput @@ -19,22 +20,26 @@ class Status(Enum): - """Execution status for both graphs and nodes.""" + """Execution status for both graphs and nodes. + + Attributes: + PENDING: Task has not started execution yet. + EXECUTING: Task is currently running. + COMPLETED: Task finished successfully. + FAILED: Task encountered an error and could not complete. + INTERRUPTED: Task was interrupted by user. + """ PENDING = "pending" EXECUTING = "executing" COMPLETED = "completed" FAILED = "failed" + INTERRUPTED = "interrupted" @dataclass class NodeResult: - """Unified result from node execution - handles both Agent and nested MultiAgentBase results. - - The status field represents the semantic outcome of the node's work: - - COMPLETED: The node's task was successfully accomplished - - FAILED: The node's task failed or produced an error - """ + """Unified result from node execution - handles both Agent and nested MultiAgentBase results.""" # Core result data - single AgentResult, nested MultiAgentResult, or Exception result: Union[AgentResult, "MultiAgentResult", Exception] @@ -47,6 +52,7 @@ class NodeResult: accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0)) accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) execution_count: int = 0 + interrupts: list[Interrupt] = field(default_factory=list) def get_agent_results(self) -> list[AgentResult]: """Get all AgentResult objects from this node, flattened if nested.""" @@ -78,6 +84,7 @@ def to_dict(self) -> dict[str, Any]: "accumulated_usage": self.accumulated_usage, "accumulated_metrics": self.accumulated_metrics, "execution_count": self.execution_count, + "interrupts": [interrupt.to_dict() for interrupt in self.interrupts], } @classmethod @@ -100,6 +107,10 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult": usage = _parse_usage(data.get("accumulated_usage", {})) metrics = _parse_metrics(data.get("accumulated_metrics", {})) + interrupts = [] + for interrupt_data in data.get("interrupts", []): + interrupts.append(Interrupt(**interrupt_data)) + return cls( result=result, execution_time=int(data.get("execution_time", 0)), @@ -107,17 +118,13 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult": accumulated_usage=usage, accumulated_metrics=metrics, execution_count=int(data.get("execution_count", 0)), + interrupts=interrupts, ) @dataclass class MultiAgentResult: - """Result from multi-agent execution with accumulated metrics. - - The status field represents the outcome of the MultiAgentBase execution: - - COMPLETED: The execution was successfully accomplished - - FAILED: The execution failed or produced an error - """ + """Result from multi-agent execution with accumulated metrics.""" status: Status = Status.PENDING results: dict[str, NodeResult] = field(default_factory=lambda: {}) @@ -125,6 +132,7 @@ class MultiAgentResult: accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) execution_count: int = 0 execution_time: int = 0 + interrupts: list[Interrupt] = field(default_factory=list) @classmethod def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult": @@ -136,6 +144,10 @@ def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult": usage = _parse_usage(data.get("accumulated_usage", {})) metrics = _parse_metrics(data.get("accumulated_metrics", {})) + interrupts = [] + for interrupt_data in data.get("interrupts", []): + interrupts.append(Interrupt(**interrupt_data)) + multiagent_result = cls( status=Status(data["status"]), results=results, @@ -143,6 +155,7 @@ def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult": accumulated_metrics=metrics, execution_count=int(data.get("execution_count", 0)), execution_time=int(data.get("execution_time", 0)), + interrupts=interrupts, ) return multiagent_result @@ -156,6 +169,7 @@ def to_dict(self) -> dict[str, Any]: "accumulated_metrics": self.accumulated_metrics, "execution_count": self.execution_count, "execution_time": self.execution_time, + "interrupts": [interrupt.to_dict() for interrupt in self.interrupts], } diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 740cbc175..aa5092529 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -964,7 +964,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: if isinstance(self.state.task, str): return [ContentBlock(text=self.state.task)] else: - return self.state.task + return cast(list[ContentBlock], self.state.task) # Combine task with dependency outputs node_input = [] @@ -975,7 +975,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: else: # Add task content blocks with a prefix node_input.append(ContentBlock(text="Original Task:")) - node_input.extend(self.state.task) + node_input.extend(cast(list[ContentBlock], self.state.task)) # Add dependency outputs node_input.append(ContentBlock(text="\nInputs from previous nodes:")) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 1c447f571..1839a7303 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -33,11 +33,14 @@ MultiAgentInitializedEvent, ) from ..hooks import HookProvider, HookRegistry +from ..interrupt import Interrupt, _InterruptState from ..session import SessionManager from ..telemetry import get_tracer from ..tools.decorator import tool from ..types._events import ( MultiAgentHandoffEvent, + MultiAgentNodeCancelEvent, + MultiAgentNodeInterruptEvent, MultiAgentNodeStartEvent, MultiAgentNodeStopEvent, MultiAgentNodeStreamEvent, @@ -59,6 +62,7 @@ class SwarmNode: node_id: str executor: Agent + swarm: Optional["Swarm"] = None _initial_messages: Messages = field(default_factory=list, init=False) _initial_state: AgentState = field(default_factory=AgentState, init=False) @@ -87,7 +91,17 @@ def __repr__(self) -> str: return f"SwarmNode(node_id='{self.node_id}')" def reset_executor_state(self) -> None: - """Reset SwarmNode executor state to initial state when swarm was created.""" + """Reset SwarmNode executor state to initial state when swarm was created. + + . + """ + if self.swarm and self.swarm._interrupt_state.activated: + context = self.swarm._interrupt_state.context[self.node_id] + self.executor.messages = context["messages"] + self.executor.state = AgentState(context["state"]) + self.executor._interrupt_state = _InterruptState.from_dict(context["interrupt_state"]) + return + self.executor.messages = copy.deepcopy(self._initial_messages) self.executor.state = AgentState(self._initial_state.get()) @@ -256,11 +270,14 @@ def __init__( self.shared_context = SharedContext() self.nodes: dict[str, SwarmNode] = {} + self.state = SwarmState( current_node=None, # Placeholder, will be set properly task="", completion_status=Status.PENDING, ) + self._interrupt_state = _InterruptState() + self.tracer = get_tracer() self.session_manager = session_manager @@ -335,6 +352,8 @@ async def stream_async( - multi_agent_node_stop: When a node stops execution - result: Final swarm result """ + self._interrupt_state.resume(task) + if invocation_state is None: invocation_state = {} @@ -342,7 +361,10 @@ async def stream_async( logger.debug("starting swarm execution") - if not self._resume_from_session: + if self._resume_from_session or self._interrupt_state.activated: + self.state.completion_status = Status.EXECUTING + self.state.start_time = time.time() + else: # Initialize swarm state with configuration initial_node = self._initial_node() @@ -352,12 +374,11 @@ async def stream_async( completion_status=Status.EXECUTING, shared_context=self.shared_context, ) - else: - self.state.completion_status = Status.EXECUTING - self.state.start_time = time.time() span = self.tracer.start_multiagent_span(task, "swarm") with trace_api.use_span(span, end_on_exit=True): + interrupts = [] + try: current_node = cast(SwarmNode, self.state.current_node) logger.debug("current_node=<%s> | starting swarm execution with node", current_node.node_id) @@ -369,6 +390,9 @@ async def stream_async( ) async for event in self._execute_swarm(invocation_state): + if isinstance(event, MultiAgentNodeInterruptEvent): + interrupts = event.interrupts + yield event.as_dict() except Exception: @@ -381,7 +405,7 @@ async def stream_async( self._resume_from_session = False # Yield final result after execution_time is set - result = self._build_result() + result = self._build_result(interrupts) yield MultiAgentResultEvent(result=result).as_dict() async def _stream_with_timeout( @@ -445,7 +469,7 @@ def _setup_swarm(self, nodes: list[Agent]) -> None: if node_id in self.nodes: raise ValueError(f"Node ID '{node_id}' is not unique. Each agent must have a unique name.") - self.nodes[node_id] = SwarmNode(node_id=node_id, executor=node) + self.nodes[node_id] = SwarmNode(node_id, node, swarm=self) # Validate entry point if specified if self.entry_point is not None: @@ -645,6 +669,31 @@ def _build_node_input(self, target_node: SwarmNode) -> str: return context_text + def _activate_interrupt(self, node: SwarmNode, interrupts: list[Interrupt]) -> MultiAgentNodeInterruptEvent: + """Activate the interrupt state. + + Args: + node: The interrupted node. + interrupts: The interrupts raised by the user. + + Returns: + MultiAgentNodeInterruptEvent + """ + logger.debug("node=<%s> | node interrupted", node.node_id) + self.state.completion_status = Status.INTERRUPTED + + self._interrupt_state.context[node.node_id] = { + "activated": node.executor._interrupt_state.activated, + "interrupt_state": node.executor._interrupt_state.to_dict(), + "state": node.executor.state.get(), + "messages": node.executor.messages, + } + + self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts}) + self._interrupt_state.activate() + + return MultiAgentNodeInterruptEvent(node.node_id, interrupts) + async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: """Execute swarm and yield TypedEvent objects.""" try: @@ -681,9 +730,24 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato # TODO: Implement cancellation token to stop _execute_node from continuing try: - await self.hooks.invoke_callbacks_async( + before_event, interrupts = await self.hooks.invoke_callbacks_async( BeforeNodeCallEvent(self, current_node.node_id, invocation_state) ) + if interrupts: + yield self._activate_interrupt(current_node, interrupts) + break + + if before_event.cancel_node: + cancel_message = ( + before_event.cancel_node + if isinstance(before_event.cancel_node, str) + else "node cancelled by user" + ) + logger.debug("reason=<%s> | cancelling execution", cancel_message) + yield MultiAgentNodeCancelEvent(current_node.node_id, cancel_message) + self.state.completion_status = Status.FAILED + break + node_stream = self._stream_with_timeout( self._execute_node(current_node, self.state.task, invocation_state), self.node_timeout, @@ -692,6 +756,14 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato async for event in node_stream: yield event + stop_event = cast(MultiAgentNodeStopEvent, event) + node_result = stop_event["node_result"] + if node_result.status == Status.INTERRUPTED: + yield self._activate_interrupt(current_node, node_result.interrupts) + break + + self._interrupt_state.deactivate() + self.state.node_history.append(current_node) await self.hooks.invoke_callbacks_async( AfterNodeCallEvent(self, current_node.node_id, invocation_state) @@ -753,16 +825,20 @@ async def _execute_node( yield start_event try: - # Prepare context for node - context_text = self._build_node_input(node) - node_input = [ContentBlock(text=f"Context:\n{context_text}\n\n")] + if self._interrupt_state.activated and self._interrupt_state.context[node_name]["activated"]: + node_input = self._interrupt_state.context["responses"] + + else: + # Prepare context for node + context_text = self._build_node_input(node) + node_input = [ContentBlock(text=f"Context:\n{context_text}\n\n")] - # Clear handoff message after it's been included in context - self.state.handoff_message = None + # Clear handoff message after it's been included in context + self.state.handoff_message = None - if not isinstance(task, str): - # Include additional ContentBlocks in node input - node_input = node_input + task + if not isinstance(task, str): + # Include additional ContentBlocks in node input + node_input = node_input + cast(list[ContentBlock], task) # Execute node with streaming node.reset_executor_state() @@ -780,13 +856,8 @@ async def _execute_node( if result is None: raise ValueError(f"Node '{node_name}' did not produce a result event") - if result.stop_reason == "interrupt": - node.executor.messages.pop() # remove interrupted tool use message - node.executor._interrupt_state.deactivate() - - raise RuntimeError("user raised interrupt from agent | interrupts are not yet supported in swarms") - execution_time = round((time.time() - start_time) * 1000) + status = Status.INTERRUPTED if result.stop_reason == "interrupt" else Status.COMPLETED # Create NodeResult with extracted metrics result_metrics = getattr(result, "metrics", None) @@ -796,10 +867,11 @@ async def _execute_node( node_result = NodeResult( result=result, execution_time=execution_time, - status=Status.COMPLETED, + status=status, accumulated_usage=usage, accumulated_metrics=metrics, execution_count=1, + interrupts=result.interrupts or [], ) # Store result in state @@ -848,7 +920,7 @@ def _accumulate_metrics(self, node_result: NodeResult) -> None: self.state.accumulated_usage["totalTokens"] += node_result.accumulated_usage.get("totalTokens", 0) self.state.accumulated_metrics["latencyMs"] += node_result.accumulated_metrics.get("latencyMs", 0) - def _build_result(self) -> SwarmResult: + def _build_result(self, interrupts: list[Interrupt]) -> SwarmResult: """Build swarm result from current state.""" return SwarmResult( status=self.state.completion_status, @@ -858,15 +930,18 @@ def _build_result(self) -> SwarmResult: execution_count=len(self.state.node_history), execution_time=self.state.execution_time, node_history=self.state.node_history, + interrupts=interrupts, ) def serialize_state(self) -> dict[str, Any]: """Serialize the current swarm state to a dictionary.""" status_str = self.state.completion_status.value - if self.state.handoff_node: - next_nodes = [self.state.handoff_node.node_id] - elif self.state.completion_status == Status.EXECUTING and self.state.current_node: + if self.state.completion_status == Status.EXECUTING and self.state.current_node: + next_nodes = [self.state.current_node.node_id] + elif self.state.completion_status == Status.INTERRUPTED and self.state.current_node: next_nodes = [self.state.current_node.node_id] + elif self.state.handoff_node: + next_nodes = [self.state.handoff_node.node_id] else: next_nodes = [] @@ -880,8 +955,12 @@ def serialize_state(self) -> dict[str, Any]: "current_task": self.state.task, "context": { "shared_context": getattr(self.state.shared_context, "context", {}) or {}, + "handoff_node": self.state.handoff_node.node_id if self.state.handoff_node else None, "handoff_message": self.state.handoff_message, }, + "_internal_state": { + "interrupt_state": self._interrupt_state.to_dict(), + }, } def deserialize_state(self, payload: dict[str, Any]) -> None: @@ -897,19 +976,23 @@ def deserialize_state(self, payload: dict[str, Any]) -> None: payload: Dictionary containing persisted state data including status, completed nodes, results, and next nodes to execute. """ - if not payload.get("next_nodes_to_execute"): - for node in self.nodes.values(): - node.reset_executor_state() - self.state = SwarmState( - current_node=SwarmNode("", Agent()), - task="", - completion_status=Status.PENDING, - ) - self._resume_from_session = False - return - else: + if "_internal_state" in payload: + internal_state = payload["_internal_state"] + self._interrupt_state = _InterruptState.from_dict(internal_state["interrupt_state"]) + + self._resume_from_session = "next_nodes_to_execute" in payload + if self._resume_from_session: self._from_dict(payload) - self._resume_from_session = True + return + + for node in self.nodes.values(): + node.reset_executor_state() + + self.state = SwarmState( + current_node=SwarmNode("", Agent(), swarm=self), + task="", + completion_status=Status.PENDING, + ) def _from_dict(self, payload: dict[str, Any]) -> None: self.state.completion_status = Status(payload["status"]) @@ -917,6 +1000,7 @@ def _from_dict(self, payload: dict[str, Any]) -> None: context = payload["context"] or {} self.shared_context.context = context.get("shared_context") or {} self.state.handoff_message = context.get("handoff_message") + self.state.handoff_node = self.nodes[context["handoff_node"]] if context.get("handoff_node") else None self.state.node_history = [self.nodes[nid] for nid in (payload.get("node_history") or []) if nid in self.nodes] diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index afce36f2b..459ddd460 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -524,3 +524,46 @@ def __init__(self, node_id: str, agent_event: dict[str, Any]) -> None: "event": agent_event, # Nest agent event to avoid field conflicts } ) + + +class MultiAgentNodeInterruptEvent(TypedEvent): + """Event emitted when a node is interrupted.""" + + def __init__(self, node_id: str, interrupts: list[Interrupt]) -> None: + """Set interrupt in the event payload. + + Args: + node_id: Unique identifier for the node generating the event. + interrupts: Interrupts raised by user. + """ + super().__init__( + { + "type": "multiagent_node_interrupt", + "node_id": node_id, + "interrupts": interrupts, + } + ) + + @property + def interrupts(self) -> list[Interrupt]: + """The interrupt instances.""" + return cast(list[Interrupt], self["interrupts"]) + + +class MultiAgentNodeCancelEvent(TypedEvent): + """Event emitted when a user cancels node execution from their MultiAgentBeforeNodeCallEvent hook.""" + + def __init__(self, node_id: str, message: str) -> None: + """Initialize with tool streaming data. + + Args: + node_id: Unique identifier for the node. + message: The node cancellation message. + """ + super().__init__( + { + "type": "multiagent_node_cancel", + "node_id": node_id, + "message": message, + } + ) diff --git a/src/strands/types/interrupt.py b/src/strands/types/interrupt.py index 001ce6993..59c46e807 100644 --- a/src/strands/types/interrupt.py +++ b/src/strands/types/interrupt.py @@ -71,19 +71,14 @@ def approve(self, event: BeforeToolCallEvent) -> None: - Interrupts are session managed in-between return and user response. """ -from typing import TYPE_CHECKING, Any, Protocol, TypedDict +from typing import Any, Protocol, TypedDict from ..interrupt import Interrupt, InterruptException -if TYPE_CHECKING: - from ..agent import Agent - class _Interruptible(Protocol): """Interface that adds interrupt support to hook events and tools.""" - agent: "Agent" - def interrupt(self, name: str, reason: Any = None, response: Any = None) -> Any: """Trigger the interrupt with a reason. @@ -97,9 +92,17 @@ def interrupt(self, name: str, reason: Any = None, response: Any = None) -> Any: Raises: InterruptException: If human input is required. + RuntimeError: If agent instance attribute not set. """ + for attr_name in ["agent", "source"]: + if hasattr(self, attr_name): + agent = getattr(self, attr_name) + break + else: + raise RuntimeError("agent instance attribute not set") + id = self._interrupt_id(name) - state = self.agent._interrupt_state + state = agent._interrupt_state interrupt_ = state.interrupts.setdefault(id, Interrupt(id, name, reason, response)) if interrupt_.response: diff --git a/src/strands/types/multiagent.py b/src/strands/types/multiagent.py index d9487dbd2..a8fcd4844 100644 --- a/src/strands/types/multiagent.py +++ b/src/strands/types/multiagent.py @@ -3,5 +3,6 @@ from typing import TypeAlias from .content import ContentBlock +from .interrupt import InterruptResponseContent -MultiAgentInput: TypeAlias = str | list[ContentBlock] +MultiAgentInput: TypeAlias = str | list[ContentBlock] | list[InterruptResponseContent] diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 008b2954d..6b6cc2d04 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -7,6 +7,7 @@ from strands.agent import Agent, AgentResult from strands.agent.state import AgentState from strands.hooks.registry import HookRegistry +from strands.interrupt import _InterruptState from strands.multiagent.base import Status from strands.multiagent.swarm import SharedContext, Swarm, SwarmNode, SwarmResult, SwarmState from strands.session.file_session_manager import FileSessionManager @@ -22,6 +23,7 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen agent.id = agent_id or f"{name}_id" agent.messages = [] agent.state = AgentState() # Add state attribute + agent._interrupt_state = _InterruptState() # Add interrupt state agent.tool_registry = Mock() agent.tool_registry.registry = {} agent.tool_registry.process_tools = Mock() @@ -1176,3 +1178,44 @@ async def handoff_stream(*args, **kwargs): tru_node_order = [node.node_id for node in result.node_history] exp_node_order = ["first", "second"] assert tru_node_order == exp_node_order + + +def test_swarm_interrupt_on_before_node_call_event_and_resume(): + """Test interrupt triggered in BeforeNodeCallEvent hook followed by resume.""" + from strands.experimental.hooks.multiagent import BeforeNodeCallEvent + from strands.hooks import HookProvider + + # Create a hook that interrupts + class InterruptHook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BeforeNodeCallEvent, self.interrupt) + + def interrupt(self, event): + if event.node_id == "test_agent": + return event.interrupt("approval_needed", reason="need approval") + + # Create agent and swarm with hook + agent = create_mock_agent("test_agent", "Task completed") + hook = InterruptHook() + swarm = Swarm([agent], hooks=[hook]) + + # First call should result in interrupt + result = swarm("Test task requiring approval") + + # Verify interrupt occurred + assert result.status == Status.INTERRUPTED + assert len(result.interrupts) == 1 + assert result.interrupts[0].name == "approval_needed" + assert result.interrupts[0].reason == "need approval" + + # Resume with interrupt response + interrupt_id = result.interrupts[0].id + responses = [{"interruptResponse": {"interruptId": interrupt_id, "response": "approved"}}] + + # Second call should resume and complete + result = swarm(responses) + + # Verify completion + assert result.status == Status.COMPLETED + assert len(result.node_history) == 1 + assert result.node_history[0].node_id == "test_agent" diff --git a/tests_integ/interrupts/multiagent/__init__.py b/tests_integ/interrupts/multiagent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_integ/interrupts/multiagent/test_agent.py b/tests_integ/interrupts/multiagent/test_agent.py new file mode 100644 index 000000000..36fcfef27 --- /dev/null +++ b/tests_integ/interrupts/multiagent/test_agent.py @@ -0,0 +1,67 @@ +import json +from unittest.mock import ANY + +import pytest + +from strands import Agent, tool +from strands.interrupt import Interrupt +from strands.multiagent import Swarm +from strands.multiagent.base import Status +from strands.types.tools import ToolContext + + +@pytest.fixture +def weather_tool(): + @tool(name="weather_tool", context=True) + def func(tool_context: ToolContext) -> str: + response = tool_context.interrupt("test_interrupt", reason="need weather") + return response + + return func + + +@pytest.fixture +def swarm(weather_tool): + weather_agent = Agent(name="weather", tools=[weather_tool]) + + return Swarm([weather_agent]) + + +def test_swarm_interrupt_agent(swarm): + multiagent_result = swarm("What is the weather?") + + tru_status = multiagent_result.status + exp_status = Status.INTERRUPTED + assert tru_status == exp_status + + tru_interrupts = multiagent_result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_interrupt", + reason="need weather", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "sunny", + }, + }, + ] + multiagent_result = swarm(responses) + + tru_status = multiagent_result.status + exp_status = Status.COMPLETED + assert tru_status == exp_status + + assert len(multiagent_result.results) == 1 + weather_result = multiagent_result.results["weather"] + + weather_message = json.dumps(weather_result.result.message).lower() + assert "sunny" in weather_message diff --git a/tests_integ/interrupts/multiagent/test_hook.py b/tests_integ/interrupts/multiagent/test_hook.py new file mode 100644 index 000000000..be7682082 --- /dev/null +++ b/tests_integ/interrupts/multiagent/test_hook.py @@ -0,0 +1,133 @@ +import json +from unittest.mock import ANY + +import pytest + +from strands import Agent, tool +from strands.experimental.hooks.multiagent import BeforeNodeCallEvent +from strands.hooks import HookProvider +from strands.interrupt import Interrupt +from strands.multiagent import Swarm +from strands.multiagent.base import Status + + +@pytest.fixture +def interrupt_hook(): + class Hook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BeforeNodeCallEvent, self.interrupt) + + def interrupt(self, event): + if event.node_id == "info": + return + + response = event.interrupt("test_interrupt", reason="need approval") + if response != "APPROVE": + event.cancel_node = "node rejected" + + return Hook() + + +@pytest.fixture +def weather_tool(): + @tool(name="weather_tool") + def func(): + return "sunny" + + return func + + +@pytest.fixture +def swarm(interrupt_hook, weather_tool): + info_agent = Agent(name="info") + weather_agent = Agent(name="weather", tools=[weather_tool]) + + return Swarm([info_agent, weather_agent], hooks=[interrupt_hook]) + + +def test_swarm_interrupt(swarm): + multiagent_result = swarm("What is the weather?") + + tru_status = multiagent_result.status + exp_status = Status.INTERRUPTED + assert tru_status == exp_status + + tru_interrupts = multiagent_result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_interrupt", + reason="need approval", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "APPROVE", + }, + }, + ] + multiagent_result = swarm(responses) + + tru_status = multiagent_result.status + exp_status = Status.COMPLETED + assert tru_status == exp_status + + assert len(multiagent_result.results) == 2 + weather_result = multiagent_result.results["weather"] + + weather_message = json.dumps(weather_result.result.message).lower() + assert "sunny" in weather_message + + +@pytest.mark.asyncio +async def test_swarm_interrupt_reject(swarm): + multiagent_result = swarm("What is the weather?") + + tru_status = multiagent_result.status + exp_status = Status.INTERRUPTED + assert tru_status == exp_status + + tru_interrupts = multiagent_result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_interrupt", + reason="need approval", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "REJECT", + }, + }, + ] + tru_cancel_id = None + async for event in swarm.stream_async(responses): + if event.get("type") == "multiagent_node_cancel": + tru_cancel_id = event["node_id"] + + multiagent_result = event["result"] + + exp_cancel_id = "weather" + assert tru_cancel_id == exp_cancel_id + + tru_status = multiagent_result.status + exp_status = Status.FAILED + assert tru_status == exp_status + + assert len(multiagent_result.node_history) == 1 + tru_node_id = multiagent_result.node_history[0].node_id + exp_node_id = "info" + assert tru_node_id == exp_node_id diff --git a/tests_integ/interrupts/multiagent/test_session.py b/tests_integ/interrupts/multiagent/test_session.py new file mode 100644 index 000000000..d6e8cdbf8 --- /dev/null +++ b/tests_integ/interrupts/multiagent/test_session.py @@ -0,0 +1,77 @@ +import json +from unittest.mock import ANY + +import pytest + +from strands import Agent, tool +from strands.interrupt import Interrupt +from strands.multiagent import Swarm +from strands.multiagent.base import Status +from strands.session import FileSessionManager +from strands.types.tools import ToolContext + + +@pytest.fixture +def weather_tool(): + @tool(name="weather_tool", context=True) + def func(tool_context: ToolContext) -> str: + response = tool_context.interrupt("test_interrupt", reason="need weather") + return response + + return func + + +@pytest.fixture +def swarm(weather_tool): + weather_agent = Agent(name="weather", tools=[weather_tool]) + return Swarm([weather_agent]) + + +def test_swarm_interrupt_session(weather_tool, tmpdir): + weather_agent = Agent(name="weather", tools=[weather_tool]) + summarizer_agent = Agent(name="summarizer") + session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + swarm = Swarm([weather_agent, summarizer_agent], session_manager=session_manager) + + multiagent_result = swarm("Can you check the weather and then summarize the results?") + + tru_status = multiagent_result.status + exp_status = Status.INTERRUPTED + assert tru_status == exp_status + + tru_interrupts = multiagent_result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_interrupt", + reason="need weather", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + + weather_agent = Agent(name="weather", tools=[weather_tool]) + summarizer_agent = Agent(name="summarizer") + session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + swarm = Swarm([weather_agent, summarizer_agent], session_manager=session_manager) + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "sunny", + }, + }, + ] + multiagent_result = swarm(responses) + + tru_status = multiagent_result.status + exp_status = Status.COMPLETED + assert tru_status == exp_status + + assert len(multiagent_result.results) == 2 + summarizer_result = multiagent_result.results["summarizer"] + + summarizer_message = json.dumps(summarizer_result.result.message).lower() + assert "sunny" in summarizer_message