Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/strands/tools/executors/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from ...hooks import AfterToolCallEvent, BeforeToolCallEvent
from ...telemetry.metrics import Trace
from ...telemetry.tracer import get_tracer
from ...telemetry.tracer import get_tracer, serialize
from ...types._events import ToolCancelEvent, ToolResultEvent, ToolStreamEvent, TypedEvent
from ...types.content import Message
from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolResult, ToolUse
Expand Down Expand Up @@ -59,6 +59,14 @@ async def _stream(

tool_info = agent.tool_registry.dynamic_tools.get(tool_name)
tool_func = tool_info if tool_info is not None else agent.tool_registry.registry.get(tool_name)
tool_spec = tool_func.tool_spec if tool_func is not None else None

current_span = trace_api.get_current_span()
if current_span and tool_spec is not None:
current_span.set_attribute("gen_ai.tool.description", tool_spec["description"])
input_schema = tool_spec["inputSchema"]
if "json" in input_schema:
current_span.set_attribute("gen_ai.tool.json_schema", serialize(input_schema["json"]))

invocation_state.update(
{
Expand Down
87 changes: 87 additions & 0 deletions tests/strands/tools/executors/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,90 @@ def cancel_callback(event):
tru_results = tool_results
exp_results = [exp_events[-1].tool_result]
assert tru_results == exp_results


@pytest.mark.asyncio
async def test_executor_stream_sets_span_attributes(
executor, agent, tool_results, invocation_state, weather_tool, alist
):
"""Test that span attributes are set correctly when tool_spec is available."""
with unittest.mock.patch("strands.tools.executors._executor.trace_api") as mock_trace_api:
mock_span = unittest.mock.MagicMock()
mock_trace_api.get_current_span.return_value = mock_span

# Mock tool_spec with inputSchema containing json field
with unittest.mock.patch.object(
type(weather_tool), "tool_spec", new_callable=unittest.mock.PropertyMock
) as mock_tool_spec:
mock_tool_spec.return_value = {
"name": "weather_tool",
"description": "Get weather information",
"inputSchema": {"json": {"type": "object", "properties": {}}, "type": "object"},
}

tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}}
stream = executor._stream(agent, tool_use, tool_results, invocation_state)

await alist(stream)

# Verify set_attribute was called with correct values
calls = mock_span.set_attribute.call_args_list
assert len(calls) == 2

# Check description attribute
assert calls[0][0][0] == "gen_ai.tool.description"
assert calls[0][0][1] == "Get weather information"

# Check json_schema attribute
assert calls[1][0][0] == "gen_ai.tool.json_schema"
# The serialize function should have been called on the json field


@pytest.mark.asyncio
async def test_executor_stream_handles_missing_json_in_input_schema(
executor, agent, tool_results, invocation_state, weather_tool, alist
):
"""Test that span attributes handle inputSchema without json field gracefully."""
with unittest.mock.patch("strands.tools.executors._executor.trace_api") as mock_trace_api:
mock_span = unittest.mock.MagicMock()
mock_trace_api.get_current_span.return_value = mock_span

# Mock tool_spec with inputSchema but no json field
with unittest.mock.patch.object(
type(weather_tool), "tool_spec", new_callable=unittest.mock.PropertyMock
) as mock_tool_spec:
mock_tool_spec.return_value = {
"name": "weather_tool",
"description": "Get weather information",
"inputSchema": {"type": "object", "properties": {}},
}

tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}}
stream = executor._stream(agent, tool_use, tool_results, invocation_state)

# Should not raise an error - json_schema attribute just won't be set
await alist(stream)

# Verify only description attribute was set (not json_schema)
calls = mock_span.set_attribute.call_args_list
assert len(calls) == 1
assert calls[0][0][0] == "gen_ai.tool.description"


@pytest.mark.asyncio
async def test_executor_stream_no_span_attributes_when_no_tool_spec(
executor, agent, tool_results, invocation_state, alist
):
"""Test that no span attributes are set when tool_spec is None."""
with unittest.mock.patch("strands.tools.executors._executor.trace_api") as mock_trace_api:
mock_span = unittest.mock.MagicMock()
mock_trace_api.get_current_span.return_value = mock_span

# Use unknown tool which will have no tool_spec
tool_use: ToolUse = {"name": "unknown_tool", "toolUseId": "1", "input": {}}
stream = executor._stream(agent, tool_use, tool_results, invocation_state)

await alist(stream)

# Verify set_attribute was not called since tool_spec is None
mock_span.set_attribute.assert_not_called()