diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 4579ebacf..8607a2601 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -13,6 +13,7 @@ import json import logging import random +import warnings from concurrent.futures import ThreadPoolExecutor from typing import ( Any, @@ -374,7 +375,9 @@ def tool_names(self) -> list[str]: all_tools = self.tool_registry.get_all_tools_config() return list(all_tools.keys()) - def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: + def __call__( + self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> AgentResult: """Process a natural language prompt through the agent's event loop. This method implements the conversational interface with multiple input patterns: @@ -389,7 +392,8 @@ def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: - list[ContentBlock]: Multi-modal content blocks - list[Message]: Complete messages with roles - None: Use existing conversation history - **kwargs: Additional parameters to pass through the event loop. + invocation_state: Additional parameters to pass through the event loop. + **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: Result object containing: @@ -401,13 +405,15 @@ def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: """ def execute() -> AgentResult: - return asyncio.run(self.invoke_async(prompt, **kwargs)) + return asyncio.run(self.invoke_async(prompt, invocation_state=invocation_state, **kwargs)) with ThreadPoolExecutor() as executor: future = executor.submit(execute) return future.result() - async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: + async def invoke_async( + self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> AgentResult: """Process a natural language prompt through the agent's event loop. This method implements the conversational interface with multiple input patterns: @@ -422,7 +428,8 @@ async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentR - list[ContentBlock]: Multi-modal content blocks - list[Message]: Complete messages with roles - None: Use existing conversation history - **kwargs: Additional parameters to pass through the event loop. + invocation_state: Additional parameters to pass through the event loop. + **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: Result: object containing: @@ -432,7 +439,7 @@ async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentR - metrics: Performance metrics from the event loop - state: The final state of the event loop """ - events = self.stream_async(prompt, **kwargs) + events = self.stream_async(prompt, invocation_state=invocation_state, **kwargs) async for event in events: _ = event @@ -528,9 +535,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) async def stream_async( - self, - prompt: AgentInput = None, - **kwargs: Any, + self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. @@ -546,7 +551,8 @@ async def stream_async( - list[ContentBlock]: Multi-modal content blocks - list[Message]: Complete messages with roles - None: Use existing conversation history - **kwargs: Additional parameters to pass to the event loop. + invocation_state: Additional parameters to pass through the event loop. + **kwargs: Additional parameters to pass to the event loop.[Deprecating] Yields: An async iterator that yields events. Each event is a dictionary containing @@ -567,7 +573,19 @@ async def stream_async( yield event["data"] ``` """ - callback_handler = kwargs.get("callback_handler", self.callback_handler) + merged_state = {} + if kwargs: + warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2) + merged_state.update(kwargs) + if invocation_state is not None: + merged_state["invocation_state"] = invocation_state + else: + if invocation_state is not None: + merged_state = invocation_state + + callback_handler = self.callback_handler + if kwargs: + callback_handler = kwargs.get("callback_handler", self.callback_handler) # Process input and get message to add (if any) messages = self._convert_prompt_to_messages(prompt) @@ -576,10 +594,10 @@ async def stream_async( with trace_api.use_span(self.trace_span): try: - events = self._run_loop(messages, invocation_state=kwargs) + events = self._run_loop(messages, invocation_state=merged_state) async for event in events: - event.prepare(invocation_state=kwargs) + event.prepare(invocation_state=merged_state) if event.is_callback_event: as_dict = event.as_dict() diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 2cd87c26d..200584115 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -4,6 +4,7 @@ import os import textwrap import unittest.mock +import warnings from uuid import uuid4 import pytest @@ -1877,3 +1878,58 @@ def test_tool(action: str) -> str: assert '"action": "test_value"' in tool_call_text assert '"agent"' not in tool_call_text assert '"extra_param"' not in tool_call_text + + +def test_agent__call__handles_none_invocation_state(mock_model, agent): + """Test that agent handles None invocation_state without AttributeError.""" + mock_model.mock_stream.return_value = [ + {"contentBlockDelta": {"delta": {"text": "test response"}}}, + {"contentBlockStop": {}}, + ] + + # This should not raise AttributeError: 'NoneType' object has no attribute 'get' + result = agent("test", invocation_state=None) + + assert result.message["content"][0]["text"] == "test response" + assert result.stop_reason == "end_turn" + + +def test_agent__call__invocation_state_with_kwargs_deprecation_warning(agent, mock_event_loop_cycle): + """Test that kwargs trigger deprecation warning and are merged correctly with invocation_state.""" + + async def check_invocation_state(**kwargs): + invocation_state = kwargs["invocation_state"] + # Should have nested structure when both invocation_state and kwargs are provided + assert invocation_state["invocation_state"] == {"my": "state"} + assert invocation_state["other_kwarg"] == "foobar" + yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {}) + + mock_event_loop_cycle.side_effect = check_invocation_state + + with warnings.catch_warnings(record=True) as captured_warnings: + warnings.simplefilter("always") + agent("hello!", invocation_state={"my": "state"}, other_kwarg="foobar") + + # Verify deprecation warning was issued + assert len(captured_warnings) == 1 + assert issubclass(captured_warnings[0].category, UserWarning) + assert "`**kwargs` parameter is deprecating, use `invocation_state` instead." in str(captured_warnings[0].message) + + +def test_agent__call__invocation_state_only_no_warning(agent, mock_event_loop_cycle): + """Test that using only invocation_state does not trigger warning and passes state directly.""" + + async def check_invocation_state(**kwargs): + invocation_state = kwargs["invocation_state"] + + assert invocation_state["my"] == "state" + assert "agent" in invocation_state + yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {}) + + mock_event_loop_cycle.side_effect = check_invocation_state + + with warnings.catch_warnings(record=True) as captured_warnings: + warnings.simplefilter("always") + agent("hello!", invocation_state={"my": "state"}) + + assert len(captured_warnings) == 0