diff --git a/src/strands/experimental/hooks/multiagent/events.py b/src/strands/experimental/hooks/multiagent/events.py index 9e54296a4..87066dc81 100644 --- a/src/strands/experimental/hooks/multiagent/events.py +++ b/src/strands/experimental/hooks/multiagent/events.py @@ -35,11 +35,18 @@ class BeforeNodeCallEvent(BaseHookEvent): source: The multi-agent orchestrator instance node_id: ID of the node about to execute invocation_state: Configuration that user passes in + cancel_node: A user defined message that when set, will cancel the node execution with status FAILED. + The message will be emitted under a MultiAgentNodeCancel event. If set to `True`, Strands will cancel the + node using a default cancel message. """ source: "MultiAgentBase" node_id: str invocation_state: dict[str, Any] | None = None + cancel_node: bool | str = False + + def _can_write(self, name: str) -> bool: + return name in ["cancel_node"] @dataclass diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 9f28876bf..e492ed621 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -38,6 +38,7 @@ from ..telemetry import get_tracer from ..types._events import ( MultiAgentHandoffEvent, + MultiAgentNodeCancelEvent, MultiAgentNodeStartEvent, MultiAgentNodeStopEvent, MultiAgentNodeStreamEvent, @@ -776,8 +777,6 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: """Execute a single node and yield TypedEvent objects.""" - await self.hooks.invoke_callbacks_async(BeforeNodeCallEvent(self, node.node_id, invocation_state)) - # Reset the node's state if reset_on_revisit is enabled, and it's being revisited if self.reset_on_revisit and node in self.state.completed_nodes: logger.debug("node_id=<%s> | resetting node state for revisit", node.node_id) @@ -793,8 +792,20 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) ) yield start_event + before_event, _ = await self.hooks.invoke_callbacks_async( + BeforeNodeCallEvent(self, node.node_id, invocation_state) + ) + start_time = time.time() try: + if before_event.cancel_node: + cancel_message = ( + before_event.cancel_node if isinstance(before_event.cancel_node, str) else "node cancelled by user" + ) + logger.debug("reason=<%s> | cancelling execution", cancel_message) + yield MultiAgentNodeCancelEvent(node.node_id, cancel_message) + raise RuntimeError(cancel_message) + # Build node input from satisfied dependencies node_input = self._build_node_input(node) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 3913cd837..4571ae5e0 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -38,6 +38,7 @@ from ..tools.decorator import tool from ..types._events import ( MultiAgentHandoffEvent, + MultiAgentNodeCancelEvent, MultiAgentNodeStartEvent, MultiAgentNodeStopEvent, MultiAgentNodeStreamEvent, @@ -678,11 +679,23 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato len(self.state.node_history) + 1, ) + before_event, _ = await self.hooks.invoke_callbacks_async( + BeforeNodeCallEvent(self, current_node.node_id, invocation_state) + ) + # TODO: Implement cancellation token to stop _execute_node from continuing try: - await self.hooks.invoke_callbacks_async( - BeforeNodeCallEvent(self, current_node.node_id, invocation_state) - ) + if before_event.cancel_node: + cancel_message = ( + before_event.cancel_node + if isinstance(before_event.cancel_node, str) + else "node cancelled by user" + ) + logger.debug("reason=<%s> | cancelling execution", cancel_message) + yield MultiAgentNodeCancelEvent(current_node.node_id, cancel_message) + self.state.completion_status = Status.FAILED + break + node_stream = self._stream_with_timeout( self._execute_node(current_node, self.state.task, invocation_state), self.node_timeout, @@ -692,40 +705,42 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato yield event self.state.node_history.append(current_node) + + except Exception: + logger.exception("node=<%s> | node execution failed", current_node.node_id) + self.state.completion_status = Status.FAILED + break + + finally: await self.hooks.invoke_callbacks_async( AfterNodeCallEvent(self, current_node.node_id, invocation_state) ) - logger.debug("node=<%s> | node execution completed", current_node.node_id) - - # Check if handoff requested during execution - if self.state.handoff_node: - previous_node = current_node - current_node = self.state.handoff_node + logger.debug("node=<%s> | node execution completed", current_node.node_id) - self.state.handoff_node = None - self.state.current_node = current_node + # Check if handoff requested during execution + if self.state.handoff_node: + previous_node = current_node + current_node = self.state.handoff_node - handoff_event = MultiAgentHandoffEvent( - from_node_ids=[previous_node.node_id], - to_node_ids=[current_node.node_id], - message=self.state.handoff_message or "Agent handoff occurred", - ) - yield handoff_event - logger.debug( - "from_node=<%s>, to_node=<%s> | handoff detected", - previous_node.node_id, - current_node.node_id, - ) + self.state.handoff_node = None + self.state.current_node = current_node - else: - logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id) - self.state.completion_status = Status.COMPLETED - break + handoff_event = MultiAgentHandoffEvent( + from_node_ids=[previous_node.node_id], + to_node_ids=[current_node.node_id], + message=self.state.handoff_message or "Agent handoff occurred", + ) + yield handoff_event + logger.debug( + "from_node=<%s>, to_node=<%s> | handoff detected", + previous_node.node_id, + current_node.node_id, + ) - except Exception: - logger.exception("node=<%s> | node execution failed", current_node.node_id) - self.state.completion_status = Status.FAILED + else: + logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id) + self.state.completion_status = Status.COMPLETED break except Exception: diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index afce36f2b..558d3e298 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -524,3 +524,22 @@ def __init__(self, node_id: str, agent_event: dict[str, Any]) -> None: "event": agent_event, # Nest agent event to avoid field conflicts } ) + + +class MultiAgentNodeCancelEvent(TypedEvent): + """Event emitted when a user cancels node execution from their BeforeNodeCallEvent hook.""" + + def __init__(self, node_id: str, message: str) -> None: + """Initialize with cancel message. + + Args: + node_id: Unique identifier for the node. + message: The node cancellation message. + """ + super().__init__( + { + "type": "multiagent_node_cancel", + "node_id": node_id, + "message": message, + } + ) diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index b32356cb4..4875d1bec 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -6,12 +6,14 @@ from strands.agent import Agent, AgentResult from strands.agent.state import AgentState +from strands.experimental.hooks.multiagent import BeforeNodeCallEvent from strands.hooks import AgentInitializedEvent from strands.hooks.registry import HookProvider, HookRegistry from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult from strands.multiagent.graph import Graph, GraphBuilder, GraphEdge, GraphNode, GraphResult, GraphState, Status from strands.session.file_session_manager import FileSessionManager from strands.session.session_manager import SessionManager +from strands.types._events import MultiAgentNodeCancelEvent def create_mock_agent(name, response_text="Default response", metrics=None, agent_id=None): @@ -2033,3 +2035,36 @@ async def test_graph_persisted(mock_strands_tracer, mock_use_span): assert final_state["status"] == "completed" assert len(final_state["completed_nodes"]) == 1 assert "test_node" in final_state["node_results"] + + +@pytest.mark.parametrize( + ("cancel_node", "cancel_message"), + [(True, "node cancelled by user"), ("custom cancel message", "custom cancel message")], +) +@pytest.mark.asyncio +async def test_graph_cancel_node(cancel_node, cancel_message): + def cancel_callback(event): + event.cancel_node = cancel_node + return event + + agent = create_mock_agent("test_agent", "Should not execute") + builder = GraphBuilder() + builder.add_node(agent, "test_agent") + builder.set_entry_point("test_agent") + graph = builder.build() + graph.hooks.add_callback(BeforeNodeCallEvent, cancel_callback) + + stream = graph.stream_async("test task") + + tru_cancel_event = None + with pytest.raises(RuntimeError, match=cancel_message): + async for event in stream: + if event.get("type") == "multiagent_node_cancel": + tru_cancel_event = event + + exp_cancel_event = MultiAgentNodeCancelEvent(node_id="test_agent", message=cancel_message) + assert tru_cancel_event == exp_cancel_event + + tru_status = graph.state.status + exp_status = Status.FAILED + assert tru_status == exp_status diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 008b2954d..66850fa6f 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -1,11 +1,12 @@ import asyncio import time -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import ANY, MagicMock, Mock, patch import pytest from strands.agent import Agent, AgentResult from strands.agent.state import AgentState +from strands.experimental.hooks.multiagent import BeforeNodeCallEvent from strands.hooks.registry import HookRegistry from strands.multiagent.base import Status from strands.multiagent.swarm import SharedContext, Swarm, SwarmNode, SwarmResult, SwarmState @@ -1176,3 +1177,38 @@ async def handoff_stream(*args, **kwargs): tru_node_order = [node.node_id for node in result.node_history] exp_node_order = ["first", "second"] assert tru_node_order == exp_node_order + + +@pytest.mark.parametrize( + ("cancel_node", "cancel_message"), + [(True, "node cancelled by user"), ("custom cancel message", "custom cancel message")], +) +@pytest.mark.asyncio +async def test_swarm_cancel_node(cancel_node, cancel_message, alist): + def cancel_callback(event): + event.cancel_node = cancel_node + return event + + agent = create_mock_agent("test_agent", "Should not execute") + swarm = Swarm([agent]) + swarm.hooks.add_callback(BeforeNodeCallEvent, cancel_callback) + + stream = swarm.stream_async("test task") + + tru_events = await alist(stream) + exp_events = [ + { + "message": cancel_message, + "node_id": "test_agent", + "type": "multiagent_node_cancel", + }, + { + "result": ANY, + "type": "multiagent_result", + }, + ] + assert tru_events == exp_events + + tru_status = swarm.state.completion_status + exp_status = Status.FAILED + assert tru_status == exp_status diff --git a/tests_integ/hooks/multiagent/test_cancel.py b/tests_integ/hooks/multiagent/test_cancel.py new file mode 100644 index 000000000..9267330b7 --- /dev/null +++ b/tests_integ/hooks/multiagent/test_cancel.py @@ -0,0 +1,88 @@ +import pytest + +from strands import Agent +from strands.experimental.hooks.multiagent import BeforeNodeCallEvent +from strands.hooks import HookProvider +from strands.multiagent import GraphBuilder, Swarm +from strands.multiagent.base import Status +from strands.types._events import MultiAgentNodeCancelEvent + + +@pytest.fixture +def cancel_hook(): + class Hook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BeforeNodeCallEvent, self.cancel) + + def cancel(self, event): + if event.node_id == "weather": + event.cancel_node = "test cancel" + + return Hook() + + +@pytest.fixture +def info_agent(): + return Agent(name="info") + + +@pytest.fixture +def weather_agent(): + return Agent(name="weather") + + +@pytest.fixture +def swarm(cancel_hook, info_agent, weather_agent): + return Swarm([info_agent, weather_agent], hooks=[cancel_hook]) + + +@pytest.fixture +def graph(cancel_hook, info_agent, weather_agent): + builder = GraphBuilder() + builder.add_node(info_agent, "info") + builder.add_node(weather_agent, "weather") + builder.add_edge("info", "weather") + builder.set_entry_point("info") + builder.set_hook_providers([cancel_hook]) + + return builder.build() + + +@pytest.mark.asyncio +async def test_swarm_cancel_node(swarm): + tru_cancel_event = None + async for event in swarm.stream_async("What is the weather"): + if event.get("type") == "multiagent_node_cancel": + tru_cancel_event = event + + multiagent_result = event["result"] + + exp_cancel_event = MultiAgentNodeCancelEvent(node_id="weather", message="test cancel") + assert tru_cancel_event == exp_cancel_event + + tru_status = multiagent_result.status + exp_status = Status.FAILED + assert tru_status == exp_status + + assert len(multiagent_result.node_history) == 1 + tru_node_id = multiagent_result.node_history[0].node_id + exp_node_id = "info" + assert tru_node_id == exp_node_id + + +@pytest.mark.asyncio +async def test_graph_cancel_node(graph): + tru_cancel_event = None + with pytest.raises(RuntimeError, match="test cancel"): + async for event in graph.stream_async("What is the weather"): + if event.get("type") == "multiagent_node_cancel": + tru_cancel_event = event + + exp_cancel_event = MultiAgentNodeCancelEvent(node_id="weather", message="test cancel") + assert tru_cancel_event == exp_cancel_event + + state = graph.state + + tru_status = state.status + exp_status = Status.FAILED + assert tru_status == exp_status