diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index f163d05b5..8b3f00712 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -8,7 +8,9 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum -from typing import Any, AsyncIterator, Mapping, Union +from typing import Any, AsyncIterator, Mapping, Type, Union + +from pydantic import BaseModel from .._async import run_async from ..agent import AgentResult @@ -188,7 +190,11 @@ class MultiAgentBase(ABC): @abstractmethod async def invoke_async( - self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + task: MultiAgentInput, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> MultiAgentResult: """Invoke asynchronously. @@ -196,12 +202,18 @@ async def invoke_async( task: The task to execute invocation_state: Additional state/context passed to underlying agents. Defaults to None to avoid mutable default argument issues. + structured_output_model: Pydantic model to use for structured output from nodes. + Individual nodes may override this with their own default model. **kwargs: Additional keyword arguments passed to underlying agents. """ raise NotImplementedError("invoke_async not implemented") async def stream_async( - self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + task: MultiAgentInput, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> AsyncIterator[dict[str, Any]]: """Stream events during multi-agent execution. @@ -212,6 +224,8 @@ async def stream_async( task: The task to execute invocation_state: Additional state/context passed to underlying agents. Defaults to None to avoid mutable default argument issues. + structured_output_model: Pydantic model to use for structured output from nodes. + Individual nodes may override this with their own default model. **kwargs: Additional keyword arguments passed to underlying agents. Yields: @@ -222,11 +236,15 @@ async def stream_async( """ # Default implementation for backward compatibility # Execute invoke_async and yield the result as a single event - result = await self.invoke_async(task, invocation_state, **kwargs) + result = await self.invoke_async(task, invocation_state, structured_output_model, **kwargs) yield {"result": result} def __call__( - self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + task: MultiAgentInput, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> MultiAgentResult: """Invoke synchronously. @@ -234,6 +252,8 @@ def __call__( task: The task to execute invocation_state: Additional state/context passed to underlying agents. Defaults to None to avoid mutable default argument issues. + structured_output_model: Pydantic model to use for structured output from nodes. + Individual nodes may override this with their own default model. **kwargs: Additional keyword arguments passed to underlying agents. """ if invocation_state is None: @@ -243,7 +263,7 @@ def __call__( invocation_state.update(kwargs) warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2) - return run_async(lambda: self.invoke_async(task, invocation_state)) + return run_async(lambda: self.invoke_async(task, invocation_state, structured_output_model)) def serialize_state(self) -> dict[str, Any]: """Return a JSON-serializable snapshot of the orchestrator state.""" diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 6156d332c..e7a3d4aaa 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -19,9 +19,10 @@ import logging import time from dataclasses import dataclass, field -from typing import Any, AsyncIterator, Callable, Mapping, Optional, Tuple, cast +from typing import Any, AsyncIterator, Callable, Mapping, Optional, Tuple, Type, cast from opentelemetry import trace as trace_api +from pydantic import BaseModel from .._async import run_async from ..agent import Agent @@ -461,24 +462,12 @@ def __init__( run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) - def __call__( - self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any - ) -> GraphResult: - """Invoke the graph synchronously. - - Args: - task: The task to execute - invocation_state: Additional state/context passed to underlying agents. - Defaults to None to avoid mutable default argument issues. - **kwargs: Keyword arguments allowing backward compatible future changes. - """ - if invocation_state is None: - invocation_state = {} - - return run_async(lambda: self.invoke_async(task, invocation_state)) - async def invoke_async( - self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + task: MultiAgentInput, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> GraphResult: """Invoke the graph asynchronously. @@ -489,9 +478,10 @@ async def invoke_async( task: The task to execute invocation_state: Additional state/context passed to underlying agents. Defaults to None to avoid mutable default argument issues. + structured_output_model: Pydantic model to use for structured output from nodes. **kwargs: Keyword arguments allowing backward compatible future changes. """ - events = self.stream_async(task, invocation_state, **kwargs) + events = self.stream_async(task, invocation_state, structured_output_model, **kwargs) final_event = None async for event in events: final_event = event @@ -502,7 +492,11 @@ async def invoke_async( return cast(GraphResult, final_event["result"]) async def stream_async( - self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + task: MultiAgentInput, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> AsyncIterator[dict[str, Any]]: """Stream events during graph execution. @@ -510,6 +504,7 @@ async def stream_async( task: The task to execute invocation_state: Additional state/context passed to underlying agents. Defaults to None to avoid mutable default argument issues. + structured_output_model: Pydantic model to use for structured output from nodes. **kwargs: Keyword arguments allowing backward compatible future changes. Yields: @@ -552,7 +547,7 @@ async def stream_async( self.node_timeout or "None", ) - async for event in self._execute_graph(invocation_state): + async for event in self._execute_graph(invocation_state, structured_output_model): yield event.as_dict() # Set final status based on execution results @@ -591,7 +586,9 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: # Validate Agent-specific constraints for each node _validate_node_executor(node.executor) - async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: + async def _execute_graph( + self, invocation_state: dict[str, Any], structured_output_model: Type[BaseModel] | None = None + ) -> AsyncIterator[Any]: """Execute graph and yield TypedEvent objects.""" ready_nodes = self._resume_next_nodes if self._resume_from_session else list(self.entry_points) @@ -610,7 +607,7 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato ready_nodes.clear() # Execute current batch - async for event in self._execute_nodes_parallel(current_batch, invocation_state): + async for event in self._execute_nodes_parallel(current_batch, invocation_state, structured_output_model): yield event # Find newly ready nodes after batch execution @@ -634,7 +631,10 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato ready_nodes.extend(newly_ready) async def _execute_nodes_parallel( - self, nodes: list["GraphNode"], invocation_state: dict[str, Any] + self, + nodes: list["GraphNode"], + invocation_state: dict[str, Any], + structured_output_model: Type[BaseModel] | None = None, ) -> AsyncIterator[Any]: """Execute multiple nodes in parallel and merge their event streams in real-time. @@ -644,7 +644,12 @@ async def _execute_nodes_parallel( event_queue: asyncio.Queue[Any | None | Exception] = asyncio.Queue() # Start all node streams as independent tasks - tasks = [asyncio.create_task(self._stream_node_to_queue(node, event_queue, invocation_state)) for node in nodes] + tasks = [ + asyncio.create_task( + self._stream_node_to_queue(node, event_queue, invocation_state, structured_output_model) + ) + for node in nodes + ] try: # Consume events from the queue as they arrive @@ -695,6 +700,7 @@ async def _stream_node_to_queue( node: GraphNode, event_queue: asyncio.Queue[Any | None | Exception], invocation_state: dict[str, Any], + structured_output_model: Type[BaseModel] | None = None, ) -> None: """Stream events from a node to the shared queue with optional timeout.""" try: @@ -702,7 +708,7 @@ async def _stream_node_to_queue( if self.node_timeout is not None: async def stream_node() -> None: - async for event in self._execute_node(node, invocation_state): + async for event in self._execute_node(node, invocation_state, structured_output_model): await event_queue.put(event) try: @@ -713,7 +719,7 @@ async def stream_node() -> None: await event_queue.put(timeout_exc) else: # No timeout - stream normally - async for event in self._execute_node(node, invocation_state): + async for event in self._execute_node(node, invocation_state, structured_output_model): await event_queue.put(event) except Exception as e: # Send exception through queue for fail-fast behavior @@ -780,7 +786,12 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[ ) return False - async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: + async def _execute_node( + self, + node: GraphNode, + invocation_state: dict[str, Any], + structured_output_model: Type[BaseModel] | None = None, + ) -> AsyncIterator[Any]: """Execute a single node and yield TypedEvent objects.""" # Reset the node's state if reset_on_revisit is enabled, and it's being revisited if self.reset_on_revisit and node in self.state.completed_nodes: @@ -818,7 +829,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) if isinstance(node.executor, MultiAgentBase): # For nested multi-agent systems, stream their events and collect result multi_agent_result = None - async for event in node.executor.stream_async(node_input, invocation_state): + async for event in node.executor.stream_async(node_input, invocation_state, structured_output_model): # Forward nested multi-agent events with node context wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event) yield wrapped_event @@ -842,7 +853,11 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) elif isinstance(node.executor, Agent): # For agents, stream their events and collect result agent_response = None - async for event in node.executor.stream_async(node_input, invocation_state=invocation_state): + # Use agent's own model if it has one, otherwise use graph-level model + effective_output_model = node.executor._default_structured_output_model or structured_output_model + async for event in node.executor.stream_async( + node_input, invocation_state=invocation_state, structured_output_model=effective_output_model + ): # Forward agent events with node context wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event) yield wrapped_event diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 7eec49649..5c9ebd63c 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -19,9 +19,10 @@ import logging import time from dataclasses import dataclass, field -from typing import Any, AsyncIterator, Callable, Mapping, Optional, Tuple, cast +from typing import Any, AsyncIterator, Callable, Mapping, Optional, Tuple, Type, cast from opentelemetry import trace as trace_api +from pydantic import BaseModel from .._async import run_async from ..agent import Agent @@ -299,23 +300,12 @@ def __init__( self._inject_swarm_tools() run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) - def __call__( - self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any - ) -> SwarmResult: - """Invoke the swarm synchronously. - - Args: - task: The task to execute - invocation_state: Additional state/context passed to underlying agents. - Defaults to None to avoid mutable default argument issues. - **kwargs: Keyword arguments allowing backward compatible future changes. - """ - if invocation_state is None: - invocation_state = {} - return run_async(lambda: self.invoke_async(task, invocation_state)) - async def invoke_async( - self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + task: MultiAgentInput, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> SwarmResult: """Invoke the swarm asynchronously. @@ -326,9 +316,10 @@ async def invoke_async( task: The task to execute invocation_state: Additional state/context passed to underlying agents. Defaults to None to avoid mutable default argument issues. + structured_output_model: Pydantic model to use for structured output from nodes. **kwargs: Keyword arguments allowing backward compatible future changes. """ - events = self.stream_async(task, invocation_state, **kwargs) + events = self.stream_async(task, invocation_state, structured_output_model, **kwargs) final_event = None async for event in events: final_event = event @@ -339,7 +330,11 @@ async def invoke_async( return cast(SwarmResult, final_event["result"]) async def stream_async( - self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + task: MultiAgentInput, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> AsyncIterator[dict[str, Any]]: """Stream events during swarm execution. @@ -347,6 +342,7 @@ async def stream_async( task: The task to execute invocation_state: Additional state/context passed to underlying agents. Defaults to None to avoid mutable default argument issues. + structured_output_model: Pydantic model to use for structured output from nodes. **kwargs: Keyword arguments allowing backward compatible future changes. Yields: @@ -394,7 +390,7 @@ async def stream_async( self.execution_timeout, ) - async for event in self._execute_swarm(invocation_state): + async for event in self._execute_swarm(invocation_state, structured_output_model): if isinstance(event, MultiAgentNodeInterruptEvent): interrupts = event.interrupts @@ -702,7 +698,9 @@ def _activate_interrupt(self, node: SwarmNode, interrupts: list[Interrupt]) -> M return MultiAgentNodeInterruptEvent(node.node_id, interrupts) - async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: + async def _execute_swarm( + self, invocation_state: dict[str, Any], structured_output_model: Type[BaseModel] | None = None + ) -> AsyncIterator[Any]: """Execute swarm and yield TypedEvent objects.""" try: # Main execution loop @@ -758,7 +756,7 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato break node_stream = self._stream_with_timeout( - self._execute_node(current_node, self.state.task, invocation_state), + self._execute_node(current_node, self.state.task, invocation_state, structured_output_model), self.node_timeout, f"Node '{current_node.node_id}' execution timed out after {self.node_timeout}s", ) @@ -825,7 +823,11 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato ) async def _execute_node( - self, node: SwarmNode, task: MultiAgentInput, invocation_state: dict[str, Any] + self, + node: SwarmNode, + task: MultiAgentInput, + invocation_state: dict[str, Any], + structured_output_model: Type[BaseModel] | None = None, ) -> AsyncIterator[Any]: """Execute swarm node and yield TypedEvent objects.""" start_time = time.time() @@ -856,7 +858,11 @@ async def _execute_node( # Stream agent events with node context and capture final result result = None - async for event in node.executor.stream_async(node_input, invocation_state=invocation_state): + # Use agent's own model if it has one, otherwise use swarm-level model + effective_output_model = node.executor._default_structured_output_model or structured_output_model + async for event in node.executor.stream_async( + node_input, invocation_state=invocation_state, structured_output_model=effective_output_model + ): # Forward agent events with node context wrapped_event = MultiAgentNodeStreamEvent(node_name, event) yield wrapped_event diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index 4e8a5dd06..335a19064 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -142,7 +142,9 @@ class IncompleteMultiAgent(MultiAgentBase): # Test that complete implementations can be instantiated class CompleteMultiAgent(MultiAgentBase): - async def invoke_async(self, task: str) -> MultiAgentResult: + async def invoke_async( + self, task: str, invocation_state=None, structured_output_model=None + ) -> MultiAgentResult: return MultiAgentResult(results={}) def serialize_state(self) -> dict: @@ -164,12 +166,14 @@ def __init__(self): self.invoke_async_called = False self.received_task = None self.received_kwargs = None + self.received_structured_output_model = None - async def invoke_async(self, task, invocation_state, **kwargs): + async def invoke_async(self, task, invocation_state, structured_output_model=None, **kwargs): self.invoke_async_called = True self.received_task = task self.received_kwargs = kwargs self.received_invocation_state = invocation_state + self.received_structured_output_model = structured_output_model return MultiAgentResult( status=Status.COMPLETED, results={"test": NodeResult(result=Exception("test"), status=Status.COMPLETED)} ) @@ -188,6 +192,7 @@ def deserialize_state(self, payload: dict) -> None: assert agent.invoke_async_called assert agent.received_task == "test task" assert agent.received_invocation_state == {"param1": "value1", "param2": "value2", "value3": "value4"} + assert agent.received_structured_output_model is None assert isinstance(result, MultiAgentResult) assert result.status == Status.COMPLETED @@ -239,3 +244,32 @@ def test_serialize_node_result_for_persist(agent_result): assert "result" in serialized_exception assert serialized_exception["result"]["type"] == "exception" assert serialized_exception["result"]["message"] == "Test error" + + +@pytest.mark.asyncio +async def test_multi_agent_base_stream_async_default_implementation(): + """Test the default stream_async implementation in MultiAgentBase.""" + + class TestMultiAgent(MultiAgentBase): + async def invoke_async(self, task, invocation_state=None, structured_output_model=None, **kwargs): + return MultiAgentResult( + status=Status.COMPLETED, + results={"test": NodeResult(result=Exception("test"), status=Status.COMPLETED)}, + ) + + def serialize_state(self) -> dict: + return {} + + def deserialize_state(self, payload: dict) -> None: + pass + + agent = TestMultiAgent() + + events = [] + async for event in agent.stream_async("test task"): + events.append(event) + + assert len(events) == 1 + assert "result" in events[0] + assert isinstance(events[0]["result"], MultiAgentResult) + assert events[0]["result"].status == Status.COMPLETED diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 4875d1bec..749dd2131 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -23,6 +23,7 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen agent.id = agent_id or f"{name}_id" agent._session_manager = None agent.hooks = HookRegistry() + agent._default_structured_output_model = None if metrics is None: metrics = Mock( @@ -289,6 +290,7 @@ async def test_graph_execution_with_failures(mock_strands_tracer, mock_use_span) # Add required attributes for validation failing_agent._session_manager = None failing_agent.hooks = HookRegistry() + failing_agent._default_structured_output_model = None async def mock_invoke_failure(*args, **kwargs): raise Exception("Simulated failure") @@ -1530,6 +1532,7 @@ async def test_graph_streaming_with_failures(mock_strands_tracer, mock_use_span) failing_agent.id = "fail_node" failing_agent._session_manager = None failing_agent.hooks = HookRegistry() + failing_agent._default_structured_output_model = None async def failing_stream(*args, **kwargs): yield {"agent_start": True} @@ -1703,6 +1706,7 @@ async def test_graph_parallel_with_failures(mock_strands_tracer, mock_use_span): failing_agent.id = "fail_node" failing_agent._session_manager = None failing_agent.hooks = HookRegistry() + failing_agent._default_structured_output_model = None async def mock_invoke_failure(*args, **kwargs): await asyncio.sleep(0.05) # Small delay diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index f2abed9f7..501261ad8 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -31,6 +31,7 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen agent._call_count = 0 agent._should_fail = should_fail agent._session_manager = None + agent._default_structured_output_model = None agent.hooks = HookRegistry() if metrics is None: @@ -1346,4 +1347,4 @@ def test_swarm_interrupt_on_agent(agenerator): exp_status = Status.COMPLETED assert tru_status == exp_status - agent.stream_async.assert_called_once_with(responses, invocation_state={}) + agent.stream_async.assert_called_once_with(responses, invocation_state={}, structured_output_model=None) diff --git a/tests_integ/test_multiagent_graph.py b/tests_integ/test_multiagent_graph.py index 08343a554..34b71b8fa 100644 --- a/tests_integ/test_multiagent_graph.py +++ b/tests_integ/test_multiagent_graph.py @@ -1,8 +1,9 @@ -from typing import Any, AsyncIterator +from typing import Any, AsyncIterator, Literal, Type from unittest.mock import patch from uuid import uuid4 import pytest +from pydantic import BaseModel from strands import Agent, tool from strands.hooks import ( @@ -241,7 +242,11 @@ async def invoke_async( return MultiAgentResult(status=Status.COMPLETED, results={self.name: node_result}) async def stream_async( - self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, + task: str | list[ContentBlock], + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, ) -> AsyncIterator[dict[str, Any]]: yield {"custom_event": "start", "node": self.name} result = await self.agent.invoke_async(task, **kwargs) @@ -533,6 +538,121 @@ async def failing_stream_async(*args, **kwargs): session_manager.delete_session(session_id) +@pytest.mark.asyncio +async def test_graph_structured_output_conditional_routing(): + """Test conditional edge traversal based on structured output from a node.""" + + # Classifier's structured output for routing decisions + class ClassificationResult(BaseModel): + category: Literal["technical", "billing", "general"] + confidence: float + reasoning: str + + # Specialist agents have their own structured output models + class TechnicalResponse(BaseModel): + issue_type: str + resolution: str + + class BillingResponse(BaseModel): + issue_type: str + refund_eligible: bool + + class GeneralResponse(BaseModel): + answer: str + + # Classifier agent uses ClassificationResult for routing + classifier_agent = Agent( + name="classifier", + model="us.amazon.nova-pro-v1:0", + system_prompt=( + "You are a customer support classifier. Classify queries into:\n" + "- technical: software issues, bugs, crashes\n" + "- billing: payments, charges, refunds\n" + "- general: hours, contact info, policies" + ), + structured_output_model=ClassificationResult, + ) + + # Specialist agents - each has their OWN structured output model + technical_agent = Agent( + name="technical_support", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are technical support. Diagnose the issue and provide resolution steps.", + structured_output_model=TechnicalResponse, + ) + + billing_agent = Agent( + name="billing_support", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are billing support. Analyze the billing issue and determine refund eligibility.", + structured_output_model=BillingResponse, + ) + + general_agent = Agent( + name="general_support", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are general support. Answer general inquiries helpfully.", + structured_output_model=GeneralResponse, + ) + + # Conditional edge functions that read classifier's structured output + def route_to_technical(state): + classifier_result = state.results.get("classifier") + if classifier_result and classifier_result.result: + structured = classifier_result.result.structured_output + return structured and structured.category == "technical" + return False + + def route_to_billing(state): + classifier_result = state.results.get("classifier") + if classifier_result and classifier_result.result: + structured = classifier_result.result.structured_output + return structured and structured.category == "billing" + return False + + def route_to_general(state): + classifier_result = state.results.get("classifier") + if classifier_result and classifier_result.result: + structured = classifier_result.result.structured_output + return structured and structured.category == "general" + return False + + # Build the graph with conditional routing + builder = GraphBuilder() + builder.add_node(classifier_agent, "classifier") + builder.add_node(technical_agent, "technical") + builder.add_node(billing_agent, "billing") + builder.add_node(general_agent, "general") + + # Conditional edges from classifier to specialists + builder.add_edge("classifier", "technical", condition=route_to_technical) + builder.add_edge("classifier", "billing", condition=route_to_billing) + builder.add_edge("classifier", "general", condition=route_to_general) + + builder.set_entry_point("classifier") + graph = builder.build() + + # Test: Technical query should route to technical_support + technical_query = "My application keeps crashing when I upload large files" + result = await graph.invoke_async(technical_query) + + assert result.status == Status.COMPLETED + + # Verify classifier's structured output was captured + classifier_result = result.results["classifier"] + assert isinstance(classifier_result.result.structured_output, ClassificationResult) + assert classifier_result.result.structured_output.category == "technical" + + # Verify only technical agent was executed (not billing or general) + assert "technical" in result.results + assert "billing" not in result.results + assert "general" not in result.results + + # Verify technical agent's structured output was generated + technical_result = result.results["technical"] + assert isinstance(technical_result.result.structured_output, TechnicalResponse) + + @pytest.mark.asyncio async def test_self_loop_resume_from_persisted_state(tmp_path): """Test resuming self-loop from persisted state where next node is itself.""" diff --git a/tests_integ/test_multiagent_swarm.py b/tests_integ/test_multiagent_swarm.py index e8e969af1..346e22c66 100644 --- a/tests_integ/test_multiagent_swarm.py +++ b/tests_integ/test_multiagent_swarm.py @@ -1,6 +1,8 @@ +from typing import Literal from uuid import uuid4 import pytest +from pydantic import BaseModel from strands import Agent, tool from strands.experimental.hooks.multiagent import BeforeNodeCallEvent @@ -357,6 +359,80 @@ async def test_swarm_get_agent_results_flattening(): assert agent_results[0].message is not None +@pytest.mark.asyncio +async def test_swarm_structured_output_with_handoffs(): + """Test that swarm properly handles structured output from each agent during handoffs.""" + + # Coordinator's analysis model + class CoordinatorAnalysis(BaseModel): + request_type: Literal["technical", "billing", "general"] + complexity: Literal["simple", "moderate", "complex"] + summary: str + + # Technical agent's response model + class TechnicalResponse(BaseModel): + issue_type: str + severity: Literal["low", "medium", "high", "critical"] + troubleshooting_steps: list[str] + resolution_summary: str + + # Billing agent's response model + class BillingResponse(BaseModel): + issue_type: str + refund_eligible: bool + resolution_action: str + + # Coordinator agent that routes to specialists + coordinator = Agent( + name="coordinator", + model="us.amazon.nova-pro-v1:0", + system_prompt=( + "You are a customer support coordinator. Analyze the request and:\n" + "- For technical issues (bugs, crashes, performance), hand off to 'technical_support'\n" + "- For billing issues (charges, refunds), hand off to 'billing_support'\n" + "- You MUST use handoff_to_agent for technical or billing issues." + ), + structured_output_model=CoordinatorAnalysis, + ) + + # Technical specialist + technical_agent = Agent( + name="technical_support", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are technical support. Diagnose the issue and provide resolution steps.", + structured_output_model=TechnicalResponse, + ) + + # Billing specialist + billing_agent = Agent( + name="billing_support", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are billing support. Analyze billing issues and determine resolution.", + structured_output_model=BillingResponse, + ) + + swarm = Swarm( + nodes=[coordinator, technical_agent, billing_agent], + entry_point=coordinator, + max_handoffs=5, + max_iterations=5, + ) + + # Test: Technical query should hand off to technical_support + technical_query = "My application keeps crashing when I upload files larger than 100MB" + result = await swarm.invoke_async(technical_query) + + # Verify coordinator was executed + coordinator_result = result.results["coordinator"] + assert coordinator_result.result.structured_output is not None + assert isinstance(coordinator_result.result.structured_output, CoordinatorAnalysis) + assert coordinator_result.result.structured_output.request_type == "technical" + + # Verify technical agent was invoked and structured output was created + technical_result = result.results["technical_support"] + assert isinstance(technical_result.result.structured_output, TechnicalResponse) + + def test_swarm_resume_from_executing_state(tmpdir, exit_hook, verify_hook): """Test swarm resuming from EXECUTING state using BeforeNodeCallEvent hook.""" session_id = f"swarm_resume_{uuid4()}"