From 6e39ad0012122ddc4745125023d4fdd1327aef2a Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Mon, 8 Dec 2025 13:37:45 -0500 Subject: [PATCH 1/7] feat(multiagent): Add structured output support to Multi-Agent Allow for structured output in Graph workflows by allowing for structured output on the graph level and on the agent level as well. This allows for devs to not only get structured output results but also unlocks features such as enabling routing to the correct agent based structured output result of another agent - conditional routing --- src/strands/multiagent/base.py | 32 +++++-- src/strands/multiagent/graph.py | 68 ++++++++++---- src/strands/multiagent/swarm.py | 46 +++++++--- tests/strands/multiagent/test_base.py | 9 +- tests/strands/multiagent/test_graph.py | 4 + tests/strands/multiagent/test_swarm.py | 2 +- tests_integ/test_multiagent_graph.py | 121 ++++++++++++++++++++++++- 7 files changed, 245 insertions(+), 37 deletions(-) diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index f163d05b5..8b3f00712 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -8,7 +8,9 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum -from typing import Any, AsyncIterator, Mapping, Union +from typing import Any, AsyncIterator, Mapping, Type, Union + +from pydantic import BaseModel from .._async import run_async from ..agent import AgentResult @@ -188,7 +190,11 @@ class MultiAgentBase(ABC): @abstractmethod async def invoke_async( - self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + task: MultiAgentInput, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> MultiAgentResult: """Invoke asynchronously. @@ -196,12 +202,18 @@ async def invoke_async( task: The task to execute invocation_state: Additional state/context passed to underlying agents. Defaults to None to avoid mutable default argument issues. + structured_output_model: Pydantic model to use for structured output from nodes. + Individual nodes may override this with their own default model. **kwargs: Additional keyword arguments passed to underlying agents. """ raise NotImplementedError("invoke_async not implemented") async def stream_async( - self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + task: MultiAgentInput, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> AsyncIterator[dict[str, Any]]: """Stream events during multi-agent execution. @@ -212,6 +224,8 @@ async def stream_async( task: The task to execute invocation_state: Additional state/context passed to underlying agents. Defaults to None to avoid mutable default argument issues. + structured_output_model: Pydantic model to use for structured output from nodes. + Individual nodes may override this with their own default model. **kwargs: Additional keyword arguments passed to underlying agents. Yields: @@ -222,11 +236,15 @@ async def stream_async( """ # Default implementation for backward compatibility # Execute invoke_async and yield the result as a single event - result = await self.invoke_async(task, invocation_state, **kwargs) + result = await self.invoke_async(task, invocation_state, structured_output_model, **kwargs) yield {"result": result} def __call__( - self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + task: MultiAgentInput, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> MultiAgentResult: """Invoke synchronously. @@ -234,6 +252,8 @@ def __call__( task: The task to execute invocation_state: Additional state/context passed to underlying agents. Defaults to None to avoid mutable default argument issues. + structured_output_model: Pydantic model to use for structured output from nodes. + Individual nodes may override this with their own default model. **kwargs: Additional keyword arguments passed to underlying agents. """ if invocation_state is None: @@ -243,7 +263,7 @@ def __call__( invocation_state.update(kwargs) warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2) - return run_async(lambda: self.invoke_async(task, invocation_state)) + return run_async(lambda: self.invoke_async(task, invocation_state, structured_output_model)) def serialize_state(self) -> dict[str, Any]: """Return a JSON-serializable snapshot of the orchestrator state.""" diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 6156d332c..0de352de0 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -19,9 +19,10 @@ import logging import time from dataclasses import dataclass, field -from typing import Any, AsyncIterator, Callable, Mapping, Optional, Tuple, cast +from typing import Any, AsyncIterator, Callable, Mapping, Optional, Tuple, Type, cast from opentelemetry import trace as trace_api +from pydantic import BaseModel from .._async import run_async from ..agent import Agent @@ -462,7 +463,11 @@ def __init__( run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) def __call__( - self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + task: MultiAgentInput, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> GraphResult: """Invoke the graph synchronously. @@ -470,15 +475,20 @@ def __call__( task: The task to execute invocation_state: Additional state/context passed to underlying agents. Defaults to None to avoid mutable default argument issues. + structured_output_model: Pydantic model to use for structured output from nodes. **kwargs: Keyword arguments allowing backward compatible future changes. """ if invocation_state is None: invocation_state = {} - return run_async(lambda: self.invoke_async(task, invocation_state)) + return run_async(lambda: self.invoke_async(task, invocation_state, structured_output_model)) async def invoke_async( - self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + task: MultiAgentInput, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> GraphResult: """Invoke the graph asynchronously. @@ -489,9 +499,10 @@ async def invoke_async( task: The task to execute invocation_state: Additional state/context passed to underlying agents. Defaults to None to avoid mutable default argument issues. + structured_output_model: Pydantic model to use for structured output from nodes. **kwargs: Keyword arguments allowing backward compatible future changes. """ - events = self.stream_async(task, invocation_state, **kwargs) + events = self.stream_async(task, invocation_state, structured_output_model, **kwargs) final_event = None async for event in events: final_event = event @@ -502,7 +513,11 @@ async def invoke_async( return cast(GraphResult, final_event["result"]) async def stream_async( - self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + task: MultiAgentInput, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> AsyncIterator[dict[str, Any]]: """Stream events during graph execution. @@ -510,6 +525,7 @@ async def stream_async( task: The task to execute invocation_state: Additional state/context passed to underlying agents. Defaults to None to avoid mutable default argument issues. + structured_output_model: Pydantic model to use for structured output from nodes. **kwargs: Keyword arguments allowing backward compatible future changes. Yields: @@ -552,7 +568,7 @@ async def stream_async( self.node_timeout or "None", ) - async for event in self._execute_graph(invocation_state): + async for event in self._execute_graph(invocation_state, structured_output_model): yield event.as_dict() # Set final status based on execution results @@ -591,7 +607,9 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: # Validate Agent-specific constraints for each node _validate_node_executor(node.executor) - async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: + async def _execute_graph( + self, invocation_state: dict[str, Any], structured_output_model: Type[BaseModel] | None = None + ) -> AsyncIterator[Any]: """Execute graph and yield TypedEvent objects.""" ready_nodes = self._resume_next_nodes if self._resume_from_session else list(self.entry_points) @@ -610,7 +628,7 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato ready_nodes.clear() # Execute current batch - async for event in self._execute_nodes_parallel(current_batch, invocation_state): + async for event in self._execute_nodes_parallel(current_batch, invocation_state, structured_output_model): yield event # Find newly ready nodes after batch execution @@ -634,7 +652,10 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato ready_nodes.extend(newly_ready) async def _execute_nodes_parallel( - self, nodes: list["GraphNode"], invocation_state: dict[str, Any] + self, + nodes: list["GraphNode"], + invocation_state: dict[str, Any], + structured_output_model: Type[BaseModel] | None = None, ) -> AsyncIterator[Any]: """Execute multiple nodes in parallel and merge their event streams in real-time. @@ -644,7 +665,12 @@ async def _execute_nodes_parallel( event_queue: asyncio.Queue[Any | None | Exception] = asyncio.Queue() # Start all node streams as independent tasks - tasks = [asyncio.create_task(self._stream_node_to_queue(node, event_queue, invocation_state)) for node in nodes] + tasks = [ + asyncio.create_task( + self._stream_node_to_queue(node, event_queue, invocation_state, structured_output_model) + ) + for node in nodes + ] try: # Consume events from the queue as they arrive @@ -695,6 +721,7 @@ async def _stream_node_to_queue( node: GraphNode, event_queue: asyncio.Queue[Any | None | Exception], invocation_state: dict[str, Any], + structured_output_model: Type[BaseModel] | None = None, ) -> None: """Stream events from a node to the shared queue with optional timeout.""" try: @@ -702,7 +729,7 @@ async def _stream_node_to_queue( if self.node_timeout is not None: async def stream_node() -> None: - async for event in self._execute_node(node, invocation_state): + async for event in self._execute_node(node, invocation_state, structured_output_model): await event_queue.put(event) try: @@ -713,7 +740,7 @@ async def stream_node() -> None: await event_queue.put(timeout_exc) else: # No timeout - stream normally - async for event in self._execute_node(node, invocation_state): + async for event in self._execute_node(node, invocation_state, structured_output_model): await event_queue.put(event) except Exception as e: # Send exception through queue for fail-fast behavior @@ -780,7 +807,12 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[ ) return False - async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: + async def _execute_node( + self, + node: GraphNode, + invocation_state: dict[str, Any], + structured_output_model: Type[BaseModel] | None = None, + ) -> AsyncIterator[Any]: """Execute a single node and yield TypedEvent objects.""" # Reset the node's state if reset_on_revisit is enabled, and it's being revisited if self.reset_on_revisit and node in self.state.completed_nodes: @@ -818,7 +850,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) if isinstance(node.executor, MultiAgentBase): # For nested multi-agent systems, stream their events and collect result multi_agent_result = None - async for event in node.executor.stream_async(node_input, invocation_state): + async for event in node.executor.stream_async(node_input, invocation_state, structured_output_model): # Forward nested multi-agent events with node context wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event) yield wrapped_event @@ -842,7 +874,11 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) elif isinstance(node.executor, Agent): # For agents, stream their events and collect result agent_response = None - async for event in node.executor.stream_async(node_input, invocation_state=invocation_state): + # Use agent's own model if it has one, otherwise use graph-level model + effective_output_model = node.executor._default_structured_output_model or structured_output_model + async for event in node.executor.stream_async( + node_input, invocation_state=invocation_state, structured_output_model=effective_output_model + ): # Forward agent events with node context wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event) yield wrapped_event diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 7eec49649..b8feb1e33 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -19,9 +19,10 @@ import logging import time from dataclasses import dataclass, field -from typing import Any, AsyncIterator, Callable, Mapping, Optional, Tuple, cast +from typing import Any, AsyncIterator, Callable, Mapping, Optional, Tuple, Type, cast from opentelemetry import trace as trace_api +from pydantic import BaseModel from .._async import run_async from ..agent import Agent @@ -300,7 +301,11 @@ def __init__( run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) def __call__( - self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + task: MultiAgentInput, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> SwarmResult: """Invoke the swarm synchronously. @@ -308,14 +313,19 @@ def __call__( task: The task to execute invocation_state: Additional state/context passed to underlying agents. Defaults to None to avoid mutable default argument issues. + structured_output_model: Pydantic model to use for structured output from nodes. **kwargs: Keyword arguments allowing backward compatible future changes. """ if invocation_state is None: invocation_state = {} - return run_async(lambda: self.invoke_async(task, invocation_state)) + return run_async(lambda: self.invoke_async(task, invocation_state, structured_output_model)) async def invoke_async( - self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + task: MultiAgentInput, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> SwarmResult: """Invoke the swarm asynchronously. @@ -326,9 +336,10 @@ async def invoke_async( task: The task to execute invocation_state: Additional state/context passed to underlying agents. Defaults to None to avoid mutable default argument issues. + structured_output_model: Pydantic model to use for structured output from nodes. **kwargs: Keyword arguments allowing backward compatible future changes. """ - events = self.stream_async(task, invocation_state, **kwargs) + events = self.stream_async(task, invocation_state, structured_output_model, **kwargs) final_event = None async for event in events: final_event = event @@ -339,7 +350,11 @@ async def invoke_async( return cast(SwarmResult, final_event["result"]) async def stream_async( - self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + task: MultiAgentInput, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> AsyncIterator[dict[str, Any]]: """Stream events during swarm execution. @@ -347,6 +362,7 @@ async def stream_async( task: The task to execute invocation_state: Additional state/context passed to underlying agents. Defaults to None to avoid mutable default argument issues. + structured_output_model: Pydantic model to use for structured output from nodes. **kwargs: Keyword arguments allowing backward compatible future changes. Yields: @@ -394,7 +410,7 @@ async def stream_async( self.execution_timeout, ) - async for event in self._execute_swarm(invocation_state): + async for event in self._execute_swarm(invocation_state, structured_output_model): if isinstance(event, MultiAgentNodeInterruptEvent): interrupts = event.interrupts @@ -702,7 +718,9 @@ def _activate_interrupt(self, node: SwarmNode, interrupts: list[Interrupt]) -> M return MultiAgentNodeInterruptEvent(node.node_id, interrupts) - async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: + async def _execute_swarm( + self, invocation_state: dict[str, Any], structured_output_model: Type[BaseModel] | None = None + ) -> AsyncIterator[Any]: """Execute swarm and yield TypedEvent objects.""" try: # Main execution loop @@ -758,7 +776,7 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato break node_stream = self._stream_with_timeout( - self._execute_node(current_node, self.state.task, invocation_state), + self._execute_node(current_node, self.state.task, invocation_state, structured_output_model), self.node_timeout, f"Node '{current_node.node_id}' execution timed out after {self.node_timeout}s", ) @@ -825,7 +843,11 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato ) async def _execute_node( - self, node: SwarmNode, task: MultiAgentInput, invocation_state: dict[str, Any] + self, + node: SwarmNode, + task: MultiAgentInput, + invocation_state: dict[str, Any], + structured_output_model: Type[BaseModel] | None = None, ) -> AsyncIterator[Any]: """Execute swarm node and yield TypedEvent objects.""" start_time = time.time() @@ -856,7 +878,9 @@ async def _execute_node( # Stream agent events with node context and capture final result result = None - async for event in node.executor.stream_async(node_input, invocation_state=invocation_state): + async for event in node.executor.stream_async( + node_input, invocation_state=invocation_state, structured_output_model=structured_output_model + ): # Forward agent events with node context wrapped_event = MultiAgentNodeStreamEvent(node_name, event) yield wrapped_event diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index 4e8a5dd06..5ae3114a0 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -142,7 +142,9 @@ class IncompleteMultiAgent(MultiAgentBase): # Test that complete implementations can be instantiated class CompleteMultiAgent(MultiAgentBase): - async def invoke_async(self, task: str) -> MultiAgentResult: + async def invoke_async( + self, task: str, invocation_state=None, structured_output_model=None + ) -> MultiAgentResult: return MultiAgentResult(results={}) def serialize_state(self) -> dict: @@ -164,12 +166,14 @@ def __init__(self): self.invoke_async_called = False self.received_task = None self.received_kwargs = None + self.received_structured_output_model = None - async def invoke_async(self, task, invocation_state, **kwargs): + async def invoke_async(self, task, invocation_state, structured_output_model=None, **kwargs): self.invoke_async_called = True self.received_task = task self.received_kwargs = kwargs self.received_invocation_state = invocation_state + self.received_structured_output_model = structured_output_model return MultiAgentResult( status=Status.COMPLETED, results={"test": NodeResult(result=Exception("test"), status=Status.COMPLETED)} ) @@ -188,6 +192,7 @@ def deserialize_state(self, payload: dict) -> None: assert agent.invoke_async_called assert agent.received_task == "test task" assert agent.received_invocation_state == {"param1": "value1", "param2": "value2", "value3": "value4"} + assert agent.received_structured_output_model is None assert isinstance(result, MultiAgentResult) assert result.status == Status.COMPLETED diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 4875d1bec..1f97d9607 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -23,6 +23,7 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen agent.id = agent_id or f"{name}_id" agent._session_manager = None agent.hooks = HookRegistry() + agent._default_structured_output_model = None # For structured output support if metrics is None: metrics = Mock( @@ -289,6 +290,7 @@ async def test_graph_execution_with_failures(mock_strands_tracer, mock_use_span) # Add required attributes for validation failing_agent._session_manager = None failing_agent.hooks = HookRegistry() + failing_agent._default_structured_output_model = None # For structured output support async def mock_invoke_failure(*args, **kwargs): raise Exception("Simulated failure") @@ -1530,6 +1532,7 @@ async def test_graph_streaming_with_failures(mock_strands_tracer, mock_use_span) failing_agent.id = "fail_node" failing_agent._session_manager = None failing_agent.hooks = HookRegistry() + failing_agent._default_structured_output_model = None # For structured output support async def failing_stream(*args, **kwargs): yield {"agent_start": True} @@ -1703,6 +1706,7 @@ async def test_graph_parallel_with_failures(mock_strands_tracer, mock_use_span): failing_agent.id = "fail_node" failing_agent._session_manager = None failing_agent.hooks = HookRegistry() + failing_agent._default_structured_output_model = None # For structured output support async def mock_invoke_failure(*args, **kwargs): await asyncio.sleep(0.05) # Small delay diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index f2abed9f7..4c6812b87 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -1346,4 +1346,4 @@ def test_swarm_interrupt_on_agent(agenerator): exp_status = Status.COMPLETED assert tru_status == exp_status - agent.stream_async.assert_called_once_with(responses, invocation_state={}) + agent.stream_async.assert_called_once_with(responses, invocation_state={}, structured_output_model=None) diff --git a/tests_integ/test_multiagent_graph.py b/tests_integ/test_multiagent_graph.py index 08343a554..efd81b9f0 100644 --- a/tests_integ/test_multiagent_graph.py +++ b/tests_integ/test_multiagent_graph.py @@ -1,8 +1,9 @@ -from typing import Any, AsyncIterator +from typing import Any, AsyncIterator, Literal from unittest.mock import patch from uuid import uuid4 import pytest +from pydantic import BaseModel from strands import Agent, tool from strands.hooks import ( @@ -533,6 +534,124 @@ async def failing_stream_async(*args, **kwargs): session_manager.delete_session(session_id) +@pytest.mark.asyncio +async def test_graph_structured_output_conditional_routing(): + """Test conditional edge traversal based on structured output from a node.""" + + # Classifier's structured output for routing decisions + class ClassificationResult(BaseModel): + category: Literal["technical", "billing", "general"] + confidence: float + reasoning: str + + # Specialist agents have their own structured output models + class TechnicalResponse(BaseModel): + issue_type: str + resolution: str + + class BillingResponse(BaseModel): + issue_type: str + refund_eligible: bool + + class GeneralResponse(BaseModel): + answer: str + + # Classifier agent uses ClassificationResult for routing + classifier_agent = Agent( + name="classifier", + model="us.amazon.nova-pro-v1:0", + system_prompt=( + "You are a customer support classifier. Classify queries into:\n" + "- technical: software issues, bugs, crashes\n" + "- billing: payments, charges, refunds\n" + "- general: hours, contact info, policies" + ), + structured_output_model=ClassificationResult, + ) + + # Specialist agents - each has their OWN structured output model + technical_agent = Agent( + name="technical_support", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are technical support. Diagnose the issue and provide resolution steps.", + structured_output_model=TechnicalResponse, + ) + + billing_agent = Agent( + name="billing_support", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are billing support. Analyze the billing issue and determine refund eligibility.", + structured_output_model=BillingResponse, + ) + + general_agent = Agent( + name="general_support", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are general support. Answer general inquiries helpfully.", + structured_output_model=GeneralResponse, + ) + + # Conditional edge functions that read classifier's structured output + def route_to_technical(state): + classifier_result = state.results.get("classifier") + if classifier_result and classifier_result.result: + structured = classifier_result.result.structured_output + return structured and structured.category == "technical" + return False + + def route_to_billing(state): + classifier_result = state.results.get("classifier") + if classifier_result and classifier_result.result: + structured = classifier_result.result.structured_output + return structured and structured.category == "billing" + return False + + def route_to_general(state): + classifier_result = state.results.get("classifier") + if classifier_result and classifier_result.result: + structured = classifier_result.result.structured_output + return structured and structured.category == "general" + return False + + # Build the graph with conditional routing + builder = GraphBuilder() + builder.add_node(classifier_agent, "classifier") + builder.add_node(technical_agent, "technical") + builder.add_node(billing_agent, "billing") + builder.add_node(general_agent, "general") + + # Conditional edges from classifier to specialists + builder.add_edge("classifier", "technical", condition=route_to_technical) + builder.add_edge("classifier", "billing", condition=route_to_billing) + builder.add_edge("classifier", "general", condition=route_to_general) + + builder.set_entry_point("classifier") + graph = builder.build() + + # Test: Technical query should route to technical_support + technical_query = "My application keeps crashing when I upload large files" + result = await graph.invoke_async(technical_query) + + assert result.status == Status.COMPLETED + assert "classifier" in result.results + + # Verify classifier's structured output was captured + classifier_result = result.results["classifier"] + assert classifier_result.result.structured_output is not None + assert isinstance(classifier_result.result.structured_output, ClassificationResult) + assert classifier_result.result.structured_output.category == "technical" + + # Verify only technical agent was executed (not billing or general) + assert "technical" in result.results + assert "billing" not in result.results + assert "general" not in result.results + + # Verify technical agent's own structured output was preserved + technical_result = result.results["technical"] + assert technical_result.result.structured_output is not None + assert isinstance(technical_result.result.structured_output, TechnicalResponse) + + @pytest.mark.asyncio async def test_self_loop_resume_from_persisted_state(tmp_path): """Test resuming self-loop from persisted state where next node is itself.""" From 18d2bf8d5c50e8ddaa34be573cd09554248ef070 Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Mon, 8 Dec 2025 14:43:02 -0500 Subject: [PATCH 2/7] use agent's own structured output model over swarm-level Prioritize individual agent's _default_structured_output_model when executing in a swarm, falling back to swarm-level model only if the agent doesn't have one. This allows agents with different structured output schemas to work correctly during handoffs. --- src/strands/multiagent/swarm.py | 4 +- tests/strands/multiagent/test_graph.py | 8 +-- tests/strands/multiagent/test_swarm.py | 1 + tests_integ/test_multiagent_swarm.py | 91 ++++++++++++++++++++++++++ 4 files changed, 99 insertions(+), 5 deletions(-) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index b8feb1e33..5eaf3a720 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -878,8 +878,10 @@ async def _execute_node( # Stream agent events with node context and capture final result result = None + # Use agent's own model if it has one, otherwise use swarm-level model + effective_output_model = node.executor._default_structured_output_model or structured_output_model async for event in node.executor.stream_async( - node_input, invocation_state=invocation_state, structured_output_model=structured_output_model + node_input, invocation_state=invocation_state, structured_output_model=effective_output_model ): # Forward agent events with node context wrapped_event = MultiAgentNodeStreamEvent(node_name, event) diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 1f97d9607..749dd2131 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -23,7 +23,7 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen agent.id = agent_id or f"{name}_id" agent._session_manager = None agent.hooks = HookRegistry() - agent._default_structured_output_model = None # For structured output support + agent._default_structured_output_model = None if metrics is None: metrics = Mock( @@ -290,7 +290,7 @@ async def test_graph_execution_with_failures(mock_strands_tracer, mock_use_span) # Add required attributes for validation failing_agent._session_manager = None failing_agent.hooks = HookRegistry() - failing_agent._default_structured_output_model = None # For structured output support + failing_agent._default_structured_output_model = None async def mock_invoke_failure(*args, **kwargs): raise Exception("Simulated failure") @@ -1532,7 +1532,7 @@ async def test_graph_streaming_with_failures(mock_strands_tracer, mock_use_span) failing_agent.id = "fail_node" failing_agent._session_manager = None failing_agent.hooks = HookRegistry() - failing_agent._default_structured_output_model = None # For structured output support + failing_agent._default_structured_output_model = None async def failing_stream(*args, **kwargs): yield {"agent_start": True} @@ -1706,7 +1706,7 @@ async def test_graph_parallel_with_failures(mock_strands_tracer, mock_use_span): failing_agent.id = "fail_node" failing_agent._session_manager = None failing_agent.hooks = HookRegistry() - failing_agent._default_structured_output_model = None # For structured output support + failing_agent._default_structured_output_model = None async def mock_invoke_failure(*args, **kwargs): await asyncio.sleep(0.05) # Small delay diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 4c6812b87..501261ad8 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -31,6 +31,7 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen agent._call_count = 0 agent._should_fail = should_fail agent._session_manager = None + agent._default_structured_output_model = None agent.hooks = HookRegistry() if metrics is None: diff --git a/tests_integ/test_multiagent_swarm.py b/tests_integ/test_multiagent_swarm.py index e8e969af1..3a3e2e4bb 100644 --- a/tests_integ/test_multiagent_swarm.py +++ b/tests_integ/test_multiagent_swarm.py @@ -1,5 +1,7 @@ +from typing import Literal from uuid import uuid4 +from pydantic import BaseModel import pytest from strands import Agent, tool @@ -357,6 +359,95 @@ async def test_swarm_get_agent_results_flattening(): assert agent_results[0].message is not None +@pytest.mark.asyncio +async def test_swarm_structured_output_with_handoffs(): + """Test that swarm properly handles structured output from each agent during handoffs.""" + + + # Coordinator's analysis model + class CoordinatorAnalysis(BaseModel): + request_type: Literal["technical", "billing", "general"] + complexity: Literal["simple", "moderate", "complex"] + summary: str + + # Technical agent's response model + class TechnicalResponse(BaseModel): + issue_type: str + severity: Literal["low", "medium", "high", "critical"] + troubleshooting_steps: list[str] + resolution_summary: str + + # Billing agent's response model + class BillingResponse(BaseModel): + issue_type: str + refund_eligible: bool + resolution_action: str + + # Coordinator agent that routes to specialists + coordinator = Agent( + name="coordinator", + model="us.amazon.nova-pro-v1:0", + system_prompt=( + "You are a customer support coordinator. Analyze the request and:\n" + "- For technical issues (bugs, crashes, performance), hand off to 'technical_support'\n" + "- For billing issues (charges, refunds), hand off to 'billing_support'\n" + "- You MUST use handoff_to_agent for technical or billing issues." + ), + structured_output_model=CoordinatorAnalysis, + ) + + # Technical specialist + technical_agent = Agent( + name="technical_support", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are technical support. Diagnose the issue and provide resolution steps.", + structured_output_model=TechnicalResponse, + ) + + # Billing specialist + billing_agent = Agent( + name="billing_support", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are billing support. Analyze billing issues and determine resolution.", + structured_output_model=BillingResponse, + ) + + swarm = Swarm( + nodes=[coordinator, technical_agent, billing_agent], + entry_point=coordinator, + max_handoffs=5, + max_iterations=5, + ) + + # Test: Technical query should hand off to technical_support + technical_query = "My application keeps crashing when I upload files larger than 100MB" + result = await swarm.invoke_async(technical_query) + + # Verify completion + assert result.status.value in ["completed", "failed"] + + if result.status.value == "completed": + # Verify coordinator was executed + assert "coordinator" in result.results + coordinator_result = result.results["coordinator"] + assert coordinator_result.result.structured_output is not None + assert isinstance(coordinator_result.result.structured_output, CoordinatorAnalysis) + assert coordinator_result.result.structured_output.request_type == "technical" + + # Verify handoff occurred and technical agent was executed + assert len(result.node_history) >= 2, "Expected at least 2 agents in execution history (coordinator + specialist)" + node_ids = [n.node_id for n in result.node_history] + assert "coordinator" in node_ids + assert "technical_support" in node_ids + + # Verify technical agent's structured output was preserved + assert "technical_support" in result.results + technical_result = result.results["technical_support"] + assert technical_result.result.structured_output is not None + assert isinstance(technical_result.result.structured_output, TechnicalResponse) + assert len(technical_result.result.structured_output.troubleshooting_steps) > 0 + + def test_swarm_resume_from_executing_state(tmpdir, exit_hook, verify_hook): """Test swarm resuming from EXECUTING state using BeforeNodeCallEvent hook.""" session_id = f"swarm_resume_{uuid4()}" From d4eee549cc8ab7e72160dc5178f9b6ad8a032236 Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Mon, 8 Dec 2025 15:15:49 -0500 Subject: [PATCH 3/7] refactor(multiagent): Remove redundant __call__ methods from Swarm and Graph Both Swarm and Graph classes had __call__ implementations that were functionally identical to the base class. Removing them consolidates the synchronous invocation logic in MultiAgentBase, improving maintainability and ensuring consistent kwargs deprecation handling. --- src/strands/multiagent/graph.py | 21 --------------------- src/strands/multiagent/swarm.py | 20 -------------------- tests_integ/test_multiagent_swarm.py | 7 ++++--- 3 files changed, 4 insertions(+), 44 deletions(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 0de352de0..e7a3d4aaa 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -462,27 +462,6 @@ def __init__( run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) - def __call__( - self, - task: MultiAgentInput, - invocation_state: dict[str, Any] | None = None, - structured_output_model: Type[BaseModel] | None = None, - **kwargs: Any, - ) -> GraphResult: - """Invoke the graph synchronously. - - Args: - task: The task to execute - invocation_state: Additional state/context passed to underlying agents. - Defaults to None to avoid mutable default argument issues. - structured_output_model: Pydantic model to use for structured output from nodes. - **kwargs: Keyword arguments allowing backward compatible future changes. - """ - if invocation_state is None: - invocation_state = {} - - return run_async(lambda: self.invoke_async(task, invocation_state, structured_output_model)) - async def invoke_async( self, task: MultiAgentInput, diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 5eaf3a720..5c9ebd63c 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -300,26 +300,6 @@ def __init__( self._inject_swarm_tools() run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) - def __call__( - self, - task: MultiAgentInput, - invocation_state: dict[str, Any] | None = None, - structured_output_model: Type[BaseModel] | None = None, - **kwargs: Any, - ) -> SwarmResult: - """Invoke the swarm synchronously. - - Args: - task: The task to execute - invocation_state: Additional state/context passed to underlying agents. - Defaults to None to avoid mutable default argument issues. - structured_output_model: Pydantic model to use for structured output from nodes. - **kwargs: Keyword arguments allowing backward compatible future changes. - """ - if invocation_state is None: - invocation_state = {} - return run_async(lambda: self.invoke_async(task, invocation_state, structured_output_model)) - async def invoke_async( self, task: MultiAgentInput, diff --git a/tests_integ/test_multiagent_swarm.py b/tests_integ/test_multiagent_swarm.py index 3a3e2e4bb..d8ed28bd0 100644 --- a/tests_integ/test_multiagent_swarm.py +++ b/tests_integ/test_multiagent_swarm.py @@ -1,8 +1,8 @@ from typing import Literal from uuid import uuid4 -from pydantic import BaseModel import pytest +from pydantic import BaseModel from strands import Agent, tool from strands.experimental.hooks.multiagent import BeforeNodeCallEvent @@ -363,7 +363,6 @@ async def test_swarm_get_agent_results_flattening(): async def test_swarm_structured_output_with_handoffs(): """Test that swarm properly handles structured output from each agent during handoffs.""" - # Coordinator's analysis model class CoordinatorAnalysis(BaseModel): request_type: Literal["technical", "billing", "general"] @@ -435,7 +434,9 @@ class BillingResponse(BaseModel): assert coordinator_result.result.structured_output.request_type == "technical" # Verify handoff occurred and technical agent was executed - assert len(result.node_history) >= 2, "Expected at least 2 agents in execution history (coordinator + specialist)" + assert len(result.node_history) >= 2, ( + "Expected at least 2 agents in execution history (coordinator + specialist)" + ) node_ids = [n.node_id for n in result.node_history] assert "coordinator" in node_ids assert "technical_support" in node_ids From 9edebca1d275a85e9437bcdbcca108fb99f3ab90 Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Mon, 8 Dec 2025 15:28:54 -0500 Subject: [PATCH 4/7] refactor(multiagent): Remove redundant __call__ methods from Swarm and Graph Both Swarm and Graph classes had __call__ implementations that were functionally identical to the base class. Removing them consolidates the synchronous invocation logic in MultiAgentBase, improving maintainability and ensuring consistent kwargs deprecation handling. --- tests_integ/test_multiagent_swarm.py | 34 ++++++++-------------------- 1 file changed, 9 insertions(+), 25 deletions(-) diff --git a/tests_integ/test_multiagent_swarm.py b/tests_integ/test_multiagent_swarm.py index d8ed28bd0..c8016c39e 100644 --- a/tests_integ/test_multiagent_swarm.py +++ b/tests_integ/test_multiagent_swarm.py @@ -422,31 +422,15 @@ class BillingResponse(BaseModel): technical_query = "My application keeps crashing when I upload files larger than 100MB" result = await swarm.invoke_async(technical_query) - # Verify completion - assert result.status.value in ["completed", "failed"] - - if result.status.value == "completed": - # Verify coordinator was executed - assert "coordinator" in result.results - coordinator_result = result.results["coordinator"] - assert coordinator_result.result.structured_output is not None - assert isinstance(coordinator_result.result.structured_output, CoordinatorAnalysis) - assert coordinator_result.result.structured_output.request_type == "technical" - - # Verify handoff occurred and technical agent was executed - assert len(result.node_history) >= 2, ( - "Expected at least 2 agents in execution history (coordinator + specialist)" - ) - node_ids = [n.node_id for n in result.node_history] - assert "coordinator" in node_ids - assert "technical_support" in node_ids - - # Verify technical agent's structured output was preserved - assert "technical_support" in result.results - technical_result = result.results["technical_support"] - assert technical_result.result.structured_output is not None - assert isinstance(technical_result.result.structured_output, TechnicalResponse) - assert len(technical_result.result.structured_output.troubleshooting_steps) > 0 + # Verify coordinator was executed + coordinator_result = result.results["coordinator"] + assert coordinator_result.result.structured_output is not None + assert isinstance(coordinator_result.result.structured_output, CoordinatorAnalysis) + assert coordinator_result.result.structured_output.request_type == "technical" + + # Verify technical agent was invoked and structured output was created + technical_result = result.results["technical_support"] + assert isinstance(technical_result.result.structured_output, TechnicalResponse) def test_swarm_resume_from_executing_state(tmpdir, exit_hook, verify_hook): From ae5c65485a7635c231de02dbf6403cd2de40ac7c Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Mon, 8 Dec 2025 15:30:27 -0500 Subject: [PATCH 5/7] update tests --- tests_integ/test_multiagent_graph.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests_integ/test_multiagent_graph.py b/tests_integ/test_multiagent_graph.py index efd81b9f0..a98362152 100644 --- a/tests_integ/test_multiagent_graph.py +++ b/tests_integ/test_multiagent_graph.py @@ -633,11 +633,9 @@ def route_to_general(state): result = await graph.invoke_async(technical_query) assert result.status == Status.COMPLETED - assert "classifier" in result.results # Verify classifier's structured output was captured classifier_result = result.results["classifier"] - assert classifier_result.result.structured_output is not None assert isinstance(classifier_result.result.structured_output, ClassificationResult) assert classifier_result.result.structured_output.category == "technical" @@ -646,9 +644,8 @@ def route_to_general(state): assert "billing" not in result.results assert "general" not in result.results - # Verify technical agent's own structured output was preserved + # Verify technical agent's structured output was generated technical_result = result.results["technical"] - assert technical_result.result.structured_output is not None assert isinstance(technical_result.result.structured_output, TechnicalResponse) From 1a3c5adc50e5be25823105c5d45fdc4fbd175e4f Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Mon, 8 Dec 2025 15:47:24 -0500 Subject: [PATCH 6/7] update tests --- tests_integ/test_multiagent_graph.py | 8 ++++++-- tests_integ/test_multiagent_swarm.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tests_integ/test_multiagent_graph.py b/tests_integ/test_multiagent_graph.py index a98362152..34b71b8fa 100644 --- a/tests_integ/test_multiagent_graph.py +++ b/tests_integ/test_multiagent_graph.py @@ -1,4 +1,4 @@ -from typing import Any, AsyncIterator, Literal +from typing import Any, AsyncIterator, Literal, Type from unittest.mock import patch from uuid import uuid4 @@ -242,7 +242,11 @@ async def invoke_async( return MultiAgentResult(status=Status.COMPLETED, results={self.name: node_result}) async def stream_async( - self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + task: str | list[ContentBlock], + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> AsyncIterator[dict[str, Any]]: yield {"custom_event": "start", "node": self.name} result = await self.agent.invoke_async(task, **kwargs) diff --git a/tests_integ/test_multiagent_swarm.py b/tests_integ/test_multiagent_swarm.py index c8016c39e..346e22c66 100644 --- a/tests_integ/test_multiagent_swarm.py +++ b/tests_integ/test_multiagent_swarm.py @@ -428,7 +428,7 @@ class BillingResponse(BaseModel): assert isinstance(coordinator_result.result.structured_output, CoordinatorAnalysis) assert coordinator_result.result.structured_output.request_type == "technical" - # Verify technical agent was invoked and structured output was created + # Verify technical agent was invoked and structured output was created technical_result = result.results["technical_support"] assert isinstance(technical_result.result.structured_output, TechnicalResponse) From 9758bef3ea8281eca9e4b2203d01c9bd555b3f70 Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Mon, 8 Dec 2025 15:56:44 -0500 Subject: [PATCH 7/7] code coverage --- tests/strands/multiagent/test_base.py | 29 +++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index 5ae3114a0..335a19064 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -244,3 +244,32 @@ def test_serialize_node_result_for_persist(agent_result): assert "result" in serialized_exception assert serialized_exception["result"]["type"] == "exception" assert serialized_exception["result"]["message"] == "Test error" + + +@pytest.mark.asyncio +async def test_multi_agent_base_stream_async_default_implementation(): + """Test the default stream_async implementation in MultiAgentBase.""" + + class TestMultiAgent(MultiAgentBase): + async def invoke_async(self, task, invocation_state=None, structured_output_model=None, **kwargs): + return MultiAgentResult( + status=Status.COMPLETED, + results={"test": NodeResult(result=Exception("test"), status=Status.COMPLETED)}, + ) + + def serialize_state(self) -> dict: + return {} + + def deserialize_state(self, payload: dict) -> None: + pass + + agent = TestMultiAgent() + + events = [] + async for event in agent.stream_async("test task"): + events.append(event) + + assert len(events) == 1 + assert "result" in events[0] + assert isinstance(events[0]["result"], MultiAgentResult) + assert events[0]["result"].status == Status.COMPLETED