diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 000000000..866a0af3a --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,3 @@ +coverage: + ignore: + - "src/strands/experimental/tools/tool_provider.py" # This is an interface, cannot meaningfully cover diff --git a/src/strands/_async.py b/src/strands/_async.py new file mode 100644 index 000000000..976487c37 --- /dev/null +++ b/src/strands/_async.py @@ -0,0 +1,31 @@ +"""Private async execution utilities.""" + +import asyncio +from concurrent.futures import ThreadPoolExecutor +from typing import Awaitable, Callable, TypeVar + +T = TypeVar("T") + + +def run_async(async_func: Callable[[], Awaitable[T]]) -> T: + """Run an async function in a separate thread to avoid event loop conflicts. + + This utility handles the common pattern of running async code from sync contexts + by using ThreadPoolExecutor to isolate the async execution. + + Args: + async_func: A callable that returns an awaitable + + Returns: + The result of the async function + """ + + async def execute_async() -> T: + return await async_func() + + def execute() -> T: + return asyncio.run(execute_async()) + + with ThreadPoolExecutor() as executor: + future = executor.submit(execute) + return future.result() diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 4579ebacf..e428ee0e5 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -9,11 +9,9 @@ 2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")` """ -import asyncio import json import logging import random -from concurrent.futures import ThreadPoolExecutor from typing import ( Any, AsyncGenerator, @@ -31,7 +29,9 @@ from pydantic import BaseModel from .. import _identifier +from .._async import run_async from ..event_loop.event_loop import event_loop_cycle +from ..experimental.tools import ToolProvider from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from ..hooks import ( AfterInvocationEvent, @@ -160,12 +160,7 @@ async def acall() -> ToolResult: return tool_results[0] - def tcall() -> ToolResult: - return asyncio.run(acall()) - - with ThreadPoolExecutor() as executor: - future = executor.submit(tcall) - tool_result = future.result() + tool_result = run_async(acall) if record_direct_tool_call is not None: should_record_direct_tool_call = record_direct_tool_call @@ -208,7 +203,7 @@ def __init__( self, model: Union[Model, str, None] = None, messages: Optional[Messages] = None, - tools: Optional[list[Union[str, dict[str, str], Any]]] = None, + tools: Optional[list[Union[str, dict[str, str], ToolProvider, Any]]] = None, system_prompt: Optional[str] = None, callback_handler: Optional[ Union[Callable[..., Any], _DefaultCallbackHandlerSentinel] @@ -240,7 +235,8 @@ def __init__( - File paths (e.g., "/path/to/tool.py") - Imported Python modules (e.g., from strands_tools import current_time) - Dictionaries with name/path keys (e.g., {"name": "tool_name", "path": "/path/to/tool.py"}) - - Functions decorated with `@strands.tool` decorator. + - Functions decorated with `@strands.tool` decorator + - ToolProvider instances for managed tool collections If provided, only these tools will be available. If None, all tools will be available. system_prompt: System prompt to guide model behavior. @@ -333,6 +329,9 @@ def __init__( else: self.state = AgentState() + # Track cleanup state + self._cleanup_called = False + self.tool_caller = Agent.ToolCaller(self) self.hooks = HookRegistry() @@ -399,13 +398,7 @@ def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: - metrics: Performance metrics from the event loop - state: The final state of the event loop """ - - def execute() -> AgentResult: - return asyncio.run(self.invoke_async(prompt, **kwargs)) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() + return run_async(lambda: self.invoke_async(prompt, **kwargs)) async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: """Process a natural language prompt through the agent's event loop. @@ -459,13 +452,7 @@ def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> Raises: ValueError: If no conversation history or prompt is provided. """ - - def execute() -> T: - return asyncio.run(self.structured_output_async(output_model, prompt)) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() + return run_async(lambda: self.structured_output_async(output_model, prompt)) async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. @@ -527,6 +514,60 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu finally: self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) + def cleanup(self) -> None: + """Clean up resources used by the agent. + + This method cleans up all tool providers that require explicit cleanup, + such as MCP clients. It should be called when the agent is no longer needed + to ensure proper resource cleanup. + + Note: This method uses a "belt and braces" approach with automatic cleanup + through __del__ as a fallback, but explicit cleanup is recommended. + """ + if self._cleanup_called: + return + + run_async(self.cleanup_async) + + async def cleanup_async(self) -> None: + """Asynchronously clean up resources used by the agent. + + This method cleans up all tool providers that require explicit cleanup, + such as MCP clients. It should be called when the agent is no longer needed + to ensure proper resource cleanup. + + Note: This method uses a "belt and braces" approach with automatic cleanup + through __del__ as a fallback, but explicit cleanup is recommended. + """ + if self._cleanup_called: + return + + logger.debug("agent_id=<%s> | cleaning up agent resources", self.agent_id) + + await self.tool_registry.cleanup_async() + + self._cleanup_called = True + logger.debug("agent_id=<%s> | agent cleanup complete", self.agent_id) + + def __del__(self) -> None: + """Automatic cleanup when agent is garbage collected. + + This serves as a fallback cleanup mechanism, but explicit cleanup() is preferred. + """ + try: + if self._cleanup_called or not self.tool_registry.tool_providers: + return + + logger.warning( + "agent_id=<%s> | Agent cleanup called via __del__. " + "Consider calling agent.cleanup() explicitly for better resource management.", + self.agent_id, + ) + self.cleanup() + except Exception as e: + # Log exceptions during garbage collection cleanup for debugging + logger.debug("agent_id=<%s>, error=<%s> | exception during __del__ cleanup", self.agent_id, e) + async def stream_async( self, prompt: AgentInput = None, diff --git a/src/strands/experimental/__init__.py b/src/strands/experimental/__init__.py index c40d0fcec..5b9c95582 100644 --- a/src/strands/experimental/__init__.py +++ b/src/strands/experimental/__init__.py @@ -2,3 +2,7 @@ This module implements experimental features that are subject to change in future revisions without notice. """ + +from . import tools + +__all__ = ["tools"] diff --git a/src/strands/experimental/tools/__init__.py b/src/strands/experimental/tools/__init__.py new file mode 100644 index 000000000..ad693f8ac --- /dev/null +++ b/src/strands/experimental/tools/__init__.py @@ -0,0 +1,5 @@ +"""Experimental tools package.""" + +from .tool_provider import ToolProvider + +__all__ = ["ToolProvider"] diff --git a/src/strands/experimental/tools/tool_provider.py b/src/strands/experimental/tools/tool_provider.py new file mode 100644 index 000000000..401555368 --- /dev/null +++ b/src/strands/experimental/tools/tool_provider.py @@ -0,0 +1,49 @@ +"""Tool provider interface.""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Sequence + +if TYPE_CHECKING: + from ...types.tools import AgentTool + + +class ToolProvider(ABC): + """Interface for providing tools with lifecycle management. + + Provides a way to load a collection of tools and clean them up + when done, with lifecycle managed by the agent. + """ + + @abstractmethod + async def load_tools(self, **kwargs: Any) -> Sequence["AgentTool"]: + """Load and return the tools in this provider. + + Args: + **kwargs: Additional arguments for future compatibility. + + Returns: + List of tools that are ready to use. + """ + pass + + @abstractmethod + async def add_consumer(self, id: Any, **kwargs: Any) -> None: + """Add a consumer to this tool provider. + + Args: + id: Unique identifier for the consumer. + **kwargs: Additional arguments for future compatibility. + """ + pass + + @abstractmethod + async def remove_consumer(self, id: Any, **kwargs: Any) -> None: + """Remove a consumer from this tool provider. + + Args: + id: Unique identifier for the consumer. + **kwargs: Additional arguments for future compatibility. + + Provider may clean up resources when no consumers remain. + """ + pass diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 03d7de9b4..b146ec8ec 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -3,13 +3,12 @@ Provides minimal foundation for multi-agent patterns (Swarm, Graph). """ -import asyncio from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from enum import Enum from typing import Any, Union +from .._async import run_async from ..agent import AgentResult from ..types.content import ContentBlock from ..types.event_loop import Metrics, Usage @@ -111,9 +110,4 @@ def __call__( if invocation_state is None: invocation_state = {} - def execute() -> MultiAgentResult: - return asyncio.run(self.invoke_async(task, invocation_state, **kwargs)) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() + return run_async(lambda: self.invoke_async(task, invocation_state, **kwargs)) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 738dc4d4c..d48ec2f0d 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -18,12 +18,12 @@ import copy import logging import time -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from typing import Any, Callable, Optional, Tuple from opentelemetry import trace as trace_api +from .._async import run_async from ..agent import Agent from ..agent.state import AgentState from ..telemetry import get_tracer @@ -399,12 +399,7 @@ def __call__( if invocation_state is None: invocation_state = {} - def execute() -> GraphResult: - return asyncio.run(self.invoke_async(task, invocation_state)) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() + return run_async(lambda: self.invoke_async(task, invocation_state)) async def invoke_async( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 620fa5e24..1b49a3081 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -17,13 +17,14 @@ import json import logging import time -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from typing import Any, Callable, Tuple from opentelemetry import trace as trace_api -from ..agent import Agent, AgentResult +from .._async import run_async +from ..agent import Agent +from ..agent.agent_result import AgentResult from ..agent.state import AgentState from ..telemetry import get_tracer from ..tools.decorator import tool @@ -254,12 +255,7 @@ def __call__( if invocation_state is None: invocation_state = {} - def execute() -> SwarmResult: - return asyncio.run(self.invoke_async(task, invocation_state)) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() + return run_async(lambda: self.invoke_async(task, invocation_state)) async def invoke_async( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any diff --git a/src/strands/tools/mcp/__init__.py b/src/strands/tools/mcp/__init__.py index d95c54fed..cfa841c46 100644 --- a/src/strands/tools/mcp/__init__.py +++ b/src/strands/tools/mcp/__init__.py @@ -7,7 +7,7 @@ """ from .mcp_agent_tool import MCPAgentTool -from .mcp_client import MCPClient +from .mcp_client import MCPClient, ToolFilters from .mcp_types import MCPTransport -__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport"] +__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport", "ToolFilters"] diff --git a/src/strands/tools/mcp/mcp_agent_tool.py b/src/strands/tools/mcp/mcp_agent_tool.py index acc48443c..af0c069a1 100644 --- a/src/strands/tools/mcp/mcp_agent_tool.py +++ b/src/strands/tools/mcp/mcp_agent_tool.py @@ -28,26 +28,29 @@ class MCPAgentTool(AgentTool): seamlessly within the agent framework. """ - def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient") -> None: + def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient", name_override: str | None = None) -> None: """Initialize a new MCPAgentTool instance. Args: mcp_tool: The MCP tool to adapt mcp_client: The MCP server connection to use for tool invocation + name_override: Optional name to use for the agent tool (for disambiguation) + If None, uses the original MCP tool name """ super().__init__() logger.debug("tool_name=<%s> | creating mcp agent tool", mcp_tool.name) self.mcp_tool = mcp_tool self.mcp_client = mcp_client + self._agent_tool_name = name_override or mcp_tool.name @property def tool_name(self) -> str: """Get the name of the tool. Returns: - str: The name of the MCP tool + str: The agent-facing name of the tool (may be disambiguated) """ - return self.mcp_tool.name + return self._agent_tool_name @property def tool_spec(self) -> ToolSpec: @@ -63,7 +66,7 @@ def tool_spec(self) -> ToolSpec: spec: ToolSpec = { "inputSchema": {"json": self.mcp_tool.inputSchema}, - "name": self.mcp_tool.name, + "name": self.tool_name, # Use agent-facing name in spec "description": description, } @@ -100,7 +103,7 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw result = await self.mcp_client.call_tool_async( tool_use_id=tool_use["toolUseId"], - name=self.tool_name, + name=self.mcp_tool.name, # Use original MCP name for server communication arguments=tool_use["input"], ) yield ToolResultEvent(result) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index dec8ec313..f10ee2d33 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -16,7 +16,7 @@ from concurrent import futures from datetime import timedelta from types import TracebackType -from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar, Union, cast +from typing import Any, Callable, Coroutine, Dict, Optional, Pattern, Sequence, TypeVar, Union, cast import anyio from mcp import ClientSession, ListToolsResult @@ -24,11 +24,13 @@ from mcp.types import GetPromptResult, ListPromptsResult from mcp.types import ImageContent as MCPImageContent from mcp.types import TextContent as MCPTextContent +from typing_extensions import Protocol, TypedDict +from ...experimental.tools import ToolProvider from ...types import PaginatedList -from ...types.exceptions import MCPClientInitializationError +from ...types.exceptions import MCPClientInitializationError, ToolProviderException from ...types.media import ImageFormat -from ...types.tools import ToolResultContent, ToolResultStatus +from ...types.tools import AgentTool, ToolResultContent, ToolResultStatus from .mcp_agent_tool import MCPAgentTool from .mcp_instrumentation import mcp_instrumentation from .mcp_types import MCPToolResult, MCPTransport @@ -37,6 +39,26 @@ T = TypeVar("T") + +class _ToolFilterCallback(Protocol): + def __call__(self, tool: AgentTool, **kwargs: Any) -> bool: ... + + +_ToolFilterPattern = str | Pattern[str] | _ToolFilterCallback + + +class ToolFilters(TypedDict, total=False): + """Filters for controlling which MCP tools are loaded and available. + + Tools are filtered in this order: + 1. If 'allowed' is specified, only tools matching these patterns are included + 2. Tools matching 'rejected' patterns are then excluded + """ + + allowed: list[_ToolFilterPattern] + rejected: list[_ToolFilterPattern] + + MIME_TO_FORMAT: Dict[str, ImageFormat] = { "image/jpeg": "jpeg", "image/jpg": "jpeg", @@ -52,7 +74,7 @@ ) -class MCPClient: +class MCPClient(ToolProvider): """Represents a connection to a Model Context Protocol (MCP) server. This class implements a context manager pattern for efficient connection management, @@ -64,15 +86,26 @@ class MCPClient: from MCP tools, it will be returned as the last item in the content array of the ToolResult. """ - def __init__(self, transport_callable: Callable[[], MCPTransport], *, startup_timeout: int = 30): + def __init__( + self, + transport_callable: Callable[[], MCPTransport], + *, + startup_timeout: int = 30, + tool_filters: Optional[ToolFilters] = None, + prefix: Optional[str] = None, + ): """Initialize a new MCP Server connection. Args: transport_callable: A callable that returns an MCPTransport (read_stream, write_stream) tuple startup_timeout: Timeout after which MCP server initialization should be cancelled Defaults to 30. + tool_filters: Optional filters to apply to tools. + prefix: Optional prefix for tool names. """ self._startup_timeout = startup_timeout + self._tool_filters = tool_filters + self._prefix = prefix mcp_instrumentation() self._session_id = uuid.uuid4() @@ -86,6 +119,9 @@ def __init__(self, transport_callable: Callable[[], MCPTransport], *, startup_ti self._background_thread: threading.Thread | None = None self._background_thread_session: ClientSession | None = None self._background_thread_event_loop: AbstractEventLoop | None = None + self._loaded_tools: list[MCPAgentTool] | None = None + self._tool_provider_started = False + self._consumers: set[Any] = set() def __enter__(self) -> "MCPClient": """Context manager entry point which initializes the MCP server connection. @@ -136,6 +172,92 @@ def start(self) -> "MCPClient": raise MCPClientInitializationError("the client initialization failed") from e return self + # ToolProvider interface methods (experimental, as ToolProvider is experimental) + async def load_tools(self, **kwargs: Any) -> Sequence[AgentTool]: + """Load and return tools from the MCP server. + + This method implements the ToolProvider interface by loading tools + from the MCP server and caching them for reuse. + + Args: + **kwargs: Additional arguments for future compatibility. + + Returns: + List of AgentTool instances from the MCP server. + """ + logger.debug( + "started=<%s>, cached_tools=<%s> | loading tools", + self._tool_provider_started, + self._loaded_tools is not None, + ) + + if not self._tool_provider_started: + try: + logger.debug("starting MCP client") + self.start() + self._tool_provider_started = True + logger.debug("MCP client started successfully") + except Exception as e: + logger.error("error=<%s> | failed to start MCP client", e) + raise ToolProviderException(f"Failed to start MCP client: {e}") from e + + if self._loaded_tools is None: + logger.debug("loading tools from MCP server") + self._loaded_tools = [] + pagination_token = None + page_count = 0 + + while True: + logger.debug("page=<%d>, token=<%s> | fetching tools page", page_count, pagination_token) + paginated_tools = self.list_tools_sync(pagination_token) + + # Process each tool as we get it + for tool in paginated_tools: + # Apply filters + if self._should_include_tool(tool): + # Apply prefix if needed + processed_tool = self._apply_prefix(tool) + self._loaded_tools.append(processed_tool) + + logger.debug( + "page=<%d>, page_tools=<%d>, total_filtered=<%d> | processed page", + page_count, + len(paginated_tools), + len(self._loaded_tools), + ) + + pagination_token = paginated_tools.pagination_token + page_count += 1 + + if pagination_token is None: + break + + logger.debug("final_tools=<%d> | loading complete", len(self._loaded_tools)) + + return self._loaded_tools + + async def add_consumer(self, id: Any, **kwargs: Any) -> None: + """Add a consumer to this tool provider.""" + self._consumers.add(id) + logger.debug("added provider consumer, count=%d", len(self._consumers)) + + async def remove_consumer(self, id: Any, **kwargs: Any) -> None: + """Remove a consumer from this tool provider.""" + self._consumers.discard(id) + logger.debug("removed provider consumer, count=%d", len(self._consumers)) + + if not self._consumers and self._tool_provider_started: + logger.debug("no consumers remaining, cleaning up") + try: + self.stop(None, None, None) + self._tool_provider_started = False + self._loaded_tools = None + except Exception as e: + logger.error("error=<%s> | failed to cleanup MCP client", e) + raise ToolProviderException(f"Failed to cleanup MCP client: {e}") from e + + # MCP-specific methods + def stop( self, exc_type: Optional[BaseException], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] ) -> None: @@ -186,6 +308,9 @@ async def _set_close_event() -> None: self._background_thread_session = None self._background_thread_event_loop = None self._session_id = uuid.uuid4() + self._loaded_tools = None + self._tool_provider_started = False + self._consumers = set() def list_tools_sync(self, pagination_token: Optional[str] = None) -> PaginatedList[MCPAgentTool]: """Synchronously retrieves the list of available tools from the MCP server. @@ -478,5 +603,48 @@ def _invoke_on_background_thread(self, coro: Coroutine[Any, Any, T]) -> futures. raise MCPClientInitializationError("the client session was not initialized") return asyncio.run_coroutine_threadsafe(coro=coro, loop=self._background_thread_event_loop) + def _should_include_tool(self, tool: MCPAgentTool) -> bool: + """Check if a tool should be included based on allowed/rejected filters.""" + if not self._tool_filters: + return True + + # Apply allowed filter + if "allowed" in self._tool_filters: + if not self._matches_patterns(tool, self._tool_filters["allowed"]): + return False + + # Apply rejected filter + if "rejected" in self._tool_filters: + if self._matches_patterns(tool, self._tool_filters["rejected"]): + return False + + return True + + def _apply_prefix(self, tool: MCPAgentTool) -> MCPAgentTool: + """Apply prefix to a single tool if needed.""" + if not self._prefix: + return tool + + # Create new tool with prefixed agent name but preserve original MCP name + old_name = tool.tool_name + new_agent_name = f"{self._prefix}_{tool.mcp_tool.name}" + new_tool = MCPAgentTool(tool.mcp_tool, tool.mcp_client, name_override=new_agent_name) + logger.debug("tool_rename=<%s->%s> | renamed tool", old_name, new_agent_name) + return new_tool + + def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolFilterPattern]) -> bool: + """Check if tool matches any of the given patterns.""" + for pattern in patterns: + if callable(pattern): + if pattern(tool): + return True + elif isinstance(pattern, Pattern): + if pattern.match(tool.tool_name): + return True + elif isinstance(pattern, str): + if pattern == tool.tool_name: + return True + return False + def _is_session_active(self) -> bool: return self._background_thread is not None and self._background_thread.is_alive() diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 3631c9dee..39fd508d0 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -8,16 +8,19 @@ import logging import os import sys +import uuid import warnings from importlib import import_module, util from os.path import expanduser from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Sequence from typing_extensions import TypedDict, cast from strands.tools.decorator import DecoratedFunctionTool +from .._async import run_async +from ..experimental.tools import ToolProvider from ..types.tools import AgentTool, ToolSpec from .loader import load_tool_from_string, load_tools_from_module from .tools import PythonAgentTool, normalize_schema, normalize_tool_spec @@ -36,6 +39,8 @@ def __init__(self) -> None: self.registry: Dict[str, AgentTool] = {} self.dynamic_tools: Dict[str, AgentTool] = {} self.tool_config: Optional[Dict[str, Any]] = None + self.tool_providers: List[ToolProvider] = [] + self._registry_id = str(uuid.uuid4()) def process_tools(self, tools: List[Any]) -> List[str]: """Process tools list. @@ -118,6 +123,21 @@ def add_tool(tool: Any) -> None: elif isinstance(tool, Iterable) and not isinstance(tool, (str, bytes, bytearray)): for t in tool: add_tool(t) + + # Case 5: ToolProvider + elif isinstance(tool, ToolProvider): + self.tool_providers.append(tool) + + async def get_tools_and_register_consumer() -> Sequence[AgentTool]: + provider_tools = await tool.load_tools() + await tool.add_consumer(self._registry_id) + return provider_tools + + provider_tools = run_async(get_tools_and_register_consumer) + + for provider_tool in provider_tools: + self.register_tool(provider_tool) + tool_names.append(provider_tool.tool_name) else: logger.warning("tool=<%s> | unrecognized tool specification", tool) @@ -640,3 +660,16 @@ def _scan_module_for_tools(self, module: Any) -> List[AgentTool]: logger.warning("tool_name=<%s> | failed to create function tool | %s", name, e) return tools + + async def cleanup_async(self, **kwargs: Any) -> None: + """Clean up all tool providers in this registry.""" + for provider in self.tool_providers: + try: + await provider.remove_consumer(self._registry_id) + logger.debug("provider=<%s> | removed provider consumer", type(provider).__name__) + except Exception as e: + logger.warning( + "provider=<%s>, error=<%s> | failed to remove provider consumer", + type(provider).__name__, + e, + ) diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 90f2b8d7f..adff7add7 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -75,3 +75,9 @@ class SessionException(Exception): """Exception raised when session operations fail.""" pass + + +class ToolProviderException(Exception): + """Exception raised when a tool provider fails to load or cleanup tools.""" + + pass diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 2cd87c26d..83ed4d324 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -887,8 +887,199 @@ def test_agent_tool_names(tools, agent): assert actual == expected +def test_agent_cleanup(agent): + """Test that agent cleanup method works correctly.""" + # Create mock tool provider + mock_provider = unittest.mock.MagicMock() + mock_provider.cleanup = unittest.mock.AsyncMock() + + # Add provider to agent's tool registry + agent.tool_registry.tool_providers = [mock_provider] + + with unittest.mock.patch("strands.agent.agent.run_async") as mock_run_async: + agent.cleanup() + + # Verify run_async was called once (for cleanup_async) + mock_run_async.assert_called_once() + # Get the function that was passed to run_async and verify it's cleanup_async + called_func = mock_run_async.call_args[0][0] + assert called_func == agent.cleanup_async + + +@pytest.mark.asyncio +async def test_agent_cleanup_async(agent): + """Test that agent cleanup_async method works correctly.""" + # Create mock tool provider + mock_provider = unittest.mock.MagicMock() + mock_provider.remove_consumer = unittest.mock.AsyncMock() + + # Add provider to agent's tool registry + agent.tool_registry.tool_providers = [mock_provider] + + await agent.cleanup_async() + + # Verify provider remove_consumer was called + mock_provider.remove_consumer.assert_called_once_with(agent.tool_registry._registry_id) + # Verify cleanup was marked as called + assert agent._cleanup_called is True + + +@pytest.mark.asyncio +async def test_agent_cleanup_async_handles_exceptions(agent): + """Test that agent cleanup_async handles exceptions gracefully.""" + # Create mock tool providers, one that raises an exception + mock_provider1 = unittest.mock.MagicMock() + mock_provider1.remove_consumer = unittest.mock.AsyncMock() + mock_provider2 = unittest.mock.MagicMock() + mock_provider2.remove_consumer = unittest.mock.AsyncMock(side_effect=Exception("Cleanup failed")) + + # Add providers to agent's tool registry + agent.tool_registry.tool_providers = [mock_provider1, mock_provider2] + + # Should not raise exception despite provider2 failing + await agent.cleanup_async() + + # Verify both providers were attempted + mock_provider1.remove_consumer.assert_called_once() + mock_provider2.remove_consumer.assert_called_once() + # Verify cleanup was marked as called + assert agent._cleanup_called is True + + +@pytest.mark.asyncio +async def test_agent_cleanup_async_idempotent(agent): + """Test that calling cleanup_async multiple times is safe.""" + # Create mock tool provider + mock_provider = unittest.mock.MagicMock() + mock_provider.remove_consumer = unittest.mock.AsyncMock() + + # Add provider to agent's tool registry + agent.tool_registry.tool_providers = [mock_provider] + + # Call cleanup_async twice + await agent.cleanup_async() + await agent.cleanup_async() + + # Verify provider remove_consumer was only called once due to idempotency + mock_provider.remove_consumer.assert_called_once() + + +@pytest.mark.asyncio +async def test_agent_cleanup_async_with_no_providers(agent): + """Test that agent cleanup_async works when there are no tool providers.""" + # Ensure no providers + agent.tool_registry.tool_providers = [] + + # Should not raise any exceptions + await agent.cleanup_async() + + # Verify cleanup was marked as called + assert agent._cleanup_called is True + + def test_agent__del__(agent): - del agent + """Test that agent destructor calls cleanup.""" + # Add a mock tool provider so cleanup will be called + mock_provider = unittest.mock.MagicMock() + agent.tool_registry.tool_providers = [mock_provider] + + with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: + agent.__del__() + mock_cleanup.assert_called_once() + + +def test_agent__del__handles_cleanup_exception(agent): + """Test that agent destructor handles cleanup exceptions.""" + with unittest.mock.patch.object(agent, "cleanup", side_effect=Exception("Cleanup failed")): + # Should not raise exception + agent.__del__() + + +def test_agent_cleanup_idempotent(agent): + """Test that calling cleanup multiple times is safe.""" + # Create mock tool provider + mock_provider = unittest.mock.MagicMock() + mock_provider.remove_consumer = unittest.mock.AsyncMock() + + # Add provider to agent's tool registry + agent.tool_registry.tool_providers = [mock_provider] + + # Call cleanup twice + agent.cleanup() + agent.cleanup() + + # Verify provider remove_consumer was only called once due to idempotency + mock_provider.remove_consumer.assert_called_once() + + +def test_agent_cleanup_early_return_avoids_thread_spawn(agent): + """Test that cleanup returns early when already called, avoiding thread spawn cost.""" + # Mark cleanup as already called + agent._cleanup_called = True + + with unittest.mock.patch("strands.agent.agent.run_async") as mock_run_async: + agent.cleanup() + + # Verify run_async was not called since cleanup already happened + mock_run_async.assert_not_called() + + +def test_agent__del__emits_warning_for_automatic_cleanup(): + """Test that __del__ emits warning when cleanup wasn't called manually.""" + # Create a fresh agent for this test to avoid fixture lifecycle issues + agent = Agent() + + # Add a mock tool provider so cleanup will be called + mock_provider = unittest.mock.MagicMock() + agent.tool_registry.tool_providers = [mock_provider] + + with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: + with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: + agent.__del__() + + # Verify warning was logged + mock_logger.warning.assert_called_once() + warning_call = mock_logger.warning.call_args[0] + assert "Agent cleanup called via __del__" in warning_call[0] + # Verify cleanup was called + mock_cleanup.assert_called_once() + + +def test_agent__del__no_warning_after_manual_cleanup(): + """Test that __del__ doesn't emit warning if cleanup was called manually.""" + # Create a fresh agent for this test + from strands import Agent + + agent = Agent() + + # Call cleanup manually first + with unittest.mock.patch.object(agent, "cleanup_async"): + agent.cleanup() + + with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: + agent.__del__() + + # Verify no warning was logged + mock_logger.warning.assert_not_called() + + +def test_agent__del__no_warning_when_no_tool_providers(): + """Test that __del__ doesn't emit warning when there are no tool providers.""" + # Create a fresh agent for this test + from strands import Agent + + agent = Agent() + + # Ensure no tool providers + agent.tool_registry.tool_providers = [] + + with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: + with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: + agent.__del__() + + # Verify no warning was logged and cleanup wasn't called + mock_logger.warning.assert_not_called() + mock_cleanup.assert_not_called() def test_agent_init_with_no_model_or_model_id(): diff --git a/tests/strands/experimental/tools/__init__.py b/tests/strands/experimental/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/test_async.py b/tests/strands/test_async.py new file mode 100644 index 000000000..2a98a953c --- /dev/null +++ b/tests/strands/test_async.py @@ -0,0 +1,25 @@ +"""Tests for _async module.""" + +import pytest + +from strands._async import run_async + + +def test_run_async_with_return_value(): + """Test run_async returns correct value.""" + + async def async_with_value(): + return 42 + + result = run_async(async_with_value) + assert result == 42 + + +def test_run_async_exception_propagation(): + """Test that exceptions are properly propagated.""" + + async def async_with_exception(): + raise ValueError("test exception") + + with pytest.raises(ValueError, match="test exception"): + run_async(async_with_exception) diff --git a/tests/strands/tools/mcp/test_mcp_client_tool_provider.py b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py new file mode 100644 index 000000000..094cc05b1 --- /dev/null +++ b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py @@ -0,0 +1,320 @@ +"""Unit tests for MCPClient ToolProvider functionality.""" + +import re +from unittest.mock import MagicMock, patch + +import pytest + +from strands.tools.mcp import MCPClient +from strands.tools.mcp.mcp_agent_tool import MCPAgentTool +from strands.tools.mcp.mcp_client import ToolFilters +from strands.types import PaginatedList +from strands.types.exceptions import ToolProviderException + + +@pytest.fixture +def mock_transport(): + """Create a mock transport callable.""" + + def transport(): + read_stream = MagicMock() + write_stream = MagicMock() + return read_stream, write_stream + + return transport + + +@pytest.fixture +def mock_mcp_tool(): + """Create a mock MCP tool.""" + tool = MagicMock() + tool.name = "test_tool" + return tool + + +@pytest.fixture +def mock_agent_tool(mock_mcp_tool): + """Create a mock MCPAgentTool.""" + agent_tool = MagicMock(spec=MCPAgentTool) + agent_tool.tool_name = "test_tool" + agent_tool.mcp_tool = mock_mcp_tool + return agent_tool + + +def create_mock_tool(name: str) -> MagicMock: + """Helper to create mock tools with specific names.""" + tool = MagicMock(spec=MCPAgentTool) + tool.tool_name = name + tool.mcp_tool = MagicMock() + tool.mcp_tool.name = name + return tool + + +def test_init_with_tool_filters_and_prefix(mock_transport): + """Test initialization with tool filters and prefix.""" + filters = {"allowed": ["tool1"]} + prefix = "test_prefix" + + client = MCPClient(mock_transport, tool_filters=filters, prefix=prefix) + + assert client._tool_filters == filters + assert client._prefix == prefix + assert client._loaded_tools is None + assert client._tool_provider_started is False + + +@pytest.mark.asyncio +async def test_load_tools_starts_client_when_not_started(mock_transport, mock_agent_tool): + """Test that load_tools starts the client when not already started.""" + client = MCPClient(mock_transport) + + with patch.object(client, "start") as mock_start, patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([mock_agent_tool]) + + tools = await client.load_tools() + + mock_start.assert_called_once() + assert client._tool_provider_started is True + assert len(tools) == 1 + assert tools[0] is mock_agent_tool + + +@pytest.mark.asyncio +async def test_load_tools_does_not_start_client_when_already_started(mock_transport, mock_agent_tool): + """Test that load_tools does not start client when already started.""" + client = MCPClient(mock_transport) + client._tool_provider_started = True + + with patch.object(client, "start") as mock_start, patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([mock_agent_tool]) + + tools = await client.load_tools() + + mock_start.assert_not_called() + assert len(tools) == 1 + + +@pytest.mark.asyncio +async def test_load_tools_raises_exception_on_client_start_failure(mock_transport): + """Test that load_tools raises ToolProviderException when client start fails.""" + client = MCPClient(mock_transport) + + with patch.object(client, "start") as mock_start: + mock_start.side_effect = Exception("Client start failed") + + with pytest.raises(ToolProviderException, match="Failed to start MCP client: Client start failed"): + await client.load_tools() + + +@pytest.mark.asyncio +async def test_load_tools_caches_tools(mock_transport, mock_agent_tool): + """Test that load_tools caches tools and doesn't reload them.""" + client = MCPClient(mock_transport) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([mock_agent_tool]) + + # First call + tools1 = await client.load_tools() + # Second call + tools2 = await client.load_tools() + + # Client should only be called once + mock_list_tools.assert_called_once() + assert tools1 is tools2 + + +@pytest.mark.asyncio +async def test_load_tools_handles_pagination(mock_transport): + """Test that load_tools handles pagination correctly.""" + tool1 = create_mock_tool("tool1") + tool2 = create_mock_tool("tool2") + + client = MCPClient(mock_transport) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + # Mock pagination: first page returns tool1 with next token, second page returns tool2 with no token + mock_list_tools.side_effect = [ + PaginatedList([tool1], token="page2"), + PaginatedList([tool2], token=None), + ] + + tools = await client.load_tools() + + # Should have called list_tools_sync twice + assert mock_list_tools.call_count == 2 + # First call with no token, second call with "page2" token + mock_list_tools.assert_any_call(None) + mock_list_tools.assert_any_call("page2") + + assert len(tools) == 2 + assert tools[0] is tool1 + assert tools[1] is tool2 + + +@pytest.mark.asyncio +async def test_allowed_filter_string_match(mock_transport): + """Test allowed filter with string matching.""" + tool1 = create_mock_tool("allowed_tool") + tool2 = create_mock_tool("rejected_tool") + + filters: ToolFilters = {"allowed": ["allowed_tool"]} + client = MCPClient(mock_transport, tool_filters=filters) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([tool1, tool2]) + + tools = await client.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "allowed_tool" + + +@pytest.mark.asyncio +async def test_allowed_filter_regex_match(mock_transport): + """Test allowed filter with regex matching.""" + tool1 = create_mock_tool("echo_tool") + tool2 = create_mock_tool("other_tool") + + filters: ToolFilters = {"allowed": [re.compile(r"echo_.*")]} + client = MCPClient(mock_transport, tool_filters=filters) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([tool1, tool2]) + + tools = await client.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "echo_tool" + + +@pytest.mark.asyncio +async def test_allowed_filter_callable_match(mock_transport): + """Test allowed filter with callable matching.""" + tool1 = create_mock_tool("short") + tool2 = create_mock_tool("very_long_tool_name") + + def short_names_only(tool) -> bool: + return len(tool.tool_name) <= 10 + + filters: ToolFilters = {"allowed": [short_names_only]} + client = MCPClient(mock_transport, tool_filters=filters) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([tool1, tool2]) + + tools = await client.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "short" + + +@pytest.mark.asyncio +async def test_rejected_filter_string_match(mock_transport): + """Test rejected filter with string matching.""" + tool1 = create_mock_tool("good_tool") + tool2 = create_mock_tool("bad_tool") + + filters: ToolFilters = {"rejected": ["bad_tool"]} + client = MCPClient(mock_transport, tool_filters=filters) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([tool1, tool2]) + + tools = await client.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "good_tool" + + +@pytest.mark.asyncio +async def test_prefix_renames_tools(mock_transport): + """Test that prefix properly renames tools.""" + original_tool = create_mock_tool("original_name") + original_tool.mcp_client = MagicMock() + + client = MCPClient(mock_transport, prefix="prefix") + client._tool_provider_started = True + + with ( + patch.object(client, "list_tools_sync") as mock_list_tools, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + mock_list_tools.return_value = PaginatedList([original_tool]) + + new_tool = MagicMock(spec=MCPAgentTool) + new_tool.tool_name = "prefix_original_name" + mock_agent_tool_class.return_value = new_tool + + tools = await client.load_tools() + + # Should create new MCPAgentTool with prefixed name + mock_agent_tool_class.assert_called_once_with( + original_tool.mcp_tool, original_tool.mcp_client, name_override="prefix_original_name" + ) + + assert len(tools) == 1 + assert tools[0] is new_tool + + +@pytest.mark.asyncio +async def test_add_consumer(mock_transport): + """Test adding a provider consumer.""" + client = MCPClient(mock_transport) + + await client.add_consumer("consumer1") + + assert "consumer1" in client._consumers + assert len(client._consumers) == 1 + + +@pytest.mark.asyncio +async def test_remove_consumer_without_cleanup(mock_transport): + """Test removing a provider consumer without triggering cleanup.""" + client = MCPClient(mock_transport) + client._consumers.add("consumer1") + client._consumers.add("consumer2") + client._tool_provider_started = True + + await client.remove_consumer("consumer1") + + assert "consumer1" not in client._consumers + assert "consumer2" in client._consumers + assert client._tool_provider_started is True # Should not cleanup yet + + +@pytest.mark.asyncio +async def test_remove_consumer_with_cleanup(mock_transport): + """Test removing the last provider consumer triggers cleanup.""" + client = MCPClient(mock_transport) + client._consumers.add("consumer1") + client._tool_provider_started = True + client._loaded_tools = [MagicMock()] + + with patch.object(client, "stop") as mock_stop: + await client.remove_consumer("consumer1") + + assert len(client._consumers) == 0 + assert client._tool_provider_started is False + assert client._loaded_tools is None + mock_stop.assert_called_once_with(None, None, None) + + +@pytest.mark.asyncio +async def test_remove_consumer_cleanup_failure(mock_transport): + """Test that remove_consumer raises ToolProviderException when cleanup fails.""" + client = MCPClient(mock_transport) + client._consumers.add("consumer1") + client._tool_provider_started = True + + with patch.object(client, "stop") as mock_stop: + mock_stop.side_effect = Exception("Cleanup failed") + + with pytest.raises(ToolProviderException, match="Failed to cleanup MCP client: Cleanup failed"): + await client.remove_consumer("consumer1") diff --git a/tests/strands/tools/test_registry_tool_provider.py b/tests/strands/tools/test_registry_tool_provider.py new file mode 100644 index 000000000..ca10862dc --- /dev/null +++ b/tests/strands/tools/test_registry_tool_provider.py @@ -0,0 +1,296 @@ +"""Unit tests for ToolRegistry ToolProvider functionality.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from strands.experimental.tools.tool_provider import ToolProvider +from strands.tools.registry import ToolRegistry +from strands.types.tools import AgentTool + + +class MockToolProvider(ToolProvider): + """Mock ToolProvider for testing.""" + + def __init__(self, tools=None, cleanup_error=None): + self._tools = tools or [] + self._cleanup_error = cleanup_error + self.cleanup_called = False + self.remove_consumer_called = False + self.remove_consumer_id = None + self.add_consumer_called = False + self.add_consumer_id = None + + async def load_tools(self): + return self._tools + + async def cleanup(self): + self.cleanup_called = True + if self._cleanup_error: + raise self._cleanup_error + + async def add_consumer(self, consumer_id): + self.add_consumer_called = True + self.add_consumer_id = consumer_id + + async def remove_consumer(self, consumer_id): + self.remove_consumer_called = True + self.remove_consumer_id = consumer_id + + +class TestToolRegistryToolProvider: + """Test ToolRegistry integration with ToolProvider.""" + + def test_process_tools_with_tool_provider(self): + """Test that process_tools handles ToolProvider correctly.""" + # Create mock tools + mock_tool1 = MagicMock(spec=AgentTool) + mock_tool1.tool_name = "provider_tool_1" + mock_tool2 = MagicMock(spec=AgentTool) + mock_tool2.tool_name = "provider_tool_2" + + # Create mock provider + provider = MockToolProvider([mock_tool1, mock_tool2]) + + registry = ToolRegistry() + + with patch("strands.tools.registry.run_async") as mock_run_async: + # Mock run_async to return the tools directly + mock_run_async.return_value = [mock_tool1, mock_tool2] + + tool_names = registry.process_tools([provider]) + + # Verify run_async was called with the provider's load_tools method + mock_run_async.assert_called_once() + + # Verify tools were registered + assert "provider_tool_1" in tool_names + assert "provider_tool_2" in tool_names + assert len(tool_names) == 2 + + # Verify provider was tracked + assert provider in registry.tool_providers + + # Verify tools are in registry + assert registry.registry["provider_tool_1"] is mock_tool1 + assert registry.registry["provider_tool_2"] is mock_tool2 + + def test_process_tools_with_multiple_providers(self): + """Test that process_tools handles multiple ToolProviders.""" + # Create mock tools for first provider + mock_tool1 = MagicMock(spec=AgentTool) + mock_tool1.tool_name = "provider1_tool" + provider1 = MockToolProvider([mock_tool1]) + + # Create mock tools for second provider + mock_tool2 = MagicMock(spec=AgentTool) + mock_tool2.tool_name = "provider2_tool" + provider2 = MockToolProvider([mock_tool2]) + + registry = ToolRegistry() + + with patch("strands.tools.registry.run_async") as mock_run_async: + # Mock run_async to return appropriate tools for each call + mock_run_async.side_effect = [[mock_tool1], [mock_tool2]] + + tool_names = registry.process_tools([provider1, provider2]) + + # Verify run_async was called twice + assert mock_run_async.call_count == 2 + + # Verify all tools were registered + assert "provider1_tool" in tool_names + assert "provider2_tool" in tool_names + assert len(tool_names) == 2 + + # Verify both providers were tracked + assert provider1 in registry.tool_providers + assert provider2 in registry.tool_providers + assert len(registry.tool_providers) == 2 + + def test_process_tools_with_mixed_tools_and_providers(self): + """Test that process_tools handles mix of regular tools and providers.""" + # Create regular tool + regular_tool = MagicMock(spec=AgentTool) + regular_tool.tool_name = "regular_tool" + + # Create provider tool + provider_tool = MagicMock(spec=AgentTool) + provider_tool.tool_name = "provider_tool" + provider = MockToolProvider([provider_tool]) + + registry = ToolRegistry() + + with patch("strands.tools.registry.run_async") as mock_run_async: + mock_run_async.return_value = [provider_tool] + + tool_names = registry.process_tools([regular_tool, provider]) + + # Verify both tools were registered + assert "regular_tool" in tool_names + assert "provider_tool" in tool_names + assert len(tool_names) == 2 + + # Verify only provider was tracked + assert provider in registry.tool_providers + assert len(registry.tool_providers) == 1 + + def test_process_tools_with_empty_provider(self): + """Test that process_tools handles provider with no tools.""" + provider = MockToolProvider([]) # Empty tools list + + registry = ToolRegistry() + + with patch("strands.tools.registry.run_async") as mock_run_async: + mock_run_async.return_value = [] + + tool_names = registry.process_tools([provider]) + + # Verify no tools were registered + assert not tool_names + + # Verify provider was still tracked + assert provider in registry.tool_providers + + def test_tool_providers_public_access(self): + """Test that tool_providers can be accessed directly.""" + provider1 = MockToolProvider() + provider2 = MockToolProvider() + + registry = ToolRegistry() + registry.tool_providers = [provider1, provider2] + + # Verify direct access works + assert len(registry.tool_providers) == 2 + assert provider1 in registry.tool_providers + assert provider2 in registry.tool_providers + + def test_tool_providers_empty_by_default(self): + """Test that tool_providers is empty by default.""" + registry = ToolRegistry() + + assert not registry.tool_providers + assert isinstance(registry.tool_providers, list) + + def test_process_tools_provider_load_exception(self): + """Test that process_tools handles exceptions from provider.load_tools().""" + provider = MockToolProvider() + + registry = ToolRegistry() + + with patch("strands.tools.registry.run_async") as mock_run_async: + # Make load_tools raise an exception + mock_run_async.side_effect = Exception("Load tools failed") + + # Should raise the exception from load_tools + with pytest.raises(Exception, match="Load tools failed"): + registry.process_tools([provider]) + + # Provider should still be tracked even if load_tools failed + assert provider in registry.tool_providers + + def test_tool_provider_tracking_persistence(self): + """Test that tool providers are tracked across multiple process_tools calls.""" + provider1 = MockToolProvider([MagicMock(spec=AgentTool, tool_name="tool1")]) + provider2 = MockToolProvider([MagicMock(spec=AgentTool, tool_name="tool2")]) + + registry = ToolRegistry() + + with patch("strands.tools.registry.run_async") as mock_run_async: + mock_run_async.side_effect = [ + [MagicMock(spec=AgentTool, tool_name="tool1")], + [MagicMock(spec=AgentTool, tool_name="tool2")], + ] + + # Process first provider + registry.process_tools([provider1]) + assert len(registry.tool_providers) == 1 + assert provider1 in registry.tool_providers + + # Process second provider + registry.process_tools([provider2]) + assert len(registry.tool_providers) == 2 + assert provider1 in registry.tool_providers + assert provider2 in registry.tool_providers + + def test_process_tools_provider_async_optimization(self): + """Test that load_tools and add_consumer are called in same async context.""" + mock_tool = MagicMock(spec=AgentTool) + mock_tool.tool_name = "test_tool" + + class TestProvider(ToolProvider): + def __init__(self): + self.load_tools_called = False + self.add_consumer_called = False + self.add_consumer_id = None + + async def load_tools(self): + self.load_tools_called = True + return [mock_tool] + + async def add_consumer(self, consumer_id): + self.add_consumer_called = True + self.add_consumer_id = consumer_id + + async def remove_consumer(self, consumer_id): + pass + + provider = TestProvider() + registry = ToolRegistry() + + # Process the provider - this should call both methods in same async context + tool_names = registry.process_tools([provider]) + + # Verify both methods were called + assert provider.load_tools_called + assert provider.add_consumer_called + assert provider.add_consumer_id == registry._registry_id + + # Verify tool was registered + assert "test_tool" in tool_names + assert provider in registry.tool_providers + + @pytest.mark.asyncio + async def test_registry_cleanup(self): + """Test that registry cleanup calls remove_consumer on all providers.""" + provider1 = MockToolProvider() + provider2 = MockToolProvider() + + registry = ToolRegistry() + registry.tool_providers = [provider1, provider2] + + await registry.cleanup_async() + + # Verify both providers had remove_consumer called + assert provider1.remove_consumer_called + assert provider2.remove_consumer_called + + @pytest.mark.asyncio + async def test_registry_cleanup_with_provider_consumer_removal(self): + """Test that cleanup removes provider consumers correctly.""" + + class TestProvider(ToolProvider): + def __init__(self): + self.remove_consumer_called = False + self.remove_consumer_id = None + + async def load_tools(self): + return [] + + async def add_consumer(self, consumer_id): + pass + + async def remove_consumer(self, consumer_id): + self.remove_consumer_called = True + self.remove_consumer_id = consumer_id + + provider = TestProvider() + registry = ToolRegistry() + registry.tool_providers = [provider] + + # Call cleanup + await registry.cleanup_async() + + # Verify remove_consumer was called with correct ID + assert provider.remove_consumer_called + assert provider.remove_consumer_id == registry._registry_id diff --git a/tests_integ/mcp/test_mcp_tool_provider.py b/tests_integ/mcp/test_mcp_tool_provider.py new file mode 100644 index 000000000..b45b38b86 --- /dev/null +++ b/tests_integ/mcp/test_mcp_tool_provider.py @@ -0,0 +1,184 @@ +"""Integration tests for MCPClient ToolProvider functionality with real MCP server.""" + +import logging +import re + +import pytest +from mcp import StdioServerParameters, stdio_client + +from strands import Agent +from strands.tools.mcp import MCPClient +from strands.tools.mcp.mcp_client import ToolFilters + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger(__name__) + + +def test_mcp_client_tool_provider_filters(): + """Test MCPClient with various filter combinations.""" + + def short_names_only(tool) -> bool: + return len(tool.tool_name) <= 20 + + filters: ToolFilters = { + "allowed": ["echo", re.compile(r"echo_with_.*"), short_names_only], + "rejected": ["echo_with_delay"], + } + + client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters=filters, + prefix="test", + ) + + agent = Agent(tools=[client]) + tool_names = agent.tool_names + + assert "echo_with_delay" not in [name.replace("test_", "") for name in tool_names] + assert all(name.startswith("test_") for name in tool_names) + + agent.cleanup() + + +def test_mcp_client_tool_provider_execution(): + """Test that MCPClient works with agent execution.""" + filters: ToolFilters = {"allowed": ["echo"]} + client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters=filters, + prefix="filtered", + ) + + agent = Agent(tools=[client]) + + assert "filtered_echo" in agent.tool_names + + tool_result = agent.tool.filtered_echo(to_echo="Hello World") + assert "Hello World" in str(tool_result) + + result = agent("Use the filtered_echo tool to echo whats inside the tags <>Integration Test") + assert "Integration Test" in str(result) + + assert agent.event_loop_metrics.tool_metrics["filtered_echo"].call_count == 1 + assert agent.event_loop_metrics.tool_metrics["filtered_echo"].success_count == 1 + + agent.cleanup() + + +def test_mcp_client_tool_provider_reuse(): + """Test that a single MCPClient can be used across multiple agents.""" + filters: ToolFilters = {"allowed": ["echo"]} + client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters=filters, + prefix="shared", + ) + + agent1 = Agent(tools=[client]) + assert "shared_echo" in agent1.tool_names + + result1 = agent1.tool.shared_echo(to_echo="Agent 1") + assert "Agent 1" in str(result1) + + agent2 = Agent(tools=[client]) + assert "shared_echo" in agent2.tool_names + + result2 = agent2.tool.shared_echo(to_echo="Agent 2") + assert "Agent 2" in str(result2) + + assert len(agent1.tool_names) == len(agent2.tool_names) + assert agent1.tool_names == agent2.tool_names + + agent1.cleanup() + agent2.cleanup() + + +def test_mcp_client_reference_counting(): + """Test that MCPClient uses reference counting - cleanup only happens when last consumer is removed.""" + filters: ToolFilters = {"allowed": ["echo"]} + client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters=filters, + prefix="ref", + ) + + # Create two agents with the same client + agent1 = Agent(tools=[client]) + agent2 = Agent(tools=[client]) + + # Both should have the tool + assert "ref_echo" in agent1.tool_names + assert "ref_echo" in agent2.tool_names + + # Agent 1 uses the tool + result1 = agent1.tool.ref_echo(to_echo="Agent 1 Test") + assert "Agent 1 Test" in str(result1) + + # Agent 1 cleans up - client should still be active for agent 2 + agent1.cleanup() + + # Agent 2 should still be able to use the tool + result2 = agent2.tool.ref_echo(to_echo="Agent 2 Test") + assert "Agent 2 Test" in str(result2) + + # Agent 2 cleans up - now client should be fully cleaned up + agent2.cleanup() + + +def test_mcp_client_multiple_servers(): + """Test MCPClient with multiple MCP servers simultaneously.""" + client1 = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters={"allowed": ["echo"]}, + prefix="server1", + ) + client2 = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters={"allowed": ["echo_with_structured_content"]}, + prefix="server2", + ) + + agent = Agent(tools=[client1, client2]) + + assert "server1_echo" in agent.tool_names + assert "server2_echo_with_structured_content" in agent.tool_names + assert len(agent.tool_names) == 2 + + result1 = agent.tool.server1_echo(to_echo="From Server 1") + assert "From Server 1" in str(result1) + + result2 = agent.tool.server2_echo_with_structured_content(to_echo="From Server 2") + assert "From Server 2" in str(result2) + + agent.cleanup() + + +def test_mcp_client_server_startup_failure(): + """Test that MCPClient handles server startup failure gracefully without hanging.""" + from strands.types.exceptions import ToolProviderException + + failing_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="nonexistent_command", args=["--invalid"])), + startup_timeout=2, + ) + + with pytest.raises(ValueError, match="Failed to load tool") as exc_info: + Agent(tools=[failing_client]) + + assert isinstance(exc_info.value.__cause__, ToolProviderException) + + +def test_mcp_client_server_connection_timeout(): + """Test that MCPClient times out gracefully when server hangs during startup.""" + from strands.types.exceptions import ToolProviderException + + hanging_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="sleep", args=["10"])), + startup_timeout=1, + ) + + with pytest.raises(ValueError, match="Failed to load tool") as exc_info: + Agent(tools=[hanging_client]) + + assert isinstance(exc_info.value.__cause__, ToolProviderException)