diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 86ca69ad1..cc11be043 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -20,10 +20,9 @@ from opentelemetry import trace from pydantic import BaseModel -from ..event_loop.event_loop import event_loop_cycle +from ..event_loop.event_loop import event_loop_cycle, run_tool from ..experimental.hooks import AgentInitializedEvent, EndRequestEvent, HookRegistry, StartRequestEvent from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler -from ..handlers.tool_handler import AgentToolHandler from ..models.bedrock import BedrockModel from ..telemetry.metrics import EventLoopMetrics from ..telemetry.tracer import get_tracer @@ -130,14 +129,7 @@ def caller( } # Execute the tool - events = self._agent.tool_handler.process( - tool=tool_use, - model=self._agent.model, - system_prompt=self._agent.system_prompt, - messages=self._agent.messages, - tool_config=self._agent.tool_config, - kwargs=kwargs, - ) + events = run_tool(agent=self._agent, tool=tool_use, kwargs=kwargs) try: while True: @@ -283,7 +275,6 @@ def __init__( self.load_tools_from_directory = load_tools_from_directory self.tool_registry = ToolRegistry() - self.tool_handler = AgentToolHandler(tool_registry=self.tool_registry) # Process tool list if provided if tools is not None: @@ -563,14 +554,7 @@ async def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> AsyncGenera try: # Execute the main event loop cycle events = event_loop_cycle( - model=self.model, - system_prompt=self.system_prompt, - messages=self.messages, # will be modified by event_loop_cycle - tool_config=self.tool_config, - tool_handler=self.tool_handler, - thread_pool=self.thread_pool, - event_loop_metrics=self.event_loop_metrics, - event_loop_parent_span=self.trace_span, + agent=self, kwargs=kwargs, ) async for event in events: diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 37ef6309a..effb32e54 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -11,23 +11,21 @@ import logging import time import uuid -from concurrent.futures import ThreadPoolExecutor -from functools import partial -from typing import Any, AsyncGenerator, Optional +from typing import TYPE_CHECKING, Any, AsyncGenerator -from opentelemetry import trace - -from ..telemetry.metrics import EventLoopMetrics, Trace +from ..telemetry.metrics import Trace from ..telemetry.tracer import get_tracer from ..tools.executor import run_tools, validate_and_prepare_tools -from ..types.content import Message, Messages +from ..types.content import Message from ..types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException -from ..types.models import Model from ..types.streaming import Metrics, StopReason -from ..types.tools import ToolConfig, ToolHandler, ToolResult, ToolUse +from ..types.tools import ToolGenerator, ToolResult, ToolUse from .message_processor import clean_orphaned_empty_tool_uses from .streaming import stream_messages +if TYPE_CHECKING: + from ..agent import Agent + logger = logging.getLogger(__name__) MAX_ATTEMPTS = 6 @@ -35,17 +33,7 @@ MAX_DELAY = 240 # 4 minutes -async def event_loop_cycle( - model: Model, - system_prompt: Optional[str], - messages: Messages, - tool_config: Optional[ToolConfig], - tool_handler: Optional[ToolHandler], - thread_pool: Optional[ThreadPoolExecutor], - event_loop_metrics: EventLoopMetrics, - event_loop_parent_span: Optional[trace.Span], - kwargs: dict[str, Any], -) -> AsyncGenerator[dict[str, Any], None]: +async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: """Execute a single cycle of the event loop. This core function processes a single conversation turn, handling model inference, tool execution, and error @@ -60,14 +48,7 @@ async def event_loop_cycle( 7. Error handling and recovery Args: - model: Provider for running model inference. - system_prompt: System prompt instructions for the model. - messages: Conversation history messages. - tool_config: Configuration for available tools. - tool_handler: Handler for executing tools. - thread_pool: Optional thread pool for parallel tool execution. - event_loop_metrics: Metrics tracking object for the event loop. - event_loop_parent_span: Span for the parent of this event loop. + agent: The agent for which the cycle is being executed. kwargs: Additional arguments including: - request_state: State maintained across cycles @@ -93,7 +74,7 @@ async def event_loop_cycle( if "request_state" not in kwargs: kwargs["request_state"] = {} attributes = {"event_loop_cycle_id": str(kwargs.get("event_loop_cycle_id"))} - cycle_start_time, cycle_trace = event_loop_metrics.start_cycle(attributes=attributes) + cycle_start_time, cycle_trace = agent.event_loop_metrics.start_cycle(attributes=attributes) kwargs["event_loop_cycle_trace"] = cycle_trace yield {"callback": {"start": True}} @@ -102,7 +83,7 @@ async def event_loop_cycle( # Create tracer span for this event loop cycle tracer = get_tracer() cycle_span = tracer.start_event_loop_cycle_span( - event_loop_kwargs=kwargs, messages=messages, parent_span=event_loop_parent_span + event_loop_kwargs=kwargs, messages=agent.messages, parent_span=agent.trace_span ) kwargs["event_loop_cycle_span"] = cycle_span @@ -111,7 +92,7 @@ async def event_loop_cycle( cycle_trace.add_child(stream_trace) # Clean up orphaned empty tool uses - clean_orphaned_empty_tool_uses(messages) + clean_orphaned_empty_tool_uses(agent.messages) # Process messages with exponential backoff for throttling message: Message @@ -122,17 +103,17 @@ async def event_loop_cycle( # Retry loop for handling throttling exceptions current_delay = INITIAL_DELAY for attempt in range(MAX_ATTEMPTS): - model_id = model.config.get("model_id") if hasattr(model, "config") else None + model_id = agent.model.config.get("model_id") if hasattr(agent.model, "config") else None model_invoke_span = tracer.start_model_invoke_span( - messages=messages, + messages=agent.messages, parent_span=cycle_span, model_id=model_id, ) try: - # TODO: To maintain backwards compatability, we need to combine the stream event with kwargs before yielding + # TODO: To maintain backwards compatibility, we need to combine the stream event with kwargs before yielding # to the callback handler. This will be revisited when migrating to strongly typed events. - async for event in stream_messages(model, system_prompt, messages, tool_config): + async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, agent.tool_config): if "callback" in event: yield {"callback": {**event["callback"], **(kwargs if "delta" in event["callback"] else {})}} @@ -180,22 +161,16 @@ async def event_loop_cycle( stream_trace.end() # Add the response message to the conversation - messages.append(message) + agent.messages.append(message) yield {"callback": {"message": message}} # Update metrics - event_loop_metrics.update_usage(usage) - event_loop_metrics.update_metrics(metrics) + agent.event_loop_metrics.update_usage(usage) + agent.event_loop_metrics.update_metrics(metrics) # If the model is requesting to use tools if stop_reason == "tool_use": - if not tool_handler: - raise EventLoopException( - Exception("Model requested tool use but no tool handler provided"), - kwargs["request_state"], - ) - - if tool_config is None: + if agent.tool_config is None: raise EventLoopException( Exception("Model requested tool use but no tool config provided"), kwargs["request_state"], @@ -205,18 +180,11 @@ async def event_loop_cycle( events = _handle_tool_execution( stop_reason, message, - model, - system_prompt, - messages, - tool_config, - tool_handler, - thread_pool, - event_loop_metrics, - event_loop_parent_span, - cycle_trace, - cycle_span, - cycle_start_time, - kwargs, + agent=agent, + cycle_trace=cycle_trace, + cycle_span=cycle_span, + cycle_start_time=cycle_start_time, + kwargs=kwargs, ) async for event in events: yield event @@ -224,7 +192,7 @@ async def event_loop_cycle( return # End the cycle and return results - event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes) + agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes) if cycle_span: tracer.end_event_loop_cycle_span( span=cycle_span, @@ -250,33 +218,16 @@ async def event_loop_cycle( logger.exception("cycle failed") raise EventLoopException(e, kwargs["request_state"]) from e - yield {"stop": (stop_reason, message, event_loop_metrics, kwargs["request_state"])} + yield {"stop": (stop_reason, message, agent.event_loop_metrics, kwargs["request_state"])} -async def recurse_event_loop( - model: Model, - system_prompt: Optional[str], - messages: Messages, - tool_config: Optional[ToolConfig], - tool_handler: Optional[ToolHandler], - thread_pool: Optional[ThreadPoolExecutor], - event_loop_metrics: EventLoopMetrics, - event_loop_parent_span: Optional[trace.Span], - kwargs: dict[str, Any], -) -> AsyncGenerator[dict[str, Any], None]: +async def recurse_event_loop(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: """Make a recursive call to event_loop_cycle with the current state. This function is used when the event loop needs to continue processing after tool execution. Args: - model: Provider for running model inference - system_prompt: System prompt instructions for the model - messages: Conversation history messages - tool_config: Configuration for available tools - tool_handler: Handler for tool execution - thread_pool: Optional thread pool for parallel tool execution. - event_loop_metrics: Metrics tracking object for the event loop. - event_loop_parent_span: Span for the parent of this event loop. + agent: Agent for which the recursive call is being made. kwargs: Arguments to pass through event_loop_cycle @@ -296,34 +247,77 @@ async def recurse_event_loop( yield {"callback": {"start": True}} - events = event_loop_cycle( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - tool_handler=tool_handler, - thread_pool=thread_pool, - event_loop_metrics=event_loop_metrics, - event_loop_parent_span=event_loop_parent_span, - kwargs=kwargs, - ) + events = event_loop_cycle(agent=agent, kwargs=kwargs) async for event in events: yield event recursive_trace.end() +def run_tool(agent: "Agent", kwargs: dict[str, Any], tool: ToolUse) -> ToolGenerator: + """Process a tool invocation. + + Looks up the tool in the registry and invokes it with the provided parameters. + + Args: + agent: The agent for which the tool is being executed. + tool: The tool object to process, containing name and parameters. + kwargs: Additional keyword arguments passed to the tool. + + Yields: + Events of the tool invocation. + + Returns: + The final tool result or an error response if the tool fails or is not found. + """ + logger.debug("tool=<%s> | invoking", tool) + tool_use_id = tool["toolUseId"] + tool_name = tool["name"] + + # Get the tool info + tool_info = agent.tool_registry.dynamic_tools.get(tool_name) + tool_func = tool_info if tool_info is not None else agent.tool_registry.registry.get(tool_name) + + try: + # Check if tool exists + if not tool_func: + logger.error( + "tool_name=<%s>, available_tools=<%s> | tool not found in registry", + tool_name, + list(agent.tool_registry.registry.keys()), + ) + return { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Unknown tool: {tool_name}"}], + } + # Add standard arguments to kwargs for Python tools + kwargs.update( + { + "model": agent.model, + "system_prompt": agent.system_prompt, + "messages": agent.messages, + "tool_config": agent.tool_config, + } + ) + + result = tool_func.invoke(tool, **kwargs) + yield {"result": result} # Placeholder until tool_func becomes a generator from which we can yield from + return result + + except Exception as e: + logger.exception("tool_name=<%s> | failed to process tool", tool_name) + return { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Error: {str(e)}"}], + } + + async def _handle_tool_execution( stop_reason: StopReason, message: Message, - model: Model, - system_prompt: Optional[str], - messages: Messages, - tool_config: ToolConfig, - tool_handler: ToolHandler, - thread_pool: Optional[ThreadPoolExecutor], - event_loop_metrics: EventLoopMetrics, - event_loop_parent_span: Optional[trace.Span], + agent: "Agent", cycle_trace: Trace, cycle_span: Any, cycle_start_time: float, @@ -339,12 +333,6 @@ async def _handle_tool_execution( Args: stop_reason: The reason the model stopped generating. message: The message from the model that may contain tool use requests. - model: The model provider instance. - system_prompt: The system prompt instructions for the model. - messages: The conversation history messages. - tool_config: Configuration for available tools. - tool_handler: Handler for tool execution. - thread_pool: Optional thread pool for parallel tool execution. event_loop_metrics: Metrics tracking object for the event loop. event_loop_parent_span: Span for the parent of this event loop. cycle_trace: Trace object for the current event loop cycle. @@ -363,27 +351,21 @@ async def _handle_tool_execution( validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) if not tool_uses: - yield {"stop": (stop_reason, message, event_loop_metrics, kwargs["request_state"])} + yield {"stop": (stop_reason, message, agent.event_loop_metrics, kwargs["request_state"])} return - tool_handler_process = partial( - tool_handler.process, - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - kwargs=kwargs, - ) + def tool_handler(tool_use: ToolUse) -> ToolGenerator: + return run_tool(agent=agent, kwargs=kwargs, tool=tool_use) tool_events = run_tools( - handler=tool_handler_process, + handler=tool_handler, tool_uses=tool_uses, - event_loop_metrics=event_loop_metrics, + event_loop_metrics=agent.event_loop_metrics, invalid_tool_use_ids=invalid_tool_use_ids, tool_results=tool_results, cycle_trace=cycle_trace, parent_span=cycle_span, - thread_pool=thread_pool, + thread_pool=agent.thread_pool, ) for tool_event in tool_events: yield tool_event @@ -396,7 +378,7 @@ async def _handle_tool_execution( "content": [{"toolResult": result} for result in tool_results], } - messages.append(tool_result_message) + agent.messages.append(tool_result_message) yield {"callback": {"message": tool_result_message}} if cycle_span: @@ -404,20 +386,10 @@ async def _handle_tool_execution( tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message) if kwargs["request_state"].get("stop_event_loop", False): - event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) - yield {"stop": (stop_reason, message, event_loop_metrics, kwargs["request_state"])} + agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) + yield {"stop": (stop_reason, message, agent.event_loop_metrics, kwargs["request_state"])} return - events = recurse_event_loop( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - tool_handler=tool_handler, - thread_pool=thread_pool, - event_loop_metrics=event_loop_metrics, - event_loop_parent_span=event_loop_parent_span, - kwargs=kwargs, - ) + events = recurse_event_loop(agent=agent, kwargs=kwargs) async for event in events: yield event diff --git a/src/strands/handlers/tool_handler.py b/src/strands/handlers/tool_handler.py deleted file mode 100644 index 4f93edf76..000000000 --- a/src/strands/handlers/tool_handler.py +++ /dev/null @@ -1,98 +0,0 @@ -"""This module provides handlers for managing tool invocations.""" - -import logging -from typing import Any, Optional - -from ..tools.registry import ToolRegistry -from ..types.content import Messages -from ..types.models import Model -from ..types.tools import ToolConfig, ToolGenerator, ToolHandler, ToolUse - -logger = logging.getLogger(__name__) - - -class AgentToolHandler(ToolHandler): - """Handler for processing tool invocations in agent. - - This class implements the ToolHandler interface and provides functionality for looking up tools in a registry and - invoking them with the appropriate parameters. - """ - - def __init__(self, tool_registry: ToolRegistry) -> None: - """Initialize handler. - - Args: - tool_registry: Registry of available tools. - """ - self.tool_registry = tool_registry - - def process( - self, - tool: ToolUse, - *, - model: Model, - system_prompt: Optional[str], - messages: Messages, - tool_config: ToolConfig, - kwargs: dict[str, Any], - ) -> ToolGenerator: - """Process a tool invocation. - - Looks up the tool in the registry and invokes it with the provided parameters. - - Args: - tool: The tool object to process, containing name and parameters. - model: The model being used for the agent. - system_prompt: The system prompt for the agent. - messages: The conversation history. - tool_config: Configuration for the tool. - kwargs: Additional keyword arguments passed to the tool. - - Yields: - Events of the tool invocation. - - Returns: - The final tool result or an error response if the tool fails or is not found. - """ - logger.debug("tool=<%s> | invoking", tool) - tool_use_id = tool["toolUseId"] - tool_name = tool["name"] - - # Get the tool info - tool_info = self.tool_registry.dynamic_tools.get(tool_name) - tool_func = tool_info if tool_info is not None else self.tool_registry.registry.get(tool_name) - - try: - # Check if tool exists - if not tool_func: - logger.error( - "tool_name=<%s>, available_tools=<%s> | tool not found in registry", - tool_name, - list(self.tool_registry.registry.keys()), - ) - return { - "toolUseId": tool_use_id, - "status": "error", - "content": [{"text": f"Unknown tool: {tool_name}"}], - } - # Add standard arguments to kwargs for Python tools - kwargs.update( - { - "model": model, - "system_prompt": system_prompt, - "messages": messages, - "tool_config": tool_config, - } - ) - - result = tool_func.invoke(tool, **kwargs) - yield {"result": result} # Placeholder until tool_func becomes a generator from which we can yield from - return result - - except Exception as e: - logger.exception("tool_name=<%s> | failed to process tool", tool_name) - return { - "toolUseId": tool_use_id, - "status": "error", - "content": [{"text": f"Error: {str(e)}"}], - } diff --git a/src/strands/tools/executor.py b/src/strands/tools/executor.py index 01c749498..2291e0ff4 100644 --- a/src/strands/tools/executor.py +++ b/src/strands/tools/executor.py @@ -5,7 +5,7 @@ import threading import time from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable, Generator, Optional, cast +from typing import Any, Generator, Optional, cast from opentelemetry import trace @@ -13,13 +13,13 @@ from ..telemetry.tracer import get_tracer from ..tools.tools import InvalidToolUseNameException, validate_tool_use from ..types.content import Message -from ..types.tools import ToolGenerator, ToolResult, ToolUse +from ..types.tools import RunToolHandler, ToolGenerator, ToolResult, ToolUse logger = logging.getLogger(__name__) def run_tools( - handler: Callable[[ToolUse], Generator[dict[str, Any], None, ToolResult]], + handler: RunToolHandler, tool_uses: list[ToolUse], event_loop_metrics: EventLoopMetrics, invalid_tool_use_ids: list[str], diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index 652024175..798cbc185 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -6,16 +6,12 @@ """ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Generator, Literal, Optional, Union +from typing import Any, Callable, Generator, Literal, Union from typing_extensions import TypedDict from .media import DocumentContent, ImageContent -if TYPE_CHECKING: - from .content import Messages - from .models import Model - JSONSchema = dict """Type alias for JSON Schema dictionaries.""" @@ -134,6 +130,8 @@ class ToolChoiceTool(TypedDict): - "tool": The model must use the specified tool """ +RunToolHandler = Callable[[ToolUse], Generator[dict[str, Any], None, ToolResult]] +"""Callback that runs a single tool and streams back results.""" ToolGenerator = Generator[dict[str, Any], None, ToolResult] """Generator of tool events and a returned tool result.""" @@ -239,36 +237,3 @@ def get_display_properties(self) -> dict[str, str]: "Name": self.tool_name, "Type": self.tool_type, } - - -class ToolHandler(ABC): - """Abstract base class for handling tool execution within the agent framework.""" - - @abstractmethod - def process( - self, - tool: ToolUse, - *, - model: "Model", - system_prompt: Optional[str], - messages: "Messages", - tool_config: ToolConfig, - kwargs: dict[str, Any], - ) -> ToolGenerator: - """Process a tool use request and execute the tool. - - Args: - tool: The tool use request to process. - messages: The current conversation history. - model: The model being used for the conversation. - system_prompt: The system prompt for the conversation. - tool_config: The tool configuration for the current session. - kwargs: Additional context-specific arguments. - - Yields: - Events of the tool invocation. - - Returns: - The final tool result. - """ - ... diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 21df8014d..b49e294e2 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -58,6 +58,12 @@ def mock_event_loop_cycle(): yield mock +@pytest.fixture +def mock_run_tool(): + with unittest.mock.patch("strands.agent.agent.run_tool") as mock: + yield mock + + @pytest.fixture def tool_registry(): return strands.tools.registry.ToolRegistry() @@ -832,8 +838,8 @@ def test_agent_init_with_no_model_or_model_id(): assert agent.model.get_config().get("model_id") == DEFAULT_BEDROCK_MODEL_ID -def test_agent_tool_no_parameter_conflict(agent, tool_registry, mock_randint): - agent.tool_handler = unittest.mock.Mock(process=unittest.mock.Mock(return_value=iter([]))) +def test_agent_tool_no_parameter_conflict(agent, tool_registry, mock_randint, mock_run_tool): + mock_run_tool.return_value = iter([]) @strands.tools.tool(name="system_prompter") def function(system_prompt: str) -> str: @@ -845,22 +851,19 @@ def function(system_prompt: str) -> str: agent.tool.system_prompter(system_prompt="tool prompt") - agent.tool_handler.process.assert_called_with( + mock_run_tool.assert_called_with( + agent=agent, tool={ "toolUseId": "tooluse_system_prompter_1", "name": "system_prompter", "input": {"system_prompt": "tool prompt"}, }, - model=unittest.mock.ANY, - system_prompt="You are a helpful assistant.", - messages=unittest.mock.ANY, - tool_config=unittest.mock.ANY, kwargs={"system_prompt": "tool prompt"}, ) -def test_agent_tool_with_name_normalization(agent, tool_registry, mock_randint): - agent.tool_handler = unittest.mock.Mock(process=unittest.mock.Mock(return_value=iter([]))) +def test_agent_tool_with_name_normalization(agent, tool_registry, mock_randint, mock_run_tool): + mock_run_tool.return_value = iter([]) tool_name = "system-prompter" @@ -875,8 +878,8 @@ def function(system_prompt: str) -> str: agent.tool.system_prompter(system_prompt="tool prompt") # Verify the correct tool was invoked - assert agent.tool_handler.process.call_count == 1 - tool_call = agent.tool_handler.process.call_args.kwargs.get("tool") + assert mock_run_tool.call_count == 1 + tool_call = mock_run_tool.call_args.kwargs.get("tool") assert tool_call == { # Note that the tool-use uses the "python safe" name @@ -1267,31 +1270,6 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) -@unittest.mock.patch("strands.agent.agent.get_tracer") -def test_event_loop_cycle_includes_parent_span(mock_get_tracer, mock_event_loop_cycle, mock_model, agenerator): - """Test that event_loop_cycle is called with the parent span.""" - # Setup mock tracer and span - mock_tracer = unittest.mock.MagicMock() - mock_span = unittest.mock.MagicMock() - mock_tracer.start_agent_span.return_value = mock_span - mock_get_tracer.return_value = mock_tracer - - # Setup mock for event_loop_cycle - mock_event_loop_cycle.return_value = agenerator( - [{"stop": ("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})}] - ) - - # Create agent and make a call - agent = Agent(model=mock_model) - agent("test prompt") - - # Verify event_loop_cycle was called with the span - mock_event_loop_cycle.assert_called_once() - kwargs = mock_event_loop_cycle.call_args[1] - assert "event_loop_parent_span" in kwargs - assert kwargs["event_loop_parent_span"] == mock_span - - def test_non_dict_throws_error(): with pytest.raises(ValueError, match="state must be an AgentState object or a dict"): agent = Agent(state={"object", object()}) diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 291b7be30..1b37fc106 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -6,7 +6,7 @@ import strands import strands.telemetry -from strands.handlers.tool_handler import AgentToolHandler +from strands.event_loop.event_loop import run_tool from strands.telemetry.metrics import EventLoopMetrics from strands.tools.registry import ToolRegistry from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException @@ -43,11 +43,6 @@ def tool_registry(): return ToolRegistry() -@pytest.fixture -def tool_handler(tool_registry): - return AgentToolHandler(tool_registry) - - @pytest.fixture def thread_pool(): return concurrent.futures.ThreadPoolExecutor(max_workers=1) @@ -84,9 +79,16 @@ def tool_stream(tool): @pytest.fixture -def agent(): - mock = unittest.mock.Mock() +def agent(model, system_prompt, messages, tool_config, tool_registry, thread_pool): + mock = unittest.mock.Mock(name="agent") mock.config.cache_points = [] + mock.model = model + mock.system_prompt = system_prompt + mock.messages = messages + mock.tool_config = tool_config + mock.tool_registry = tool_registry + mock.thread_pool = thread_pool + mock.event_loop_metrics = EventLoopMetrics() return mock @@ -101,12 +103,8 @@ def mock_tracer(): @pytest.mark.asyncio async def test_event_loop_cycle_text_response( + agent, model, - system_prompt, - messages, - tool_config, - tool_handler, - thread_pool, agenerator, alist, ): @@ -118,14 +116,7 @@ async def test_event_loop_cycle_text_response( ) stream = strands.event_loop.event_loop.event_loop_cycle( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - tool_handler=tool_handler, - thread_pool=thread_pool, - event_loop_metrics=EventLoopMetrics(), - event_loop_parent_span=None, + agent=agent, kwargs={}, ) events = await alist(stream) @@ -141,12 +132,8 @@ async def test_event_loop_cycle_text_response( @pytest.mark.asyncio async def test_event_loop_cycle_text_response_throttling( mock_time, + agent, model, - system_prompt, - messages, - tool_config, - tool_handler, - thread_pool, agenerator, alist, ): @@ -161,14 +148,7 @@ async def test_event_loop_cycle_text_response_throttling( ] stream = strands.event_loop.event_loop.event_loop_cycle( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - tool_handler=tool_handler, - thread_pool=thread_pool, - event_loop_metrics=EventLoopMetrics(), - event_loop_parent_span=None, + agent=agent, kwargs={}, ) events = await alist(stream) @@ -186,12 +166,8 @@ async def test_event_loop_cycle_text_response_throttling( @pytest.mark.asyncio async def test_event_loop_cycle_exponential_backoff( mock_time, + agent, model, - system_prompt, - messages, - tool_config, - tool_handler, - thread_pool, agenerator, alist, ): @@ -210,14 +186,7 @@ async def test_event_loop_cycle_exponential_backoff( ] stream = strands.event_loop.event_loop.event_loop_cycle( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - tool_handler=tool_handler, - thread_pool=thread_pool, - event_loop_metrics=EventLoopMetrics(), - event_loop_parent_span=None, + agent=agent, kwargs={}, ) events = await alist(stream) @@ -237,12 +206,8 @@ async def test_event_loop_cycle_exponential_backoff( @pytest.mark.asyncio async def test_event_loop_cycle_text_response_throttling_exceeded( mock_time, + agent, model, - system_prompt, - messages, - tool_config, - tool_handler, - thread_pool, alist, ): model.converse.side_effect = [ @@ -256,14 +221,7 @@ async def test_event_loop_cycle_text_response_throttling_exceeded( with pytest.raises(ModelThrottledException): stream = strands.event_loop.event_loop.event_loop_cycle( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - tool_handler=tool_handler, - thread_pool=thread_pool, - event_loop_metrics=EventLoopMetrics(), - event_loop_parent_span=None, + agent=agent, kwargs={}, ) await alist(stream) @@ -281,26 +239,15 @@ async def test_event_loop_cycle_text_response_throttling_exceeded( @pytest.mark.asyncio async def test_event_loop_cycle_text_response_error( + agent, model, - system_prompt, - messages, - tool_config, - tool_handler, - thread_pool, alist, ): model.converse.side_effect = RuntimeError("Unhandled error") with pytest.raises(RuntimeError): stream = strands.event_loop.event_loop.event_loop_cycle( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - tool_handler=tool_handler, - thread_pool=thread_pool, - event_loop_metrics=EventLoopMetrics(), - event_loop_parent_span=None, + agent=agent, kwargs={}, ) await alist(stream) @@ -308,12 +255,10 @@ async def test_event_loop_cycle_text_response_error( @pytest.mark.asyncio async def test_event_loop_cycle_tool_result( + agent, model, system_prompt, messages, - tool_config, - tool_handler, - thread_pool, tool_stream, agenerator, alist, @@ -329,14 +274,7 @@ async def test_event_loop_cycle_tool_result( ] stream = strands.event_loop.event_loop.event_loop_cycle( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - tool_handler=tool_handler, - thread_pool=thread_pool, - event_loop_metrics=EventLoopMetrics(), - event_loop_parent_span=None, + agent=agent, kwargs={}, ) events = await alist(stream) @@ -384,12 +322,8 @@ async def test_event_loop_cycle_tool_result( @pytest.mark.asyncio async def test_event_loop_cycle_tool_result_error( + agent, model, - system_prompt, - messages, - tool_config, - tool_handler, - thread_pool, tool_stream, agenerator, alist, @@ -398,14 +332,7 @@ async def test_event_loop_cycle_tool_result_error( with pytest.raises(EventLoopException): stream = strands.event_loop.event_loop.event_loop_cycle( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - tool_handler=tool_handler, - thread_pool=thread_pool, - event_loop_metrics=EventLoopMetrics(), - event_loop_parent_span=None, + agent=agent, kwargs={}, ) await alist(stream) @@ -413,27 +340,19 @@ async def test_event_loop_cycle_tool_result_error( @pytest.mark.asyncio async def test_event_loop_cycle_tool_result_no_tool_handler( + agent, model, - system_prompt, - messages, - tool_config, - thread_pool, tool_stream, agenerator, alist, ): model.converse.side_effect = [agenerator(tool_stream)] + # Set tool_handler to None for this test + agent.tool_handler = None with pytest.raises(EventLoopException): stream = strands.event_loop.event_loop.event_loop_cycle( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - tool_handler=None, - thread_pool=thread_pool, - event_loop_metrics=EventLoopMetrics(), - event_loop_parent_span=None, + agent=agent, kwargs={}, ) await alist(stream) @@ -441,27 +360,19 @@ async def test_event_loop_cycle_tool_result_no_tool_handler( @pytest.mark.asyncio async def test_event_loop_cycle_tool_result_no_tool_config( + agent, model, - system_prompt, - messages, - tool_handler, - thread_pool, tool_stream, agenerator, alist, ): model.converse.side_effect = [agenerator(tool_stream)] + # Set tool_config to None for this test + agent.tool_config = None with pytest.raises(EventLoopException): stream = strands.event_loop.event_loop.event_loop_cycle( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=None, - tool_handler=tool_handler, - thread_pool=thread_pool, - event_loop_metrics=EventLoopMetrics(), - event_loop_parent_span=None, + agent=agent, kwargs={}, ) await alist(stream) @@ -469,12 +380,8 @@ async def test_event_loop_cycle_tool_result_no_tool_config( @pytest.mark.asyncio async def test_event_loop_cycle_stop( + agent, model, - system_prompt, - messages, - tool_config, - tool_handler, - thread_pool, tool, agenerator, alist, @@ -499,14 +406,7 @@ async def test_event_loop_cycle_stop( ] stream = strands.event_loop.event_loop.event_loop_cycle( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - tool_handler=tool_handler, - thread_pool=thread_pool, - event_loop_metrics=EventLoopMetrics(), - event_loop_parent_span=None, + agent=agent, kwargs={"request_state": {"stop_event_loop": True}}, ) events = await alist(stream) @@ -532,12 +432,8 @@ async def test_event_loop_cycle_stop( @pytest.mark.asyncio async def test_cycle_exception( + agent, model, - system_prompt, - messages, - tool_config, - tool_handler, - thread_pool, tool_stream, agenerator, ): @@ -553,14 +449,7 @@ async def test_cycle_exception( with pytest.raises(EventLoopException): stream = strands.event_loop.event_loop.event_loop_cycle( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - tool_handler=tool_handler, - thread_pool=thread_pool, - event_loop_metrics=EventLoopMetrics(), - event_loop_parent_span=None, + agent=agent, kwargs={}, ) async for event in stream: @@ -573,12 +462,8 @@ async def test_cycle_exception( @pytest.mark.asyncio async def test_event_loop_cycle_creates_spans( mock_get_tracer, + agent, model, - system_prompt, - messages, - tool_config, - tool_handler, - thread_pool, mock_tracer, agenerator, alist, @@ -599,14 +484,7 @@ async def test_event_loop_cycle_creates_spans( # Call event_loop_cycle stream = strands.event_loop.event_loop.event_loop_cycle( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - tool_handler=tool_handler, - thread_pool=thread_pool, - event_loop_metrics=EventLoopMetrics(), - event_loop_parent_span=None, + agent=agent, kwargs={}, ) await alist(stream) @@ -623,12 +501,8 @@ async def test_event_loop_cycle_creates_spans( @pytest.mark.asyncio async def test_event_loop_tracing_with_model_error( mock_get_tracer, + agent, model, - system_prompt, - messages, - tool_config, - tool_handler, - thread_pool, mock_tracer, alist, ): @@ -645,14 +519,7 @@ async def test_event_loop_tracing_with_model_error( # Call event_loop_cycle, expecting it to handle the exception with pytest.raises(ContextWindowOverflowException): stream = strands.event_loop.event_loop.event_loop_cycle( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - tool_handler=tool_handler, - thread_pool=thread_pool, - event_loop_metrics=EventLoopMetrics(), - event_loop_parent_span=None, + agent=agent, kwargs={}, ) await alist(stream) @@ -665,12 +532,8 @@ async def test_event_loop_tracing_with_model_error( @pytest.mark.asyncio async def test_event_loop_tracing_with_tool_execution( mock_get_tracer, + agent, model, - system_prompt, - messages, - tool_config, - tool_handler, - thread_pool, tool_stream, mock_tracer, agenerator, @@ -696,14 +559,7 @@ async def test_event_loop_tracing_with_tool_execution( # Call event_loop_cycle which should execute a tool stream = strands.event_loop.event_loop.event_loop_cycle( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - tool_handler=tool_handler, - thread_pool=thread_pool, - event_loop_metrics=EventLoopMetrics(), - event_loop_parent_span=None, + agent=agent, kwargs={}, ) await alist(stream) @@ -718,12 +574,8 @@ async def test_event_loop_tracing_with_tool_execution( @pytest.mark.asyncio async def test_event_loop_tracing_with_throttling_exception( mock_get_tracer, + agent, model, - system_prompt, - messages, - tool_config, - tool_handler, - thread_pool, mock_tracer, agenerator, alist, @@ -749,14 +601,7 @@ async def test_event_loop_tracing_with_throttling_exception( # Mock the time.sleep function to speed up the test with patch("strands.event_loop.event_loop.time.sleep"): stream = strands.event_loop.event_loop.event_loop_cycle( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - tool_handler=tool_handler, - thread_pool=thread_pool, - event_loop_metrics=EventLoopMetrics(), - event_loop_parent_span=None, + agent=agent, kwargs={}, ) await alist(stream) @@ -772,12 +617,9 @@ async def test_event_loop_tracing_with_throttling_exception( @pytest.mark.asyncio async def test_event_loop_cycle_with_parent_span( mock_get_tracer, + agent, model, - system_prompt, messages, - tool_config, - tool_handler, - thread_pool, mock_tracer, agenerator, alist, @@ -795,16 +637,12 @@ async def test_event_loop_cycle_with_parent_span( ] ) + # Set the parent span for this test + agent.trace_span = parent_span + # Call event_loop_cycle with a parent span stream = strands.event_loop.event_loop.event_loop_cycle( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - tool_handler=tool_handler, - thread_pool=thread_pool, - event_loop_metrics=EventLoopMetrics(), - event_loop_parent_span=parent_span, + agent=agent, kwargs={}, ) await alist(stream) @@ -817,16 +655,13 @@ async def test_event_loop_cycle_with_parent_span( @pytest.mark.asyncio async def test_request_state_initialization(alist): + # Create a mock agent + mock_agent = MagicMock() + mock_agent.event_loop_metrics.start_cycle.return_value = (0, MagicMock()) + # Call without providing request_state stream = strands.event_loop.event_loop.event_loop_cycle( - model=MagicMock(), - system_prompt=MagicMock(), - messages=MagicMock(), - tool_config=MagicMock(), - tool_handler=MagicMock(), - thread_pool=MagicMock(), - event_loop_metrics=EventLoopMetrics(), - event_loop_parent_span=None, + agent=mock_agent, kwargs={}, ) events = await alist(stream) @@ -838,14 +673,7 @@ async def test_request_state_initialization(alist): # Call with pre-existing request_state initial_request_state = {"key": "value"} stream = strands.event_loop.event_loop.event_loop_cycle( - model=MagicMock(), - system_prompt=MagicMock(), - messages=MagicMock(), - tool_config=MagicMock(), - tool_handler=MagicMock(), - thread_pool=MagicMock(), - event_loop_metrics=EventLoopMetrics(), - event_loop_parent_span=None, + agent=mock_agent, kwargs={"request_state": initial_request_state}, ) events = await alist(stream) @@ -856,7 +684,7 @@ async def test_request_state_initialization(alist): @pytest.mark.asyncio -async def test_prepare_next_cycle_in_tool_execution(model, tool_stream, agenerator, alist): +async def test_prepare_next_cycle_in_tool_execution(agent, model, tool_stream, agenerator, alist): """Test that cycle ID and metrics are properly updated during tool execution.""" model.converse.side_effect = [ agenerator(tool_stream), @@ -883,14 +711,7 @@ async def test_prepare_next_cycle_in_tool_execution(model, tool_stream, agenerat # Call event_loop_cycle which should execute a tool and then call recurse_event_loop stream = strands.event_loop.event_loop.event_loop_cycle( - model=model, - system_prompt=MagicMock(), - messages=MagicMock(), - tool_config=MagicMock(), - tool_handler=MagicMock(), - thread_pool=MagicMock(), - event_loop_metrics=EventLoopMetrics(), - event_loop_parent_span=None, + agent=agent, kwargs={}, ) await alist(stream) @@ -901,3 +722,33 @@ async def test_prepare_next_cycle_in_tool_execution(model, tool_stream, agenerat recursive_args = mock_recurse.call_args[1] assert "event_loop_parent_cycle_id" in recursive_args["kwargs"] assert recursive_args["kwargs"]["event_loop_parent_cycle_id"] == recursive_args["kwargs"]["event_loop_cycle_id"] + + +def test_run_tool(agent, tool, generate): + process = run_tool( + agent=agent, + tool={"toolUseId": "tool_use_id", "name": tool.tool_name, "input": {"random_string": "a_string"}}, + kwargs={}, + ) + + _, tru_result = generate(process) + exp_result = {"toolUseId": "tool_use_id", "status": "success", "content": [{"text": "a_string"}]} + + assert tru_result == exp_result + + +def test_run_tool_missing_tool(agent, generate): + process = run_tool( + agent=agent, + tool={"toolUseId": "missing", "name": "missing", "input": {}}, + kwargs={}, + ) + + _, tru_result = generate(process) + exp_result = { + "toolUseId": "missing", + "status": "error", + "content": [{"text": "Unknown tool: missing"}], + } + + assert tru_result == exp_result diff --git a/tests/strands/handlers/test_tool_handler.py b/tests/strands/handlers/test_tool_handler.py deleted file mode 100644 index c4e5aae8f..000000000 --- a/tests/strands/handlers/test_tool_handler.py +++ /dev/null @@ -1,62 +0,0 @@ -import unittest.mock - -import pytest - -import strands - - -@pytest.fixture -def tool_registry(): - return strands.tools.registry.ToolRegistry() - - -@pytest.fixture -def tool_handler(tool_registry): - return strands.handlers.tool_handler.AgentToolHandler(tool_registry) - - -@pytest.fixture -def tool_use_identity(tool_registry): - @strands.tools.tool - def identity(a: int) -> int: - return a - - tool_registry.register_tool(identity) - - return {"toolUseId": "identity", "name": "identity", "input": {"a": 1}} - - -def test_process(tool_handler, tool_use_identity, generate): - process = tool_handler.process( - tool_use_identity, - model=unittest.mock.Mock(), - system_prompt="p1", - messages=[], - tool_config={}, - kwargs={}, - ) - - _, tru_result = generate(process) - exp_result = {"toolUseId": "identity", "status": "success", "content": [{"text": "1"}]} - - assert tru_result == exp_result - - -def test_process_missing_tool(tool_handler, generate): - process = tool_handler.process( - tool={"toolUseId": "missing", "name": "missing", "input": {}}, - model=unittest.mock.Mock(), - system_prompt="p1", - messages=[], - tool_config={}, - kwargs={}, - ) - - _, tru_result = generate(process) - exp_result = { - "toolUseId": "missing", - "status": "error", - "content": [{"text": "Unknown tool: missing"}], - } - - assert tru_result == exp_result