Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 24 additions & 27 deletions src/strands/multiagent/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ class SwarmState:
# Total metrics across all agents
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
execution_time: int = 0 # Total execution time in milliseconds
handoff_node: SwarmNode | None = None # The agent to execute next
handoff_message: str | None = None # Message passed during agent handoff

def should_continue(
Expand Down Expand Up @@ -537,7 +538,7 @@ def handoff_to_agent(agent_name: str, message: str, context: dict[str, Any] | No
# Execute handoff
swarm_ref._handle_handoff(target_node, message, context)

return {"status": "success", "content": [{"text": f"Handed off to {agent_name}: {message}"}]}
return {"status": "success", "content": [{"text": f"Handing off to {agent_name}: {message}"}]}
except Exception as e:
return {"status": "error", "content": [{"text": f"Error in handoff: {str(e)}"}]}

Expand All @@ -553,21 +554,19 @@ def _handle_handoff(self, target_node: SwarmNode, message: str, context: dict[st
)
return

# Update swarm state
previous_agent = cast(SwarmNode, self.state.current_node)
self.state.current_node = target_node
current_node = cast(SwarmNode, self.state.current_node)

# Store handoff message for the target agent
self.state.handoff_node = target_node
self.state.handoff_message = message

# Store handoff context as shared context
if context:
for key, value in context.items():
self.shared_context.add_context(previous_agent, key, value)
self.shared_context.add_context(current_node, key, value)

logger.debug(
"from_node=<%s>, to_node=<%s> | handed off from agent to agent",
previous_agent.node_id,
"from_node=<%s>, to_node=<%s> | handing off from agent to agent",
current_node.node_id,
target_node.node_id,
)

Expand Down Expand Up @@ -667,7 +666,6 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
logger.debug("reason=<%s> | stopping execution", reason)
break

# Get current node
current_node = self.state.current_node
if not current_node or current_node.node_id not in self.nodes:
logger.error("node=<%s> | node not found", current_node.node_id if current_node else "None")
Expand All @@ -680,13 +678,8 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
len(self.state.node_history) + 1,
)

# Store the current node before execution to detect handoffs
previous_node = current_node

# Execute node with timeout protection
# TODO: Implement cancellation token to stop _execute_node from continuing
try:
# Execute with timeout wrapper for async generator streaming
await self.hooks.invoke_callbacks_async(
BeforeNodeCallEvent(self, current_node.node_id, invocation_state)
)
Expand All @@ -699,30 +692,33 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
yield event

self.state.node_history.append(current_node)

# After self.state add current node, swarm state finish updating, we persist here
await self.hooks.invoke_callbacks_async(
AfterNodeCallEvent(self, current_node.node_id, invocation_state)
)

logger.debug("node=<%s> | node execution completed", current_node.node_id)

# Check if handoff occurred during execution
if self.state.current_node is not None and self.state.current_node != previous_node:
# Emit handoff event (single node transition in Swarm)
# Check if handoff requested during execution
if self.state.handoff_node:
previous_node = current_node
current_node = self.state.handoff_node

self.state.handoff_node = None
self.state.current_node = current_node

handoff_event = MultiAgentHandoffEvent(
from_node_ids=[previous_node.node_id],
to_node_ids=[self.state.current_node.node_id],
to_node_ids=[current_node.node_id],
message=self.state.handoff_message or "Agent handoff occurred",
)
yield handoff_event
logger.debug(
"from_node=<%s>, to_node=<%s> | handoff detected",
previous_node.node_id,
self.state.current_node.node_id,
current_node.node_id,
)

else:
# No handoff occurred, mark swarm as complete
logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id)
self.state.completion_status = Status.COMPLETED
break
Expand Down Expand Up @@ -866,11 +862,12 @@ def _build_result(self) -> SwarmResult:
def serialize_state(self) -> dict[str, Any]:
"""Serialize the current swarm state to a dictionary."""
status_str = self.state.completion_status.value
next_nodes = (
[self.state.current_node.node_id]
if self.state.completion_status == Status.EXECUTING and self.state.current_node
else []
)
if self.state.handoff_node:
next_nodes = [self.state.handoff_node.node_id]
elif self.state.completion_status == Status.EXECUTING and self.state.current_node:
next_nodes = [self.state.current_node.node_id]
else:
next_nodes = []

return {
"type": "swarm",
Expand Down
27 changes: 27 additions & 0 deletions tests/strands/multiagent/test_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,3 +1149,30 @@ async def test_swarm_persistence(mock_strands_tracer, mock_use_span):
assert final_state["status"] == "completed"
assert len(final_state["node_history"]) == 1
assert "test_agent" in final_state["node_results"]


@pytest.mark.asyncio
async def test_swarm_handle_handoff():
first_agent = create_mock_agent("first")
second_agent = create_mock_agent("second")

swarm = Swarm([first_agent, second_agent])

async def handoff_stream(*args, **kwargs):
yield {"agent_start": True}

swarm._handle_handoff(swarm.nodes["second"], "test message", {})

assert swarm.state.current_node.node_id == "first"
assert swarm.state.handoff_node.node_id == "second"

yield {"result": first_agent.return_value}

first_agent.stream_async = Mock(side_effect=handoff_stream)

result = await swarm.invoke_async("test")
assert result.status == Status.COMPLETED

tru_node_order = [node.node_id for node in result.node_history]
exp_node_order = ["first", "second"]
assert tru_node_order == exp_node_order
Loading