Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,26 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]

return super().format_request_message_content(content)

def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> tuple[list[StreamEvent], str]:
"""Handle switching to a new content stream.

Args:
data_type: The next content data type.
prev_data_type: The previous content data type.

Returns:
Tuple containing:
- Stop block for previous content and the start block for the next content.
- Next content data type.
"""
chunks = []
if data_type != prev_data_type:
if prev_data_type is not None:
chunks.append(self.format_chunk({"chunk_type": "content_stop", "data_type": prev_data_type}))
chunks.append(self.format_chunk({"chunk_type": "content_start", "data_type": data_type}))

return chunks, data_type

@override
async def stream(
self,
Expand Down Expand Up @@ -146,38 +166,46 @@ async def stream(

logger.debug("got response from model")
yield self.format_chunk({"chunk_type": "message_start"})
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})

tool_calls: dict[int, list[Any]] = {}
data_type: str | None = None

async for event in response:
# Defensive: skip events with empty or missing choices
if not getattr(event, "choices", None):
continue
choice = event.choices[0]

if choice.delta.content:
yield self.format_chunk(
{"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
)

if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
chunks, data_type = self._stream_switch_content("reasoning_content", data_type)
for chunk in chunks:
yield chunk

yield self.format_chunk(
{
"chunk_type": "content_delta",
"data_type": "reasoning_content",
"data_type": data_type,
"data": choice.delta.reasoning_content,
}
)

if choice.delta.content:
chunks, data_type = self._stream_switch_content("text", data_type)
for chunk in chunks:
yield chunk

yield self.format_chunk(
{"chunk_type": "content_delta", "data_type": data_type, "data": choice.delta.content}
)

for tool_call in choice.delta.tool_calls or []:
tool_calls.setdefault(tool_call.index, []).append(tool_call)

if choice.finish_reason:
if data_type:
yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type})
break

yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})

for tool_deltas in tool_calls.values():
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})

Expand Down
63 changes: 51 additions & 12 deletions tests/strands/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,39 +142,71 @@ def test_format_request_message_content(content, exp_result):

@pytest.mark.asyncio
async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, alist):
mock_tool_call_1_part_1 = unittest.mock.Mock(index=0)
mock_tool_call_2_part_1 = unittest.mock.Mock(index=1)
mock_delta_1 = unittest.mock.Mock(
reasoning_content="",
content=None,
tool_calls=None,
)

mock_delta_2 = unittest.mock.Mock(
reasoning_content="\nI'm thinking",
content=None,
tool_calls=None,
)
mock_delta_3 = unittest.mock.Mock(
reasoning_content=None,
content="One second",
tool_calls=None,
)
mock_delta_4 = unittest.mock.Mock(
reasoning_content="\nI'm think",
content=None,
tool_calls=None,
)
mock_delta_5 = unittest.mock.Mock(
reasoning_content="ing again",
content=None,
tool_calls=None,
)

mock_tool_call_1_part_1 = unittest.mock.Mock(index=0)
mock_tool_call_2_part_1 = unittest.mock.Mock(index=1)
mock_delta_6 = unittest.mock.Mock(
content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1], reasoning_content=None
)

mock_tool_call_1_part_2 = unittest.mock.Mock(index=0)
mock_tool_call_2_part_2 = unittest.mock.Mock(index=1)
mock_delta_4 = unittest.mock.Mock(
mock_delta_7 = unittest.mock.Mock(
content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2], reasoning_content=None
)

mock_delta_5 = unittest.mock.Mock(content="", tool_calls=None, reasoning_content=None)
mock_delta_8 = unittest.mock.Mock(content="", tool_calls=None, reasoning_content=None)

mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)])
mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)])
mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_3)])
mock_event_4 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_4)])
mock_event_5 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_5)])
mock_event_6 = unittest.mock.Mock()
mock_event_5 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_5)])
mock_event_6 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_6)])
mock_event_7 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_7)])
mock_event_8 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_8)])
mock_event_9 = unittest.mock.Mock()

litellm_acompletion.side_effect = unittest.mock.AsyncMock(
return_value=agenerator([mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5, mock_event_6])
return_value=agenerator(
[
mock_event_1,
mock_event_2,
mock_event_3,
mock_event_4,
mock_event_5,
mock_event_6,
mock_event_7,
mock_event_8,
mock_event_9,
]
)
)

messages = [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]
Expand All @@ -184,6 +216,15 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator,
{"messageStart": {"role": "assistant"}},
{"contentBlockStart": {"start": {}}},
{"contentBlockDelta": {"delta": {"reasoningContent": {"text": "\nI'm thinking"}}}},
{"contentBlockStop": {}},
{"contentBlockStart": {"start": {}}},
{"contentBlockDelta": {"delta": {"text": "One second"}}},
{"contentBlockStop": {}},
{"contentBlockStart": {"start": {}}},
{"contentBlockDelta": {"delta": {"reasoningContent": {"text": "\nI'm think"}}}},
{"contentBlockDelta": {"delta": {"reasoningContent": {"text": "ing again"}}}},
{"contentBlockStop": {}},
{"contentBlockStart": {"start": {}}},
{"contentBlockDelta": {"delta": {"text": "I'll calculate"}}},
{"contentBlockDelta": {"delta": {"text": "that for you"}}},
{"contentBlockStop": {}},
Expand Down Expand Up @@ -211,9 +252,9 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator,
{
"metadata": {
"usage": {
"inputTokens": mock_event_6.usage.prompt_tokens,
"outputTokens": mock_event_6.usage.completion_tokens,
"totalTokens": mock_event_6.usage.total_tokens,
"inputTokens": mock_event_9.usage.prompt_tokens,
"outputTokens": mock_event_9.usage.completion_tokens,
"totalTokens": mock_event_9.usage.total_tokens,
},
"metrics": {"latencyMs": 0},
}
Expand Down Expand Up @@ -253,8 +294,6 @@ async def test_stream_empty(litellm_acompletion, api_key, model_id, model, agene
tru_events = await alist(response)
exp_events = [
{"messageStart": {"role": "assistant"}},
{"contentBlockStart": {"start": {}}},
{"contentBlockStop": {}},
{"messageStop": {"stopReason": "end_turn"}},
]

Expand Down
16 changes: 16 additions & 0 deletions tests_integ/models/test_model_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,22 @@ async def test_agent_stream_async(agent):
assert all(string in text for string in ["12:00", "sunny"])


def test_agent_invoke_reasoning(agent, model):
model.update_config(
params={
"thinking": {
"budget_tokens": 1024,
"type": "enabled",
},
},
)

result = agent("Please reason about the equation 2+2.")

assert "reasoningContent" in result.message["content"][0]
assert result.message["content"][0]["reasoningContent"]["reasoningText"]["text"]


def test_structured_output(agent, weather):
tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
exp_weather = weather
Expand Down