diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 87c38990d..8de6a83fc 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -75,6 +75,7 @@ async def _stream( invocation_state.update( { + "agent": agent, "model": agent.model, "messages": agent.messages, "system_prompt": agent.system_prompt, diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index a11e2eab2..957b3a731 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -459,3 +459,21 @@ async def test_executor_stream_tool_interrupt_resume(executor, agent, tool_resul tru_results = tool_results exp_results = [exp_events[-1].tool_result] assert tru_results == exp_results + + +@pytest.mark.asyncio +async def test_executor_stream_updates_invocation_state_with_agent( + executor, agent, tool_results, invocation_state, weather_tool, alist +): + """Test that invocation_state is updated with agent reference.""" + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + + # Start with empty invocation_state to verify agent is added + empty_invocation_state = {} + + stream = executor._stream(agent, tool_use, tool_results, empty_invocation_state) + await alist(stream) + + # Verify that the invocation_state was updated with the agent + assert "agent" in empty_invocation_state + assert empty_invocation_state["agent"] is agent diff --git a/tests_integ/test_tool_context_injection.py b/tests_integ/test_tool_context_injection.py index 3098604f1..215286a46 100644 --- a/tests_integ/test_tool_context_injection.py +++ b/tests_integ/test_tool_context_injection.py @@ -54,3 +54,22 @@ def test_strands_context_integration_context_custom(): agent("using a tool, write a bad story") _validate_tool_result_content(agent) + + +@tool(context=True) +def calculate_sum(a: int, b: int, tool_context: ToolContext) -> int: + result = a + b + tool_context.agent.state.set("last_calculation", result) + return result + + +def test_agent_state_access_through_tool_context(): + """Test that tools can access agent state through ToolContext.""" + agent = Agent(tools=[calculate_sum]) + result = agent.tool.calculate_sum(a=1, b=1) + + # Verify the tool executed successfully + assert result["status"] == "success" + + # Verify the agent state was updated + assert agent.state.get("last_calculation") == 2