Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/strands/multiagent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

from .._async import run_async
from ..agent import AgentResult
from ..types.content import ContentBlock
from ..types.event_loop import Metrics, Usage
from ..types.multiagent import MultiAgentInput

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -173,7 +173,7 @@ class MultiAgentBase(ABC):

@abstractmethod
async def invoke_async(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> MultiAgentResult:
"""Invoke asynchronously.

Expand All @@ -186,7 +186,7 @@ async def invoke_async(
raise NotImplementedError("invoke_async not implemented")

async def stream_async(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> AsyncIterator[dict[str, Any]]:
"""Stream events during multi-agent execution.

Expand All @@ -211,7 +211,7 @@ async def stream_async(
yield {"result": result}

def __call__(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> MultiAgentResult:
"""Invoke synchronously.

Expand Down
9 changes: 5 additions & 4 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
)
from ..types.content import ContentBlock, Messages
from ..types.event_loop import Metrics, Usage
from ..types.multiagent import MultiAgentInput
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status

logger = logging.getLogger(__name__)
Expand All @@ -67,7 +68,7 @@ class GraphState:
"""

# Task (with default empty string)
task: str | list[ContentBlock] = ""
task: MultiAgentInput = ""

# Execution state
status: Status = Status.PENDING
Expand Down Expand Up @@ -456,7 +457,7 @@ def __init__(
run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))

def __call__(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> GraphResult:
"""Invoke the graph synchronously.
Expand All @@ -472,7 +473,7 @@ def __call__(
return run_async(lambda: self.invoke_async(task, invocation_state))

async def invoke_async(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> GraphResult:
"""Invoke the graph asynchronously.
Expand All @@ -496,7 +497,7 @@ async def invoke_async(
return cast(GraphResult, final_event["result"])

async def stream_async(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> AsyncIterator[dict[str, Any]]:
"""Stream events during graph execution.
Expand Down
11 changes: 6 additions & 5 deletions src/strands/multiagent/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
)
from ..types.content import ContentBlock, Messages
from ..types.event_loop import Metrics, Usage
from ..types.multiagent import MultiAgentInput
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -145,7 +146,7 @@ class SwarmState:
"""Current state of swarm execution."""

current_node: SwarmNode | None # The agent currently executing
task: str | list[ContentBlock] # The original task from the user that is being executed
task: MultiAgentInput # The original task from the user that is being executed
completion_status: Status = Status.PENDING # Current swarm execution status
shared_context: SharedContext = field(default_factory=SharedContext) # Context shared between agents
node_history: list[SwarmNode] = field(default_factory=list) # Complete history of agents that have executed
Expand Down Expand Up @@ -277,7 +278,7 @@ def __init__(
run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))

def __call__(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> SwarmResult:
"""Invoke the swarm synchronously.

Expand All @@ -292,7 +293,7 @@ def __call__(
return run_async(lambda: self.invoke_async(task, invocation_state))

async def invoke_async(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> SwarmResult:
"""Invoke the swarm asynchronously.

Expand All @@ -316,7 +317,7 @@ async def invoke_async(
return cast(SwarmResult, final_event["result"])

async def stream_async(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> AsyncIterator[dict[str, Any]]:
"""Stream events during swarm execution.

Expand Down Expand Up @@ -741,7 +742,7 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
)

async def _execute_node(
self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any]
self, node: SwarmNode, task: MultiAgentInput, invocation_state: dict[str, Any]
) -> AsyncIterator[Any]:
"""Execute swarm node and yield TypedEvent objects."""
start_time = time.time()
Expand Down
25 changes: 19 additions & 6 deletions src/strands/telemetry/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
import logging
import os
from datetime import date, datetime, timezone
from typing import Any, Dict, Mapping, Optional
from typing import Any, Dict, Mapping, Optional, cast

import opentelemetry.trace as trace_api
from opentelemetry.instrumentation.threading import ThreadingInstrumentor
from opentelemetry.trace import Span, StatusCode

from ..agent.agent_result import AgentResult
from ..types.content import ContentBlock, Message, Messages
from ..types.interrupt import InterruptResponseContent
from ..types.multiagent import MultiAgentInput
from ..types.streaming import Metrics, StopReason, Usage
from ..types.tools import ToolResult, ToolUse
from ..types.traces import Attributes, AttributeValue
Expand Down Expand Up @@ -675,7 +677,7 @@ def _construct_tool_definitions(self, tools_config: dict) -> list[dict[str, Any]

def start_multiagent_span(
self,
task: str | list[ContentBlock],
task: MultiAgentInput,
instance: str,
) -> Span:
"""Start a new span for swarm invocation."""
Expand Down Expand Up @@ -789,12 +791,23 @@ def _add_event_messages(self, span: Span, messages: Messages) -> None:
{"content": serialize(message["content"])},
)

def _map_content_blocks_to_otel_parts(self, content_blocks: list[ContentBlock]) -> list[dict[str, Any]]:
"""Map ContentBlock objects to OpenTelemetry parts format."""
def _map_content_blocks_to_otel_parts(
self, content_blocks: list[ContentBlock] | list[InterruptResponseContent]
) -> list[dict[str, Any]]:
"""Map content blocks to OpenTelemetry parts format."""
parts: list[dict[str, Any]] = []

for block in content_blocks:
if "text" in block:
for block in cast(list[dict[str, Any]], content_blocks):
if "interruptResponse" in block:
interrupt_response = block["interruptResponse"]
parts.append(
{
"type": "interrupt_response",
"id": interrupt_response["interruptId"],
"response": interrupt_response["response"],
},
)
elif "text" in block:
# Standard TextPart
parts.append({"type": "text", "content": block["text"]})
elif "toolUse" in block:
Expand Down
4 changes: 2 additions & 2 deletions src/strands/types/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
from typing import TypeAlias

from .content import ContentBlock, Messages
from .interrupt import InterruptResponse
from .interrupt import InterruptResponseContent

AgentInput: TypeAlias = str | list[ContentBlock] | list[InterruptResponse] | Messages | None
AgentInput: TypeAlias = str | list[ContentBlock] | list[InterruptResponseContent] | Messages | None
7 changes: 7 additions & 0 deletions src/strands/types/multiagent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Multi-agent related type definitions for the SDK."""

from typing import TypeAlias

from .content import ContentBlock

MultiAgentInput: TypeAlias = str | list[ContentBlock]
29 changes: 29 additions & 0 deletions tests/strands/telemetry/test_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer, serialize
from strands.types.content import ContentBlock
from strands.types.interrupt import InterruptResponseContent
from strands.types.streaming import Metrics, StopReason, Usage


Expand Down Expand Up @@ -396,6 +397,34 @@ def test_start_swarm_span_with_contentblock_task(mock_tracer):
assert span is not None


@pytest.mark.parametrize(
"task, expected_parts",
[
([ContentBlock(text="Test message")], [{"type": "text", "content": "Test message"}]),
(
[InterruptResponseContent(interruptResponse={"interruptId": "test-id", "response": "approved"})],
[{"type": "interrupt_response", "id": "test-id", "response": "approved"}],
),
],
)
def test_start_multiagent_span_task_part_conversion(mock_tracer, task, expected_parts, monkeypatch):
monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental")

with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer):
tracer = Tracer()
tracer.tracer = mock_tracer

mock_span = mock.MagicMock()
mock_tracer.start_span.return_value = mock_span

tracer.start_multiagent_span(task, "swarm")

expected_content = json.dumps([{"role": "user", "parts": expected_parts}])
mock_span.add_event.assert_any_call(
"gen_ai.client.inference.operation.details", attributes={"gen_ai.input.messages": expected_content}
)


def test_start_swarm_span_with_contentblock_task_latest_conventions(mock_tracer, monkeypatch):
"""Test starting a swarm call span with task as list of contentBlock with latest semantic conventions."""
with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer):
Expand Down
Loading