Skip to content

Commit c2ba0f7

Browse files
authored
Transform invalid tool usages on sending, not on initial detection (#1091)
Per bug #1069, session-managers never persist tool-name changes after we initially persist the message, which means once an agent generates an invalid-tool name, that message history is poisoned on re-hydration. To avoid that going forward, do the translation of invalid-tool names on sending to the provider and not on the initial tool_use detection. The initial tool_use detection is needed to add a tool_response with a proper error message for the LLM, but this will avoid the poisoning issue --------- Co-authored-by: Mackenzie Zastrow <zastrowm@users.noreply.github.com>
1 parent 104ecb5 commit c2ba0f7

File tree

5 files changed

+223
-5
lines changed

5 files changed

+223
-5
lines changed

src/strands/event_loop/streaming.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
import json
44
import logging
55
import time
6+
import warnings
67
from typing import Any, AsyncGenerator, AsyncIterable, Optional
78

89
from ..models.model import Model
10+
from ..tools import InvalidToolUseNameException
11+
from ..tools.tools import validate_tool_use_name
912
from ..types._events import (
1013
CitationStreamEvent,
1114
ModelStopReason,
@@ -38,15 +41,84 @@
3841
logger = logging.getLogger(__name__)
3942

4043

44+
def _normalize_messages(messages: Messages) -> Messages:
45+
"""Remove or replace blank text in message content.
46+
47+
Args:
48+
messages: Conversation messages to update.
49+
50+
Returns:
51+
Updated messages.
52+
"""
53+
removed_blank_message_content_text = False
54+
replaced_blank_message_content_text = False
55+
replaced_tool_names = False
56+
57+
for message in messages:
58+
# only modify assistant messages
59+
if "role" in message and message["role"] != "assistant":
60+
continue
61+
if "content" in message:
62+
content = message["content"]
63+
if len(content) == 0:
64+
content.append({"text": "[blank text]"})
65+
continue
66+
67+
has_tool_use = False
68+
69+
# Ensure the tool-uses always have valid names before sending
70+
# https://github.com/strands-agents/sdk-python/issues/1069
71+
for item in content:
72+
if "toolUse" in item:
73+
has_tool_use = True
74+
tool_use: ToolUse = item["toolUse"]
75+
76+
try:
77+
validate_tool_use_name(tool_use)
78+
except InvalidToolUseNameException:
79+
tool_use["name"] = "INVALID_TOOL_NAME"
80+
replaced_tool_names = True
81+
82+
if has_tool_use:
83+
# Remove blank 'text' items for assistant messages
84+
before_len = len(content)
85+
content[:] = [item for item in content if "text" not in item or item["text"].strip()]
86+
if not removed_blank_message_content_text and before_len != len(content):
87+
removed_blank_message_content_text = True
88+
else:
89+
# Replace blank 'text' with '[blank text]' for assistant messages
90+
for item in content:
91+
if "text" in item and not item["text"].strip():
92+
replaced_blank_message_content_text = True
93+
item["text"] = "[blank text]"
94+
95+
if removed_blank_message_content_text:
96+
logger.debug("removed blank message context text")
97+
if replaced_blank_message_content_text:
98+
logger.debug("replaced blank message context text")
99+
if replaced_tool_names:
100+
logger.debug("replaced invalid tool name")
101+
102+
return messages
103+
104+
41105
def remove_blank_messages_content_text(messages: Messages) -> Messages:
42106
"""Remove or replace blank text in message content.
43107
108+
!!deprecated!!
109+
This function is deprecated and will be removed in a future version.
110+
44111
Args:
45112
messages: Conversation messages to update.
46113
47114
Returns:
48115
Updated messages.
49116
"""
117+
warnings.warn(
118+
"remove_blank_messages_content_text is deprecated and will be removed in a future version.",
119+
DeprecationWarning,
120+
stacklevel=2,
121+
)
50122
removed_blank_message_content_text = False
51123
replaced_blank_message_content_text = False
52124

@@ -362,7 +434,7 @@ async def stream_messages(
362434
"""
363435
logger.debug("model=<%s> | streaming messages", model)
364436

365-
messages = remove_blank_messages_content_text(messages)
437+
messages = _normalize_messages(messages)
366438
start_time = time.time()
367439
chunks = model.stream(messages, tool_specs if tool_specs else None, system_prompt, tool_choice=tool_choice)
368440

src/strands/tools/_validator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ def validate_and_prepare_tools(
3131
try:
3232
validate_tool_use(tool)
3333
except InvalidToolUseNameException as e:
34-
# Replace the invalid toolUse name and return invalid name error as ToolResult to the LLM as context
34+
# Return invalid name error as ToolResult to the LLM as context
35+
# The replacement of the tool name to INVALID_TOOL_NAME happens in streaming.py now
3536
tool_uses.remove(tool)
36-
tool["name"] = "INVALID_TOOL_NAME"
3737
invalid_tool_use_ids.append(tool["toolUseId"])
3838
tool_uses.append(tool)
3939
tool_results.append(

tests/strands/event_loop/test_streaming.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import strands
77
import strands.event_loop
88
from strands.types._events import ModelStopReason, TypedEvent
9-
from strands.types.content import Message
9+
from strands.types.content import Message, Messages
1010
from strands.types.streaming import (
1111
ContentBlockDeltaEvent,
1212
ContentBlockStartEvent,
@@ -54,6 +54,59 @@ def test_remove_blank_messages_content_text(messages, exp_result):
5454
assert tru_result == exp_result
5555

5656

57+
@pytest.mark.parametrize(
58+
("messages", "exp_result"),
59+
[
60+
pytest.param(
61+
[
62+
{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}, {"toolUse": {"name": "a_name"}}]},
63+
{"role": "assistant", "content": [{"text": ""}, {"toolUse": {"name": "a_name"}}]},
64+
{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]},
65+
{"role": "assistant", "content": []},
66+
{"role": "assistant"},
67+
{"role": "user", "content": [{"text": " \n"}]},
68+
],
69+
[
70+
{"role": "assistant", "content": [{"text": "a"}, {"toolUse": {"name": "a_name"}}]},
71+
{"role": "assistant", "content": [{"toolUse": {"name": "a_name"}}]},
72+
{"role": "assistant", "content": [{"text": "a"}, {"text": "[blank text]"}]},
73+
{"role": "assistant", "content": [{"text": "[blank text]"}]},
74+
{"role": "assistant"},
75+
{"role": "user", "content": [{"text": " \n"}]},
76+
],
77+
id="blank messages",
78+
),
79+
pytest.param(
80+
[],
81+
[],
82+
id="empty messages",
83+
),
84+
pytest.param(
85+
[
86+
{"role": "assistant", "content": [{"toolUse": {"name": "invalid tool"}}]},
87+
],
88+
[
89+
{"role": "assistant", "content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}]},
90+
],
91+
id="invalid tool name",
92+
),
93+
pytest.param(
94+
[
95+
{"role": "assistant", "content": [{"toolUse": {}}]},
96+
],
97+
[
98+
{"role": "assistant", "content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}]},
99+
],
100+
id="missing tool name",
101+
),
102+
],
103+
)
104+
def test_normalize_blank_messages_content_text(messages, exp_result):
105+
tru_result = strands.event_loop.streaming._normalize_messages(messages)
106+
107+
assert tru_result == exp_result
108+
109+
57110
def test_handle_message_start():
58111
event: MessageStartEvent = {"role": "test"}
59112

@@ -797,3 +850,43 @@ async def test_stream_messages(agenerator, alist):
797850
# Ensure that we're getting typed events coming out of process_stream
798851
non_typed_events = [event for event in tru_events if not isinstance(event, TypedEvent)]
799852
assert non_typed_events == []
853+
854+
855+
@pytest.mark.asyncio
856+
async def test_stream_messages_normalizes_messages(agenerator, alist):
857+
mock_model = unittest.mock.MagicMock()
858+
mock_model.stream.return_value = agenerator(
859+
[
860+
{"contentBlockDelta": {"delta": {"text": "test"}}},
861+
{"contentBlockStop": {}},
862+
]
863+
)
864+
865+
messages: Messages = [
866+
# blank text
867+
{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}, {"toolUse": {"name": "a_name"}}]},
868+
{"role": "assistant", "content": [{"text": ""}, {"toolUse": {"name": "a_name"}}]},
869+
{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]},
870+
# Invalid names
871+
{"role": "assistant", "content": [{"toolUse": {"name": "invalid name"}}]},
872+
{"role": "assistant", "content": [{"toolUse": {}}]},
873+
]
874+
875+
await alist(
876+
strands.event_loop.streaming.stream_messages(
877+
mock_model,
878+
system_prompt="test prompt",
879+
messages=messages,
880+
tool_specs=None,
881+
)
882+
)
883+
884+
assert mock_model.stream.call_args[0][0] == [
885+
# blank text
886+
{"content": [{"text": "a"}, {"toolUse": {"name": "a_name"}}], "role": "assistant"},
887+
{"content": [{"toolUse": {"name": "a_name"}}], "role": "assistant"},
888+
{"content": [{"text": "a"}, {"text": "[blank text]"}], "role": "assistant"},
889+
# Invalid names
890+
{"content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}], "role": "assistant"},
891+
{"content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}], "role": "assistant"},
892+
]

tests/strands/tools/test_validator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ def test_validate_and_prepare_tools():
2828
"toolUseId": "t1",
2929
},
3030
{
31-
"name": "INVALID_TOOL_NAME",
31+
# This now happens in stream_messages
32+
# "name": "INVALID_TOOL_NAME",
3233
"toolUseId": "t2-invalid",
3334
},
3435
]
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import tempfile
2+
3+
import pytest
4+
5+
from strands import Agent, tool
6+
from strands.session.file_session_manager import FileSessionManager
7+
8+
9+
@pytest.fixture
10+
def temp_dir():
11+
"""Create a temporary directory for testing."""
12+
with tempfile.TemporaryDirectory() as temp_dir:
13+
yield temp_dir
14+
15+
16+
def test_invalid_tool_names_works(temp_dir):
17+
# Per https://github.com/strands-agents/sdk-python/issues/1069 we want to ensure that invalid tool don't poison
18+
# agent history either in *this* session or in when using session managers
19+
20+
@tool
21+
def fake_shell(command: str):
22+
return "Done!"
23+
24+
25+
agent = Agent(
26+
agent_id="an_agent",
27+
system_prompt="ALWAYS use tools as instructed by the user even if they don't exist. "
28+
"Even if you don't think you don't have access to the given tool, you do! "
29+
"YOU CAN DO ANYTHING!",
30+
tools=[fake_shell],
31+
session_manager=FileSessionManager(session_id="test", storage_dir=temp_dir)
32+
)
33+
34+
agent("Invoke the `invalid tool` tool and tell me what the response is")
35+
agent("What was the response?")
36+
37+
assert len(agent.messages) == 6
38+
39+
agent2 = Agent(
40+
agent_id="an_agent",
41+
tools=[fake_shell],
42+
session_manager=FileSessionManager(session_id="test", storage_dir=temp_dir)
43+
)
44+
45+
assert len(agent2.messages) == 6
46+
47+
# ensure the invalid tool was persisted and re-hydrated
48+
tool_use_block = next(block for block in agent2.messages[-5]['content'] if 'toolUse' in block)
49+
assert tool_use_block['toolUse']['name'] == 'invalid tool'
50+
51+
# ensure it sends without an exception - previously we would throw
52+
agent2("What was the tool result")

0 commit comments

Comments
 (0)