diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 1763f5dec..486f67bf8 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -13,6 +13,7 @@ from pydantic import BaseModel from typing_extensions import Unpack, override +from ..tools import convert_pydantic_to_tool_spec from ..types.content import ContentBlock, Messages from ..types.exceptions import ContextWindowOverflowException from ..types.streaming import StreamEvent @@ -202,6 +203,10 @@ async def structured_output( ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. + Some models do not support native structured output via response_format. + In cases of proxies, we may not have a way to determine support, so we + fallback to using tool calling to achieve structured output. + Args: output_model: The output model to use for the agent. prompt: The prompt messages to use for the agent. @@ -211,42 +216,69 @@ async def structured_output( Yields: Model events with the last being the structured output. """ - supports_schema = supports_response_schema(self.get_config()["model_id"]) + if supports_response_schema(self.get_config()["model_id"]): + logger.debug("structuring output using response schema") + result = await self._structured_output_using_response_schema(output_model, prompt, system_prompt) + else: + logger.debug("model does not support response schema, structuring output using tool approach") + result = await self._structured_output_using_tool(output_model, prompt, system_prompt) + + yield {"output": result} + + async def _structured_output_using_response_schema( + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None + ) -> T: + """Get structured output using native response_format support.""" + response = await litellm.acompletion( + **self.client_args, + model=self.get_config()["model_id"], + messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], + response_format=output_model, + ) - # If the provider does not support response schemas, we cannot reliably parse structured output. - # In that case we must not call the provider and must raise the documented ValueError. - if not supports_schema: - raise ValueError("Model does not support response_format") + if len(response.choices) > 1: + raise ValueError("Multiple choices found in the response.") + if not response.choices or response.choices[0].finish_reason != "tool_calls": + raise ValueError("No tool_calls found in response") - # For providers that DO support response schemas, call litellm and map context-window errors. + choice = response.choices[0] try: - response = await litellm.acompletion( - **self.client_args, - model=self.get_config()["model_id"], - messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], - response_format=output_model, - ) + # Parse the message content as JSON + tool_call_data = json.loads(choice.message.content) + # Instantiate the output model with the parsed data + return output_model(**tool_call_data) except ContextWindowExceededError as e: logger.warning("litellm client raised context window overflow in structured_output") raise ContextWindowOverflowException(e) from e + except (json.JSONDecodeError, TypeError, ValueError) as e: + raise ValueError(f"Failed to parse or load content into model: {e}") from e + + async def _structured_output_using_tool( + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None + ) -> T: + """Get structured output using tool calling fallback.""" + tool_spec = convert_pydantic_to_tool_spec(output_model) + request = self.format_request(prompt, [tool_spec], system_prompt, cast(ToolChoice, {"any": {}})) + args = {**self.client_args, **request, "stream": False} + response = await litellm.acompletion(**args) if len(response.choices) > 1: raise ValueError("Multiple choices found in the response.") + if not response.choices or response.choices[0].finish_reason != "tool_calls": + raise ValueError("No tool_calls found in response") - # Find the first choice with tool_calls - for choice in response.choices: - if choice.finish_reason == "tool_calls": - try: - # Parse the tool call content as JSON - tool_call_data = json.loads(choice.message.content) - # Instantiate the output model with the parsed data - yield {"output": output_model(**tool_call_data)} - return - except (json.JSONDecodeError, TypeError, ValueError) as e: - raise ValueError(f"Failed to parse or load content into model: {e}") from e - - # If no tool_calls found, raise an error - raise ValueError("No tool_calls found in response") + choice = response.choices[0] + try: + # Parse the tool call content as JSON + tool_call = choice.message.tool_calls[0] + tool_call_data = json.loads(tool_call.function.arguments) + # Instantiate the output model with the parsed data + return output_model(**tool_call_data) + except ContextWindowExceededError as e: + logger.warning("litellm client raised context window overflow in structured_output") + raise ContextWindowOverflowException(e) from e + except (json.JSONDecodeError, TypeError, ValueError) as e: + raise ValueError(f"Failed to parse or load content into model: {e}") from e def _apply_proxy_prefix(self) -> None: """Apply litellm_proxy/ prefix to model_id when use_litellm_proxy is True. diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 776ae7bae..82023cae3 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -292,15 +292,27 @@ async def test_structured_output(litellm_acompletion, model, test_output_model_c @pytest.mark.asyncio -async def test_structured_output_unsupported_model(litellm_acompletion, model, test_output_model_cls): +async def test_structured_output_unsupported_model(litellm_acompletion, model, test_output_model_cls, alist): messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + mock_tool_call = unittest.mock.Mock() + mock_tool_call.function.arguments = '{"name": "John", "age": 30}' + + mock_choice = unittest.mock.Mock() + mock_choice.finish_reason = "tool_calls" + mock_choice.message.tool_calls = [mock_tool_call] + mock_response = unittest.mock.Mock() + mock_response.choices = [mock_choice] + + litellm_acompletion.return_value = mock_response + with unittest.mock.patch.object(strands.models.litellm, "supports_response_schema", return_value=False): - with pytest.raises(ValueError, match="Model does not support response_format"): - stream = model.structured_output(test_output_model_cls, messages) - await stream.__anext__() + stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) + tru_result = events[-1] - litellm_acompletion.assert_not_called() + exp_result = {"output": test_output_model_cls(name="John", age=30)} + assert tru_result == exp_result def test_config_validation_warns_on_unknown_keys(litellm_acompletion, captured_warnings): diff --git a/tests_integ/models/test_model_litellm.py b/tests_integ/models/test_model_litellm.py index 6cfdd3038..c5a09e3e9 100644 --- a/tests_integ/models/test_model_litellm.py +++ b/tests_integ/models/test_model_litellm.py @@ -1,3 +1,5 @@ +import unittest.mock + import pydantic import pytest @@ -40,6 +42,37 @@ class Weather(pydantic.BaseModel): return Weather(time="12:00", weather="sunny") +class Location(pydantic.BaseModel): + """Location information.""" + + city: str = pydantic.Field(description="The city name") + country: str = pydantic.Field(description="The country name") + + +class WeatherCondition(pydantic.BaseModel): + """Weather condition details.""" + + condition: str = pydantic.Field(description="The weather condition (e.g., 'sunny', 'rainy', 'cloudy')") + temperature: int = pydantic.Field(description="Temperature in Celsius") + + +class NestedWeather(pydantic.BaseModel): + """Weather report with nested location and condition information.""" + + time: str = pydantic.Field(description="The time in HH:MM format") + location: Location = pydantic.Field(description="Location information") + weather: WeatherCondition = pydantic.Field(description="Weather condition details") + + +@pytest.fixture +def nested_weather(): + return NestedWeather( + time="12:00", + location=Location(city="New York", country="USA"), + weather=WeatherCondition(condition="sunny", temperature=25), + ) + + @pytest.fixture def yellow_color(): class Color(pydantic.BaseModel): @@ -134,3 +167,31 @@ def test_structured_output_multi_modal_input(agent, yellow_img, yellow_color): tru_color = agent.structured_output(type(yellow_color), content) exp_color = yellow_color assert tru_color == exp_color + + +def test_structured_output_unsupported_model(model, nested_weather): + # Mock supports_response_schema to return False to test fallback mechanism + with ( + unittest.mock.patch.multiple( + "strands.models.litellm", + supports_response_schema=unittest.mock.DEFAULT, + ) as mocks, + unittest.mock.patch.object( + model, "_structured_output_using_tool", wraps=model._structured_output_using_tool + ) as mock_tool, + unittest.mock.patch.object( + model, "_structured_output_using_response_schema", wraps=model._structured_output_using_response_schema + ) as mock_schema, + ): + mocks["supports_response_schema"].return_value = False + + # Test that structured output still works via tool calling fallback + agent = Agent(model=model) + prompt = "The time is 12:00 in New York, USA and the weather is sunny with temperature 25 degrees Celsius" + tru_weather = agent.structured_output(NestedWeather, prompt) + exp_weather = nested_weather + assert tru_weather == exp_weather + + # Verify that the tool method was called and schema method was not + mock_tool.assert_called_once() + mock_schema.assert_not_called()