diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index cb5b36839..3913cd837 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -156,6 +156,7 @@ class SwarmState: # Total metrics across all agents accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) execution_time: int = 0 # Total execution time in milliseconds + handoff_node: SwarmNode | None = None # The agent to execute next handoff_message: str | None = None # Message passed during agent handoff def should_continue( @@ -537,7 +538,7 @@ def handoff_to_agent(agent_name: str, message: str, context: dict[str, Any] | No # Execute handoff swarm_ref._handle_handoff(target_node, message, context) - return {"status": "success", "content": [{"text": f"Handed off to {agent_name}: {message}"}]} + return {"status": "success", "content": [{"text": f"Handing off to {agent_name}: {message}"}]} except Exception as e: return {"status": "error", "content": [{"text": f"Error in handoff: {str(e)}"}]} @@ -553,21 +554,19 @@ def _handle_handoff(self, target_node: SwarmNode, message: str, context: dict[st ) return - # Update swarm state - previous_agent = cast(SwarmNode, self.state.current_node) - self.state.current_node = target_node + current_node = cast(SwarmNode, self.state.current_node) - # Store handoff message for the target agent + self.state.handoff_node = target_node self.state.handoff_message = message # Store handoff context as shared context if context: for key, value in context.items(): - self.shared_context.add_context(previous_agent, key, value) + self.shared_context.add_context(current_node, key, value) logger.debug( - "from_node=<%s>, to_node=<%s> | handed off from agent to agent", - previous_agent.node_id, + "from_node=<%s>, to_node=<%s> | handing off from agent to agent", + current_node.node_id, target_node.node_id, ) @@ -667,7 +666,6 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato logger.debug("reason=<%s> | stopping execution", reason) break - # Get current node current_node = self.state.current_node if not current_node or current_node.node_id not in self.nodes: logger.error("node=<%s> | node not found", current_node.node_id if current_node else "None") @@ -680,13 +678,8 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato len(self.state.node_history) + 1, ) - # Store the current node before execution to detect handoffs - previous_node = current_node - - # Execute node with timeout protection # TODO: Implement cancellation token to stop _execute_node from continuing try: - # Execute with timeout wrapper for async generator streaming await self.hooks.invoke_callbacks_async( BeforeNodeCallEvent(self, current_node.node_id, invocation_state) ) @@ -699,30 +692,33 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato yield event self.state.node_history.append(current_node) - - # After self.state add current node, swarm state finish updating, we persist here await self.hooks.invoke_callbacks_async( AfterNodeCallEvent(self, current_node.node_id, invocation_state) ) logger.debug("node=<%s> | node execution completed", current_node.node_id) - # Check if handoff occurred during execution - if self.state.current_node is not None and self.state.current_node != previous_node: - # Emit handoff event (single node transition in Swarm) + # Check if handoff requested during execution + if self.state.handoff_node: + previous_node = current_node + current_node = self.state.handoff_node + + self.state.handoff_node = None + self.state.current_node = current_node + handoff_event = MultiAgentHandoffEvent( from_node_ids=[previous_node.node_id], - to_node_ids=[self.state.current_node.node_id], + to_node_ids=[current_node.node_id], message=self.state.handoff_message or "Agent handoff occurred", ) yield handoff_event logger.debug( "from_node=<%s>, to_node=<%s> | handoff detected", previous_node.node_id, - self.state.current_node.node_id, + current_node.node_id, ) + else: - # No handoff occurred, mark swarm as complete logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id) self.state.completion_status = Status.COMPLETED break @@ -866,11 +862,12 @@ def _build_result(self) -> SwarmResult: def serialize_state(self) -> dict[str, Any]: """Serialize the current swarm state to a dictionary.""" status_str = self.state.completion_status.value - next_nodes = ( - [self.state.current_node.node_id] - if self.state.completion_status == Status.EXECUTING and self.state.current_node - else [] - ) + if self.state.handoff_node: + next_nodes = [self.state.handoff_node.node_id] + elif self.state.completion_status == Status.EXECUTING and self.state.current_node: + next_nodes = [self.state.current_node.node_id] + else: + next_nodes = [] return { "type": "swarm", diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index e8a6a5f79..008b2954d 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -1149,3 +1149,30 @@ async def test_swarm_persistence(mock_strands_tracer, mock_use_span): assert final_state["status"] == "completed" assert len(final_state["node_history"]) == 1 assert "test_agent" in final_state["node_results"] + + +@pytest.mark.asyncio +async def test_swarm_handle_handoff(): + first_agent = create_mock_agent("first") + second_agent = create_mock_agent("second") + + swarm = Swarm([first_agent, second_agent]) + + async def handoff_stream(*args, **kwargs): + yield {"agent_start": True} + + swarm._handle_handoff(swarm.nodes["second"], "test message", {}) + + assert swarm.state.current_node.node_id == "first" + assert swarm.state.handoff_node.node_id == "second" + + yield {"result": first_agent.return_value} + + first_agent.stream_async = Mock(side_effect=handoff_stream) + + result = await swarm.invoke_async("test") + assert result.status == Status.COMPLETED + + tru_node_order = [node.node_id for node in result.node_history] + exp_node_order = ["first", "second"] + assert tru_node_order == exp_node_order