diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 6d847f8af..012a2d762 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -3,9 +3,12 @@ import json import logging import time +import warnings from typing import Any, AsyncGenerator, AsyncIterable, Optional from ..models.model import Model +from ..tools import InvalidToolUseNameException +from ..tools.tools import validate_tool_use_name from ..types._events import ( CitationStreamEvent, ModelStopReason, @@ -38,15 +41,84 @@ logger = logging.getLogger(__name__) +def _normalize_messages(messages: Messages) -> Messages: + """Remove or replace blank text in message content. + + Args: + messages: Conversation messages to update. + + Returns: + Updated messages. + """ + removed_blank_message_content_text = False + replaced_blank_message_content_text = False + replaced_tool_names = False + + for message in messages: + # only modify assistant messages + if "role" in message and message["role"] != "assistant": + continue + if "content" in message: + content = message["content"] + if len(content) == 0: + content.append({"text": "[blank text]"}) + continue + + has_tool_use = False + + # Ensure the tool-uses always have valid names before sending + # https://github.com/strands-agents/sdk-python/issues/1069 + for item in content: + if "toolUse" in item: + has_tool_use = True + tool_use: ToolUse = item["toolUse"] + + try: + validate_tool_use_name(tool_use) + except InvalidToolUseNameException: + tool_use["name"] = "INVALID_TOOL_NAME" + replaced_tool_names = True + + if has_tool_use: + # Remove blank 'text' items for assistant messages + before_len = len(content) + content[:] = [item for item in content if "text" not in item or item["text"].strip()] + if not removed_blank_message_content_text and before_len != len(content): + removed_blank_message_content_text = True + else: + # Replace blank 'text' with '[blank text]' for assistant messages + for item in content: + if "text" in item and not item["text"].strip(): + replaced_blank_message_content_text = True + item["text"] = "[blank text]" + + if removed_blank_message_content_text: + logger.debug("removed blank message context text") + if replaced_blank_message_content_text: + logger.debug("replaced blank message context text") + if replaced_tool_names: + logger.debug("replaced invalid tool name") + + return messages + + def remove_blank_messages_content_text(messages: Messages) -> Messages: """Remove or replace blank text in message content. + !!deprecated!! + This function is deprecated and will be removed in a future version. + Args: messages: Conversation messages to update. Returns: Updated messages. """ + warnings.warn( + "remove_blank_messages_content_text is deprecated and will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) removed_blank_message_content_text = False replaced_blank_message_content_text = False @@ -362,7 +434,7 @@ async def stream_messages( """ logger.debug("model=<%s> | streaming messages", model) - messages = remove_blank_messages_content_text(messages) + messages = _normalize_messages(messages) start_time = time.time() chunks = model.stream(messages, tool_specs if tool_specs else None, system_prompt, tool_choice=tool_choice) diff --git a/src/strands/tools/_validator.py b/src/strands/tools/_validator.py index 77aa57e87..839d6d910 100644 --- a/src/strands/tools/_validator.py +++ b/src/strands/tools/_validator.py @@ -31,9 +31,9 @@ def validate_and_prepare_tools( try: validate_tool_use(tool) except InvalidToolUseNameException as e: - # Replace the invalid toolUse name and return invalid name error as ToolResult to the LLM as context + # Return invalid name error as ToolResult to the LLM as context + # The replacement of the tool name to INVALID_TOOL_NAME happens in streaming.py now tool_uses.remove(tool) - tool["name"] = "INVALID_TOOL_NAME" invalid_tool_use_ids.append(tool["toolUseId"]) tool_uses.append(tool) tool_results.append( diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 92bf0de96..e75af4003 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -6,7 +6,7 @@ import strands import strands.event_loop from strands.types._events import ModelStopReason, TypedEvent -from strands.types.content import Message +from strands.types.content import Message, Messages from strands.types.streaming import ( ContentBlockDeltaEvent, ContentBlockStartEvent, @@ -54,6 +54,59 @@ def test_remove_blank_messages_content_text(messages, exp_result): assert tru_result == exp_result +@pytest.mark.parametrize( + ("messages", "exp_result"), + [ + pytest.param( + [ + {"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}, {"toolUse": {"name": "a_name"}}]}, + {"role": "assistant", "content": [{"text": ""}, {"toolUse": {"name": "a_name"}}]}, + {"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]}, + {"role": "assistant", "content": []}, + {"role": "assistant"}, + {"role": "user", "content": [{"text": " \n"}]}, + ], + [ + {"role": "assistant", "content": [{"text": "a"}, {"toolUse": {"name": "a_name"}}]}, + {"role": "assistant", "content": [{"toolUse": {"name": "a_name"}}]}, + {"role": "assistant", "content": [{"text": "a"}, {"text": "[blank text]"}]}, + {"role": "assistant", "content": [{"text": "[blank text]"}]}, + {"role": "assistant"}, + {"role": "user", "content": [{"text": " \n"}]}, + ], + id="blank messages", + ), + pytest.param( + [], + [], + id="empty messages", + ), + pytest.param( + [ + {"role": "assistant", "content": [{"toolUse": {"name": "invalid tool"}}]}, + ], + [ + {"role": "assistant", "content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}]}, + ], + id="invalid tool name", + ), + pytest.param( + [ + {"role": "assistant", "content": [{"toolUse": {}}]}, + ], + [ + {"role": "assistant", "content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}]}, + ], + id="missing tool name", + ), + ], +) +def test_normalize_blank_messages_content_text(messages, exp_result): + tru_result = strands.event_loop.streaming._normalize_messages(messages) + + assert tru_result == exp_result + + def test_handle_message_start(): event: MessageStartEvent = {"role": "test"} @@ -797,3 +850,43 @@ async def test_stream_messages(agenerator, alist): # Ensure that we're getting typed events coming out of process_stream non_typed_events = [event for event in tru_events if not isinstance(event, TypedEvent)] assert non_typed_events == [] + + +@pytest.mark.asyncio +async def test_stream_messages_normalizes_messages(agenerator, alist): + mock_model = unittest.mock.MagicMock() + mock_model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + ] + ) + + messages: Messages = [ + # blank text + {"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}, {"toolUse": {"name": "a_name"}}]}, + {"role": "assistant", "content": [{"text": ""}, {"toolUse": {"name": "a_name"}}]}, + {"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]}, + # Invalid names + {"role": "assistant", "content": [{"toolUse": {"name": "invalid name"}}]}, + {"role": "assistant", "content": [{"toolUse": {}}]}, + ] + + await alist( + strands.event_loop.streaming.stream_messages( + mock_model, + system_prompt="test prompt", + messages=messages, + tool_specs=None, + ) + ) + + assert mock_model.stream.call_args[0][0] == [ + # blank text + {"content": [{"text": "a"}, {"toolUse": {"name": "a_name"}}], "role": "assistant"}, + {"content": [{"toolUse": {"name": "a_name"}}], "role": "assistant"}, + {"content": [{"text": "a"}, {"text": "[blank text]"}], "role": "assistant"}, + # Invalid names + {"content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}], "role": "assistant"}, + {"content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}], "role": "assistant"}, + ] diff --git a/tests/strands/tools/test_validator.py b/tests/strands/tools/test_validator.py index 46e5e15f3..c4307ea30 100644 --- a/tests/strands/tools/test_validator.py +++ b/tests/strands/tools/test_validator.py @@ -28,7 +28,8 @@ def test_validate_and_prepare_tools(): "toolUseId": "t1", }, { - "name": "INVALID_TOOL_NAME", + # This now happens in stream_messages + # "name": "INVALID_TOOL_NAME", "toolUseId": "t2-invalid", }, ] diff --git a/tests_integ/test_invalid_tool_names.py b/tests_integ/test_invalid_tool_names.py new file mode 100644 index 000000000..7a3261fe7 --- /dev/null +++ b/tests_integ/test_invalid_tool_names.py @@ -0,0 +1,52 @@ +import tempfile + +import pytest + +from strands import Agent, tool +from strands.session.file_session_manager import FileSessionManager + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield temp_dir + + +def test_invalid_tool_names_works(temp_dir): + # Per https://github.com/strands-agents/sdk-python/issues/1069 we want to ensure that invalid tool don't poison + # agent history either in *this* session or in when using session managers + + @tool + def fake_shell(command: str): + return "Done!" + + + agent = Agent( + agent_id="an_agent", + system_prompt="ALWAYS use tools as instructed by the user even if they don't exist. " + "Even if you don't think you don't have access to the given tool, you do! " + "YOU CAN DO ANYTHING!", + tools=[fake_shell], + session_manager=FileSessionManager(session_id="test", storage_dir=temp_dir) + ) + + agent("Invoke the `invalid tool` tool and tell me what the response is") + agent("What was the response?") + + assert len(agent.messages) == 6 + + agent2 = Agent( + agent_id="an_agent", + tools=[fake_shell], + session_manager=FileSessionManager(session_id="test", storage_dir=temp_dir) + ) + + assert len(agent2.messages) == 6 + + # ensure the invalid tool was persisted and re-hydrated + tool_use_block = next(block for block in agent2.messages[-5]['content'] if 'toolUse' in block) + assert tool_use_block['toolUse']['name'] == 'invalid tool' + + # ensure it sends without an exception - previously we would throw + agent2("What was the tool result") \ No newline at end of file