diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 4579ebacf..e6d412b85 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -441,7 +441,7 @@ async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentR def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. - If you pass in a prompt, it will be used temporarily without adding it to the conversation history. + If you pass in a prompt, it will be added to the conversation history along with the structured output result. If you don't pass in a prompt, it will use only the existing conversation history to respond. For smaller models, you may want to use the optional prompt to add additional instructions to explicitly @@ -470,7 +470,7 @@ def execute() -> T: async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. - If you pass in a prompt, it will be used temporarily without adding it to the conversation history. + If you pass in a prompt, it will be added to the conversation history along with the structured output result. If you don't pass in a prompt, it will use only the existing conversation history to respond. For smaller models, you may want to use the optional prompt to add additional instructions to explicitly @@ -479,7 +479,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu Args: output_model: The output model (a JSON schema written as a Pydantic BaseModel) that the agent will use when responding. - prompt: The prompt to use for the agent (will not be added to conversation history). + prompt: The prompt to use for the agent (will be added to conversation history). Raises: ValueError: If no conversation history or prompt is provided. @@ -492,7 +492,13 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu if not self.messages and not prompt: raise ValueError("No conversation history or prompt provided") - temp_messages: Messages = self.messages + self._convert_prompt_to_messages(prompt) + # Add prompt to conversation history if provided + if prompt: + prompt_messages = self._convert_prompt_to_messages(prompt) + for message in prompt_messages: + self._append_message(message) + + temp_messages: Messages = self.messages structured_output_span.set_attributes( { @@ -502,16 +508,16 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu "gen_ai.operation.name": "execute_structured_output", } ) - if self.system_prompt: - structured_output_span.add_event( - "gen_ai.system.message", - attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])}, - ) for message in temp_messages: structured_output_span.add_event( f"gen_ai.{message['role']}.message", attributes={"role": message["role"], "content": serialize(message["content"])}, ) + if self.system_prompt: + structured_output_span.add_event( + "gen_ai.system.message", + attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])}, + ) events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt) async for event in events: if isinstance(event, TypedEvent): @@ -522,7 +528,16 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu structured_output_span.add_event( "gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())} ) - return event["output"] + + # Add structured output result to conversation history + result = event["output"] + assistant_message = { + "role": "assistant", + "content": [{"text": f"Structured output ({output_model.__name__}): {result.model_dump_json()}"}] + } + self._append_message(assistant_message) + + return result finally: self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 005eed3df..27080ddcb 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -108,6 +108,157 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] return super().format_request_message_content(content) + def _format_request_message_contents(self, role: str, content: ContentBlock) -> list[dict[str, Any]]: + """Format LiteLLM compatible message contents. + + LiteLLM expects content to be a string for simple text messages, not a list of content blocks. + This method flattens the content structure to be compatible with LiteLLM providers like Cerebras and Groq. + + Args: + role: Message role (e.g., "user", "assistant"). + content: Content block to format. + + Returns: + LiteLLM formatted message contents. + + Raises: + TypeError: If the content block type cannot be converted to a LiteLLM-compatible format. + """ + if "text" in content: + return [{"role": role, "content": content["text"]}] + + if "image" in content: + # For images, we still need to use the structured format + return [{"role": role, "content": [self.format_request_message_content(content)]}] + + if "toolUse" in content: + return [ + { + "role": role, + "tool_calls": [ + { + "function": { + "name": content["toolUse"]["name"], + "arguments": json.dumps(content["toolUse"]["input"]), + }, + "id": content["toolUse"]["toolUseId"], + "type": "function", + } + ], + } + ] + + if "toolResult" in content: + # For tool results, we need to format the content properly + tool_content_parts = [] + for tool_content in content["toolResult"]["content"]: + if "json" in tool_content: + tool_content_parts.append(json.dumps(tool_content["json"])) + elif "text" in tool_content: + tool_content_parts.append(tool_content["text"]) + else: + tool_content_parts.append(str(tool_content)) + + tool_content_string = " ".join(tool_content_parts) + return [ + { + "role": "tool", + "tool_call_id": content["toolResult"]["toolUseId"], + "content": tool_content_string, + } + ] + + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") + + @override + @classmethod + def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format LiteLLM compatible messages array. + + This method overrides the parent class to ensure compatibility with LiteLLM providers + that expect string content instead of content block arrays. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. + + Returns: + A LiteLLM compatible messages array. + """ + formatted_messages: list[dict[str, Any]] = [] + + # Add system prompt if provided + if system_prompt: + formatted_messages.append({"role": "system", "content": system_prompt}) + + for message in messages: + contents = message["content"] + + # Separate different types of content + text_contents = [content for content in contents if "text" in content and not any(block_type in content for block_type in ["toolResult", "toolUse"])] + tool_use_contents = [content for content in contents if "toolUse" in content] + tool_result_contents = [content for content in contents if "toolResult" in content] + other_contents = [content for content in contents if not any(block_type in content for block_type in ["text", "toolResult", "toolUse"])] + + # Handle text content - flatten to string for Cerebras/Groq compatibility + if text_contents: + if len(text_contents) == 1: + # Single text content - use string format + formatted_messages.append({ + "role": message["role"], + "content": text_contents[0]["text"] + }) + else: + # Multiple text contents - concatenate + combined_text = " ".join(content["text"] for content in text_contents) + formatted_messages.append({ + "role": message["role"], + "content": combined_text + }) + + # Handle tool use content + for content in tool_use_contents: + formatted_messages.append({ + "role": message["role"], + "tool_calls": [ + { + "function": { + "name": content["toolUse"]["name"], + "arguments": json.dumps(content["toolUse"]["input"]), + }, + "id": content["toolUse"]["toolUseId"], + "type": "function", + } + ], + }) + + # Handle tool result content + for content in tool_result_contents: + tool_content_parts = [] + for tool_content in content["toolResult"]["content"]: + if "json" in tool_content: + tool_content_parts.append(json.dumps(tool_content["json"])) + elif "text" in tool_content: + tool_content_parts.append(tool_content["text"]) + else: + tool_content_parts.append(str(tool_content)) + + tool_content_string = " ".join(tool_content_parts) + formatted_messages.append({ + "role": "tool", + "tool_call_id": content["toolResult"]["toolUseId"], + "content": tool_content_string, + }) + + # Handle other content types (images, etc.) - use structured format + for content in other_contents: + formatted_messages.append({ + "role": message["role"], + "content": [cls.format_request_message_content(content)] + }) + + return formatted_messages + @override async def stream( self, @@ -211,7 +362,7 @@ async def structured_output( response = await litellm.acompletion( **self.client_args, model=self.get_config()["model_id"], - messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], + messages=self.format_request_messages(prompt, system_prompt=system_prompt), response_format=output_model, ) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 2cd87c26d..48dad736d 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -18,7 +18,6 @@ from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel from strands.session.repository_session_manager import RepositorySessionManager -from strands.telemetry.tracer import serialize from strands.types._events import EventLoopStopEvent, ModelStreamEvent from strands.types.content import Messages from strands.types.exceptions import ContextWindowOverflowException, EventLoopException @@ -989,6 +988,9 @@ def test_agent_structured_output(agent, system_prompt, user, agenerator): agent.tracer = mock_strands_tracer agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}])) + agent.hooks = unittest.mock.MagicMock() + agent.hooks.invoke_callbacks = unittest.mock.Mock() + agent.callback_handler = unittest.mock.Mock() prompt = "Jane Doe is 30 years old and her email is jane@doe.com" @@ -999,12 +1001,31 @@ def test_agent_structured_output(agent, system_prompt, user, agenerator): exp_result = user assert tru_result == exp_result - # Verify conversation history is not polluted - assert len(agent.messages) == initial_message_count + # Verify conversation history is updated with prompt and structured output + assert len(agent.messages) == initial_message_count + 2 + + # Verify the prompt was added to conversation history + user_message_added = any( + msg['role'] == 'user' and prompt in msg['content'][0]['text'] + for msg in agent.messages + ) + assert user_message_added, "User prompt should be added to conversation history" + + # Verify the structured output was added to conversation history + assistant_message_added = any( + msg['role'] == 'assistant' and 'Structured output (User):' in msg['content'][0]['text'] + for msg in agent.messages + ) + assert assistant_message_added, "Structured output should be added to conversation history" - # Verify the model was called with temporary messages array + # Verify the model was called with all messages (including the added prompt) agent.model.structured_output.assert_called_once_with( - type(user), [{"role": "user", "content": [{"text": prompt}]}], system_prompt=system_prompt + type(user), + [ + {"role": "user", "content": [{"text": prompt}]}, + {"role": "assistant", "content": [{"text": f"Structured output (User): {user.model_dump_json()}"}]} + ], + system_prompt=system_prompt ) mock_span.set_attributes.assert_called_once_with( @@ -1016,23 +1037,15 @@ def test_agent_structured_output(agent, system_prompt, user, agenerator): } ) - # ensure correct otel event messages are emitted - act_event_names = mock_span.add_event.call_args_list - exp_event_names = [ - unittest.mock.call( - "gen_ai.system.message", attributes={"role": "system", "content": serialize([{"text": system_prompt}])} - ), - unittest.mock.call( - "gen_ai.user.message", - attributes={ - "role": "user", - "content": '[{"text": "Jane Doe is 30 years old and her email is jane@doe.com"}]', - }, - ), - unittest.mock.call("gen_ai.choice", attributes={"message": json.dumps(user.model_dump())}), - ] + mock_span.add_event.assert_any_call( + "gen_ai.user.message", + attributes={"role": "user", "content": '[{"text": "Jane Doe is 30 years old and her email is jane@doe.com"}]'}, + ) - assert act_event_names == exp_event_names + mock_span.add_event.assert_called_with( + "gen_ai.choice", + attributes={"message": json.dumps(user.model_dump())}, + ) def test_agent_structured_output_multi_modal_input(agent, system_prompt, user, agenerator): @@ -1064,12 +1077,31 @@ def test_agent_structured_output_multi_modal_input(agent, system_prompt, user, a exp_result = user assert tru_result == exp_result - # Verify conversation history is not polluted - assert len(agent.messages) == initial_message_count + # Verify conversation history is updated with prompt and structured output + assert len(agent.messages) == initial_message_count + 2 + + # Verify the multi-modal prompt was added to conversation history + user_message_added = any( + msg['role'] == 'user' and 'Please describe the user in this image' in msg['content'][0]['text'] + for msg in agent.messages + ) + assert user_message_added, "Multi-modal user prompt should be added to conversation history" + + # Verify the structured output was added to conversation history + assistant_message_added = any( + msg['role'] == 'assistant' and 'Structured output (User):' in msg['content'][0]['text'] + for msg in agent.messages + ) + assert assistant_message_added, "Structured output should be added to conversation history" - # Verify the model was called with temporary messages array + # Verify the model was called with all messages (including the added prompt) agent.model.structured_output.assert_called_once_with( - type(user), [{"role": "user", "content": prompt}], system_prompt=system_prompt + type(user), + [ + {"role": "user", "content": prompt}, + {"role": "assistant", "content": [{"text": f"Structured output (User): {user.model_dump_json()}"}]} + ], + system_prompt=system_prompt ) mock_span.add_event.assert_called_with( @@ -1081,6 +1113,9 @@ def test_agent_structured_output_multi_modal_input(agent, system_prompt, user, a @pytest.mark.asyncio async def test_agent_structured_output_in_async_context(agent, user, agenerator): agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}])) + agent.hooks = unittest.mock.MagicMock() + agent.hooks.invoke_callbacks = unittest.mock.Mock() + agent.callback_handler = unittest.mock.Mock() prompt = "Jane Doe is 30 years old and her email is jane@doe.com" @@ -1091,13 +1126,30 @@ async def test_agent_structured_output_in_async_context(agent, user, agenerator) exp_result = user assert tru_result == exp_result - # Verify conversation history is not polluted - assert len(agent.messages) == initial_message_count + # Verify conversation history is updated with prompt and structured output + assert len(agent.messages) == initial_message_count + 2 + + # Verify the prompt was added to conversation history + user_message_added = any( + msg['role'] == 'user' and prompt in msg['content'][0]['text'] + for msg in agent.messages + ) + assert user_message_added, "User prompt should be added to conversation history" + + # Verify the structured output was added to conversation history + assistant_message_added = any( + msg['role'] == 'assistant' and 'Structured output (User):' in msg['content'][0]['text'] + for msg in agent.messages + ) + assert assistant_message_added, "Structured output should be added to conversation history" def test_agent_structured_output_without_prompt(agent, system_prompt, user, agenerator): """Test that structured_output works with existing conversation history and no new prompt.""" agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}])) + agent.hooks = unittest.mock.MagicMock() + agent.hooks.invoke_callbacks = unittest.mock.Mock() + agent.callback_handler = unittest.mock.Mock() # Add some existing messages to the agent existing_messages = [ @@ -1112,17 +1164,27 @@ def test_agent_structured_output_without_prompt(agent, system_prompt, user, agen exp_result = user assert tru_result == exp_result - # Verify conversation history is unchanged - assert len(agent.messages) == initial_message_count - assert agent.messages == existing_messages + # Verify conversation history is updated with structured output only (no prompt added) + assert len(agent.messages) == initial_message_count + 1 + + # Verify the structured output was added to conversation history + assistant_message_added = any( + msg['role'] == 'assistant' and 'Structured output (User):' in msg['content'][0]['text'] + for msg in agent.messages + ) + assert assistant_message_added, "Structured output should be added to conversation history" - # Verify the model was called with existing messages only - agent.model.structured_output.assert_called_once_with(type(user), existing_messages, system_prompt=system_prompt) + # Verify the model was called with existing messages plus the added structured output + expected_messages = existing_messages + [{"role": "assistant", "content": [{"text": f"Structured output (User): {user.model_dump_json()}"}]}] + agent.model.structured_output.assert_called_once_with(type(user), expected_messages, system_prompt=system_prompt) @pytest.mark.asyncio async def test_agent_structured_output_async(agent, system_prompt, user, agenerator): agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}])) + agent.hooks = unittest.mock.MagicMock() + agent.hooks.invoke_callbacks = unittest.mock.Mock() + agent.callback_handler = unittest.mock.Mock() prompt = "Jane Doe is 30 years old and her email is jane@doe.com" @@ -1133,12 +1195,31 @@ async def test_agent_structured_output_async(agent, system_prompt, user, agenera exp_result = user assert tru_result == exp_result - # Verify conversation history is not polluted - assert len(agent.messages) == initial_message_count + # Verify conversation history is updated with prompt and structured output + assert len(agent.messages) == initial_message_count + 2 + + # Verify the prompt was added to conversation history + user_message_added = any( + msg['role'] == 'user' and prompt in msg['content'][0]['text'] + for msg in agent.messages + ) + assert user_message_added, "User prompt should be added to conversation history" + + # Verify the structured output was added to conversation history + assistant_message_added = any( + msg['role'] == 'assistant' and 'Structured output (User):' in msg['content'][0]['text'] + for msg in agent.messages + ) + assert assistant_message_added, "Structured output should be added to conversation history" - # Verify the model was called with temporary messages array + # Verify the model was called with all messages (including the added prompt) agent.model.structured_output.assert_called_once_with( - type(user), [{"role": "user", "content": [{"text": prompt}]}], system_prompt=system_prompt + type(user), + [ + {"role": "user", "content": [{"text": prompt}]}, + {"role": "assistant", "content": [{"text": f"Structured output (User): {user.model_dump_json()}"}]} + ], + system_prompt=system_prompt ) diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 6c5625e0b..8a9637c55 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -265,12 +265,14 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator): length, events = hook_provider.get_events() - assert length == 2 + assert length == 4 # BeforeInvocationEvent, MessageAddedEvent (prompt), MessageAddedEvent (output), AfterInvocationEvent assert next(events) == BeforeInvocationEvent(agent=agent) + assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[0]) # Prompt added + assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1]) # Output added assert next(events) == AfterInvocationEvent(agent=agent) - assert len(agent.messages) == 0 # no new messages added + assert len(agent.messages) == 2 # prompt and structured output added @pytest.mark.asyncio @@ -282,9 +284,11 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a length, events = hook_provider.get_events() - assert length == 2 + assert length == 4 # BeforeInvocationEvent, MessageAddedEvent (prompt), MessageAddedEvent (output), AfterInvocationEvent assert next(events) == BeforeInvocationEvent(agent=agent) + assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[0]) # Prompt added + assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1]) # Output added assert next(events) == AfterInvocationEvent(agent=agent) - assert len(agent.messages) == 0 # no new messages added + assert len(agent.messages) == 2 # prompt and structured output added