Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import json
import logging
import uuid
from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast

import litellm
Expand Down Expand Up @@ -321,7 +322,11 @@ async def stream(
break

for tool_deltas in tool_calls.values():
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})
first_delta = tool_deltas[0]
if not first_delta.id:
first_delta.id = f"call_{uuid.uuid4()}"

yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": first_delta})

for tool_delta in tool_deltas:
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})
Expand Down
36 changes: 36 additions & 0 deletions tests/strands/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,3 +478,39 @@ def test_format_request_messages_cache_point_support():
]

assert result == expected


@pytest.mark.asyncio
async def test_stream_generates_tool_call_id_when_null(litellm_acompletion, model, agenerator, alist):
"""Test that stream generates a tool call ID when LiteLLM returns null."""
mock_tool_call = unittest.mock.Mock(index=0)
mock_tool_call.id = None
mock_tool_call.function.name = "test_tool"
mock_tool_call.function.arguments = '{"arg": "value"}'

mock_delta = unittest.mock.Mock(content=None, tool_calls=[mock_tool_call], reasoning_content=None)

mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)])
mock_event_2 = unittest.mock.Mock(
choices=[
unittest.mock.Mock(
finish_reason="tool_calls",
delta=unittest.mock.Mock(content=None, tool_calls=None, reasoning_content=None),
)
]
)

litellm_acompletion.side_effect = unittest.mock.AsyncMock(return_value=agenerator([mock_event_1, mock_event_2]))

messages = [{"role": "user", "content": [{"text": "test"}]}]
response = model.stream(messages)
tru_events = await alist(response)

tool_start_event = next(
e for e in tru_events if "contentBlockStart" in e and "toolUse" in e["contentBlockStart"]["start"]
)

tool_id = tool_start_event["contentBlockStart"]["start"]["toolUse"]["toolUseId"]
assert tool_id is not None
assert tool_id.startswith("call_")
assert len(tool_id) > 5