Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion src/strands/event_loop/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import json
import logging
import time
import warnings
from typing import Any, AsyncGenerator, AsyncIterable, Optional

from ..models.model import Model
from ..tools import InvalidToolUseNameException
from ..tools.tools import validate_tool_use_name
from ..types._events import (
CitationStreamEvent,
ModelStopReason,
Expand Down Expand Up @@ -38,15 +41,84 @@
logger = logging.getLogger(__name__)


def _normalize_messages(messages: Messages) -> Messages:
"""Remove or replace blank text in message content.

Args:
messages: Conversation messages to update.

Returns:
Updated messages.
"""
removed_blank_message_content_text = False
replaced_blank_message_content_text = False
replaced_tool_names = False

for message in messages:
# only modify assistant messages
if "role" in message and message["role"] != "assistant":
continue
if "content" in message:
content = message["content"]
if len(content) == 0:
content.append({"text": "[blank text]"})
continue

has_tool_use = False

# Ensure the tool-uses always have valid names before sending
# https://github.com/strands-agents/sdk-python/issues/1069
for item in content:
if "toolUse" in item:
has_tool_use = True
tool_use: ToolUse = item["toolUse"]

try:
validate_tool_use_name(tool_use)
except InvalidToolUseNameException:
tool_use["name"] = "INVALID_TOOL_NAME"
replaced_tool_names = True

if has_tool_use:
# Remove blank 'text' items for assistant messages
before_len = len(content)
content[:] = [item for item in content if "text" not in item or item["text"].strip()]
if not removed_blank_message_content_text and before_len != len(content):
removed_blank_message_content_text = True
else:
# Replace blank 'text' with '[blank text]' for assistant messages
for item in content:
if "text" in item and not item["text"].strip():
replaced_blank_message_content_text = True
item["text"] = "[blank text]"

if removed_blank_message_content_text:
logger.debug("removed blank message context text")
if replaced_blank_message_content_text:
logger.debug("replaced blank message context text")
if replaced_tool_names:
logger.debug("replaced invalid tool name")

return messages


def remove_blank_messages_content_text(messages: Messages) -> Messages:
"""Remove or replace blank text in message content.

!!deprecated!!
This function is deprecated and will be removed in a future version.

Args:
messages: Conversation messages to update.

Returns:
Updated messages.
"""
warnings.warn(
"remove_blank_messages_content_text is deprecated and will be removed in a future version.",
DeprecationWarning,
stacklevel=2,
)
removed_blank_message_content_text = False
replaced_blank_message_content_text = False

Expand Down Expand Up @@ -362,7 +434,7 @@ async def stream_messages(
"""
logger.debug("model=<%s> | streaming messages", model)

messages = remove_blank_messages_content_text(messages)
messages = _normalize_messages(messages)
start_time = time.time()
chunks = model.stream(messages, tool_specs if tool_specs else None, system_prompt, tool_choice=tool_choice)

Expand Down
4 changes: 2 additions & 2 deletions src/strands/tools/_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def validate_and_prepare_tools(
try:
validate_tool_use(tool)
except InvalidToolUseNameException as e:
# Replace the invalid toolUse name and return invalid name error as ToolResult to the LLM as context
# Return invalid name error as ToolResult to the LLM as context
# The replacement of the tool name to INVALID_TOOL_NAME happens in streaming.py now
tool_uses.remove(tool)
tool["name"] = "INVALID_TOOL_NAME"
invalid_tool_use_ids.append(tool["toolUseId"])
tool_uses.append(tool)
tool_results.append(
Expand Down
95 changes: 94 additions & 1 deletion tests/strands/event_loop/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import strands
import strands.event_loop
from strands.types._events import ModelStopReason, TypedEvent
from strands.types.content import Message
from strands.types.content import Message, Messages
from strands.types.streaming import (
ContentBlockDeltaEvent,
ContentBlockStartEvent,
Expand Down Expand Up @@ -54,6 +54,59 @@ def test_remove_blank_messages_content_text(messages, exp_result):
assert tru_result == exp_result


@pytest.mark.parametrize(
("messages", "exp_result"),
[
pytest.param(
[
{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}, {"toolUse": {"name": "a_name"}}]},
{"role": "assistant", "content": [{"text": ""}, {"toolUse": {"name": "a_name"}}]},
{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]},
{"role": "assistant", "content": []},
{"role": "assistant"},
{"role": "user", "content": [{"text": " \n"}]},
],
[
{"role": "assistant", "content": [{"text": "a"}, {"toolUse": {"name": "a_name"}}]},
{"role": "assistant", "content": [{"toolUse": {"name": "a_name"}}]},
{"role": "assistant", "content": [{"text": "a"}, {"text": "[blank text]"}]},
{"role": "assistant", "content": [{"text": "[blank text]"}]},
{"role": "assistant"},
{"role": "user", "content": [{"text": " \n"}]},
],
id="blank messages",
),
pytest.param(
[],
[],
id="empty messages",
),
pytest.param(
[
{"role": "assistant", "content": [{"toolUse": {"name": "invalid tool"}}]},
],
[
{"role": "assistant", "content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}]},
],
id="invalid tool name",
),
pytest.param(
[
{"role": "assistant", "content": [{"toolUse": {}}]},
],
[
{"role": "assistant", "content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}]},
],
id="missing tool name",
),
],
)
def test_normalize_blank_messages_content_text(messages, exp_result):
tru_result = strands.event_loop.streaming._normalize_messages(messages)

assert tru_result == exp_result


def test_handle_message_start():
event: MessageStartEvent = {"role": "test"}

Expand Down Expand Up @@ -797,3 +850,43 @@ async def test_stream_messages(agenerator, alist):
# Ensure that we're getting typed events coming out of process_stream
non_typed_events = [event for event in tru_events if not isinstance(event, TypedEvent)]
assert non_typed_events == []


@pytest.mark.asyncio
async def test_stream_messages_normalizes_messages(agenerator, alist):
mock_model = unittest.mock.MagicMock()
mock_model.stream.return_value = agenerator(
[
{"contentBlockDelta": {"delta": {"text": "test"}}},
{"contentBlockStop": {}},
]
)

messages: Messages = [
# blank text
{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}, {"toolUse": {"name": "a_name"}}]},
{"role": "assistant", "content": [{"text": ""}, {"toolUse": {"name": "a_name"}}]},
{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]},
# Invalid names
{"role": "assistant", "content": [{"toolUse": {"name": "invalid name"}}]},
{"role": "assistant", "content": [{"toolUse": {}}]},
]

await alist(
strands.event_loop.streaming.stream_messages(
mock_model,
system_prompt="test prompt",
messages=messages,
tool_specs=None,
)
)

assert mock_model.stream.call_args[0][0] == [
# blank text
{"content": [{"text": "a"}, {"toolUse": {"name": "a_name"}}], "role": "assistant"},
{"content": [{"toolUse": {"name": "a_name"}}], "role": "assistant"},
{"content": [{"text": "a"}, {"text": "[blank text]"}], "role": "assistant"},
# Invalid names
{"content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}], "role": "assistant"},
{"content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}], "role": "assistant"},
]
3 changes: 2 additions & 1 deletion tests/strands/tools/test_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def test_validate_and_prepare_tools():
"toolUseId": "t1",
},
{
"name": "INVALID_TOOL_NAME",
# This now happens in stream_messages
# "name": "INVALID_TOOL_NAME",
"toolUseId": "t2-invalid",
},
]
Expand Down
52 changes: 52 additions & 0 deletions tests_integ/test_invalid_tool_names.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import tempfile

import pytest

from strands import Agent, tool
from strands.session.file_session_manager import FileSessionManager


@pytest.fixture
def temp_dir():
"""Create a temporary directory for testing."""
with tempfile.TemporaryDirectory() as temp_dir:
yield temp_dir


def test_invalid_tool_names_works(temp_dir):
# Per https://github.com/strands-agents/sdk-python/issues/1069 we want to ensure that invalid tool don't poison
# agent history either in *this* session or in when using session managers

@tool
def fake_shell(command: str):
return "Done!"


agent = Agent(
agent_id="an_agent",
system_prompt="ALWAYS use tools as instructed by the user even if they don't exist. "
"Even if you don't think you don't have access to the given tool, you do! "
"YOU CAN DO ANYTHING!",
tools=[fake_shell],
session_manager=FileSessionManager(session_id="test", storage_dir=temp_dir)
)

agent("Invoke the `invalid tool` tool and tell me what the response is")
agent("What was the response?")

assert len(agent.messages) == 6

agent2 = Agent(
agent_id="an_agent",
tools=[fake_shell],
session_manager=FileSessionManager(session_id="test", storage_dir=temp_dir)
)

assert len(agent2.messages) == 6

# ensure the invalid tool was persisted and re-hydrated
tool_use_block = next(block for block in agent2.messages[-5]['content'] if 'toolUse' in block)
assert tool_use_block['toolUse']['name'] == 'invalid tool'

# ensure it sends without an exception - previously we would throw
agent2("What was the tool result")