From 68fd0c798e9755eab25fecf414a95819aaba5f70 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 18 Nov 2025 13:04:09 -0500 Subject: [PATCH] multi agent input --- src/strands/multiagent/base.py | 8 +++---- src/strands/multiagent/graph.py | 9 ++++---- src/strands/multiagent/swarm.py | 11 +++++----- src/strands/telemetry/tracer.py | 25 ++++++++++++++++------ src/strands/types/agent.py | 4 ++-- src/strands/types/multiagent.py | 7 +++++++ tests/strands/telemetry/test_tracer.py | 29 ++++++++++++++++++++++++++ 7 files changed, 72 insertions(+), 21 deletions(-) create mode 100644 src/strands/types/multiagent.py diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 7c552b144..0a1628530 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -12,8 +12,8 @@ from .._async import run_async from ..agent import AgentResult -from ..types.content import ContentBlock from ..types.event_loop import Metrics, Usage +from ..types.multiagent import MultiAgentInput logger = logging.getLogger(__name__) @@ -173,7 +173,7 @@ class MultiAgentBase(ABC): @abstractmethod async def invoke_async( - self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> MultiAgentResult: """Invoke asynchronously. @@ -186,7 +186,7 @@ async def invoke_async( raise NotImplementedError("invoke_async not implemented") async def stream_async( - self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> AsyncIterator[dict[str, Any]]: """Stream events during multi-agent execution. @@ -211,7 +211,7 @@ async def stream_async( yield {"result": result} def __call__( - self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> MultiAgentResult: """Invoke synchronously. diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 9f28876bf..740cbc175 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -45,6 +45,7 @@ ) from ..types.content import ContentBlock, Messages from ..types.event_loop import Metrics, Usage +from ..types.multiagent import MultiAgentInput from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status logger = logging.getLogger(__name__) @@ -67,7 +68,7 @@ class GraphState: """ # Task (with default empty string) - task: str | list[ContentBlock] = "" + task: MultiAgentInput = "" # Execution state status: Status = Status.PENDING @@ -456,7 +457,7 @@ def __init__( run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) def __call__( - self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> GraphResult: """Invoke the graph synchronously. @@ -472,7 +473,7 @@ def __call__( return run_async(lambda: self.invoke_async(task, invocation_state)) async def invoke_async( - self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> GraphResult: """Invoke the graph asynchronously. @@ -496,7 +497,7 @@ async def invoke_async( return cast(GraphResult, final_event["result"]) async def stream_async( - self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> AsyncIterator[dict[str, Any]]: """Stream events during graph execution. diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 3913cd837..1c447f571 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -45,6 +45,7 @@ ) from ..types.content import ContentBlock, Messages from ..types.event_loop import Metrics, Usage +from ..types.multiagent import MultiAgentInput from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status logger = logging.getLogger(__name__) @@ -145,7 +146,7 @@ class SwarmState: """Current state of swarm execution.""" current_node: SwarmNode | None # The agent currently executing - task: str | list[ContentBlock] # The original task from the user that is being executed + task: MultiAgentInput # The original task from the user that is being executed completion_status: Status = Status.PENDING # Current swarm execution status shared_context: SharedContext = field(default_factory=SharedContext) # Context shared between agents node_history: list[SwarmNode] = field(default_factory=list) # Complete history of agents that have executed @@ -277,7 +278,7 @@ def __init__( run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) def __call__( - self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> SwarmResult: """Invoke the swarm synchronously. @@ -292,7 +293,7 @@ def __call__( return run_async(lambda: self.invoke_async(task, invocation_state)) async def invoke_async( - self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> SwarmResult: """Invoke the swarm asynchronously. @@ -316,7 +317,7 @@ async def invoke_async( return cast(SwarmResult, final_event["result"]) async def stream_async( - self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> AsyncIterator[dict[str, Any]]: """Stream events during swarm execution. @@ -741,7 +742,7 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato ) async def _execute_node( - self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any] + self, node: SwarmNode, task: MultiAgentInput, invocation_state: dict[str, Any] ) -> AsyncIterator[Any]: """Execute swarm node and yield TypedEvent objects.""" start_time = time.time() diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index c47a10c3f..a75121b88 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -8,7 +8,7 @@ import logging import os from datetime import date, datetime, timezone -from typing import Any, Dict, Mapping, Optional +from typing import Any, Dict, Mapping, Optional, cast import opentelemetry.trace as trace_api from opentelemetry.instrumentation.threading import ThreadingInstrumentor @@ -16,6 +16,8 @@ from ..agent.agent_result import AgentResult from ..types.content import ContentBlock, Message, Messages +from ..types.interrupt import InterruptResponseContent +from ..types.multiagent import MultiAgentInput from ..types.streaming import Metrics, StopReason, Usage from ..types.tools import ToolResult, ToolUse from ..types.traces import Attributes, AttributeValue @@ -675,7 +677,7 @@ def _construct_tool_definitions(self, tools_config: dict) -> list[dict[str, Any] def start_multiagent_span( self, - task: str | list[ContentBlock], + task: MultiAgentInput, instance: str, ) -> Span: """Start a new span for swarm invocation.""" @@ -789,12 +791,23 @@ def _add_event_messages(self, span: Span, messages: Messages) -> None: {"content": serialize(message["content"])}, ) - def _map_content_blocks_to_otel_parts(self, content_blocks: list[ContentBlock]) -> list[dict[str, Any]]: - """Map ContentBlock objects to OpenTelemetry parts format.""" + def _map_content_blocks_to_otel_parts( + self, content_blocks: list[ContentBlock] | list[InterruptResponseContent] + ) -> list[dict[str, Any]]: + """Map content blocks to OpenTelemetry parts format.""" parts: list[dict[str, Any]] = [] - for block in content_blocks: - if "text" in block: + for block in cast(list[dict[str, Any]], content_blocks): + if "interruptResponse" in block: + interrupt_response = block["interruptResponse"] + parts.append( + { + "type": "interrupt_response", + "id": interrupt_response["interruptId"], + "response": interrupt_response["response"], + }, + ) + elif "text" in block: # Standard TextPart parts.append({"type": "text", "content": block["text"]}) elif "toolUse" in block: diff --git a/src/strands/types/agent.py b/src/strands/types/agent.py index a2a4c7dce..aa69149a6 100644 --- a/src/strands/types/agent.py +++ b/src/strands/types/agent.py @@ -6,6 +6,6 @@ from typing import TypeAlias from .content import ContentBlock, Messages -from .interrupt import InterruptResponse +from .interrupt import InterruptResponseContent -AgentInput: TypeAlias = str | list[ContentBlock] | list[InterruptResponse] | Messages | None +AgentInput: TypeAlias = str | list[ContentBlock] | list[InterruptResponseContent] | Messages | None diff --git a/src/strands/types/multiagent.py b/src/strands/types/multiagent.py new file mode 100644 index 000000000..d9487dbd2 --- /dev/null +++ b/src/strands/types/multiagent.py @@ -0,0 +1,7 @@ +"""Multi-agent related type definitions for the SDK.""" + +from typing import TypeAlias + +from .content import ContentBlock + +MultiAgentInput: TypeAlias = str | list[ContentBlock] diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 98cfb459f..581b8ccd3 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -11,6 +11,7 @@ from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer, serialize from strands.types.content import ContentBlock +from strands.types.interrupt import InterruptResponseContent from strands.types.streaming import Metrics, StopReason, Usage @@ -396,6 +397,34 @@ def test_start_swarm_span_with_contentblock_task(mock_tracer): assert span is not None +@pytest.mark.parametrize( + "task, expected_parts", + [ + ([ContentBlock(text="Test message")], [{"type": "text", "content": "Test message"}]), + ( + [InterruptResponseContent(interruptResponse={"interruptId": "test-id", "response": "approved"})], + [{"type": "interrupt_response", "id": "test-id", "response": "approved"}], + ), + ], +) +def test_start_multiagent_span_task_part_conversion(mock_tracer, task, expected_parts, monkeypatch): + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") + + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + tracer = Tracer() + tracer.tracer = mock_tracer + + mock_span = mock.MagicMock() + mock_tracer.start_span.return_value = mock_span + + tracer.start_multiagent_span(task, "swarm") + + expected_content = json.dumps([{"role": "user", "parts": expected_parts}]) + mock_span.add_event.assert_any_call( + "gen_ai.client.inference.operation.details", attributes={"gen_ai.input.messages": expected_content} + ) + + def test_start_swarm_span_with_contentblock_task_latest_conventions(mock_tracer, monkeypatch): """Test starting a swarm call span with task as list of contentBlock with latest semantic conventions.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer):