From eafa16c4a5edb922bfefb3c1c215087a49c16900 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Fri, 24 Oct 2025 15:57:12 -0400 Subject: [PATCH 1/3] fix: Transform invalid tool usages on sending, not on initial detection Per bug #1069, session-managers never persist tool-name changes after we initially persist the message, which means once an agent generates an invalid-tool name, that message history is poisoned on re-hydration. To avoid that going forward, do the translation of invalid-tool names on sending to the provider and not on the initial tool_use detection. The initial tool_use detection is needed to add a tool_response with a proper error message for the LLM, but this will avoid the poisoning issue --- src/strands/event_loop/streaming.py | 72 +++++++++++++++- src/strands/tools/_validator.py | 43 ++++++++-- src/strands/tools/tools.py | 38 ++++----- tests/strands/event_loop/test_streaming.py | 95 +++++++++++++++++++++- tests/strands/tools/test_validator.py | 3 +- 5 files changed, 221 insertions(+), 30 deletions(-) diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 6d847f8af..4f063fa65 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -3,9 +3,11 @@ import json import logging import time +import warnings from typing import Any, AsyncGenerator, AsyncIterable, Optional from ..models.model import Model +from ..tools._validator import check_tool_name_validity from ..types._events import ( CitationStreamEvent, ModelStopReason, @@ -38,15 +40,83 @@ logger = logging.getLogger(__name__) +def _normalize_messages(messages: Messages) -> Messages: + """Remove or replace blank text in message content. + + Args: + messages: Conversation messages to update. + + Returns: + Updated messages. + """ + removed_blank_message_content_text = False + replaced_blank_message_content_text = False + replaced_tool_names = False + + for message in messages: + # only modify assistant messages + if "role" in message and message["role"] != "assistant": + continue + if "content" in message: + content = message["content"] + if len(content) == 0: + content.append({"text": "[blank text]"}) + continue + + has_tool_use = False + + # Ensure the tool-uses always have invalid names before sending + # https://github.com/strands-agents/sdk-python/issues/1069 + for item in content: + if "toolUse" in item: + has_tool_use = True + tool_use: ToolUse = item["toolUse"] + + is_valid, _ = check_tool_name_validity(tool_use) + if not is_valid: + tool_use["name"] = "INVALID_TOOL_NAME" + replaced_tool_names = True + + if has_tool_use: + # Remove blank 'text' items for assistant messages + before_len = len(content) + content[:] = [item for item in content if "text" not in item or item["text"].strip()] + if not removed_blank_message_content_text and before_len != len(content): + removed_blank_message_content_text = True + else: + # Replace blank 'text' with '[blank text]' for assistant messages + for item in content: + if "text" in item and not item["text"].strip(): + replaced_blank_message_content_text = True + item["text"] = "[blank text]" + + if removed_blank_message_content_text: + logger.debug("removed blank message context text") + if replaced_blank_message_content_text: + logger.debug("replaced blank message context text") + if replaced_tool_names: + logger.debug("replaced invalid tool name") + + return messages + + def remove_blank_messages_content_text(messages: Messages) -> Messages: """Remove or replace blank text in message content. + !!deprecated!! + This function is deprecated and will be removed in a future version. + Args: messages: Conversation messages to update. Returns: Updated messages. """ + warnings.warn( + "remove_blank_messages_content_text is deprecated and will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) removed_blank_message_content_text = False replaced_blank_message_content_text = False @@ -362,7 +432,7 @@ async def stream_messages( """ logger.debug("model=<%s> | streaming messages", model) - messages = remove_blank_messages_content_text(messages) + messages = _normalize_messages(messages) start_time = time.time() chunks = model.stream(messages, tool_specs if tool_specs else None, system_prompt, tool_choice=tool_choice) diff --git a/src/strands/tools/_validator.py b/src/strands/tools/_validator.py index 77aa57e87..202e175d4 100644 --- a/src/strands/tools/_validator.py +++ b/src/strands/tools/_validator.py @@ -1,9 +1,14 @@ """Tool validation utilities.""" -from ..tools.tools import InvalidToolUseNameException, validate_tool_use +import logging +import re +from typing import Tuple + from ..types.content import Message from ..types.tools import ToolResult, ToolUse +logger = logging.getLogger(__name__) + def validate_and_prepare_tools( message: Message, @@ -28,18 +33,42 @@ def validate_and_prepare_tools( # Avoid modifying original `tool_uses` variable during iteration tool_uses_copy = tool_uses.copy() for tool in tool_uses_copy: - try: - validate_tool_use(tool) - except InvalidToolUseNameException as e: - # Replace the invalid toolUse name and return invalid name error as ToolResult to the LLM as context + is_valid, validity_message = check_tool_name_validity(tool) + + if not is_valid: + logger.warning(validity_message) + # Return invalid name error as ToolResult to the LLM as context; + # The replacement of the tool name to INVALID_TOOL_NAME happens in streaming.py now tool_uses.remove(tool) - tool["name"] = "INVALID_TOOL_NAME" invalid_tool_use_ids.append(tool["toolUseId"]) tool_uses.append(tool) tool_results.append( { "toolUseId": tool["toolUseId"], "status": "error", - "content": [{"text": f"Error: {str(e)}"}], + "content": [{"text": f"Error: {validity_message}"}], } ) + + +def check_tool_name_validity(tool: ToolUse) -> Tuple[bool, str]: + """Validate a tool use name.""" + # We need to fix some typing here, because we don't actually expect a ToolUse, but dict[str, Any] + if "name" not in tool: + return False, "tool name missing" # type: ignore[unreachable] + + tool_name = tool["name"] + tool_name_pattern = r"^[a-zA-Z0-9_\-]{1,}$" + tool_name_max_length = 64 + valid_name_pattern = bool(re.match(tool_name_pattern, tool_name)) + tool_name_len = len(tool_name) + + if not valid_name_pattern: + message = f"tool_name=<{tool_name}> | invalid tool name pattern" + return False, message + + if tool_name_len > tool_name_max_length: + message = f"tool_name=<{tool_name}>, tool_name_max_length=<{tool_name_max_length}> | invalid tool name length" + return False, message + + return True, "" diff --git a/src/strands/tools/tools.py b/src/strands/tools/tools.py index 48b969bc3..8e3211120 100644 --- a/src/strands/tools/tools.py +++ b/src/strands/tools/tools.py @@ -7,13 +7,14 @@ import asyncio import inspect import logging -import re +import warnings from typing import Any from typing_extensions import override from ..types._events import ToolResultEvent from ..types.tools import AgentTool, ToolFunc, ToolGenerator, ToolSpec, ToolUse +from ._validator import check_tool_name_validity logger = logging.getLogger(__name__) @@ -27,40 +28,37 @@ class InvalidToolUseNameException(Exception): def validate_tool_use(tool: ToolUse) -> None: """Validate a tool use request. + !!deprecated!! + Args: tool: The tool use to validate. """ + warnings.warn( + "validate_tool_use is deprecated and will be removed in Strands SDK 2.0.", + DeprecationWarning, + stacklevel=2, + ) validate_tool_use_name(tool) def validate_tool_use_name(tool: ToolUse) -> None: """Validate the name of a tool use. + !!deprecated!! + Args: tool: The tool use to validate. Raises: InvalidToolUseNameException: If the tool name is invalid. """ - # We need to fix some typing here, because we don't actually expect a ToolUse, but dict[str, Any] - if "name" not in tool: - message = "tool name missing" # type: ignore[unreachable] - logger.warning(message) - raise InvalidToolUseNameException(message) - - tool_name = tool["name"] - tool_name_pattern = r"^[a-zA-Z0-9_\-]{1,}$" - tool_name_max_length = 64 - valid_name_pattern = bool(re.match(tool_name_pattern, tool_name)) - tool_name_len = len(tool_name) - - if not valid_name_pattern: - message = f"tool_name=<{tool_name}> | invalid tool name pattern" - logger.warning(message) - raise InvalidToolUseNameException(message) - - if tool_name_len > tool_name_max_length: - message = f"tool_name=<{tool_name}>, tool_name_max_length=<{tool_name_max_length}> | invalid tool name length" + warnings.warn( + "validate_tool_use_name is deprecated and will be removed in Strands SDK 2.0.", + DeprecationWarning, + stacklevel=2, + ) + is_valid, message = check_tool_name_validity(tool) + if not is_valid: logger.warning(message) raise InvalidToolUseNameException(message) diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 92bf0de96..e75af4003 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -6,7 +6,7 @@ import strands import strands.event_loop from strands.types._events import ModelStopReason, TypedEvent -from strands.types.content import Message +from strands.types.content import Message, Messages from strands.types.streaming import ( ContentBlockDeltaEvent, ContentBlockStartEvent, @@ -54,6 +54,59 @@ def test_remove_blank_messages_content_text(messages, exp_result): assert tru_result == exp_result +@pytest.mark.parametrize( + ("messages", "exp_result"), + [ + pytest.param( + [ + {"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}, {"toolUse": {"name": "a_name"}}]}, + {"role": "assistant", "content": [{"text": ""}, {"toolUse": {"name": "a_name"}}]}, + {"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]}, + {"role": "assistant", "content": []}, + {"role": "assistant"}, + {"role": "user", "content": [{"text": " \n"}]}, + ], + [ + {"role": "assistant", "content": [{"text": "a"}, {"toolUse": {"name": "a_name"}}]}, + {"role": "assistant", "content": [{"toolUse": {"name": "a_name"}}]}, + {"role": "assistant", "content": [{"text": "a"}, {"text": "[blank text]"}]}, + {"role": "assistant", "content": [{"text": "[blank text]"}]}, + {"role": "assistant"}, + {"role": "user", "content": [{"text": " \n"}]}, + ], + id="blank messages", + ), + pytest.param( + [], + [], + id="empty messages", + ), + pytest.param( + [ + {"role": "assistant", "content": [{"toolUse": {"name": "invalid tool"}}]}, + ], + [ + {"role": "assistant", "content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}]}, + ], + id="invalid tool name", + ), + pytest.param( + [ + {"role": "assistant", "content": [{"toolUse": {}}]}, + ], + [ + {"role": "assistant", "content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}]}, + ], + id="missing tool name", + ), + ], +) +def test_normalize_blank_messages_content_text(messages, exp_result): + tru_result = strands.event_loop.streaming._normalize_messages(messages) + + assert tru_result == exp_result + + def test_handle_message_start(): event: MessageStartEvent = {"role": "test"} @@ -797,3 +850,43 @@ async def test_stream_messages(agenerator, alist): # Ensure that we're getting typed events coming out of process_stream non_typed_events = [event for event in tru_events if not isinstance(event, TypedEvent)] assert non_typed_events == [] + + +@pytest.mark.asyncio +async def test_stream_messages_normalizes_messages(agenerator, alist): + mock_model = unittest.mock.MagicMock() + mock_model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + ] + ) + + messages: Messages = [ + # blank text + {"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}, {"toolUse": {"name": "a_name"}}]}, + {"role": "assistant", "content": [{"text": ""}, {"toolUse": {"name": "a_name"}}]}, + {"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]}, + # Invalid names + {"role": "assistant", "content": [{"toolUse": {"name": "invalid name"}}]}, + {"role": "assistant", "content": [{"toolUse": {}}]}, + ] + + await alist( + strands.event_loop.streaming.stream_messages( + mock_model, + system_prompt="test prompt", + messages=messages, + tool_specs=None, + ) + ) + + assert mock_model.stream.call_args[0][0] == [ + # blank text + {"content": [{"text": "a"}, {"toolUse": {"name": "a_name"}}], "role": "assistant"}, + {"content": [{"toolUse": {"name": "a_name"}}], "role": "assistant"}, + {"content": [{"text": "a"}, {"text": "[blank text]"}], "role": "assistant"}, + # Invalid names + {"content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}], "role": "assistant"}, + {"content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}], "role": "assistant"}, + ] diff --git a/tests/strands/tools/test_validator.py b/tests/strands/tools/test_validator.py index 46e5e15f3..c4307ea30 100644 --- a/tests/strands/tools/test_validator.py +++ b/tests/strands/tools/test_validator.py @@ -28,7 +28,8 @@ def test_validate_and_prepare_tools(): "toolUseId": "t1", }, { - "name": "INVALID_TOOL_NAME", + # This now happens in stream_messages + # "name": "INVALID_TOOL_NAME", "toolUseId": "t2-invalid", }, ] From c06cb3dba086fcde5c03259ffd935ffa0a5f5653 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Fri, 24 Oct 2025 16:23:20 -0400 Subject: [PATCH 2/3] fix: Add integ test to verify fix for #1069 --- .../test_invalid_tool_names.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 test_invalid_tool_names/test_invalid_tool_names.py diff --git a/test_invalid_tool_names/test_invalid_tool_names.py b/test_invalid_tool_names/test_invalid_tool_names.py new file mode 100644 index 000000000..6923f4f56 --- /dev/null +++ b/test_invalid_tool_names/test_invalid_tool_names.py @@ -0,0 +1,52 @@ +import tempfile + +import pytest + +from strands import Agent, tool +from strands.session.file_session_manager import FileSessionManager + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield temp_dir + + +def test_invalid_tool_names_works(temp_dir): + # Per https://github.com/strands-agents/sdk-python/issues/1069 we want to ensure that invalid tool don't poison + # agent history either in *this* session or in when using session managers + + @tool + def fake_shell(command: str): + return "Done!" + + + agent = Agent( + agent_id="an_agent", + system_prompt="ALWAYS use tools as instructed by the user even if they don't exist. " + "Even if you don't think you don't have access to the given tool, you do! " + "YOU CAN DO ANYTHING!", + tools=[fake_shell], + session_manager=FileSessionManager(session_id="test", storage_dir=temp_dir) + ) + + agent("Invoke the `invalid tool` tool and tell me what the response is") + agent("What was the response?") + + assert len(agent.messages) == 6 + + agent2 = Agent( + agent_id="an_agent", + tools=[fake_shell], + session_manager=FileSessionManager(session_id="test", storage_dir=temp_dir) + ) + + assert len(agent2.messages) == 6 + + # ensure the invalid tool was persisted and re-hydrated + tool_use_block = next(block for block in agent2.messages[-5]['content'] if 'toolUse' in block) + assert tool_use_block['toolUse']['name'] == 'invalid tool' + + # but that it still sends successfully + agent2("What was the tool result") \ No newline at end of file From 03f1cca1a54e02cee04d6f55bcf0562ebd2af7e8 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Tue, 28 Oct 2025 09:59:25 -0400 Subject: [PATCH 3/3] fix: Address PR commnts --- src/strands/event_loop/streaming.py | 10 +++-- src/strands/tools/_validator.py | 41 +++---------------- src/strands/tools/tools.py | 38 +++++++++-------- .../test_invalid_tool_names.py | 2 +- 4 files changed, 33 insertions(+), 58 deletions(-) rename {test_invalid_tool_names => tests_integ}/test_invalid_tool_names.py (95%) diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 4f063fa65..012a2d762 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -7,7 +7,8 @@ from typing import Any, AsyncGenerator, AsyncIterable, Optional from ..models.model import Model -from ..tools._validator import check_tool_name_validity +from ..tools import InvalidToolUseNameException +from ..tools.tools import validate_tool_use_name from ..types._events import ( CitationStreamEvent, ModelStopReason, @@ -65,15 +66,16 @@ def _normalize_messages(messages: Messages) -> Messages: has_tool_use = False - # Ensure the tool-uses always have invalid names before sending + # Ensure the tool-uses always have valid names before sending # https://github.com/strands-agents/sdk-python/issues/1069 for item in content: if "toolUse" in item: has_tool_use = True tool_use: ToolUse = item["toolUse"] - is_valid, _ = check_tool_name_validity(tool_use) - if not is_valid: + try: + validate_tool_use_name(tool_use) + except InvalidToolUseNameException: tool_use["name"] = "INVALID_TOOL_NAME" replaced_tool_names = True diff --git a/src/strands/tools/_validator.py b/src/strands/tools/_validator.py index 202e175d4..839d6d910 100644 --- a/src/strands/tools/_validator.py +++ b/src/strands/tools/_validator.py @@ -1,14 +1,9 @@ """Tool validation utilities.""" -import logging -import re -from typing import Tuple - +from ..tools.tools import InvalidToolUseNameException, validate_tool_use from ..types.content import Message from ..types.tools import ToolResult, ToolUse -logger = logging.getLogger(__name__) - def validate_and_prepare_tools( message: Message, @@ -33,11 +28,10 @@ def validate_and_prepare_tools( # Avoid modifying original `tool_uses` variable during iteration tool_uses_copy = tool_uses.copy() for tool in tool_uses_copy: - is_valid, validity_message = check_tool_name_validity(tool) - - if not is_valid: - logger.warning(validity_message) - # Return invalid name error as ToolResult to the LLM as context; + try: + validate_tool_use(tool) + except InvalidToolUseNameException as e: + # Return invalid name error as ToolResult to the LLM as context # The replacement of the tool name to INVALID_TOOL_NAME happens in streaming.py now tool_uses.remove(tool) invalid_tool_use_ids.append(tool["toolUseId"]) @@ -46,29 +40,6 @@ def validate_and_prepare_tools( { "toolUseId": tool["toolUseId"], "status": "error", - "content": [{"text": f"Error: {validity_message}"}], + "content": [{"text": f"Error: {str(e)}"}], } ) - - -def check_tool_name_validity(tool: ToolUse) -> Tuple[bool, str]: - """Validate a tool use name.""" - # We need to fix some typing here, because we don't actually expect a ToolUse, but dict[str, Any] - if "name" not in tool: - return False, "tool name missing" # type: ignore[unreachable] - - tool_name = tool["name"] - tool_name_pattern = r"^[a-zA-Z0-9_\-]{1,}$" - tool_name_max_length = 64 - valid_name_pattern = bool(re.match(tool_name_pattern, tool_name)) - tool_name_len = len(tool_name) - - if not valid_name_pattern: - message = f"tool_name=<{tool_name}> | invalid tool name pattern" - return False, message - - if tool_name_len > tool_name_max_length: - message = f"tool_name=<{tool_name}>, tool_name_max_length=<{tool_name_max_length}> | invalid tool name length" - return False, message - - return True, "" diff --git a/src/strands/tools/tools.py b/src/strands/tools/tools.py index 8e3211120..48b969bc3 100644 --- a/src/strands/tools/tools.py +++ b/src/strands/tools/tools.py @@ -7,14 +7,13 @@ import asyncio import inspect import logging -import warnings +import re from typing import Any from typing_extensions import override from ..types._events import ToolResultEvent from ..types.tools import AgentTool, ToolFunc, ToolGenerator, ToolSpec, ToolUse -from ._validator import check_tool_name_validity logger = logging.getLogger(__name__) @@ -28,37 +27,40 @@ class InvalidToolUseNameException(Exception): def validate_tool_use(tool: ToolUse) -> None: """Validate a tool use request. - !!deprecated!! - Args: tool: The tool use to validate. """ - warnings.warn( - "validate_tool_use is deprecated and will be removed in Strands SDK 2.0.", - DeprecationWarning, - stacklevel=2, - ) validate_tool_use_name(tool) def validate_tool_use_name(tool: ToolUse) -> None: """Validate the name of a tool use. - !!deprecated!! - Args: tool: The tool use to validate. Raises: InvalidToolUseNameException: If the tool name is invalid. """ - warnings.warn( - "validate_tool_use_name is deprecated and will be removed in Strands SDK 2.0.", - DeprecationWarning, - stacklevel=2, - ) - is_valid, message = check_tool_name_validity(tool) - if not is_valid: + # We need to fix some typing here, because we don't actually expect a ToolUse, but dict[str, Any] + if "name" not in tool: + message = "tool name missing" # type: ignore[unreachable] + logger.warning(message) + raise InvalidToolUseNameException(message) + + tool_name = tool["name"] + tool_name_pattern = r"^[a-zA-Z0-9_\-]{1,}$" + tool_name_max_length = 64 + valid_name_pattern = bool(re.match(tool_name_pattern, tool_name)) + tool_name_len = len(tool_name) + + if not valid_name_pattern: + message = f"tool_name=<{tool_name}> | invalid tool name pattern" + logger.warning(message) + raise InvalidToolUseNameException(message) + + if tool_name_len > tool_name_max_length: + message = f"tool_name=<{tool_name}>, tool_name_max_length=<{tool_name_max_length}> | invalid tool name length" logger.warning(message) raise InvalidToolUseNameException(message) diff --git a/test_invalid_tool_names/test_invalid_tool_names.py b/tests_integ/test_invalid_tool_names.py similarity index 95% rename from test_invalid_tool_names/test_invalid_tool_names.py rename to tests_integ/test_invalid_tool_names.py index 6923f4f56..7a3261fe7 100644 --- a/test_invalid_tool_names/test_invalid_tool_names.py +++ b/tests_integ/test_invalid_tool_names.py @@ -48,5 +48,5 @@ def fake_shell(command: str): tool_use_block = next(block for block in agent2.messages[-5]['content'] if 'toolUse' in block) assert tool_use_block['toolUse']['name'] == 'invalid tool' - # but that it still sends successfully + # ensure it sends without an exception - previously we would throw agent2("What was the tool result") \ No newline at end of file