| 
6 | 6 | import strands  | 
7 | 7 | import strands.event_loop  | 
8 | 8 | from strands.types._events import ModelStopReason, TypedEvent  | 
9 |  | -from strands.types.content import Message  | 
 | 9 | +from strands.types.content import Message, Messages  | 
10 | 10 | from strands.types.streaming import (  | 
11 | 11 |     ContentBlockDeltaEvent,  | 
12 | 12 |     ContentBlockStartEvent,  | 
@@ -54,6 +54,59 @@ def test_remove_blank_messages_content_text(messages, exp_result):  | 
54 | 54 |     assert tru_result == exp_result  | 
55 | 55 | 
 
  | 
56 | 56 | 
 
  | 
 | 57 | +@pytest.mark.parametrize(  | 
 | 58 | +    ("messages", "exp_result"),  | 
 | 59 | +    [  | 
 | 60 | +        pytest.param(  | 
 | 61 | +            [  | 
 | 62 | +                {"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}, {"toolUse": {"name": "a_name"}}]},  | 
 | 63 | +                {"role": "assistant", "content": [{"text": ""}, {"toolUse": {"name": "a_name"}}]},  | 
 | 64 | +                {"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]},  | 
 | 65 | +                {"role": "assistant", "content": []},  | 
 | 66 | +                {"role": "assistant"},  | 
 | 67 | +                {"role": "user", "content": [{"text": " \n"}]},  | 
 | 68 | +            ],  | 
 | 69 | +            [  | 
 | 70 | +                {"role": "assistant", "content": [{"text": "a"}, {"toolUse": {"name": "a_name"}}]},  | 
 | 71 | +                {"role": "assistant", "content": [{"toolUse": {"name": "a_name"}}]},  | 
 | 72 | +                {"role": "assistant", "content": [{"text": "a"}, {"text": "[blank text]"}]},  | 
 | 73 | +                {"role": "assistant", "content": [{"text": "[blank text]"}]},  | 
 | 74 | +                {"role": "assistant"},  | 
 | 75 | +                {"role": "user", "content": [{"text": " \n"}]},  | 
 | 76 | +            ],  | 
 | 77 | +            id="blank messages",  | 
 | 78 | +        ),  | 
 | 79 | +        pytest.param(  | 
 | 80 | +            [],  | 
 | 81 | +            [],  | 
 | 82 | +            id="empty messages",  | 
 | 83 | +        ),  | 
 | 84 | +        pytest.param(  | 
 | 85 | +            [  | 
 | 86 | +                {"role": "assistant", "content": [{"toolUse": {"name": "invalid tool"}}]},  | 
 | 87 | +            ],  | 
 | 88 | +            [  | 
 | 89 | +                {"role": "assistant", "content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}]},  | 
 | 90 | +            ],  | 
 | 91 | +            id="invalid tool name",  | 
 | 92 | +        ),  | 
 | 93 | +        pytest.param(  | 
 | 94 | +            [  | 
 | 95 | +                {"role": "assistant", "content": [{"toolUse": {}}]},  | 
 | 96 | +            ],  | 
 | 97 | +            [  | 
 | 98 | +                {"role": "assistant", "content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}]},  | 
 | 99 | +            ],  | 
 | 100 | +            id="missing tool name",  | 
 | 101 | +        ),  | 
 | 102 | +    ],  | 
 | 103 | +)  | 
 | 104 | +def test_normalize_blank_messages_content_text(messages, exp_result):  | 
 | 105 | +    tru_result = strands.event_loop.streaming._normalize_messages(messages)  | 
 | 106 | + | 
 | 107 | +    assert tru_result == exp_result  | 
 | 108 | + | 
 | 109 | + | 
57 | 110 | def test_handle_message_start():  | 
58 | 111 |     event: MessageStartEvent = {"role": "test"}  | 
59 | 112 | 
 
  | 
@@ -797,3 +850,43 @@ async def test_stream_messages(agenerator, alist):  | 
797 | 850 |     # Ensure that we're getting typed events coming out of process_stream  | 
798 | 851 |     non_typed_events = [event for event in tru_events if not isinstance(event, TypedEvent)]  | 
799 | 852 |     assert non_typed_events == []  | 
 | 853 | + | 
 | 854 | + | 
 | 855 | +@pytest.mark.asyncio  | 
 | 856 | +async def test_stream_messages_normalizes_messages(agenerator, alist):  | 
 | 857 | +    mock_model = unittest.mock.MagicMock()  | 
 | 858 | +    mock_model.stream.return_value = agenerator(  | 
 | 859 | +        [  | 
 | 860 | +            {"contentBlockDelta": {"delta": {"text": "test"}}},  | 
 | 861 | +            {"contentBlockStop": {}},  | 
 | 862 | +        ]  | 
 | 863 | +    )  | 
 | 864 | + | 
 | 865 | +    messages: Messages = [  | 
 | 866 | +        # blank text  | 
 | 867 | +        {"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}, {"toolUse": {"name": "a_name"}}]},  | 
 | 868 | +        {"role": "assistant", "content": [{"text": ""}, {"toolUse": {"name": "a_name"}}]},  | 
 | 869 | +        {"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]},  | 
 | 870 | +        # Invalid names  | 
 | 871 | +        {"role": "assistant", "content": [{"toolUse": {"name": "invalid name"}}]},  | 
 | 872 | +        {"role": "assistant", "content": [{"toolUse": {}}]},  | 
 | 873 | +    ]  | 
 | 874 | + | 
 | 875 | +    await alist(  | 
 | 876 | +        strands.event_loop.streaming.stream_messages(  | 
 | 877 | +            mock_model,  | 
 | 878 | +            system_prompt="test prompt",  | 
 | 879 | +            messages=messages,  | 
 | 880 | +            tool_specs=None,  | 
 | 881 | +        )  | 
 | 882 | +    )  | 
 | 883 | + | 
 | 884 | +    assert mock_model.stream.call_args[0][0] == [  | 
 | 885 | +        # blank text  | 
 | 886 | +        {"content": [{"text": "a"}, {"toolUse": {"name": "a_name"}}], "role": "assistant"},  | 
 | 887 | +        {"content": [{"toolUse": {"name": "a_name"}}], "role": "assistant"},  | 
 | 888 | +        {"content": [{"text": "a"}, {"text": "[blank text]"}], "role": "assistant"},  | 
 | 889 | +        # Invalid names  | 
 | 890 | +        {"content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}], "role": "assistant"},  | 
 | 891 | +        {"content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}], "role": "assistant"},  | 
 | 892 | +    ]  | 
0 commit comments