Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/strands/tools/executors/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ async def _stream(

invocation_state.update(
{
"agent": agent,
"model": agent.model,
"messages": agent.messages,
"system_prompt": agent.system_prompt,
Expand Down
18 changes: 18 additions & 0 deletions tests/strands/tools/executors/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,3 +459,21 @@ async def test_executor_stream_tool_interrupt_resume(executor, agent, tool_resul
tru_results = tool_results
exp_results = [exp_events[-1].tool_result]
assert tru_results == exp_results


@pytest.mark.asyncio
async def test_executor_stream_updates_invocation_state_with_agent(
executor, agent, tool_results, invocation_state, weather_tool, alist
):
"""Test that invocation_state is updated with agent reference."""
tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}}

# Start with empty invocation_state to verify agent is added
empty_invocation_state = {}

stream = executor._stream(agent, tool_use, tool_results, empty_invocation_state)
await alist(stream)

# Verify that the invocation_state was updated with the agent
assert "agent" in empty_invocation_state
assert empty_invocation_state["agent"] is agent
19 changes: 19 additions & 0 deletions tests_integ/test_tool_context_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,22 @@ def test_strands_context_integration_context_custom():
agent("using a tool, write a bad story")

_validate_tool_result_content(agent)


@tool(context=True)
def calculate_sum(a: int, b: int, tool_context: ToolContext) -> int:
result = a + b
tool_context.agent.state.set("last_calculation", result)
return result


def test_agent_state_access_through_tool_context():
"""Test that tools can access agent state through ToolContext."""
agent = Agent(tools=[calculate_sum])
result = agent.tool.calculate_sum(a=1, b=1)

# Verify the tool executed successfully
assert result["status"] == "success"

# Verify the agent state was updated
assert agent.state.get("last_calculation") == 2
Loading