diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 486f67bf8..721dbc93a 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -238,20 +238,25 @@ async def _structured_output_using_response_schema( if len(response.choices) > 1: raise ValueError("Multiple choices found in the response.") - if not response.choices or response.choices[0].finish_reason != "tool_calls": - raise ValueError("No tool_calls found in response") + if not response.choices: + raise ValueError("No choices found in response") choice = response.choices[0] - try: - # Parse the message content as JSON - tool_call_data = json.loads(choice.message.content) - # Instantiate the output model with the parsed data - return output_model(**tool_call_data) - except ContextWindowExceededError as e: - logger.warning("litellm client raised context window overflow in structured_output") - raise ContextWindowOverflowException(e) from e - except (json.JSONDecodeError, TypeError, ValueError) as e: - raise ValueError(f"Failed to parse or load content into model: {e}") from e + if choice.finish_reason == "tool_calls" and getattr(choice.message, "tool_calls", None): + try: + tool_call = choice.message.tool_calls[0] + tool_call_data = json.loads(tool_call.function.arguments) + return output_model(**tool_call_data) + except ContextWindowExceededError as e: + logger.warning("litellm client raised context window overflow in structured_output") + raise ContextWindowOverflowException(e) from e + except (json.JSONDecodeError, TypeError, ValueError) as e: + raise ValueError(f"Failed to parse or load content into model: {e}") from e + else: + try: + return output_model.parse_raw(choice.message.content) + except Exception as e: + raise ValueError("No tool_calls found in response and could not parse plain text as model") from e async def _structured_output_using_tool( self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 82023cae3..c906fed50 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -277,6 +277,13 @@ async def test_structured_output(litellm_acompletion, model, test_output_model_c mock_choice = unittest.mock.Mock() mock_choice.finish_reason = "tool_calls" mock_choice.message.content = '{"name": "John", "age": 30}' + # PATCH START: mock tool_calls as list with .function.arguments + tool_call_mock = unittest.mock.Mock() + tool_call_function_mock = unittest.mock.Mock() + tool_call_function_mock.arguments = '{"name": "John", "age": 30}' + tool_call_mock.function = tool_call_function_mock + mock_choice.message.tool_calls = [tool_call_mock] + # PATCH END mock_response = unittest.mock.Mock() mock_response.choices = [mock_choice]