From 9837e6cf43df038e1eb8d4290c73a225ba389d5e Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 17 Sep 2025 21:18:02 -0400 Subject: [PATCH 1/7] feat(mcp): add experimental agent managed connection support --- src/strands/_async.py | 31 ++ src/strands/agent/agent.py | 98 +++-- src/strands/experimental/__init__.py | 4 + src/strands/experimental/tools/__init__.py | 5 + .../experimental/tools/mcp/__init__.py | 5 + .../tools/mcp/mcp_tool_provider.py | 180 +++++++++ .../experimental/tools/tool_provider.py | 32 ++ src/strands/multiagent/base.py | 10 +- src/strands/multiagent/graph.py | 9 +- src/strands/multiagent/swarm.py | 12 +- src/strands/tools/mcp/mcp_agent_tool.py | 13 +- src/strands/tools/registry.py | 13 + src/strands/types/exceptions.py | 6 + tests/strands/agent/test_agent.py | 178 ++++++++- tests/strands/experimental/tools/__init__.py | 0 .../experimental/tools/mcp/__init__.py | 0 .../tools/mcp/test_mcp_tool_provider.py | 375 ++++++++++++++++++ tests/strands/test_async.py | 25 ++ .../tools/test_registry_tool_provider.py | 202 ++++++++++ tests_integ/test_mcp_tool_provider.py | 171 ++++++++ 20 files changed, 1316 insertions(+), 53 deletions(-) create mode 100644 src/strands/_async.py create mode 100644 src/strands/experimental/tools/__init__.py create mode 100644 src/strands/experimental/tools/mcp/__init__.py create mode 100644 src/strands/experimental/tools/mcp/mcp_tool_provider.py create mode 100644 src/strands/experimental/tools/tool_provider.py create mode 100644 tests/strands/experimental/tools/__init__.py create mode 100644 tests/strands/experimental/tools/mcp/__init__.py create mode 100644 tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py create mode 100644 tests/strands/test_async.py create mode 100644 tests/strands/tools/test_registry_tool_provider.py create mode 100644 tests_integ/test_mcp_tool_provider.py diff --git a/src/strands/_async.py b/src/strands/_async.py new file mode 100644 index 000000000..976487c37 --- /dev/null +++ b/src/strands/_async.py @@ -0,0 +1,31 @@ +"""Private async execution utilities.""" + +import asyncio +from concurrent.futures import ThreadPoolExecutor +from typing import Awaitable, Callable, TypeVar + +T = TypeVar("T") + + +def run_async(async_func: Callable[[], Awaitable[T]]) -> T: + """Run an async function in a separate thread to avoid event loop conflicts. + + This utility handles the common pattern of running async code from sync contexts + by using ThreadPoolExecutor to isolate the async execution. + + Args: + async_func: A callable that returns an awaitable + + Returns: + The result of the async function + """ + + async def execute_async() -> T: + return await async_func() + + def execute() -> T: + return asyncio.run(execute_async()) + + with ThreadPoolExecutor() as executor: + future = executor.submit(execute) + return future.result() diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 4579ebacf..59dbe8a46 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -9,11 +9,9 @@ 2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")` """ -import asyncio import json import logging import random -from concurrent.futures import ThreadPoolExecutor from typing import ( Any, AsyncGenerator, @@ -31,7 +29,9 @@ from pydantic import BaseModel from .. import _identifier +from .._async import run_async from ..event_loop.event_loop import event_loop_cycle +from ..experimental.tools import ToolProvider from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from ..hooks import ( AfterInvocationEvent, @@ -160,12 +160,7 @@ async def acall() -> ToolResult: return tool_results[0] - def tcall() -> ToolResult: - return asyncio.run(acall()) - - with ThreadPoolExecutor() as executor: - future = executor.submit(tcall) - tool_result = future.result() + tool_result = run_async(acall) if record_direct_tool_call is not None: should_record_direct_tool_call = record_direct_tool_call @@ -208,7 +203,7 @@ def __init__( self, model: Union[Model, str, None] = None, messages: Optional[Messages] = None, - tools: Optional[list[Union[str, dict[str, str], Any]]] = None, + tools: Optional[list[Union[str, dict[str, str], ToolProvider, Any]]] = None, system_prompt: Optional[str] = None, callback_handler: Optional[ Union[Callable[..., Any], _DefaultCallbackHandlerSentinel] @@ -240,7 +235,8 @@ def __init__( - File paths (e.g., "/path/to/tool.py") - Imported Python modules (e.g., from strands_tools import current_time) - Dictionaries with name/path keys (e.g., {"name": "tool_name", "path": "/path/to/tool.py"}) - - Functions decorated with `@strands.tool` decorator. + - Functions decorated with `@strands.tool` decorator + - ToolProvider instances for managed tool collections If provided, only these tools will be available. If None, all tools will be available. system_prompt: System prompt to guide model behavior. @@ -333,6 +329,9 @@ def __init__( else: self.state = AgentState() + # Track cleanup state + self._cleanup_called = False + self.tool_caller = Agent.ToolCaller(self) self.hooks = HookRegistry() @@ -399,13 +398,7 @@ def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: - metrics: Performance metrics from the event loop - state: The final state of the event loop """ - - def execute() -> AgentResult: - return asyncio.run(self.invoke_async(prompt, **kwargs)) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() + return run_async(lambda: self.invoke_async(prompt, **kwargs)) async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: """Process a natural language prompt through the agent's event loop. @@ -459,13 +452,7 @@ def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> Raises: ValueError: If no conversation history or prompt is provided. """ - - def execute() -> T: - return asyncio.run(self.structured_output_async(output_model, prompt)) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() + return run_async(lambda: self.structured_output_async(output_model, prompt)) async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. @@ -527,6 +514,69 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu finally: self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) + def cleanup(self) -> None: + """Clean up resources used by the agent. + + This method cleans up all tool providers that require explicit cleanup, + such as MCP clients. It should be called when the agent is no longer needed + to ensure proper resource cleanup. + + Note: This method uses a "belt and braces" approach with automatic cleanup + through __del__ as a fallback, but explicit cleanup is recommended. + """ + run_async(self.cleanup_async) + + async def cleanup_async(self) -> None: + """Asynchronously clean up resources used by the agent. + + This method cleans up all tool providers that require explicit cleanup, + such as MCP clients. It should be called when the agent is no longer needed + to ensure proper resource cleanup. + + Note: This method uses a "belt and braces" approach with automatic cleanup + through __del__ as a fallback, but explicit cleanup is recommended. + """ + if self._cleanup_called: + return + + logger.debug("agent_id=<%s> | cleaning up agent resources", self.agent_id) + + for provider in self.tool_registry.tool_providers: + try: + await provider.cleanup() + logger.debug( + "agent_id=<%s>, provider=<%s> | cleaned up tool provider", self.agent_id, type(provider).__name__ + ) + except Exception as e: + logger.warning( + "agent_id=<%s>, provider=<%s>, error=<%s> | failed to cleanup tool provider", + self.agent_id, + type(provider).__name__, + e, + ) + + self._cleanup_called = True + logger.debug("agent_id=<%s> | agent cleanup complete", self.agent_id) + + def __del__(self) -> None: + """Automatic cleanup when agent is garbage collected. + + This serves as a fallback cleanup mechanism, but explicit cleanup() is preferred. + """ + try: + if self._cleanup_called or not self.tool_registry.tool_providers: + return + + logger.warning( + "agent_id=<%s> | Agent cleanup called via __del__. " + "Consider calling agent.cleanup() explicitly for better resource management.", + self.agent_id, + ) + self.cleanup() + except Exception as e: + # Log exceptions during garbage collection cleanup for debugging + logger.debug("agent_id=<%s>, error=<%s> | exception during __del__ cleanup", self.agent_id, e) + async def stream_async( self, prompt: AgentInput = None, diff --git a/src/strands/experimental/__init__.py b/src/strands/experimental/__init__.py index c40d0fcec..5b9c95582 100644 --- a/src/strands/experimental/__init__.py +++ b/src/strands/experimental/__init__.py @@ -2,3 +2,7 @@ This module implements experimental features that are subject to change in future revisions without notice. """ + +from . import tools + +__all__ = ["tools"] diff --git a/src/strands/experimental/tools/__init__.py b/src/strands/experimental/tools/__init__.py new file mode 100644 index 000000000..ad693f8ac --- /dev/null +++ b/src/strands/experimental/tools/__init__.py @@ -0,0 +1,5 @@ +"""Experimental tools package.""" + +from .tool_provider import ToolProvider + +__all__ = ["ToolProvider"] diff --git a/src/strands/experimental/tools/mcp/__init__.py b/src/strands/experimental/tools/mcp/__init__.py new file mode 100644 index 000000000..ee1ccc542 --- /dev/null +++ b/src/strands/experimental/tools/mcp/__init__.py @@ -0,0 +1,5 @@ +"""Experimental MCP Tool Provider.""" + +from .mcp_tool_provider import MCPToolProvider, ToolFilters + +__all__ = ["MCPToolProvider", "ToolFilters"] diff --git a/src/strands/experimental/tools/mcp/mcp_tool_provider.py b/src/strands/experimental/tools/mcp/mcp_tool_provider.py new file mode 100644 index 000000000..610d642bc --- /dev/null +++ b/src/strands/experimental/tools/mcp/mcp_tool_provider.py @@ -0,0 +1,180 @@ +"""MCP Tool Provider implementation.""" + +import logging +from typing import Callable, Optional, Pattern, Sequence, Union + +from typing_extensions import TypedDict + +from ....tools.mcp.mcp_agent_tool import MCPAgentTool +from ....tools.mcp.mcp_client import MCPClient +from ....types.exceptions import ToolProviderException +from ....types.tools import AgentTool +from ..tool_provider import ToolProvider + +logger = logging.getLogger(__name__) + +_ToolFilterCallback = Callable[[AgentTool], bool] +_ToolFilterPattern = Union[str, Pattern[str], _ToolFilterCallback] + + +class ToolFilters(TypedDict, total=False): + """Filters for controlling which MCP tools are loaded and available. + + Tools are filtered in this order: + 1. If 'allowed' is specified, only tools matching these patterns are included + 2. Tools matching 'rejected' patterns are then excluded + 3. If the result exceeds 'max_tools', it's truncated + """ + + allowed: list[_ToolFilterPattern] + rejected: list[_ToolFilterPattern] + max_tools: int + + +class MCPToolProvider(ToolProvider): + """Tool provider for MCP clients with managed lifecycle.""" + + def __init__( + self, *, client: MCPClient, tool_filters: Optional[ToolFilters] = None, disambiguator: Optional[str] = None + ) -> None: + """Initialize with an MCP client. + + Args: + client: The MCP client to manage. + tool_filters: Optional filters to apply to tools. + disambiguator: Optional prefix for tool names. + """ + logger.debug( + "tool_filters=<%s>, disambiguator=<%s> | initializing MCPToolProvider", tool_filters, disambiguator + ) + self._client = client + self._tool_filters = tool_filters + self._disambiguator = disambiguator + self._tools: Optional[list[MCPAgentTool]] = None # None = not loaded yet, [] = loaded but empty + self._started = False + + async def load_tools(self) -> Sequence[AgentTool]: + """Load and return tools from the MCP client. + + Returns: + List of tools from the MCP server. + """ + logger.debug("started=<%s>, cached_tools=<%s> | loading tools", self._started, self._tools is not None) + + if not self._started: + try: + logger.debug("starting MCP client") + self._client.start() + self._started = True + logger.debug("MCP client started successfully") + except Exception as e: + logger.error("error=<%s> | failed to start MCP client", e) + raise ToolProviderException(f"Failed to start MCP client: {e}") from e + + if self._tools is None: + logger.debug("loading tools from MCP server") + self._tools = [] + pagination_token = None + page_count = 0 + + # Determine max_tools limit for early termination + max_tools_limit = None + if self._tool_filters and "max_tools" in self._tool_filters: + max_tools_limit = self._tool_filters["max_tools"] + logger.debug("max_tools_limit=<%d> | will stop when reached", max_tools_limit) + + while True: + logger.debug("page=<%d>, token=<%s> | fetching tools page", page_count, pagination_token) + paginated_tools = self._client.list_tools_sync(pagination_token) + + # Process each tool as we get it + for tool in paginated_tools: + # Apply filters + if self._should_include_tool(tool): + # Apply disambiguation if needed + processed_tool = self._apply_disambiguation(tool) + self._tools.append(processed_tool) + + # Check if we've reached max_tools limit + if max_tools_limit is not None and len(self._tools) >= max_tools_limit: + logger.debug("max_tools_reached=<%d> | stopping pagination early", len(self._tools)) + return self._tools + + logger.debug( + "page=<%d>, page_tools=<%d>, total_filtered=<%d> | processed page", + page_count, + len(paginated_tools), + len(self._tools), + ) + + pagination_token = paginated_tools.pagination_token + page_count += 1 + + if pagination_token is None: + break + + logger.debug("final_tools=<%d> | loading complete", len(self._tools)) + + return self._tools + + def _should_include_tool(self, tool: MCPAgentTool) -> bool: + """Check if a tool should be included based on allowed/rejected filters.""" + if not self._tool_filters: + return True + + # Apply allowed filter + if "allowed" in self._tool_filters: + if not self._matches_patterns(tool, self._tool_filters["allowed"]): + return False + + # Apply rejected filter + if "rejected" in self._tool_filters: + if self._matches_patterns(tool, self._tool_filters["rejected"]): + return False + + return True + + def _apply_disambiguation(self, tool: MCPAgentTool) -> MCPAgentTool: + """Apply disambiguation to a single tool if needed.""" + if not self._disambiguator: + return tool + + # Create new tool with disambiguated agent name but preserve original MCP name + old_name = tool.tool_name + new_agent_name = f"{self._disambiguator}_{tool.mcp_tool.name}" + new_tool = MCPAgentTool(tool.mcp_tool, tool.mcp_client, agent_facing_tool_name=new_agent_name) + logger.debug("tool_rename=<%s->%s> | renamed tool", old_name, new_agent_name) + return new_tool + + def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolFilterPattern]) -> bool: + """Check if tool matches any of the given patterns.""" + for pattern in patterns: + if callable(pattern): + if pattern(tool): + return True + elif hasattr(pattern, "match") and hasattr(pattern, "pattern"): + if pattern.match(tool.tool_name): + return True + elif isinstance(pattern, str): + if pattern == tool.tool_name: + return True + return False + + async def cleanup(self) -> None: + """Clean up the MCP client connection.""" + if not self._started: + return + + logger.debug("cleaning up MCP client") + try: + logger.debug("stopping MCP client") + self._client.stop(None, None, None) + logger.debug("MCP client stopped successfully") + except Exception as e: + logger.error("error=<%s> | failed to cleanup MCP client", e) + raise ToolProviderException(f"Failed to cleanup MCP client: {e}") from e + + # Only reset state if cleanup succeeded + self._started = False + self._tools = None + logger.debug("MCP client cleanup complete") diff --git a/src/strands/experimental/tools/tool_provider.py b/src/strands/experimental/tools/tool_provider.py new file mode 100644 index 000000000..0e8d54dfc --- /dev/null +++ b/src/strands/experimental/tools/tool_provider.py @@ -0,0 +1,32 @@ +"""Tool provider interface.""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Sequence + +if TYPE_CHECKING: + from ...types.tools import AgentTool + + +class ToolProvider(ABC): + """Interface for providing tools with lifecycle management. + + Provides a way to load a collection of tools and clean them up + when done, with lifecycle managed by the agent. + """ + + @abstractmethod + async def load_tools(self) -> Sequence["AgentTool"]: + """Load and return the tools in this provider. + + Returns: + List of tools that are ready to use. + """ + pass + + @abstractmethod + async def cleanup(self) -> None: + """Clean up resources used by the tools in this provider. + + Should be called when the tools are no longer needed. + """ + pass diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 03d7de9b4..b146ec8ec 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -3,13 +3,12 @@ Provides minimal foundation for multi-agent patterns (Swarm, Graph). """ -import asyncio from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from enum import Enum from typing import Any, Union +from .._async import run_async from ..agent import AgentResult from ..types.content import ContentBlock from ..types.event_loop import Metrics, Usage @@ -111,9 +110,4 @@ def __call__( if invocation_state is None: invocation_state = {} - def execute() -> MultiAgentResult: - return asyncio.run(self.invoke_async(task, invocation_state, **kwargs)) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() + return run_async(lambda: self.invoke_async(task, invocation_state, **kwargs)) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 738dc4d4c..d48ec2f0d 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -18,12 +18,12 @@ import copy import logging import time -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from typing import Any, Callable, Optional, Tuple from opentelemetry import trace as trace_api +from .._async import run_async from ..agent import Agent from ..agent.state import AgentState from ..telemetry import get_tracer @@ -399,12 +399,7 @@ def __call__( if invocation_state is None: invocation_state = {} - def execute() -> GraphResult: - return asyncio.run(self.invoke_async(task, invocation_state)) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() + return run_async(lambda: self.invoke_async(task, invocation_state)) async def invoke_async( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 620fa5e24..1b49a3081 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -17,13 +17,14 @@ import json import logging import time -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from typing import Any, Callable, Tuple from opentelemetry import trace as trace_api -from ..agent import Agent, AgentResult +from .._async import run_async +from ..agent import Agent +from ..agent.agent_result import AgentResult from ..agent.state import AgentState from ..telemetry import get_tracer from ..tools.decorator import tool @@ -254,12 +255,7 @@ def __call__( if invocation_state is None: invocation_state = {} - def execute() -> SwarmResult: - return asyncio.run(self.invoke_async(task, invocation_state)) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() + return run_async(lambda: self.invoke_async(task, invocation_state)) async def invoke_async( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any diff --git a/src/strands/tools/mcp/mcp_agent_tool.py b/src/strands/tools/mcp/mcp_agent_tool.py index acc48443c..91ec6216a 100644 --- a/src/strands/tools/mcp/mcp_agent_tool.py +++ b/src/strands/tools/mcp/mcp_agent_tool.py @@ -28,26 +28,29 @@ class MCPAgentTool(AgentTool): seamlessly within the agent framework. """ - def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient") -> None: + def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient", agent_facing_tool_name: str | None = None) -> None: """Initialize a new MCPAgentTool instance. Args: mcp_tool: The MCP tool to adapt mcp_client: The MCP server connection to use for tool invocation + agent_facing_tool_name: Optional name to use for the agent tool (for disambiguation) + If None, uses the original MCP tool name """ super().__init__() logger.debug("tool_name=<%s> | creating mcp agent tool", mcp_tool.name) self.mcp_tool = mcp_tool self.mcp_client = mcp_client + self._agent_tool_name = agent_facing_tool_name or mcp_tool.name @property def tool_name(self) -> str: """Get the name of the tool. Returns: - str: The name of the MCP tool + str: The agent-facing name of the tool (may be disambiguated) """ - return self.mcp_tool.name + return self._agent_tool_name @property def tool_spec(self) -> ToolSpec: @@ -63,7 +66,7 @@ def tool_spec(self) -> ToolSpec: spec: ToolSpec = { "inputSchema": {"json": self.mcp_tool.inputSchema}, - "name": self.mcp_tool.name, + "name": self.tool_name, # Use agent-facing name in spec "description": description, } @@ -100,7 +103,7 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw result = await self.mcp_client.call_tool_async( tool_use_id=tool_use["toolUseId"], - name=self.tool_name, + name=self.mcp_tool.name, # Use original MCP name for server communication arguments=tool_use["input"], ) yield ToolResultEvent(result) diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 3631c9dee..ea71b09a0 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -18,6 +18,8 @@ from strands.tools.decorator import DecoratedFunctionTool +from .._async import run_async +from ..experimental.tools import ToolProvider from ..types.tools import AgentTool, ToolSpec from .loader import load_tool_from_string, load_tools_from_module from .tools import PythonAgentTool, normalize_schema, normalize_tool_spec @@ -36,6 +38,7 @@ def __init__(self) -> None: self.registry: Dict[str, AgentTool] = {} self.dynamic_tools: Dict[str, AgentTool] = {} self.tool_config: Optional[Dict[str, Any]] = None + self.tool_providers: List[ToolProvider] = [] def process_tools(self, tools: List[Any]) -> List[str]: """Process tools list. @@ -118,6 +121,16 @@ def add_tool(tool: Any) -> None: elif isinstance(tool, Iterable) and not isinstance(tool, (str, bytes, bytearray)): for t in tool: add_tool(t) + + # Case 5: ToolProvider + elif isinstance(tool, ToolProvider): + self.tool_providers.append(tool) + + provider_tools = run_async(tool.load_tools) + + for provider_tool in provider_tools: + self.register_tool(provider_tool) + tool_names.append(provider_tool.tool_name) else: logger.warning("tool=<%s> | unrecognized tool specification", tool) diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 90f2b8d7f..adff7add7 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -75,3 +75,9 @@ class SessionException(Exception): """Exception raised when session operations fail.""" pass + + +class ToolProviderException(Exception): + """Exception raised when a tool provider fails to load or cleanup tools.""" + + pass diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 2cd87c26d..8e94520b7 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -887,8 +887,184 @@ def test_agent_tool_names(tools, agent): assert actual == expected +def test_agent_cleanup(agent): + """Test that agent cleanup method works correctly.""" + # Create mock tool provider + mock_provider = unittest.mock.MagicMock() + mock_provider.cleanup = unittest.mock.AsyncMock() + + # Add provider to agent's tool registry + agent.tool_registry.tool_providers = [mock_provider] + + with unittest.mock.patch("strands.agent.agent.run_async") as mock_run_async: + agent.cleanup() + + # Verify run_async was called once (for cleanup_async) + mock_run_async.assert_called_once() + # Get the function that was passed to run_async and verify it's cleanup_async + called_func = mock_run_async.call_args[0][0] + assert called_func == agent.cleanup_async + + +@pytest.mark.asyncio +async def test_agent_cleanup_async(agent): + """Test that agent cleanup_async method works correctly.""" + # Create mock tool provider + mock_provider = unittest.mock.MagicMock() + mock_provider.cleanup = unittest.mock.AsyncMock() + + # Add provider to agent's tool registry + agent.tool_registry.tool_providers = [mock_provider] + + await agent.cleanup_async() + + # Verify provider cleanup was called + mock_provider.cleanup.assert_called_once() + # Verify cleanup was marked as called + assert agent._cleanup_called is True + + +@pytest.mark.asyncio +async def test_agent_cleanup_async_handles_exceptions(agent): + """Test that agent cleanup_async handles exceptions gracefully.""" + # Create mock tool providers, one that raises an exception + mock_provider1 = unittest.mock.MagicMock() + mock_provider1.cleanup = unittest.mock.AsyncMock() + mock_provider2 = unittest.mock.MagicMock() + mock_provider2.cleanup = unittest.mock.AsyncMock(side_effect=Exception("Cleanup failed")) + + # Add providers to agent's tool registry + agent.tool_registry.tool_providers = [mock_provider1, mock_provider2] + + # Should not raise exception despite provider2 failing + await agent.cleanup_async() + + # Verify both providers were attempted + mock_provider1.cleanup.assert_called_once() + mock_provider2.cleanup.assert_called_once() + # Verify cleanup was marked as called + assert agent._cleanup_called is True + + +@pytest.mark.asyncio +async def test_agent_cleanup_async_idempotent(agent): + """Test that calling cleanup_async multiple times is safe.""" + # Create mock tool provider + mock_provider = unittest.mock.MagicMock() + mock_provider.cleanup = unittest.mock.AsyncMock() + + # Add provider to agent's tool registry + agent.tool_registry.tool_providers = [mock_provider] + + # Call cleanup_async twice + await agent.cleanup_async() + await agent.cleanup_async() + + # Verify provider cleanup was only called once due to idempotency + mock_provider.cleanup.assert_called_once() + + +@pytest.mark.asyncio +async def test_agent_cleanup_async_with_no_providers(agent): + """Test that agent cleanup_async works when there are no tool providers.""" + # Ensure no providers + agent.tool_registry.tool_providers = [] + + # Should not raise any exceptions + await agent.cleanup_async() + + # Verify cleanup was marked as called + assert agent._cleanup_called is True + + def test_agent__del__(agent): - del agent + """Test that agent destructor calls cleanup.""" + # Add a mock tool provider so cleanup will be called + mock_provider = unittest.mock.MagicMock() + agent.tool_registry.tool_providers = [mock_provider] + + with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: + agent.__del__() + mock_cleanup.assert_called_once() + + +def test_agent__del__handles_cleanup_exception(agent): + """Test that agent destructor handles cleanup exceptions.""" + with unittest.mock.patch.object(agent, "cleanup", side_effect=Exception("Cleanup failed")): + # Should not raise exception + agent.__del__() + + +def test_agent_cleanup_idempotent(agent): + """Test that calling cleanup multiple times is safe.""" + # Create mock tool provider + mock_provider = unittest.mock.MagicMock() + mock_provider.cleanup = unittest.mock.AsyncMock() + + # Add provider to agent's tool registry + agent.tool_registry.tool_providers = [mock_provider] + + # Call cleanup twice + agent.cleanup() + agent.cleanup() + + # Verify provider cleanup was only called once due to idempotency + mock_provider.cleanup.assert_called_once() + + +def test_agent__del__emits_warning_for_automatic_cleanup(agent): + """Test that __del__ emits warning when cleanup wasn't called manually.""" + # Add a mock tool provider so cleanup will be called + mock_provider = unittest.mock.MagicMock() + agent.tool_registry.tool_providers = [mock_provider] + + with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: + with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: + agent.__del__() + + # Verify warning was logged + mock_logger.warning.assert_called_once() + warning_call = mock_logger.warning.call_args[0] + assert "Agent cleanup called via __del__" in warning_call[0] + # Verify cleanup was called + mock_cleanup.assert_called_once() + + +def test_agent__del__no_warning_after_manual_cleanup(): + """Test that __del__ doesn't emit warning if cleanup was called manually.""" + # Create a fresh agent for this test + from strands import Agent + + agent = Agent() + + # Call cleanup manually first + with unittest.mock.patch.object(agent, "cleanup_async"): + agent.cleanup() + + with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: + agent.__del__() + + # Verify no warning was logged + mock_logger.warning.assert_not_called() + + +def test_agent__del__no_warning_when_no_tool_providers(): + """Test that __del__ doesn't emit warning when there are no tool providers.""" + # Create a fresh agent for this test + from strands import Agent + + agent = Agent() + + # Ensure no tool providers + agent.tool_registry.tool_providers = [] + + with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: + with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: + agent.__del__() + + # Verify no warning was logged and cleanup wasn't called + mock_logger.warning.assert_not_called() + mock_cleanup.assert_not_called() def test_agent_init_with_no_model_or_model_id(): diff --git a/tests/strands/experimental/tools/__init__.py b/tests/strands/experimental/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/tools/mcp/__init__.py b/tests/strands/experimental/tools/mcp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py b/tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py new file mode 100644 index 000000000..ae28d51c8 --- /dev/null +++ b/tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py @@ -0,0 +1,375 @@ +"""Unit tests for MCPToolProvider.""" + +import re +from unittest.mock import MagicMock, patch + +import pytest + +from strands.experimental.tools.mcp import MCPToolProvider, ToolFilters +from strands.tools.mcp import MCPClient +from strands.tools.mcp.mcp_agent_tool import MCPAgentTool +from strands.types import PaginatedList +from strands.types.exceptions import ToolProviderException + + +@pytest.fixture +def mock_mcp_client(): + """Create a mock MCP client.""" + client = MagicMock(spec=MCPClient) + client.start = MagicMock() + client.stop = MagicMock() + client.list_tools_sync = MagicMock() + return client + + +@pytest.fixture +def mock_mcp_tool(): + """Create a mock MCP tool.""" + tool = MagicMock() + tool.name = "test_tool" + return tool + + +@pytest.fixture +def mock_agent_tool(mock_mcp_tool, mock_mcp_client): + """Create a mock MCPAgentTool.""" + agent_tool = MagicMock(spec=MCPAgentTool) + agent_tool.tool_name = "test_tool" + agent_tool.mcp_tool = mock_mcp_tool + agent_tool.mcp_client = mock_mcp_client + return agent_tool + + +def create_mock_tool(name: str) -> MagicMock: + """Helper to create mock tools with specific names.""" + tool = MagicMock(spec=MCPAgentTool) + tool.tool_name = name + tool.mcp_tool = MagicMock() + tool.mcp_tool.name = name + return tool + + +def test_init_with_client_only(mock_mcp_client): + """Test initialization with only client.""" + provider = MCPToolProvider(client=mock_mcp_client) + + assert provider._client is mock_mcp_client + assert provider._tool_filters is None + assert provider._disambiguator is None + assert provider._tools is None + assert provider._started is False + + +def test_init_with_all_parameters(mock_mcp_client): + """Test initialization with all parameters.""" + filters = {"allowed": ["tool1"], "max_tools": 5} + disambiguator = "test_prefix" + + provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters, disambiguator=disambiguator) + + assert provider._client is mock_mcp_client + assert provider._tool_filters == filters + assert provider._disambiguator == disambiguator + assert provider._tools is None + assert provider._started is False + + +@pytest.mark.asyncio +async def test_load_tools_starts_client_when_not_started(mock_mcp_client, mock_agent_tool): + """Test that load_tools starts the client when not already started.""" + mock_mcp_client.list_tools_sync.return_value = PaginatedList([mock_agent_tool]) + + provider = MCPToolProvider(client=mock_mcp_client) + + tools = await provider.load_tools() + + mock_mcp_client.start.assert_called_once() + assert provider._started is True + assert len(tools) == 1 + assert tools[0] is mock_agent_tool + + +@pytest.mark.asyncio +async def test_load_tools_does_not_start_client_when_already_started(mock_mcp_client, mock_agent_tool): + """Test that load_tools does not start client when already started.""" + mock_mcp_client.list_tools_sync.return_value = PaginatedList([mock_agent_tool]) + + provider = MCPToolProvider(client=mock_mcp_client) + provider._started = True + + tools = await provider.load_tools() + + mock_mcp_client.start.assert_not_called() + assert len(tools) == 1 + + +@pytest.mark.asyncio +async def test_load_tools_raises_exception_on_client_start_failure(mock_mcp_client): + """Test that load_tools raises ToolProviderException when client start fails.""" + mock_mcp_client.start.side_effect = Exception("Client start failed") + + provider = MCPToolProvider(client=mock_mcp_client) + + with pytest.raises(ToolProviderException, match="Failed to start MCP client: Client start failed"): + await provider.load_tools() + + +@pytest.mark.asyncio +async def test_load_tools_caches_tools(mock_mcp_client, mock_agent_tool): + """Test that load_tools caches tools and doesn't reload them.""" + mock_mcp_client.list_tools_sync.return_value = PaginatedList([mock_agent_tool]) + + provider = MCPToolProvider(client=mock_mcp_client) + + # First call + tools1 = await provider.load_tools() + # Second call + tools2 = await provider.load_tools() + + # Client should only be called once + mock_mcp_client.list_tools_sync.assert_called_once() + assert tools1 is tools2 + + +@pytest.mark.asyncio +async def test_load_tools_handles_pagination(mock_mcp_client, mock_agent_tool): + """Test that load_tools handles pagination correctly.""" + tool1 = MagicMock(spec=MCPAgentTool) + tool1.tool_name = "tool1" + tool2 = MagicMock(spec=MCPAgentTool) + tool2.tool_name = "tool2" + + # Mock pagination: first page returns tool1 with next token, second page returns tool2 with no token + mock_mcp_client.list_tools_sync.side_effect = [ + PaginatedList([tool1], token="page2"), + PaginatedList([tool2], token=None), + ] + + provider = MCPToolProvider(client=mock_mcp_client) + + tools = await provider.load_tools() + + # Should have called list_tools_sync twice + assert mock_mcp_client.list_tools_sync.call_count == 2 + # First call with no token, second call with "page2" token + mock_mcp_client.list_tools_sync.assert_any_call(None) + mock_mcp_client.list_tools_sync.assert_any_call("page2") + + assert len(tools) == 2 + assert tools[0] is tool1 + assert tools[1] is tool2 + + +@pytest.mark.asyncio +async def test_allowed_filter_string_match(mock_mcp_client): + """Test allowed filter with string matching.""" + tool1 = create_mock_tool("allowed_tool") + tool2 = create_mock_tool("rejected_tool") + + mock_mcp_client.list_tools_sync.return_value = PaginatedList([tool1, tool2]) + + filters: ToolFilters = {"allowed": ["allowed_tool"]} + provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters) + + tools = await provider.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "allowed_tool" + + +@pytest.mark.asyncio +async def test_allowed_filter_regex_match(mock_mcp_client): + """Test allowed filter with regex matching.""" + tool1 = create_mock_tool("echo_tool") + tool2 = create_mock_tool("other_tool") + + mock_mcp_client.list_tools_sync.return_value = PaginatedList([tool1, tool2]) + + filters: ToolFilters = {"allowed": [re.compile(r"echo_.*")]} + provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters) + + tools = await provider.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "echo_tool" + + +@pytest.mark.asyncio +async def test_allowed_filter_callable_match(mock_mcp_client): + """Test allowed filter with callable matching.""" + tool1 = create_mock_tool("short") + tool2 = create_mock_tool("very_long_tool_name") + + mock_mcp_client.list_tools_sync.return_value = PaginatedList([tool1, tool2]) + + def short_names_only(tool) -> bool: + return len(tool.tool_name) <= 10 + + filters: ToolFilters = {"allowed": [short_names_only]} + provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters) + + tools = await provider.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "short" + + +@pytest.mark.asyncio +async def test_rejected_filter(mock_mcp_client): + """Test rejected filter functionality.""" + tool1 = create_mock_tool("good_tool") + tool2 = create_mock_tool("bad_tool") + + mock_mcp_client.list_tools_sync.return_value = PaginatedList([tool1, tool2]) + + filters: ToolFilters = {"rejected": ["bad_tool"]} + provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters) + + tools = await provider.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "good_tool" + + +@pytest.mark.asyncio +async def test_max_tools_filter(mock_mcp_client): + """Test max_tools filter functionality.""" + tools_list = [create_mock_tool(f"tool_{i}") for i in range(5)] + + mock_mcp_client.list_tools_sync.return_value = PaginatedList(tools_list) + + filters: ToolFilters = {"max_tools": 3} + provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters) + + tools = await provider.load_tools() + + assert len(tools) == 3 + assert all(tool.tool_name.startswith("tool_") for tool in tools) + + +@pytest.mark.asyncio +async def test_disambiguator_renames_tools(mock_mcp_client): + """Test that disambiguator properly renames tools.""" + original_tool = MagicMock(spec=MCPAgentTool) + original_tool.tool_name = "original_name" + original_tool.mcp_tool = MagicMock() + original_tool.mcp_tool.name = "original_name" + original_tool.mcp_client = mock_mcp_client + + mock_mcp_client.list_tools_sync.return_value = PaginatedList([original_tool]) + + with patch("strands.experimental.tools.mcp.mcp_tool_provider.MCPAgentTool") as mock_agent_tool_class: + new_tool = MagicMock(spec=MCPAgentTool) + new_tool.tool_name = "prefix_original_name" + mock_agent_tool_class.return_value = new_tool + + provider = MCPToolProvider(client=mock_mcp_client, disambiguator="prefix") + + tools = await provider.load_tools() + + # Should create new MCPAgentTool with prefixed name + mock_agent_tool_class.assert_called_once_with( + original_tool.mcp_tool, original_tool.mcp_client, agent_facing_tool_name="prefix_original_name" + ) + + assert len(tools) == 1 + assert tools[0] is new_tool + + +@pytest.mark.asyncio +async def test_cleanup_stops_client_when_started(mock_mcp_client): + """Test that cleanup stops the client when started.""" + provider = MCPToolProvider(client=mock_mcp_client) + provider._started = True + provider._tools = [MagicMock()] + + await provider.cleanup() + + mock_mcp_client.stop.assert_called_once_with(None, None, None) + assert provider._started is False + assert provider._tools is None + + +@pytest.mark.asyncio +async def test_cleanup_does_nothing_when_not_started(mock_mcp_client): + """Test that cleanup does nothing when not started.""" + provider = MCPToolProvider(client=mock_mcp_client) + provider._started = False + + await provider.cleanup() + + mock_mcp_client.stop.assert_not_called() + assert provider._started is False + + +@pytest.mark.asyncio +async def test_cleanup_raises_exception_on_client_stop_failure(mock_mcp_client): + """Test that cleanup raises ToolProviderException when client stop fails.""" + mock_mcp_client.stop.side_effect = Exception("Client stop failed") + + provider = MCPToolProvider(client=mock_mcp_client) + provider._started = True + + with pytest.raises(ToolProviderException, match="Failed to cleanup MCP client: Client stop failed"): + await provider.cleanup() + + # State is not reset when cleanup fails + assert provider._started is True + assert provider._tools is None + + +@pytest.mark.asyncio +async def test_cleanup_does_not_reset_state_on_exception(mock_mcp_client): + """Test that cleanup does not reset state when exception occurs.""" + mock_mcp_client.stop.side_effect = Exception("Client stop failed") + + provider = MCPToolProvider(client=mock_mcp_client) + provider._started = True + mock_tool = MagicMock() + provider._tools = [mock_tool] + + with pytest.raises(ToolProviderException): + await provider.cleanup() + + # State should not be reset when exception occurs + assert provider._started is True + assert provider._tools == [mock_tool] + + +@pytest.mark.asyncio +async def test_load_tools_with_empty_tool_list(mock_mcp_client): + """Test load_tools with empty tool list from server.""" + mock_mcp_client.list_tools_sync.return_value = PaginatedList([]) + + provider = MCPToolProvider(client=mock_mcp_client) + + tools = await provider.load_tools() + + assert len(tools) == 0 + assert provider._started is True + + +@pytest.mark.asyncio +async def test_load_tools_with_no_filters(mock_mcp_client, mock_agent_tool): + """Test load_tools with no filters applied.""" + mock_mcp_client.list_tools_sync.return_value = PaginatedList([mock_agent_tool]) + + provider = MCPToolProvider(client=mock_mcp_client, tool_filters=None) + + tools = await provider.load_tools() + + assert len(tools) == 1 + assert tools[0] is mock_agent_tool + + +@pytest.mark.asyncio +async def test_load_tools_with_empty_filters(mock_mcp_client, mock_agent_tool): + """Test load_tools with empty filters dict.""" + mock_mcp_client.list_tools_sync.return_value = PaginatedList([mock_agent_tool]) + + provider = MCPToolProvider(client=mock_mcp_client, tool_filters={}) + + tools = await provider.load_tools() + + assert len(tools) == 1 + assert tools[0] is mock_agent_tool diff --git a/tests/strands/test_async.py b/tests/strands/test_async.py new file mode 100644 index 000000000..2a98a953c --- /dev/null +++ b/tests/strands/test_async.py @@ -0,0 +1,25 @@ +"""Tests for _async module.""" + +import pytest + +from strands._async import run_async + + +def test_run_async_with_return_value(): + """Test run_async returns correct value.""" + + async def async_with_value(): + return 42 + + result = run_async(async_with_value) + assert result == 42 + + +def test_run_async_exception_propagation(): + """Test that exceptions are properly propagated.""" + + async def async_with_exception(): + raise ValueError("test exception") + + with pytest.raises(ValueError, match="test exception"): + run_async(async_with_exception) diff --git a/tests/strands/tools/test_registry_tool_provider.py b/tests/strands/tools/test_registry_tool_provider.py new file mode 100644 index 000000000..f9f9c9ce0 --- /dev/null +++ b/tests/strands/tools/test_registry_tool_provider.py @@ -0,0 +1,202 @@ +"""Unit tests for ToolRegistry ToolProvider functionality.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from strands.experimental.tools.tool_provider import ToolProvider +from strands.tools.registry import ToolRegistry +from strands.types.tools import AgentTool + + +class MockToolProvider(ToolProvider): + """Mock ToolProvider for testing.""" + + def __init__(self, tools=None, cleanup_error=None): + self._tools = tools or [] + self._cleanup_error = cleanup_error + self.cleanup_called = False + + async def load_tools(self): + return self._tools + + async def cleanup(self): + self.cleanup_called = True + if self._cleanup_error: + raise self._cleanup_error + + +class TestToolRegistryToolProvider: + """Test ToolRegistry integration with ToolProvider.""" + + def test_process_tools_with_tool_provider(self): + """Test that process_tools handles ToolProvider correctly.""" + # Create mock tools + mock_tool1 = MagicMock(spec=AgentTool) + mock_tool1.tool_name = "provider_tool_1" + mock_tool2 = MagicMock(spec=AgentTool) + mock_tool2.tool_name = "provider_tool_2" + + # Create mock provider + provider = MockToolProvider([mock_tool1, mock_tool2]) + + registry = ToolRegistry() + + with patch("strands.tools.registry.run_async") as mock_run_async: + # Mock run_async to return the tools directly + mock_run_async.return_value = [mock_tool1, mock_tool2] + + tool_names = registry.process_tools([provider]) + + # Verify run_async was called with the provider's load_tools method + mock_run_async.assert_called_once() + + # Verify tools were registered + assert "provider_tool_1" in tool_names + assert "provider_tool_2" in tool_names + assert len(tool_names) == 2 + + # Verify provider was tracked + assert provider in registry.tool_providers + + # Verify tools are in registry + assert registry.registry["provider_tool_1"] is mock_tool1 + assert registry.registry["provider_tool_2"] is mock_tool2 + + def test_process_tools_with_multiple_providers(self): + """Test that process_tools handles multiple ToolProviders.""" + # Create mock tools for first provider + mock_tool1 = MagicMock(spec=AgentTool) + mock_tool1.tool_name = "provider1_tool" + provider1 = MockToolProvider([mock_tool1]) + + # Create mock tools for second provider + mock_tool2 = MagicMock(spec=AgentTool) + mock_tool2.tool_name = "provider2_tool" + provider2 = MockToolProvider([mock_tool2]) + + registry = ToolRegistry() + + with patch("strands.tools.registry.run_async") as mock_run_async: + # Mock run_async to return appropriate tools for each call + mock_run_async.side_effect = [[mock_tool1], [mock_tool2]] + + tool_names = registry.process_tools([provider1, provider2]) + + # Verify run_async was called twice + assert mock_run_async.call_count == 2 + + # Verify all tools were registered + assert "provider1_tool" in tool_names + assert "provider2_tool" in tool_names + assert len(tool_names) == 2 + + # Verify both providers were tracked + assert provider1 in registry.tool_providers + assert provider2 in registry.tool_providers + assert len(registry.tool_providers) == 2 + + def test_process_tools_with_mixed_tools_and_providers(self): + """Test that process_tools handles mix of regular tools and providers.""" + # Create regular tool + regular_tool = MagicMock(spec=AgentTool) + regular_tool.tool_name = "regular_tool" + + # Create provider tool + provider_tool = MagicMock(spec=AgentTool) + provider_tool.tool_name = "provider_tool" + provider = MockToolProvider([provider_tool]) + + registry = ToolRegistry() + + with patch("strands.tools.registry.run_async") as mock_run_async: + mock_run_async.return_value = [provider_tool] + + tool_names = registry.process_tools([regular_tool, provider]) + + # Verify both tools were registered + assert "regular_tool" in tool_names + assert "provider_tool" in tool_names + assert len(tool_names) == 2 + + # Verify only provider was tracked + assert provider in registry.tool_providers + assert len(registry.tool_providers) == 1 + + def test_process_tools_with_empty_provider(self): + """Test that process_tools handles provider with no tools.""" + provider = MockToolProvider([]) # Empty tools list + + registry = ToolRegistry() + + with patch("strands.tools.registry.run_async") as mock_run_async: + mock_run_async.return_value = [] + + tool_names = registry.process_tools([provider]) + + # Verify no tools were registered + assert not tool_names + + # Verify provider was still tracked + assert provider in registry.tool_providers + + def test_tool_providers_public_access(self): + """Test that tool_providers can be accessed directly.""" + provider1 = MockToolProvider() + provider2 = MockToolProvider() + + registry = ToolRegistry() + registry.tool_providers = [provider1, provider2] + + # Verify direct access works + assert len(registry.tool_providers) == 2 + assert provider1 in registry.tool_providers + assert provider2 in registry.tool_providers + + def test_tool_providers_empty_by_default(self): + """Test that tool_providers is empty by default.""" + registry = ToolRegistry() + + assert not registry.tool_providers + assert isinstance(registry.tool_providers, list) + + def test_process_tools_provider_load_exception(self): + """Test that process_tools handles exceptions from provider.load_tools().""" + provider = MockToolProvider() + + registry = ToolRegistry() + + with patch("strands.tools.registry.run_async") as mock_run_async: + # Make load_tools raise an exception + mock_run_async.side_effect = Exception("Load tools failed") + + # Should raise the exception from load_tools + with pytest.raises(Exception, match="Load tools failed"): + registry.process_tools([provider]) + + # Provider should still be tracked even if load_tools failed + assert provider in registry.tool_providers + + def test_tool_provider_tracking_persistence(self): + """Test that tool providers are tracked across multiple process_tools calls.""" + provider1 = MockToolProvider([MagicMock(spec=AgentTool, tool_name="tool1")]) + provider2 = MockToolProvider([MagicMock(spec=AgentTool, tool_name="tool2")]) + + registry = ToolRegistry() + + with patch("strands.tools.registry.run_async") as mock_run_async: + mock_run_async.side_effect = [ + [MagicMock(spec=AgentTool, tool_name="tool1")], + [MagicMock(spec=AgentTool, tool_name="tool2")], + ] + + # Process first provider + registry.process_tools([provider1]) + assert len(registry.tool_providers) == 1 + assert provider1 in registry.tool_providers + + # Process second provider + registry.process_tools([provider2]) + assert len(registry.tool_providers) == 2 + assert provider1 in registry.tool_providers + assert provider2 in registry.tool_providers diff --git a/tests_integ/test_mcp_tool_provider.py b/tests_integ/test_mcp_tool_provider.py new file mode 100644 index 000000000..4d6a39329 --- /dev/null +++ b/tests_integ/test_mcp_tool_provider.py @@ -0,0 +1,171 @@ +"""Integration tests for MCPToolProvider with real MCP server.""" + +import logging +import re + +import pytest +from mcp import StdioServerParameters, stdio_client + +from strands import Agent +from strands.experimental.tools.mcp import MCPToolProvider, ToolFilters +from strands.tools.mcp import MCPClient +from strands.types.exceptions import ToolProviderException + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger(__name__) + + +def test_mcp_tool_provider_filters(): + """Test MCPToolProvider with various filter combinations.""" + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) + ) + + # Test string filter, regex filter, callable filter, max_tools, and disambiguator + def short_names_only(tool) -> bool: + return len(tool.tool_name) <= 20 # Allow most tools + + filters: ToolFilters = { + "allowed": ["echo", re.compile(r"echo_with_.*"), short_names_only], + "rejected": ["echo_with_delay"], + "max_tools": 2, + } + + provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, disambiguator="test") + agent = Agent(tools=[provider]) + tool_names = agent.tool_names + + # Should have 2 tools max, with test_ prefix, no delay tool + assert len(tool_names) == 2 + assert "echo_with_delay" not in [name.replace("test_", "") for name in tool_names] + assert all(name.startswith("test_") for name in tool_names) + + agent.cleanup() + + +def test_mcp_tool_provider_execution(): + """Test that MCPToolProvider works with agent execution.""" + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) + ) + + filters: ToolFilters = {"allowed": ["echo"]} + provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, disambiguator="filtered") + agent = Agent( + tools=[provider], + ) + + # Verify the filtered tool exists + assert "filtered_echo" in agent.tool_names + + # # Test direct tool call to verify it works (use correct parameter name from echo server) + tool_result = agent.tool.filtered_echo(to_echo="Hello World") + assert "Hello World" in str(tool_result) + + # # Test agent execution using the tool + result = agent("Use the filtered_echo tool to echo whats inside the tags <>Integration Test") + assert "Integration Test" in str(result) + + assert agent.event_loop_metrics.tool_metrics["filtered_echo"].call_count == 1 + assert agent.event_loop_metrics.tool_metrics["filtered_echo"].success_count == 1 + + agent.cleanup() + + +def test_mcp_tool_provider_reuse(): + """Test that a single MCPToolProvider can be used across multiple agents.""" + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) + ) + + filters: ToolFilters = {"allowed": ["echo"]} + provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, disambiguator="shared") + + # Create first agent with the provider + agent1 = Agent(tools=[provider]) + assert "shared_echo" in agent1.tool_names + + # Test first agent (use correct parameter name from echo server) + result1 = agent1.tool.shared_echo(to_echo="Agent 1") + assert "Agent 1" in str(result1) + + # Create second agent with the same provider + agent2 = Agent(tools=[provider]) + assert "shared_echo" in agent2.tool_names + + # Test second agent (use correct parameter name from echo server) + result2 = agent2.tool.shared_echo(to_echo="Agent 2") + assert "Agent 2" in str(result2) + + # Both agents should have the same tool count + assert len(agent1.tool_names) == len(agent2.tool_names) + assert agent1.tool_names == agent2.tool_names + + agent1.cleanup() + agent2.cleanup() + + +def test_mcp_tool_provider_multiple_servers(): + """Test MCPToolProvider with multiple MCP servers simultaneously.""" + # Create two separate MCP clients + client1 = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) + ) + client2 = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) + ) + + # Create providers with different disambiguators + provider1 = MCPToolProvider(client=client1, tool_filters={"allowed": ["echo"]}, disambiguator="server1") + # Use correct tool name from echo_server.py + provider2 = MCPToolProvider( + client=client2, tool_filters={"allowed": ["echo_with_structured_content"]}, disambiguator="server2" + ) + + # Create agent with both providers + agent = Agent(tools=[provider1, provider2]) + + # Should have tools from both servers with different prefixes + assert "server1_echo" in agent.tool_names + assert "server2_echo_with_structured_content" in agent.tool_names + assert len(agent.tool_names) == 2 + + # Test tools from both servers work + result1 = agent.tool.server1_echo(to_echo="From Server 1") + assert "From Server 1" in str(result1) + + result2 = agent.tool.server2_echo_with_structured_content(to_echo="From Server 2") + assert "From Server 2" in str(result2) + + agent.cleanup() + + +def test_mcp_tool_provider_server_startup_failure(): + """Test that MCPToolProvider handles server startup failure gracefully without hanging.""" + # Create client with invalid command that will fail to start + failing_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="nonexistent_command", args=["--invalid"])), + startup_timeout=2, # Short timeout to avoid hanging + ) + + provider = MCPToolProvider(client=failing_client) + + # Should raise ToolProviderException when trying to load tools + with pytest.raises(ToolProviderException, match="Failed to start MCP client"): + Agent(tools=[provider]) + + +def test_mcp_tool_provider_server_connection_timeout(): + """Test that MCPToolProvider times out gracefully when server hangs during startup.""" + # Create client that will hang during connection + hanging_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="sleep", args=["10"])), # Sleep for 10 seconds + startup_timeout=1, # 1 second timeout + ) + + provider = MCPToolProvider(client=hanging_client) + + # Should raise ToolProviderException due to timeout + with pytest.raises(ToolProviderException, match="Failed to start MCP client"): + Agent(tools=[provider]) From 5d5a800b28e0e37c5264c6919e3884a842f21a76 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 26 Sep 2025 14:48:58 -0300 Subject: [PATCH 2/7] remove max_tools, add kwargs --- src/strands/agent/agent.py | 3 ++ .../tools/mcp/mcp_tool_provider.py | 49 ++++++++----------- .../experimental/tools/tool_provider.py | 12 +++-- tests/strands/agent/test_agent.py | 12 +++++ .../tools/mcp/test_mcp_tool_provider.py | 32 +++--------- tests_integ/test_mcp_tool_provider.py | 14 +++--- 6 files changed, 59 insertions(+), 63 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 59dbe8a46..0255d70db 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -524,6 +524,9 @@ def cleanup(self) -> None: Note: This method uses a "belt and braces" approach with automatic cleanup through __del__ as a fallback, but explicit cleanup is recommended. """ + if self._cleanup_called: + return + run_async(self.cleanup_async) async def cleanup_async(self) -> None: diff --git a/src/strands/experimental/tools/mcp/mcp_tool_provider.py b/src/strands/experimental/tools/mcp/mcp_tool_provider.py index 610d642bc..44e9fb61f 100644 --- a/src/strands/experimental/tools/mcp/mcp_tool_provider.py +++ b/src/strands/experimental/tools/mcp/mcp_tool_provider.py @@ -1,7 +1,7 @@ """MCP Tool Provider implementation.""" import logging -from typing import Callable, Optional, Pattern, Sequence, Union +from typing import Any, Callable, Optional, Pattern, Sequence, Union from typing_extensions import TypedDict @@ -23,37 +23,39 @@ class ToolFilters(TypedDict, total=False): Tools are filtered in this order: 1. If 'allowed' is specified, only tools matching these patterns are included 2. Tools matching 'rejected' patterns are then excluded - 3. If the result exceeds 'max_tools', it's truncated """ allowed: list[_ToolFilterPattern] rejected: list[_ToolFilterPattern] - max_tools: int class MCPToolProvider(ToolProvider): """Tool provider for MCP clients with managed lifecycle.""" def __init__( - self, *, client: MCPClient, tool_filters: Optional[ToolFilters] = None, disambiguator: Optional[str] = None + self, + *, + client: MCPClient, + tool_filters: Optional[ToolFilters] = None, + prefix: Optional[str] = None, + **kwargs: Any, ) -> None: """Initialize with an MCP client. Args: client: The MCP client to manage. tool_filters: Optional filters to apply to tools. - disambiguator: Optional prefix for tool names. + prefix: Optional prefix for tool names. + **kwargs: Additional arguments for future compatibility. """ - logger.debug( - "tool_filters=<%s>, disambiguator=<%s> | initializing MCPToolProvider", tool_filters, disambiguator - ) + logger.debug("tool_filters=<%s>, prefix=<%s> | initializing MCPToolProvider", tool_filters, prefix) self._client = client self._tool_filters = tool_filters - self._disambiguator = disambiguator + self._prefix = prefix self._tools: Optional[list[MCPAgentTool]] = None # None = not loaded yet, [] = loaded but empty self._started = False - async def load_tools(self) -> Sequence[AgentTool]: + async def load_tools(self, **kwargs: Any) -> Sequence[AgentTool]: """Load and return tools from the MCP client. Returns: @@ -77,12 +79,6 @@ async def load_tools(self) -> Sequence[AgentTool]: pagination_token = None page_count = 0 - # Determine max_tools limit for early termination - max_tools_limit = None - if self._tool_filters and "max_tools" in self._tool_filters: - max_tools_limit = self._tool_filters["max_tools"] - logger.debug("max_tools_limit=<%d> | will stop when reached", max_tools_limit) - while True: logger.debug("page=<%d>, token=<%s> | fetching tools page", page_count, pagination_token) paginated_tools = self._client.list_tools_sync(pagination_token) @@ -91,15 +87,10 @@ async def load_tools(self) -> Sequence[AgentTool]: for tool in paginated_tools: # Apply filters if self._should_include_tool(tool): - # Apply disambiguation if needed - processed_tool = self._apply_disambiguation(tool) + # Apply prefix if needed + processed_tool = self._apply_prefix(tool) self._tools.append(processed_tool) - # Check if we've reached max_tools limit - if max_tools_limit is not None and len(self._tools) >= max_tools_limit: - logger.debug("max_tools_reached=<%d> | stopping pagination early", len(self._tools)) - return self._tools - logger.debug( "page=<%d>, page_tools=<%d>, total_filtered=<%d> | processed page", page_count, @@ -134,14 +125,14 @@ def _should_include_tool(self, tool: MCPAgentTool) -> bool: return True - def _apply_disambiguation(self, tool: MCPAgentTool) -> MCPAgentTool: - """Apply disambiguation to a single tool if needed.""" - if not self._disambiguator: + def _apply_prefix(self, tool: MCPAgentTool) -> MCPAgentTool: + """Apply prefix to a single tool if needed.""" + if not self._prefix: return tool - # Create new tool with disambiguated agent name but preserve original MCP name + # Create new tool with prefixed agent name but preserve original MCP name old_name = tool.tool_name - new_agent_name = f"{self._disambiguator}_{tool.mcp_tool.name}" + new_agent_name = f"{self._prefix}_{tool.mcp_tool.name}" new_tool = MCPAgentTool(tool.mcp_tool, tool.mcp_client, agent_facing_tool_name=new_agent_name) logger.debug("tool_rename=<%s->%s> | renamed tool", old_name, new_agent_name) return new_tool @@ -160,7 +151,7 @@ def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolFilterPatter return True return False - async def cleanup(self) -> None: + async def cleanup(self, **kwargs: Any) -> None: """Clean up the MCP client connection.""" if not self._started: return diff --git a/src/strands/experimental/tools/tool_provider.py b/src/strands/experimental/tools/tool_provider.py index 0e8d54dfc..5a2bc94c3 100644 --- a/src/strands/experimental/tools/tool_provider.py +++ b/src/strands/experimental/tools/tool_provider.py @@ -1,7 +1,7 @@ """Tool provider interface.""" from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Any, Sequence if TYPE_CHECKING: from ...types.tools import AgentTool @@ -15,18 +15,24 @@ class ToolProvider(ABC): """ @abstractmethod - async def load_tools(self) -> Sequence["AgentTool"]: + async def load_tools(self, **kwargs: Any) -> Sequence["AgentTool"]: """Load and return the tools in this provider. + Args: + **kwargs: Additional arguments for future compatibility. + Returns: List of tools that are ready to use. """ pass @abstractmethod - async def cleanup(self) -> None: + async def cleanup(self, **kwargs: Any) -> None: """Clean up resources used by the tools in this provider. + Args: + **kwargs: Additional arguments for future compatibility. + Should be called when the tools are no longer needed. """ pass diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 8e94520b7..fd8e5a1c8 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1012,6 +1012,18 @@ def test_agent_cleanup_idempotent(agent): mock_provider.cleanup.assert_called_once() +def test_agent_cleanup_early_return_avoids_thread_spawn(agent): + """Test that cleanup returns early when already called, avoiding thread spawn cost.""" + # Mark cleanup as already called + agent._cleanup_called = True + + with unittest.mock.patch("strands.agent.agent.run_async") as mock_run_async: + agent.cleanup() + + # Verify run_async was not called since cleanup already happened + mock_run_async.assert_not_called() + + def test_agent__del__emits_warning_for_automatic_cleanup(agent): """Test that __del__ emits warning when cleanup wasn't called manually.""" # Add a mock tool provider so cleanup will be called diff --git a/tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py b/tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py index ae28d51c8..3576fd0b3 100644 --- a/tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py +++ b/tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py @@ -55,21 +55,21 @@ def test_init_with_client_only(mock_mcp_client): assert provider._client is mock_mcp_client assert provider._tool_filters is None - assert provider._disambiguator is None + assert provider._prefix is None assert provider._tools is None assert provider._started is False def test_init_with_all_parameters(mock_mcp_client): """Test initialization with all parameters.""" - filters = {"allowed": ["tool1"], "max_tools": 5} - disambiguator = "test_prefix" + filters = {"allowed": ["tool1"]} + prefix = "test_prefix" - provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters, disambiguator=disambiguator) + provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters, prefix=prefix) assert provider._client is mock_mcp_client assert provider._tool_filters == filters - assert provider._disambiguator == disambiguator + assert provider._prefix == prefix assert provider._tools is None assert provider._started is False @@ -232,24 +232,8 @@ async def test_rejected_filter(mock_mcp_client): @pytest.mark.asyncio -async def test_max_tools_filter(mock_mcp_client): - """Test max_tools filter functionality.""" - tools_list = [create_mock_tool(f"tool_{i}") for i in range(5)] - - mock_mcp_client.list_tools_sync.return_value = PaginatedList(tools_list) - - filters: ToolFilters = {"max_tools": 3} - provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters) - - tools = await provider.load_tools() - - assert len(tools) == 3 - assert all(tool.tool_name.startswith("tool_") for tool in tools) - - -@pytest.mark.asyncio -async def test_disambiguator_renames_tools(mock_mcp_client): - """Test that disambiguator properly renames tools.""" +async def test_prefix_renames_tools(mock_mcp_client): + """Test that prefix properly renames tools.""" original_tool = MagicMock(spec=MCPAgentTool) original_tool.tool_name = "original_name" original_tool.mcp_tool = MagicMock() @@ -263,7 +247,7 @@ async def test_disambiguator_renames_tools(mock_mcp_client): new_tool.tool_name = "prefix_original_name" mock_agent_tool_class.return_value = new_tool - provider = MCPToolProvider(client=mock_mcp_client, disambiguator="prefix") + provider = MCPToolProvider(client=mock_mcp_client, prefix="prefix") tools = await provider.load_tools() diff --git a/tests_integ/test_mcp_tool_provider.py b/tests_integ/test_mcp_tool_provider.py index 4d6a39329..5b7bb3ed1 100644 --- a/tests_integ/test_mcp_tool_provider.py +++ b/tests_integ/test_mcp_tool_provider.py @@ -22,7 +22,7 @@ def test_mcp_tool_provider_filters(): lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) ) - # Test string filter, regex filter, callable filter, max_tools, and disambiguator + # Test string filter, regex filter, callable filter, and prefix def short_names_only(tool) -> bool: return len(tool.tool_name) <= 20 # Allow most tools @@ -32,7 +32,7 @@ def short_names_only(tool) -> bool: "max_tools": 2, } - provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, disambiguator="test") + provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, prefix="test") agent = Agent(tools=[provider]) tool_names = agent.tool_names @@ -51,7 +51,7 @@ def test_mcp_tool_provider_execution(): ) filters: ToolFilters = {"allowed": ["echo"]} - provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, disambiguator="filtered") + provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, prefix="filtered") agent = Agent( tools=[provider], ) @@ -80,7 +80,7 @@ def test_mcp_tool_provider_reuse(): ) filters: ToolFilters = {"allowed": ["echo"]} - provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, disambiguator="shared") + provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, prefix="shared") # Create first agent with the provider agent1 = Agent(tools=[provider]) @@ -116,11 +116,11 @@ def test_mcp_tool_provider_multiple_servers(): lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) ) - # Create providers with different disambiguators - provider1 = MCPToolProvider(client=client1, tool_filters={"allowed": ["echo"]}, disambiguator="server1") + # Create providers with different prefixes + provider1 = MCPToolProvider(client=client1, tool_filters={"allowed": ["echo"]}, prefix="server1") # Use correct tool name from echo_server.py provider2 = MCPToolProvider( - client=client2, tool_filters={"allowed": ["echo_with_structured_content"]}, disambiguator="server2" + client=client2, tool_filters={"allowed": ["echo_with_structured_content"]}, prefix="server2" ) # Create agent with both providers From 3250247507c036335d2ae67b45f888575a2be7ea Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 8 Oct 2025 16:52:42 -0400 Subject: [PATCH 3/7] mcp_client implements tool_provider --- src/strands/agent/agent.py | 14 +- .../experimental/tools/mcp/__init__.py | 5 - .../tools/mcp/mcp_tool_provider.py | 171 --------- .../experimental/tools/tool_provider.py | 17 +- src/strands/tools/mcp/__init__.py | 4 +- src/strands/tools/mcp/mcp_client.py | 174 ++++++++- src/strands/tools/registry.py | 26 +- tests/strands/agent/test_agent.py | 31 +- .../experimental/tools/mcp/__init__.py | 0 .../tools/mcp/test_mcp_tool_provider.py | 359 ------------------ .../mcp/test_mcp_client_tool_provider.py | 320 ++++++++++++++++ .../tools/test_registry_tool_provider.py | 94 +++++ tests_integ/mcp/test_mcp_tool_provider.py | 184 +++++++++ tests_integ/test_mcp_tool_provider.py | 171 --------- 14 files changed, 824 insertions(+), 746 deletions(-) delete mode 100644 src/strands/experimental/tools/mcp/__init__.py delete mode 100644 src/strands/experimental/tools/mcp/mcp_tool_provider.py delete mode 100644 tests/strands/experimental/tools/mcp/__init__.py delete mode 100644 tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py create mode 100644 tests/strands/tools/mcp/test_mcp_client_tool_provider.py create mode 100644 tests_integ/mcp/test_mcp_tool_provider.py delete mode 100644 tests_integ/test_mcp_tool_provider.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 0255d70db..e428ee0e5 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -544,19 +544,7 @@ async def cleanup_async(self) -> None: logger.debug("agent_id=<%s> | cleaning up agent resources", self.agent_id) - for provider in self.tool_registry.tool_providers: - try: - await provider.cleanup() - logger.debug( - "agent_id=<%s>, provider=<%s> | cleaned up tool provider", self.agent_id, type(provider).__name__ - ) - except Exception as e: - logger.warning( - "agent_id=<%s>, provider=<%s>, error=<%s> | failed to cleanup tool provider", - self.agent_id, - type(provider).__name__, - e, - ) + await self.tool_registry.cleanup_async() self._cleanup_called = True logger.debug("agent_id=<%s> | agent cleanup complete", self.agent_id) diff --git a/src/strands/experimental/tools/mcp/__init__.py b/src/strands/experimental/tools/mcp/__init__.py deleted file mode 100644 index ee1ccc542..000000000 --- a/src/strands/experimental/tools/mcp/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Experimental MCP Tool Provider.""" - -from .mcp_tool_provider import MCPToolProvider, ToolFilters - -__all__ = ["MCPToolProvider", "ToolFilters"] diff --git a/src/strands/experimental/tools/mcp/mcp_tool_provider.py b/src/strands/experimental/tools/mcp/mcp_tool_provider.py deleted file mode 100644 index 44e9fb61f..000000000 --- a/src/strands/experimental/tools/mcp/mcp_tool_provider.py +++ /dev/null @@ -1,171 +0,0 @@ -"""MCP Tool Provider implementation.""" - -import logging -from typing import Any, Callable, Optional, Pattern, Sequence, Union - -from typing_extensions import TypedDict - -from ....tools.mcp.mcp_agent_tool import MCPAgentTool -from ....tools.mcp.mcp_client import MCPClient -from ....types.exceptions import ToolProviderException -from ....types.tools import AgentTool -from ..tool_provider import ToolProvider - -logger = logging.getLogger(__name__) - -_ToolFilterCallback = Callable[[AgentTool], bool] -_ToolFilterPattern = Union[str, Pattern[str], _ToolFilterCallback] - - -class ToolFilters(TypedDict, total=False): - """Filters for controlling which MCP tools are loaded and available. - - Tools are filtered in this order: - 1. If 'allowed' is specified, only tools matching these patterns are included - 2. Tools matching 'rejected' patterns are then excluded - """ - - allowed: list[_ToolFilterPattern] - rejected: list[_ToolFilterPattern] - - -class MCPToolProvider(ToolProvider): - """Tool provider for MCP clients with managed lifecycle.""" - - def __init__( - self, - *, - client: MCPClient, - tool_filters: Optional[ToolFilters] = None, - prefix: Optional[str] = None, - **kwargs: Any, - ) -> None: - """Initialize with an MCP client. - - Args: - client: The MCP client to manage. - tool_filters: Optional filters to apply to tools. - prefix: Optional prefix for tool names. - **kwargs: Additional arguments for future compatibility. - """ - logger.debug("tool_filters=<%s>, prefix=<%s> | initializing MCPToolProvider", tool_filters, prefix) - self._client = client - self._tool_filters = tool_filters - self._prefix = prefix - self._tools: Optional[list[MCPAgentTool]] = None # None = not loaded yet, [] = loaded but empty - self._started = False - - async def load_tools(self, **kwargs: Any) -> Sequence[AgentTool]: - """Load and return tools from the MCP client. - - Returns: - List of tools from the MCP server. - """ - logger.debug("started=<%s>, cached_tools=<%s> | loading tools", self._started, self._tools is not None) - - if not self._started: - try: - logger.debug("starting MCP client") - self._client.start() - self._started = True - logger.debug("MCP client started successfully") - except Exception as e: - logger.error("error=<%s> | failed to start MCP client", e) - raise ToolProviderException(f"Failed to start MCP client: {e}") from e - - if self._tools is None: - logger.debug("loading tools from MCP server") - self._tools = [] - pagination_token = None - page_count = 0 - - while True: - logger.debug("page=<%d>, token=<%s> | fetching tools page", page_count, pagination_token) - paginated_tools = self._client.list_tools_sync(pagination_token) - - # Process each tool as we get it - for tool in paginated_tools: - # Apply filters - if self._should_include_tool(tool): - # Apply prefix if needed - processed_tool = self._apply_prefix(tool) - self._tools.append(processed_tool) - - logger.debug( - "page=<%d>, page_tools=<%d>, total_filtered=<%d> | processed page", - page_count, - len(paginated_tools), - len(self._tools), - ) - - pagination_token = paginated_tools.pagination_token - page_count += 1 - - if pagination_token is None: - break - - logger.debug("final_tools=<%d> | loading complete", len(self._tools)) - - return self._tools - - def _should_include_tool(self, tool: MCPAgentTool) -> bool: - """Check if a tool should be included based on allowed/rejected filters.""" - if not self._tool_filters: - return True - - # Apply allowed filter - if "allowed" in self._tool_filters: - if not self._matches_patterns(tool, self._tool_filters["allowed"]): - return False - - # Apply rejected filter - if "rejected" in self._tool_filters: - if self._matches_patterns(tool, self._tool_filters["rejected"]): - return False - - return True - - def _apply_prefix(self, tool: MCPAgentTool) -> MCPAgentTool: - """Apply prefix to a single tool if needed.""" - if not self._prefix: - return tool - - # Create new tool with prefixed agent name but preserve original MCP name - old_name = tool.tool_name - new_agent_name = f"{self._prefix}_{tool.mcp_tool.name}" - new_tool = MCPAgentTool(tool.mcp_tool, tool.mcp_client, agent_facing_tool_name=new_agent_name) - logger.debug("tool_rename=<%s->%s> | renamed tool", old_name, new_agent_name) - return new_tool - - def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolFilterPattern]) -> bool: - """Check if tool matches any of the given patterns.""" - for pattern in patterns: - if callable(pattern): - if pattern(tool): - return True - elif hasattr(pattern, "match") and hasattr(pattern, "pattern"): - if pattern.match(tool.tool_name): - return True - elif isinstance(pattern, str): - if pattern == tool.tool_name: - return True - return False - - async def cleanup(self, **kwargs: Any) -> None: - """Clean up the MCP client connection.""" - if not self._started: - return - - logger.debug("cleaning up MCP client") - try: - logger.debug("stopping MCP client") - self._client.stop(None, None, None) - logger.debug("MCP client stopped successfully") - except Exception as e: - logger.error("error=<%s> | failed to cleanup MCP client", e) - raise ToolProviderException(f"Failed to cleanup MCP client: {e}") from e - - # Only reset state if cleanup succeeded - self._started = False - self._tools = None - logger.debug("MCP client cleanup complete") diff --git a/src/strands/experimental/tools/tool_provider.py b/src/strands/experimental/tools/tool_provider.py index 5a2bc94c3..5023dd72c 100644 --- a/src/strands/experimental/tools/tool_provider.py +++ b/src/strands/experimental/tools/tool_provider.py @@ -27,12 +27,23 @@ async def load_tools(self, **kwargs: Any) -> Sequence["AgentTool"]: pass @abstractmethod - async def cleanup(self, **kwargs: Any) -> None: - """Clean up resources used by the tools in this provider. + async def add_provider_consumer(self, id: Any, **kwargs: Any) -> None: + """Add a consumer to this tool provider. Args: + id: Unique identifier for the consumer. + **kwargs: Additional arguments for future compatibility. + """ + pass + + @abstractmethod + async def remove_provider_consumer(self, id: Any, **kwargs: Any) -> None: + """Remove a consumer from this tool provider. + + Args: + id: Unique identifier for the consumer. **kwargs: Additional arguments for future compatibility. - Should be called when the tools are no longer needed. + Provider may clean up resources when no consumers remain. """ pass diff --git a/src/strands/tools/mcp/__init__.py b/src/strands/tools/mcp/__init__.py index d95c54fed..cfa841c46 100644 --- a/src/strands/tools/mcp/__init__.py +++ b/src/strands/tools/mcp/__init__.py @@ -7,7 +7,7 @@ """ from .mcp_agent_tool import MCPAgentTool -from .mcp_client import MCPClient +from .mcp_client import MCPClient, ToolFilters from .mcp_types import MCPTransport -__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport"] +__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport", "ToolFilters"] diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index dec8ec313..a780966cd 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -16,7 +16,7 @@ from concurrent import futures from datetime import timedelta from types import TracebackType -from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar, Union, cast +from typing import Any, Callable, Coroutine, Dict, Optional, Pattern, Sequence, TypeVar, Union, cast import anyio from mcp import ClientSession, ListToolsResult @@ -24,11 +24,13 @@ from mcp.types import GetPromptResult, ListPromptsResult from mcp.types import ImageContent as MCPImageContent from mcp.types import TextContent as MCPTextContent +from typing_extensions import TypedDict +from ...experimental.tools import ToolProvider from ...types import PaginatedList -from ...types.exceptions import MCPClientInitializationError +from ...types.exceptions import MCPClientInitializationError, ToolProviderException from ...types.media import ImageFormat -from ...types.tools import ToolResultContent, ToolResultStatus +from ...types.tools import AgentTool, ToolResultContent, ToolResultStatus from .mcp_agent_tool import MCPAgentTool from .mcp_instrumentation import mcp_instrumentation from .mcp_types import MCPToolResult, MCPTransport @@ -37,6 +39,22 @@ T = TypeVar("T") +_ToolFilterCallback = Callable[[AgentTool], bool] +_ToolFilterPattern = Union[str, Pattern[str], _ToolFilterCallback] + + +class ToolFilters(TypedDict, total=False): + """Filters for controlling which MCP tools are loaded and available. + + Tools are filtered in this order: + 1. If 'allowed' is specified, only tools matching these patterns are included + 2. Tools matching 'rejected' patterns are then excluded + """ + + allowed: list[_ToolFilterPattern] + rejected: list[_ToolFilterPattern] + + MIME_TO_FORMAT: Dict[str, ImageFormat] = { "image/jpeg": "jpeg", "image/jpg": "jpeg", @@ -52,7 +70,7 @@ ) -class MCPClient: +class MCPClient(ToolProvider): """Represents a connection to a Model Context Protocol (MCP) server. This class implements a context manager pattern for efficient connection management, @@ -64,15 +82,26 @@ class MCPClient: from MCP tools, it will be returned as the last item in the content array of the ToolResult. """ - def __init__(self, transport_callable: Callable[[], MCPTransport], *, startup_timeout: int = 30): + def __init__( + self, + transport_callable: Callable[[], MCPTransport], + *, + startup_timeout: int = 30, + tool_filters: Optional[ToolFilters] = None, + prefix: Optional[str] = None, + ): """Initialize a new MCP Server connection. Args: transport_callable: A callable that returns an MCPTransport (read_stream, write_stream) tuple startup_timeout: Timeout after which MCP server initialization should be cancelled Defaults to 30. + tool_filters: Optional filters to apply to tools. + prefix: Optional prefix for tool names. """ self._startup_timeout = startup_timeout + self._tool_filters = tool_filters + self._prefix = prefix mcp_instrumentation() self._session_id = uuid.uuid4() @@ -86,6 +115,9 @@ def __init__(self, transport_callable: Callable[[], MCPTransport], *, startup_ti self._background_thread: threading.Thread | None = None self._background_thread_session: ClientSession | None = None self._background_thread_event_loop: AbstractEventLoop | None = None + self._loaded_tools: list[MCPAgentTool] | None = None + self._tool_provider_started = False + self._consumers: set[Any] = set() def __enter__(self) -> "MCPClient": """Context manager entry point which initializes the MCP server connection. @@ -136,6 +168,92 @@ def start(self) -> "MCPClient": raise MCPClientInitializationError("the client initialization failed") from e return self + # ToolProvider interface methods (experimental, as ToolProvider is experimental) + async def load_tools(self, **kwargs: Any) -> Sequence[AgentTool]: + """Load and return tools from the MCP server. + + This method implements the ToolProvider interface by loading tools + from the MCP server and caching them for reuse. + + Args: + **kwargs: Additional arguments for future compatibility. + + Returns: + List of AgentTool instances from the MCP server. + """ + logger.debug( + "started=<%s>, cached_tools=<%s> | loading tools", + self._tool_provider_started, + self._loaded_tools is not None, + ) + + if not self._tool_provider_started: + try: + logger.debug("starting MCP client") + self.start() + self._tool_provider_started = True + logger.debug("MCP client started successfully") + except Exception as e: + logger.error("error=<%s> | failed to start MCP client", e) + raise ToolProviderException(f"Failed to start MCP client: {e}") from e + + if self._loaded_tools is None: + logger.debug("loading tools from MCP server") + self._loaded_tools = [] + pagination_token = None + page_count = 0 + + while True: + logger.debug("page=<%d>, token=<%s> | fetching tools page", page_count, pagination_token) + paginated_tools = self.list_tools_sync(pagination_token) + + # Process each tool as we get it + for tool in paginated_tools: + # Apply filters + if self._should_include_tool(tool): + # Apply prefix if needed + processed_tool = self._apply_prefix(tool) + self._loaded_tools.append(processed_tool) + + logger.debug( + "page=<%d>, page_tools=<%d>, total_filtered=<%d> | processed page", + page_count, + len(paginated_tools), + len(self._loaded_tools), + ) + + pagination_token = paginated_tools.pagination_token + page_count += 1 + + if pagination_token is None: + break + + logger.debug("final_tools=<%d> | loading complete", len(self._loaded_tools)) + + return self._loaded_tools + + async def add_provider_consumer(self, id: Any, **kwargs: Any) -> None: + """Add a consumer to this tool provider.""" + self._consumers.add(id) + logger.debug("added provider consumer, count=%d", len(self._consumers)) + + async def remove_provider_consumer(self, id: Any, **kwargs: Any) -> None: + """Remove a consumer from this tool provider.""" + self._consumers.discard(id) + logger.debug("removed provider consumer, count=%d", len(self._consumers)) + + if not self._consumers and self._tool_provider_started: + logger.debug("no consumers remaining, cleaning up") + try: + self.stop(None, None, None) + self._tool_provider_started = False + self._loaded_tools = None + except Exception as e: + logger.error("error=<%s> | failed to cleanup MCP client", e) + raise ToolProviderException(f"Failed to cleanup MCP client: {e}") from e + + # MCP-specific methods + def stop( self, exc_type: Optional[BaseException], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] ) -> None: @@ -186,6 +304,9 @@ async def _set_close_event() -> None: self._background_thread_session = None self._background_thread_event_loop = None self._session_id = uuid.uuid4() + self._loaded_tools = None + self._tool_provider_started = False + self._consumers = set() def list_tools_sync(self, pagination_token: Optional[str] = None) -> PaginatedList[MCPAgentTool]: """Synchronously retrieves the list of available tools from the MCP server. @@ -478,5 +599,48 @@ def _invoke_on_background_thread(self, coro: Coroutine[Any, Any, T]) -> futures. raise MCPClientInitializationError("the client session was not initialized") return asyncio.run_coroutine_threadsafe(coro=coro, loop=self._background_thread_event_loop) + def _should_include_tool(self, tool: MCPAgentTool) -> bool: + """Check if a tool should be included based on allowed/rejected filters.""" + if not self._tool_filters: + return True + + # Apply allowed filter + if "allowed" in self._tool_filters: + if not self._matches_patterns(tool, self._tool_filters["allowed"]): + return False + + # Apply rejected filter + if "rejected" in self._tool_filters: + if self._matches_patterns(tool, self._tool_filters["rejected"]): + return False + + return True + + def _apply_prefix(self, tool: MCPAgentTool) -> MCPAgentTool: + """Apply prefix to a single tool if needed.""" + if not self._prefix: + return tool + + # Create new tool with prefixed agent name but preserve original MCP name + old_name = tool.tool_name + new_agent_name = f"{self._prefix}_{tool.mcp_tool.name}" + new_tool = MCPAgentTool(tool.mcp_tool, tool.mcp_client, agent_facing_tool_name=new_agent_name) + logger.debug("tool_rename=<%s->%s> | renamed tool", old_name, new_agent_name) + return new_tool + + def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolFilterPattern]) -> bool: + """Check if tool matches any of the given patterns.""" + for pattern in patterns: + if callable(pattern): + if pattern(tool): + return True + elif hasattr(pattern, "match") and hasattr(pattern, "pattern"): + if pattern.match(tool.tool_name): + return True + elif isinstance(pattern, str): + if pattern == tool.tool_name: + return True + return False + def _is_session_active(self) -> bool: return self._background_thread is not None and self._background_thread.is_alive() diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index ea71b09a0..52028ee32 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -8,11 +8,12 @@ import logging import os import sys +import uuid import warnings from importlib import import_module, util from os.path import expanduser from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Sequence from typing_extensions import TypedDict, cast @@ -39,6 +40,7 @@ def __init__(self) -> None: self.dynamic_tools: Dict[str, AgentTool] = {} self.tool_config: Optional[Dict[str, Any]] = None self.tool_providers: List[ToolProvider] = [] + self._registry_id = str(uuid.uuid4()) def process_tools(self, tools: List[Any]) -> List[str]: """Process tools list. @@ -121,12 +123,17 @@ def add_tool(tool: Any) -> None: elif isinstance(tool, Iterable) and not isinstance(tool, (str, bytes, bytearray)): for t in tool: add_tool(t) - + # Case 5: ToolProvider elif isinstance(tool, ToolProvider): self.tool_providers.append(tool) - provider_tools = run_async(tool.load_tools) + async def get_tools_and_register_consumer() -> Sequence[AgentTool]: + provider_tools = await tool.load_tools() + await tool.add_provider_consumer(self._registry_id) + return provider_tools + + provider_tools = run_async(get_tools_and_register_consumer) for provider_tool in provider_tools: self.register_tool(provider_tool) @@ -653,3 +660,16 @@ def _scan_module_for_tools(self, module: Any) -> List[AgentTool]: logger.warning("tool_name=<%s> | failed to create function tool | %s", name, e) return tools + + async def cleanup_async(self, **kwargs: Any) -> None: + """Clean up all tool providers in this registry.""" + for provider in self.tool_providers: + try: + await provider.remove_provider_consumer(self._registry_id) + logger.debug("provider=<%s> | removed provider consumer", type(provider).__name__) + except Exception as e: + logger.warning( + "provider=<%s>, error=<%s> | failed to remove provider consumer", + type(provider).__name__, + e, + ) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index fd8e5a1c8..5e6cee810 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -911,15 +911,15 @@ async def test_agent_cleanup_async(agent): """Test that agent cleanup_async method works correctly.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.cleanup = unittest.mock.AsyncMock() + mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] await agent.cleanup_async() - # Verify provider cleanup was called - mock_provider.cleanup.assert_called_once() + # Verify provider remove_provider_consumer was called + mock_provider.remove_provider_consumer.assert_called_once_with(agent.tool_registry._registry_id) # Verify cleanup was marked as called assert agent._cleanup_called is True @@ -929,9 +929,9 @@ async def test_agent_cleanup_async_handles_exceptions(agent): """Test that agent cleanup_async handles exceptions gracefully.""" # Create mock tool providers, one that raises an exception mock_provider1 = unittest.mock.MagicMock() - mock_provider1.cleanup = unittest.mock.AsyncMock() + mock_provider1.remove_provider_consumer = unittest.mock.AsyncMock() mock_provider2 = unittest.mock.MagicMock() - mock_provider2.cleanup = unittest.mock.AsyncMock(side_effect=Exception("Cleanup failed")) + mock_provider2.remove_provider_consumer = unittest.mock.AsyncMock(side_effect=Exception("Cleanup failed")) # Add providers to agent's tool registry agent.tool_registry.tool_providers = [mock_provider1, mock_provider2] @@ -940,8 +940,8 @@ async def test_agent_cleanup_async_handles_exceptions(agent): await agent.cleanup_async() # Verify both providers were attempted - mock_provider1.cleanup.assert_called_once() - mock_provider2.cleanup.assert_called_once() + mock_provider1.remove_provider_consumer.assert_called_once() + mock_provider2.remove_provider_consumer.assert_called_once() # Verify cleanup was marked as called assert agent._cleanup_called is True @@ -951,7 +951,7 @@ async def test_agent_cleanup_async_idempotent(agent): """Test that calling cleanup_async multiple times is safe.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.cleanup = unittest.mock.AsyncMock() + mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] @@ -960,8 +960,8 @@ async def test_agent_cleanup_async_idempotent(agent): await agent.cleanup_async() await agent.cleanup_async() - # Verify provider cleanup was only called once due to idempotency - mock_provider.cleanup.assert_called_once() + # Verify provider remove_provider_consumer was only called once due to idempotency + mock_provider.remove_provider_consumer.assert_called_once() @pytest.mark.asyncio @@ -999,7 +999,7 @@ def test_agent_cleanup_idempotent(agent): """Test that calling cleanup multiple times is safe.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.cleanup = unittest.mock.AsyncMock() + mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] @@ -1008,8 +1008,8 @@ def test_agent_cleanup_idempotent(agent): agent.cleanup() agent.cleanup() - # Verify provider cleanup was only called once due to idempotency - mock_provider.cleanup.assert_called_once() + # Verify provider remove_provider_consumer was only called once due to idempotency + mock_provider.remove_provider_consumer.assert_called_once() def test_agent_cleanup_early_return_avoids_thread_spawn(agent): @@ -1024,8 +1024,11 @@ def test_agent_cleanup_early_return_avoids_thread_spawn(agent): mock_run_async.assert_not_called() -def test_agent__del__emits_warning_for_automatic_cleanup(agent): +def test_agent__del__emits_warning_for_automatic_cleanup(): """Test that __del__ emits warning when cleanup wasn't called manually.""" + # Create a fresh agent for this test to avoid fixture lifecycle issues + agent = Agent() + # Add a mock tool provider so cleanup will be called mock_provider = unittest.mock.MagicMock() agent.tool_registry.tool_providers = [mock_provider] diff --git a/tests/strands/experimental/tools/mcp/__init__.py b/tests/strands/experimental/tools/mcp/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py b/tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py deleted file mode 100644 index 3576fd0b3..000000000 --- a/tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py +++ /dev/null @@ -1,359 +0,0 @@ -"""Unit tests for MCPToolProvider.""" - -import re -from unittest.mock import MagicMock, patch - -import pytest - -from strands.experimental.tools.mcp import MCPToolProvider, ToolFilters -from strands.tools.mcp import MCPClient -from strands.tools.mcp.mcp_agent_tool import MCPAgentTool -from strands.types import PaginatedList -from strands.types.exceptions import ToolProviderException - - -@pytest.fixture -def mock_mcp_client(): - """Create a mock MCP client.""" - client = MagicMock(spec=MCPClient) - client.start = MagicMock() - client.stop = MagicMock() - client.list_tools_sync = MagicMock() - return client - - -@pytest.fixture -def mock_mcp_tool(): - """Create a mock MCP tool.""" - tool = MagicMock() - tool.name = "test_tool" - return tool - - -@pytest.fixture -def mock_agent_tool(mock_mcp_tool, mock_mcp_client): - """Create a mock MCPAgentTool.""" - agent_tool = MagicMock(spec=MCPAgentTool) - agent_tool.tool_name = "test_tool" - agent_tool.mcp_tool = mock_mcp_tool - agent_tool.mcp_client = mock_mcp_client - return agent_tool - - -def create_mock_tool(name: str) -> MagicMock: - """Helper to create mock tools with specific names.""" - tool = MagicMock(spec=MCPAgentTool) - tool.tool_name = name - tool.mcp_tool = MagicMock() - tool.mcp_tool.name = name - return tool - - -def test_init_with_client_only(mock_mcp_client): - """Test initialization with only client.""" - provider = MCPToolProvider(client=mock_mcp_client) - - assert provider._client is mock_mcp_client - assert provider._tool_filters is None - assert provider._prefix is None - assert provider._tools is None - assert provider._started is False - - -def test_init_with_all_parameters(mock_mcp_client): - """Test initialization with all parameters.""" - filters = {"allowed": ["tool1"]} - prefix = "test_prefix" - - provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters, prefix=prefix) - - assert provider._client is mock_mcp_client - assert provider._tool_filters == filters - assert provider._prefix == prefix - assert provider._tools is None - assert provider._started is False - - -@pytest.mark.asyncio -async def test_load_tools_starts_client_when_not_started(mock_mcp_client, mock_agent_tool): - """Test that load_tools starts the client when not already started.""" - mock_mcp_client.list_tools_sync.return_value = PaginatedList([mock_agent_tool]) - - provider = MCPToolProvider(client=mock_mcp_client) - - tools = await provider.load_tools() - - mock_mcp_client.start.assert_called_once() - assert provider._started is True - assert len(tools) == 1 - assert tools[0] is mock_agent_tool - - -@pytest.mark.asyncio -async def test_load_tools_does_not_start_client_when_already_started(mock_mcp_client, mock_agent_tool): - """Test that load_tools does not start client when already started.""" - mock_mcp_client.list_tools_sync.return_value = PaginatedList([mock_agent_tool]) - - provider = MCPToolProvider(client=mock_mcp_client) - provider._started = True - - tools = await provider.load_tools() - - mock_mcp_client.start.assert_not_called() - assert len(tools) == 1 - - -@pytest.mark.asyncio -async def test_load_tools_raises_exception_on_client_start_failure(mock_mcp_client): - """Test that load_tools raises ToolProviderException when client start fails.""" - mock_mcp_client.start.side_effect = Exception("Client start failed") - - provider = MCPToolProvider(client=mock_mcp_client) - - with pytest.raises(ToolProviderException, match="Failed to start MCP client: Client start failed"): - await provider.load_tools() - - -@pytest.mark.asyncio -async def test_load_tools_caches_tools(mock_mcp_client, mock_agent_tool): - """Test that load_tools caches tools and doesn't reload them.""" - mock_mcp_client.list_tools_sync.return_value = PaginatedList([mock_agent_tool]) - - provider = MCPToolProvider(client=mock_mcp_client) - - # First call - tools1 = await provider.load_tools() - # Second call - tools2 = await provider.load_tools() - - # Client should only be called once - mock_mcp_client.list_tools_sync.assert_called_once() - assert tools1 is tools2 - - -@pytest.mark.asyncio -async def test_load_tools_handles_pagination(mock_mcp_client, mock_agent_tool): - """Test that load_tools handles pagination correctly.""" - tool1 = MagicMock(spec=MCPAgentTool) - tool1.tool_name = "tool1" - tool2 = MagicMock(spec=MCPAgentTool) - tool2.tool_name = "tool2" - - # Mock pagination: first page returns tool1 with next token, second page returns tool2 with no token - mock_mcp_client.list_tools_sync.side_effect = [ - PaginatedList([tool1], token="page2"), - PaginatedList([tool2], token=None), - ] - - provider = MCPToolProvider(client=mock_mcp_client) - - tools = await provider.load_tools() - - # Should have called list_tools_sync twice - assert mock_mcp_client.list_tools_sync.call_count == 2 - # First call with no token, second call with "page2" token - mock_mcp_client.list_tools_sync.assert_any_call(None) - mock_mcp_client.list_tools_sync.assert_any_call("page2") - - assert len(tools) == 2 - assert tools[0] is tool1 - assert tools[1] is tool2 - - -@pytest.mark.asyncio -async def test_allowed_filter_string_match(mock_mcp_client): - """Test allowed filter with string matching.""" - tool1 = create_mock_tool("allowed_tool") - tool2 = create_mock_tool("rejected_tool") - - mock_mcp_client.list_tools_sync.return_value = PaginatedList([tool1, tool2]) - - filters: ToolFilters = {"allowed": ["allowed_tool"]} - provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters) - - tools = await provider.load_tools() - - assert len(tools) == 1 - assert tools[0].tool_name == "allowed_tool" - - -@pytest.mark.asyncio -async def test_allowed_filter_regex_match(mock_mcp_client): - """Test allowed filter with regex matching.""" - tool1 = create_mock_tool("echo_tool") - tool2 = create_mock_tool("other_tool") - - mock_mcp_client.list_tools_sync.return_value = PaginatedList([tool1, tool2]) - - filters: ToolFilters = {"allowed": [re.compile(r"echo_.*")]} - provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters) - - tools = await provider.load_tools() - - assert len(tools) == 1 - assert tools[0].tool_name == "echo_tool" - - -@pytest.mark.asyncio -async def test_allowed_filter_callable_match(mock_mcp_client): - """Test allowed filter with callable matching.""" - tool1 = create_mock_tool("short") - tool2 = create_mock_tool("very_long_tool_name") - - mock_mcp_client.list_tools_sync.return_value = PaginatedList([tool1, tool2]) - - def short_names_only(tool) -> bool: - return len(tool.tool_name) <= 10 - - filters: ToolFilters = {"allowed": [short_names_only]} - provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters) - - tools = await provider.load_tools() - - assert len(tools) == 1 - assert tools[0].tool_name == "short" - - -@pytest.mark.asyncio -async def test_rejected_filter(mock_mcp_client): - """Test rejected filter functionality.""" - tool1 = create_mock_tool("good_tool") - tool2 = create_mock_tool("bad_tool") - - mock_mcp_client.list_tools_sync.return_value = PaginatedList([tool1, tool2]) - - filters: ToolFilters = {"rejected": ["bad_tool"]} - provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters) - - tools = await provider.load_tools() - - assert len(tools) == 1 - assert tools[0].tool_name == "good_tool" - - -@pytest.mark.asyncio -async def test_prefix_renames_tools(mock_mcp_client): - """Test that prefix properly renames tools.""" - original_tool = MagicMock(spec=MCPAgentTool) - original_tool.tool_name = "original_name" - original_tool.mcp_tool = MagicMock() - original_tool.mcp_tool.name = "original_name" - original_tool.mcp_client = mock_mcp_client - - mock_mcp_client.list_tools_sync.return_value = PaginatedList([original_tool]) - - with patch("strands.experimental.tools.mcp.mcp_tool_provider.MCPAgentTool") as mock_agent_tool_class: - new_tool = MagicMock(spec=MCPAgentTool) - new_tool.tool_name = "prefix_original_name" - mock_agent_tool_class.return_value = new_tool - - provider = MCPToolProvider(client=mock_mcp_client, prefix="prefix") - - tools = await provider.load_tools() - - # Should create new MCPAgentTool with prefixed name - mock_agent_tool_class.assert_called_once_with( - original_tool.mcp_tool, original_tool.mcp_client, agent_facing_tool_name="prefix_original_name" - ) - - assert len(tools) == 1 - assert tools[0] is new_tool - - -@pytest.mark.asyncio -async def test_cleanup_stops_client_when_started(mock_mcp_client): - """Test that cleanup stops the client when started.""" - provider = MCPToolProvider(client=mock_mcp_client) - provider._started = True - provider._tools = [MagicMock()] - - await provider.cleanup() - - mock_mcp_client.stop.assert_called_once_with(None, None, None) - assert provider._started is False - assert provider._tools is None - - -@pytest.mark.asyncio -async def test_cleanup_does_nothing_when_not_started(mock_mcp_client): - """Test that cleanup does nothing when not started.""" - provider = MCPToolProvider(client=mock_mcp_client) - provider._started = False - - await provider.cleanup() - - mock_mcp_client.stop.assert_not_called() - assert provider._started is False - - -@pytest.mark.asyncio -async def test_cleanup_raises_exception_on_client_stop_failure(mock_mcp_client): - """Test that cleanup raises ToolProviderException when client stop fails.""" - mock_mcp_client.stop.side_effect = Exception("Client stop failed") - - provider = MCPToolProvider(client=mock_mcp_client) - provider._started = True - - with pytest.raises(ToolProviderException, match="Failed to cleanup MCP client: Client stop failed"): - await provider.cleanup() - - # State is not reset when cleanup fails - assert provider._started is True - assert provider._tools is None - - -@pytest.mark.asyncio -async def test_cleanup_does_not_reset_state_on_exception(mock_mcp_client): - """Test that cleanup does not reset state when exception occurs.""" - mock_mcp_client.stop.side_effect = Exception("Client stop failed") - - provider = MCPToolProvider(client=mock_mcp_client) - provider._started = True - mock_tool = MagicMock() - provider._tools = [mock_tool] - - with pytest.raises(ToolProviderException): - await provider.cleanup() - - # State should not be reset when exception occurs - assert provider._started is True - assert provider._tools == [mock_tool] - - -@pytest.mark.asyncio -async def test_load_tools_with_empty_tool_list(mock_mcp_client): - """Test load_tools with empty tool list from server.""" - mock_mcp_client.list_tools_sync.return_value = PaginatedList([]) - - provider = MCPToolProvider(client=mock_mcp_client) - - tools = await provider.load_tools() - - assert len(tools) == 0 - assert provider._started is True - - -@pytest.mark.asyncio -async def test_load_tools_with_no_filters(mock_mcp_client, mock_agent_tool): - """Test load_tools with no filters applied.""" - mock_mcp_client.list_tools_sync.return_value = PaginatedList([mock_agent_tool]) - - provider = MCPToolProvider(client=mock_mcp_client, tool_filters=None) - - tools = await provider.load_tools() - - assert len(tools) == 1 - assert tools[0] is mock_agent_tool - - -@pytest.mark.asyncio -async def test_load_tools_with_empty_filters(mock_mcp_client, mock_agent_tool): - """Test load_tools with empty filters dict.""" - mock_mcp_client.list_tools_sync.return_value = PaginatedList([mock_agent_tool]) - - provider = MCPToolProvider(client=mock_mcp_client, tool_filters={}) - - tools = await provider.load_tools() - - assert len(tools) == 1 - assert tools[0] is mock_agent_tool diff --git a/tests/strands/tools/mcp/test_mcp_client_tool_provider.py b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py new file mode 100644 index 000000000..187cbeaa5 --- /dev/null +++ b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py @@ -0,0 +1,320 @@ +"""Unit tests for MCPClient ToolProvider functionality.""" + +import re +from unittest.mock import MagicMock, patch + +import pytest + +from strands.tools.mcp import MCPClient +from strands.tools.mcp.mcp_agent_tool import MCPAgentTool +from strands.tools.mcp.mcp_client import ToolFilters +from strands.types import PaginatedList +from strands.types.exceptions import ToolProviderException + + +@pytest.fixture +def mock_transport(): + """Create a mock transport callable.""" + + def transport(): + read_stream = MagicMock() + write_stream = MagicMock() + return read_stream, write_stream + + return transport + + +@pytest.fixture +def mock_mcp_tool(): + """Create a mock MCP tool.""" + tool = MagicMock() + tool.name = "test_tool" + return tool + + +@pytest.fixture +def mock_agent_tool(mock_mcp_tool): + """Create a mock MCPAgentTool.""" + agent_tool = MagicMock(spec=MCPAgentTool) + agent_tool.tool_name = "test_tool" + agent_tool.mcp_tool = mock_mcp_tool + return agent_tool + + +def create_mock_tool(name: str) -> MagicMock: + """Helper to create mock tools with specific names.""" + tool = MagicMock(spec=MCPAgentTool) + tool.tool_name = name + tool.mcp_tool = MagicMock() + tool.mcp_tool.name = name + return tool + + +def test_init_with_tool_filters_and_prefix(mock_transport): + """Test initialization with tool filters and prefix.""" + filters = {"allowed": ["tool1"]} + prefix = "test_prefix" + + client = MCPClient(mock_transport, tool_filters=filters, prefix=prefix) + + assert client._tool_filters == filters + assert client._prefix == prefix + assert client._loaded_tools is None + assert client._tool_provider_started is False + + +@pytest.mark.asyncio +async def test_load_tools_starts_client_when_not_started(mock_transport, mock_agent_tool): + """Test that load_tools starts the client when not already started.""" + client = MCPClient(mock_transport) + + with patch.object(client, "start") as mock_start, patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([mock_agent_tool]) + + tools = await client.load_tools() + + mock_start.assert_called_once() + assert client._tool_provider_started is True + assert len(tools) == 1 + assert tools[0] is mock_agent_tool + + +@pytest.mark.asyncio +async def test_load_tools_does_not_start_client_when_already_started(mock_transport, mock_agent_tool): + """Test that load_tools does not start client when already started.""" + client = MCPClient(mock_transport) + client._tool_provider_started = True + + with patch.object(client, "start") as mock_start, patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([mock_agent_tool]) + + tools = await client.load_tools() + + mock_start.assert_not_called() + assert len(tools) == 1 + + +@pytest.mark.asyncio +async def test_load_tools_raises_exception_on_client_start_failure(mock_transport): + """Test that load_tools raises ToolProviderException when client start fails.""" + client = MCPClient(mock_transport) + + with patch.object(client, "start") as mock_start: + mock_start.side_effect = Exception("Client start failed") + + with pytest.raises(ToolProviderException, match="Failed to start MCP client: Client start failed"): + await client.load_tools() + + +@pytest.mark.asyncio +async def test_load_tools_caches_tools(mock_transport, mock_agent_tool): + """Test that load_tools caches tools and doesn't reload them.""" + client = MCPClient(mock_transport) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([mock_agent_tool]) + + # First call + tools1 = await client.load_tools() + # Second call + tools2 = await client.load_tools() + + # Client should only be called once + mock_list_tools.assert_called_once() + assert tools1 is tools2 + + +@pytest.mark.asyncio +async def test_load_tools_handles_pagination(mock_transport): + """Test that load_tools handles pagination correctly.""" + tool1 = create_mock_tool("tool1") + tool2 = create_mock_tool("tool2") + + client = MCPClient(mock_transport) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + # Mock pagination: first page returns tool1 with next token, second page returns tool2 with no token + mock_list_tools.side_effect = [ + PaginatedList([tool1], token="page2"), + PaginatedList([tool2], token=None), + ] + + tools = await client.load_tools() + + # Should have called list_tools_sync twice + assert mock_list_tools.call_count == 2 + # First call with no token, second call with "page2" token + mock_list_tools.assert_any_call(None) + mock_list_tools.assert_any_call("page2") + + assert len(tools) == 2 + assert tools[0] is tool1 + assert tools[1] is tool2 + + +@pytest.mark.asyncio +async def test_allowed_filter_string_match(mock_transport): + """Test allowed filter with string matching.""" + tool1 = create_mock_tool("allowed_tool") + tool2 = create_mock_tool("rejected_tool") + + filters: ToolFilters = {"allowed": ["allowed_tool"]} + client = MCPClient(mock_transport, tool_filters=filters) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([tool1, tool2]) + + tools = await client.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "allowed_tool" + + +@pytest.mark.asyncio +async def test_allowed_filter_regex_match(mock_transport): + """Test allowed filter with regex matching.""" + tool1 = create_mock_tool("echo_tool") + tool2 = create_mock_tool("other_tool") + + filters: ToolFilters = {"allowed": [re.compile(r"echo_.*")]} + client = MCPClient(mock_transport, tool_filters=filters) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([tool1, tool2]) + + tools = await client.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "echo_tool" + + +@pytest.mark.asyncio +async def test_allowed_filter_callable_match(mock_transport): + """Test allowed filter with callable matching.""" + tool1 = create_mock_tool("short") + tool2 = create_mock_tool("very_long_tool_name") + + def short_names_only(tool) -> bool: + return len(tool.tool_name) <= 10 + + filters: ToolFilters = {"allowed": [short_names_only]} + client = MCPClient(mock_transport, tool_filters=filters) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([tool1, tool2]) + + tools = await client.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "short" + + +@pytest.mark.asyncio +async def test_rejected_filter(mock_transport): + """Test rejected filter functionality.""" + tool1 = create_mock_tool("good_tool") + tool2 = create_mock_tool("bad_tool") + + filters: ToolFilters = {"rejected": ["bad_tool"]} + client = MCPClient(mock_transport, tool_filters=filters) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([tool1, tool2]) + + tools = await client.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "good_tool" + + +@pytest.mark.asyncio +async def test_prefix_renames_tools(mock_transport): + """Test that prefix properly renames tools.""" + original_tool = create_mock_tool("original_name") + original_tool.mcp_client = MagicMock() + + client = MCPClient(mock_transport, prefix="prefix") + client._tool_provider_started = True + + with ( + patch.object(client, "list_tools_sync") as mock_list_tools, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + mock_list_tools.return_value = PaginatedList([original_tool]) + + new_tool = MagicMock(spec=MCPAgentTool) + new_tool.tool_name = "prefix_original_name" + mock_agent_tool_class.return_value = new_tool + + tools = await client.load_tools() + + # Should create new MCPAgentTool with prefixed name + mock_agent_tool_class.assert_called_once_with( + original_tool.mcp_tool, original_tool.mcp_client, agent_facing_tool_name="prefix_original_name" + ) + + assert len(tools) == 1 + assert tools[0] is new_tool + + +@pytest.mark.asyncio +async def test_add_provider_consumer(mock_transport): + """Test adding a provider consumer.""" + client = MCPClient(mock_transport) + + await client.add_provider_consumer("consumer1") + + assert "consumer1" in client._consumers + assert len(client._consumers) == 1 + + +@pytest.mark.asyncio +async def test_remove_provider_consumer_without_cleanup(mock_transport): + """Test removing a provider consumer without triggering cleanup.""" + client = MCPClient(mock_transport) + client._consumers.add("consumer1") + client._consumers.add("consumer2") + client._tool_provider_started = True + + await client.remove_provider_consumer("consumer1") + + assert "consumer1" not in client._consumers + assert "consumer2" in client._consumers + assert client._tool_provider_started is True # Should not cleanup yet + + +@pytest.mark.asyncio +async def test_remove_provider_consumer_with_cleanup(mock_transport): + """Test removing the last provider consumer triggers cleanup.""" + client = MCPClient(mock_transport) + client._consumers.add("consumer1") + client._tool_provider_started = True + client._loaded_tools = [MagicMock()] + + with patch.object(client, "stop") as mock_stop: + await client.remove_provider_consumer("consumer1") + + assert len(client._consumers) == 0 + assert client._tool_provider_started is False + assert client._loaded_tools is None + mock_stop.assert_called_once_with(None, None, None) + + +@pytest.mark.asyncio +async def test_remove_provider_consumer_cleanup_failure(mock_transport): + """Test that remove_provider_consumer raises ToolProviderException when cleanup fails.""" + client = MCPClient(mock_transport) + client._consumers.add("consumer1") + client._tool_provider_started = True + + with patch.object(client, "stop") as mock_stop: + mock_stop.side_effect = Exception("Cleanup failed") + + with pytest.raises(ToolProviderException, match="Failed to cleanup MCP client: Cleanup failed"): + await client.remove_provider_consumer("consumer1") diff --git a/tests/strands/tools/test_registry_tool_provider.py b/tests/strands/tools/test_registry_tool_provider.py index f9f9c9ce0..c9794326f 100644 --- a/tests/strands/tools/test_registry_tool_provider.py +++ b/tests/strands/tools/test_registry_tool_provider.py @@ -16,6 +16,10 @@ def __init__(self, tools=None, cleanup_error=None): self._tools = tools or [] self._cleanup_error = cleanup_error self.cleanup_called = False + self.remove_consumer_called = False + self.remove_consumer_id = None + self.add_consumer_called = False + self.add_consumer_id = None async def load_tools(self): return self._tools @@ -25,6 +29,14 @@ async def cleanup(self): if self._cleanup_error: raise self._cleanup_error + async def add_provider_consumer(self, consumer_id): + self.add_consumer_called = True + self.add_consumer_id = consumer_id + + async def remove_provider_consumer(self, consumer_id): + self.remove_consumer_called = True + self.remove_consumer_id = consumer_id + class TestToolRegistryToolProvider: """Test ToolRegistry integration with ToolProvider.""" @@ -200,3 +212,85 @@ def test_tool_provider_tracking_persistence(self): assert len(registry.tool_providers) == 2 assert provider1 in registry.tool_providers assert provider2 in registry.tool_providers + + def test_process_tools_provider_async_optimization(self): + """Test that load_tools and add_provider_consumer are called in same async context.""" + mock_tool = MagicMock(spec=AgentTool) + mock_tool.tool_name = "test_tool" + + class TestProvider(ToolProvider): + def __init__(self): + self.load_tools_called = False + self.add_consumer_called = False + self.add_consumer_id = None + + async def load_tools(self): + self.load_tools_called = True + return [mock_tool] + + async def add_provider_consumer(self, consumer_id): + self.add_consumer_called = True + self.add_consumer_id = consumer_id + + async def remove_provider_consumer(self, consumer_id): + pass + + provider = TestProvider() + registry = ToolRegistry() + + # Process the provider - this should call both methods in same async context + tool_names = registry.process_tools([provider]) + + # Verify both methods were called + assert provider.load_tools_called + assert provider.add_consumer_called + assert provider.add_consumer_id == registry._registry_id + + # Verify tool was registered + assert "test_tool" in tool_names + assert provider in registry.tool_providers + + @pytest.mark.asyncio + async def test_registry_cleanup(self): + """Test that registry cleanup calls remove_provider_consumer on all providers.""" + provider1 = MockToolProvider() + provider2 = MockToolProvider() + + registry = ToolRegistry() + registry.tool_providers = [provider1, provider2] + + await registry.cleanup_async() + + # Verify both providers had remove_provider_consumer called + assert provider1.remove_consumer_called + assert provider2.remove_consumer_called + + @pytest.mark.asyncio + async def test_registry_cleanup_with_provider_consumer_removal(self): + """Test that cleanup removes provider consumers correctly.""" + + class TestProvider(ToolProvider): + def __init__(self): + self.remove_consumer_called = False + self.remove_consumer_id = None + + async def load_tools(self): + return [] + + async def add_provider_consumer(self, consumer_id): + pass + + async def remove_provider_consumer(self, consumer_id): + self.remove_consumer_called = True + self.remove_consumer_id = consumer_id + + provider = TestProvider() + registry = ToolRegistry() + registry.tool_providers = [provider] + + # Call cleanup + await registry.cleanup_async() + + # Verify remove_provider_consumer was called with correct ID + assert provider.remove_consumer_called + assert provider.remove_consumer_id == registry._registry_id diff --git a/tests_integ/mcp/test_mcp_tool_provider.py b/tests_integ/mcp/test_mcp_tool_provider.py new file mode 100644 index 000000000..b45b38b86 --- /dev/null +++ b/tests_integ/mcp/test_mcp_tool_provider.py @@ -0,0 +1,184 @@ +"""Integration tests for MCPClient ToolProvider functionality with real MCP server.""" + +import logging +import re + +import pytest +from mcp import StdioServerParameters, stdio_client + +from strands import Agent +from strands.tools.mcp import MCPClient +from strands.tools.mcp.mcp_client import ToolFilters + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger(__name__) + + +def test_mcp_client_tool_provider_filters(): + """Test MCPClient with various filter combinations.""" + + def short_names_only(tool) -> bool: + return len(tool.tool_name) <= 20 + + filters: ToolFilters = { + "allowed": ["echo", re.compile(r"echo_with_.*"), short_names_only], + "rejected": ["echo_with_delay"], + } + + client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters=filters, + prefix="test", + ) + + agent = Agent(tools=[client]) + tool_names = agent.tool_names + + assert "echo_with_delay" not in [name.replace("test_", "") for name in tool_names] + assert all(name.startswith("test_") for name in tool_names) + + agent.cleanup() + + +def test_mcp_client_tool_provider_execution(): + """Test that MCPClient works with agent execution.""" + filters: ToolFilters = {"allowed": ["echo"]} + client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters=filters, + prefix="filtered", + ) + + agent = Agent(tools=[client]) + + assert "filtered_echo" in agent.tool_names + + tool_result = agent.tool.filtered_echo(to_echo="Hello World") + assert "Hello World" in str(tool_result) + + result = agent("Use the filtered_echo tool to echo whats inside the tags <>Integration Test") + assert "Integration Test" in str(result) + + assert agent.event_loop_metrics.tool_metrics["filtered_echo"].call_count == 1 + assert agent.event_loop_metrics.tool_metrics["filtered_echo"].success_count == 1 + + agent.cleanup() + + +def test_mcp_client_tool_provider_reuse(): + """Test that a single MCPClient can be used across multiple agents.""" + filters: ToolFilters = {"allowed": ["echo"]} + client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters=filters, + prefix="shared", + ) + + agent1 = Agent(tools=[client]) + assert "shared_echo" in agent1.tool_names + + result1 = agent1.tool.shared_echo(to_echo="Agent 1") + assert "Agent 1" in str(result1) + + agent2 = Agent(tools=[client]) + assert "shared_echo" in agent2.tool_names + + result2 = agent2.tool.shared_echo(to_echo="Agent 2") + assert "Agent 2" in str(result2) + + assert len(agent1.tool_names) == len(agent2.tool_names) + assert agent1.tool_names == agent2.tool_names + + agent1.cleanup() + agent2.cleanup() + + +def test_mcp_client_reference_counting(): + """Test that MCPClient uses reference counting - cleanup only happens when last consumer is removed.""" + filters: ToolFilters = {"allowed": ["echo"]} + client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters=filters, + prefix="ref", + ) + + # Create two agents with the same client + agent1 = Agent(tools=[client]) + agent2 = Agent(tools=[client]) + + # Both should have the tool + assert "ref_echo" in agent1.tool_names + assert "ref_echo" in agent2.tool_names + + # Agent 1 uses the tool + result1 = agent1.tool.ref_echo(to_echo="Agent 1 Test") + assert "Agent 1 Test" in str(result1) + + # Agent 1 cleans up - client should still be active for agent 2 + agent1.cleanup() + + # Agent 2 should still be able to use the tool + result2 = agent2.tool.ref_echo(to_echo="Agent 2 Test") + assert "Agent 2 Test" in str(result2) + + # Agent 2 cleans up - now client should be fully cleaned up + agent2.cleanup() + + +def test_mcp_client_multiple_servers(): + """Test MCPClient with multiple MCP servers simultaneously.""" + client1 = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters={"allowed": ["echo"]}, + prefix="server1", + ) + client2 = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters={"allowed": ["echo_with_structured_content"]}, + prefix="server2", + ) + + agent = Agent(tools=[client1, client2]) + + assert "server1_echo" in agent.tool_names + assert "server2_echo_with_structured_content" in agent.tool_names + assert len(agent.tool_names) == 2 + + result1 = agent.tool.server1_echo(to_echo="From Server 1") + assert "From Server 1" in str(result1) + + result2 = agent.tool.server2_echo_with_structured_content(to_echo="From Server 2") + assert "From Server 2" in str(result2) + + agent.cleanup() + + +def test_mcp_client_server_startup_failure(): + """Test that MCPClient handles server startup failure gracefully without hanging.""" + from strands.types.exceptions import ToolProviderException + + failing_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="nonexistent_command", args=["--invalid"])), + startup_timeout=2, + ) + + with pytest.raises(ValueError, match="Failed to load tool") as exc_info: + Agent(tools=[failing_client]) + + assert isinstance(exc_info.value.__cause__, ToolProviderException) + + +def test_mcp_client_server_connection_timeout(): + """Test that MCPClient times out gracefully when server hangs during startup.""" + from strands.types.exceptions import ToolProviderException + + hanging_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="sleep", args=["10"])), + startup_timeout=1, + ) + + with pytest.raises(ValueError, match="Failed to load tool") as exc_info: + Agent(tools=[hanging_client]) + + assert isinstance(exc_info.value.__cause__, ToolProviderException) diff --git a/tests_integ/test_mcp_tool_provider.py b/tests_integ/test_mcp_tool_provider.py deleted file mode 100644 index 5b7bb3ed1..000000000 --- a/tests_integ/test_mcp_tool_provider.py +++ /dev/null @@ -1,171 +0,0 @@ -"""Integration tests for MCPToolProvider with real MCP server.""" - -import logging -import re - -import pytest -from mcp import StdioServerParameters, stdio_client - -from strands import Agent -from strands.experimental.tools.mcp import MCPToolProvider, ToolFilters -from strands.tools.mcp import MCPClient -from strands.types.exceptions import ToolProviderException - -logging.basicConfig(level=logging.DEBUG) - -logger = logging.getLogger(__name__) - - -def test_mcp_tool_provider_filters(): - """Test MCPToolProvider with various filter combinations.""" - stdio_mcp_client = MCPClient( - lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) - ) - - # Test string filter, regex filter, callable filter, and prefix - def short_names_only(tool) -> bool: - return len(tool.tool_name) <= 20 # Allow most tools - - filters: ToolFilters = { - "allowed": ["echo", re.compile(r"echo_with_.*"), short_names_only], - "rejected": ["echo_with_delay"], - "max_tools": 2, - } - - provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, prefix="test") - agent = Agent(tools=[provider]) - tool_names = agent.tool_names - - # Should have 2 tools max, with test_ prefix, no delay tool - assert len(tool_names) == 2 - assert "echo_with_delay" not in [name.replace("test_", "") for name in tool_names] - assert all(name.startswith("test_") for name in tool_names) - - agent.cleanup() - - -def test_mcp_tool_provider_execution(): - """Test that MCPToolProvider works with agent execution.""" - stdio_mcp_client = MCPClient( - lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) - ) - - filters: ToolFilters = {"allowed": ["echo"]} - provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, prefix="filtered") - agent = Agent( - tools=[provider], - ) - - # Verify the filtered tool exists - assert "filtered_echo" in agent.tool_names - - # # Test direct tool call to verify it works (use correct parameter name from echo server) - tool_result = agent.tool.filtered_echo(to_echo="Hello World") - assert "Hello World" in str(tool_result) - - # # Test agent execution using the tool - result = agent("Use the filtered_echo tool to echo whats inside the tags <>Integration Test") - assert "Integration Test" in str(result) - - assert agent.event_loop_metrics.tool_metrics["filtered_echo"].call_count == 1 - assert agent.event_loop_metrics.tool_metrics["filtered_echo"].success_count == 1 - - agent.cleanup() - - -def test_mcp_tool_provider_reuse(): - """Test that a single MCPToolProvider can be used across multiple agents.""" - stdio_mcp_client = MCPClient( - lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) - ) - - filters: ToolFilters = {"allowed": ["echo"]} - provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, prefix="shared") - - # Create first agent with the provider - agent1 = Agent(tools=[provider]) - assert "shared_echo" in agent1.tool_names - - # Test first agent (use correct parameter name from echo server) - result1 = agent1.tool.shared_echo(to_echo="Agent 1") - assert "Agent 1" in str(result1) - - # Create second agent with the same provider - agent2 = Agent(tools=[provider]) - assert "shared_echo" in agent2.tool_names - - # Test second agent (use correct parameter name from echo server) - result2 = agent2.tool.shared_echo(to_echo="Agent 2") - assert "Agent 2" in str(result2) - - # Both agents should have the same tool count - assert len(agent1.tool_names) == len(agent2.tool_names) - assert agent1.tool_names == agent2.tool_names - - agent1.cleanup() - agent2.cleanup() - - -def test_mcp_tool_provider_multiple_servers(): - """Test MCPToolProvider with multiple MCP servers simultaneously.""" - # Create two separate MCP clients - client1 = MCPClient( - lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) - ) - client2 = MCPClient( - lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) - ) - - # Create providers with different prefixes - provider1 = MCPToolProvider(client=client1, tool_filters={"allowed": ["echo"]}, prefix="server1") - # Use correct tool name from echo_server.py - provider2 = MCPToolProvider( - client=client2, tool_filters={"allowed": ["echo_with_structured_content"]}, prefix="server2" - ) - - # Create agent with both providers - agent = Agent(tools=[provider1, provider2]) - - # Should have tools from both servers with different prefixes - assert "server1_echo" in agent.tool_names - assert "server2_echo_with_structured_content" in agent.tool_names - assert len(agent.tool_names) == 2 - - # Test tools from both servers work - result1 = agent.tool.server1_echo(to_echo="From Server 1") - assert "From Server 1" in str(result1) - - result2 = agent.tool.server2_echo_with_structured_content(to_echo="From Server 2") - assert "From Server 2" in str(result2) - - agent.cleanup() - - -def test_mcp_tool_provider_server_startup_failure(): - """Test that MCPToolProvider handles server startup failure gracefully without hanging.""" - # Create client with invalid command that will fail to start - failing_client = MCPClient( - lambda: stdio_client(StdioServerParameters(command="nonexistent_command", args=["--invalid"])), - startup_timeout=2, # Short timeout to avoid hanging - ) - - provider = MCPToolProvider(client=failing_client) - - # Should raise ToolProviderException when trying to load tools - with pytest.raises(ToolProviderException, match="Failed to start MCP client"): - Agent(tools=[provider]) - - -def test_mcp_tool_provider_server_connection_timeout(): - """Test that MCPToolProvider times out gracefully when server hangs during startup.""" - # Create client that will hang during connection - hanging_client = MCPClient( - lambda: stdio_client(StdioServerParameters(command="sleep", args=["10"])), # Sleep for 10 seconds - startup_timeout=1, # 1 second timeout - ) - - provider = MCPToolProvider(client=hanging_client) - - # Should raise ToolProviderException due to timeout - with pytest.raises(ToolProviderException, match="Failed to start MCP client"): - Agent(tools=[provider]) From 6d51176b303ee71c09c8ab97eb8ff5a12e79d867 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 8 Oct 2025 17:54:41 -0400 Subject: [PATCH 4/7] fix code coverage skip --- .codecov.yml | 3 +++ src/strands/tools/mcp/mcp_agent_tool.py | 6 +++--- src/strands/tools/mcp/mcp_client.py | 14 +++++++++----- .../tools/mcp/test_mcp_client_tool_provider.py | 6 +++--- 4 files changed, 18 insertions(+), 11 deletions(-) create mode 100644 .codecov.yml diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 000000000..36be0c484 --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,3 @@ +coverage: + ignore: + - "src/strands/experimental/tools/mcp/mcp_tool_provider.py" diff --git a/src/strands/tools/mcp/mcp_agent_tool.py b/src/strands/tools/mcp/mcp_agent_tool.py index 91ec6216a..af0c069a1 100644 --- a/src/strands/tools/mcp/mcp_agent_tool.py +++ b/src/strands/tools/mcp/mcp_agent_tool.py @@ -28,20 +28,20 @@ class MCPAgentTool(AgentTool): seamlessly within the agent framework. """ - def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient", agent_facing_tool_name: str | None = None) -> None: + def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient", name_override: str | None = None) -> None: """Initialize a new MCPAgentTool instance. Args: mcp_tool: The MCP tool to adapt mcp_client: The MCP server connection to use for tool invocation - agent_facing_tool_name: Optional name to use for the agent tool (for disambiguation) + name_override: Optional name to use for the agent tool (for disambiguation) If None, uses the original MCP tool name """ super().__init__() logger.debug("tool_name=<%s> | creating mcp agent tool", mcp_tool.name) self.mcp_tool = mcp_tool self.mcp_client = mcp_client - self._agent_tool_name = agent_facing_tool_name or mcp_tool.name + self._agent_tool_name = name_override or mcp_tool.name @property def tool_name(self) -> str: diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index a780966cd..ec3705c02 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -24,7 +24,7 @@ from mcp.types import GetPromptResult, ListPromptsResult from mcp.types import ImageContent as MCPImageContent from mcp.types import TextContent as MCPTextContent -from typing_extensions import TypedDict +from typing_extensions import Protocol, TypedDict from ...experimental.tools import ToolProvider from ...types import PaginatedList @@ -39,8 +39,12 @@ T = TypeVar("T") -_ToolFilterCallback = Callable[[AgentTool], bool] -_ToolFilterPattern = Union[str, Pattern[str], _ToolFilterCallback] + +class _ToolFilterCallback(Protocol): + def __call__(self, tool: AgentTool, **kwargs: Any) -> bool: ... + + +_ToolFilterPattern = str | Pattern[str] | _ToolFilterCallback class ToolFilters(TypedDict, total=False): @@ -624,7 +628,7 @@ def _apply_prefix(self, tool: MCPAgentTool) -> MCPAgentTool: # Create new tool with prefixed agent name but preserve original MCP name old_name = tool.tool_name new_agent_name = f"{self._prefix}_{tool.mcp_tool.name}" - new_tool = MCPAgentTool(tool.mcp_tool, tool.mcp_client, agent_facing_tool_name=new_agent_name) + new_tool = MCPAgentTool(tool.mcp_tool, tool.mcp_client, name_override=new_agent_name) logger.debug("tool_rename=<%s->%s> | renamed tool", old_name, new_agent_name) return new_tool @@ -634,7 +638,7 @@ def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolFilterPatter if callable(pattern): if pattern(tool): return True - elif hasattr(pattern, "match") and hasattr(pattern, "pattern"): + elif isinstance(pattern, Pattern): if pattern.match(tool.tool_name): return True elif isinstance(pattern, str): diff --git a/tests/strands/tools/mcp/test_mcp_client_tool_provider.py b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py index 187cbeaa5..59a9eb81a 100644 --- a/tests/strands/tools/mcp/test_mcp_client_tool_provider.py +++ b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py @@ -215,8 +215,8 @@ def short_names_only(tool) -> bool: @pytest.mark.asyncio -async def test_rejected_filter(mock_transport): - """Test rejected filter functionality.""" +async def test_rejected_filter_string_match(mock_transport): + """Test rejected filter with string matching.""" tool1 = create_mock_tool("good_tool") tool2 = create_mock_tool("bad_tool") @@ -256,7 +256,7 @@ async def test_prefix_renames_tools(mock_transport): # Should create new MCPAgentTool with prefixed name mock_agent_tool_class.assert_called_once_with( - original_tool.mcp_tool, original_tool.mcp_client, agent_facing_tool_name="prefix_original_name" + original_tool.mcp_tool, original_tool.mcp_client, name_override="prefix_original_name" ) assert len(tools) == 1 From d0c323c8d90a811213559caff7225dfc84a0e3ab Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 9 Oct 2025 13:58:30 -0400 Subject: [PATCH 5/7] comments --- .codecov.yml | 2 +- src/strands/agent/agent.py | 8 ++- .../experimental/tools/tool_provider.py | 4 +- src/strands/tools/mcp/mcp_client.py | 4 +- src/strands/tools/registry.py | 4 +- tests/strands/agent/test_agent.py | 56 ++++++++++--------- .../mcp/test_mcp_client_tool_provider.py | 18 +++--- .../tools/test_registry_tool_provider.py | 20 +++---- 8 files changed, 61 insertions(+), 55 deletions(-) diff --git a/.codecov.yml b/.codecov.yml index 36be0c484..866a0af3a 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -1,3 +1,3 @@ coverage: ignore: - - "src/strands/experimental/tools/mcp/mcp_tool_provider.py" + - "src/strands/experimental/tools/tool_provider.py" # This is an interface, cannot meaningfully cover diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index e428ee0e5..c4cb83fb5 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -12,6 +12,7 @@ import json import logging import random +import warnings from typing import ( Any, AsyncGenerator, @@ -558,10 +559,11 @@ def __del__(self) -> None: if self._cleanup_called or not self.tool_registry.tool_providers: return - logger.warning( - "agent_id=<%s> | Agent cleanup called via __del__. " + warnings.warn( + f"agent_id={self.agent_id} | Agent cleanup called via __del__. " "Consider calling agent.cleanup() explicitly for better resource management.", - self.agent_id, + ResourceWarning, + stacklevel=2, ) self.cleanup() except Exception as e: diff --git a/src/strands/experimental/tools/tool_provider.py b/src/strands/experimental/tools/tool_provider.py index 5023dd72c..401555368 100644 --- a/src/strands/experimental/tools/tool_provider.py +++ b/src/strands/experimental/tools/tool_provider.py @@ -27,7 +27,7 @@ async def load_tools(self, **kwargs: Any) -> Sequence["AgentTool"]: pass @abstractmethod - async def add_provider_consumer(self, id: Any, **kwargs: Any) -> None: + async def add_consumer(self, id: Any, **kwargs: Any) -> None: """Add a consumer to this tool provider. Args: @@ -37,7 +37,7 @@ async def add_provider_consumer(self, id: Any, **kwargs: Any) -> None: pass @abstractmethod - async def remove_provider_consumer(self, id: Any, **kwargs: Any) -> None: + async def remove_consumer(self, id: Any, **kwargs: Any) -> None: """Remove a consumer from this tool provider. Args: diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index ec3705c02..f10ee2d33 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -236,12 +236,12 @@ async def load_tools(self, **kwargs: Any) -> Sequence[AgentTool]: return self._loaded_tools - async def add_provider_consumer(self, id: Any, **kwargs: Any) -> None: + async def add_consumer(self, id: Any, **kwargs: Any) -> None: """Add a consumer to this tool provider.""" self._consumers.add(id) logger.debug("added provider consumer, count=%d", len(self._consumers)) - async def remove_provider_consumer(self, id: Any, **kwargs: Any) -> None: + async def remove_consumer(self, id: Any, **kwargs: Any) -> None: """Remove a consumer from this tool provider.""" self._consumers.discard(id) logger.debug("removed provider consumer, count=%d", len(self._consumers)) diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 52028ee32..39fd508d0 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -130,7 +130,7 @@ def add_tool(tool: Any) -> None: async def get_tools_and_register_consumer() -> Sequence[AgentTool]: provider_tools = await tool.load_tools() - await tool.add_provider_consumer(self._registry_id) + await tool.add_consumer(self._registry_id) return provider_tools provider_tools = run_async(get_tools_and_register_consumer) @@ -665,7 +665,7 @@ async def cleanup_async(self, **kwargs: Any) -> None: """Clean up all tool providers in this registry.""" for provider in self.tool_providers: try: - await provider.remove_provider_consumer(self._registry_id) + await provider.remove_consumer(self._registry_id) logger.debug("provider=<%s> | removed provider consumer", type(provider).__name__) except Exception as e: logger.warning( diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 5e6cee810..37dee266c 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -4,6 +4,7 @@ import os import textwrap import unittest.mock +import warnings from uuid import uuid4 import pytest @@ -911,15 +912,15 @@ async def test_agent_cleanup_async(agent): """Test that agent cleanup_async method works correctly.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() + mock_provider.remove_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] await agent.cleanup_async() - # Verify provider remove_provider_consumer was called - mock_provider.remove_provider_consumer.assert_called_once_with(agent.tool_registry._registry_id) + # Verify provider remove_consumer was called + mock_provider.remove_consumer.assert_called_once_with(agent.tool_registry._registry_id) # Verify cleanup was marked as called assert agent._cleanup_called is True @@ -929,9 +930,9 @@ async def test_agent_cleanup_async_handles_exceptions(agent): """Test that agent cleanup_async handles exceptions gracefully.""" # Create mock tool providers, one that raises an exception mock_provider1 = unittest.mock.MagicMock() - mock_provider1.remove_provider_consumer = unittest.mock.AsyncMock() + mock_provider1.remove_consumer = unittest.mock.AsyncMock() mock_provider2 = unittest.mock.MagicMock() - mock_provider2.remove_provider_consumer = unittest.mock.AsyncMock(side_effect=Exception("Cleanup failed")) + mock_provider2.remove_consumer = unittest.mock.AsyncMock(side_effect=Exception("Cleanup failed")) # Add providers to agent's tool registry agent.tool_registry.tool_providers = [mock_provider1, mock_provider2] @@ -940,8 +941,8 @@ async def test_agent_cleanup_async_handles_exceptions(agent): await agent.cleanup_async() # Verify both providers were attempted - mock_provider1.remove_provider_consumer.assert_called_once() - mock_provider2.remove_provider_consumer.assert_called_once() + mock_provider1.remove_consumer.assert_called_once() + mock_provider2.remove_consumer.assert_called_once() # Verify cleanup was marked as called assert agent._cleanup_called is True @@ -951,7 +952,7 @@ async def test_agent_cleanup_async_idempotent(agent): """Test that calling cleanup_async multiple times is safe.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() + mock_provider.remove_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] @@ -960,8 +961,8 @@ async def test_agent_cleanup_async_idempotent(agent): await agent.cleanup_async() await agent.cleanup_async() - # Verify provider remove_provider_consumer was only called once due to idempotency - mock_provider.remove_provider_consumer.assert_called_once() + # Verify provider remove_consumer was only called once due to idempotency + mock_provider.remove_consumer.assert_called_once() @pytest.mark.asyncio @@ -999,7 +1000,7 @@ def test_agent_cleanup_idempotent(agent): """Test that calling cleanup multiple times is safe.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() + mock_provider.remove_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] @@ -1008,8 +1009,8 @@ def test_agent_cleanup_idempotent(agent): agent.cleanup() agent.cleanup() - # Verify provider remove_provider_consumer was only called once due to idempotency - mock_provider.remove_provider_consumer.assert_called_once() + # Verify provider remove_consumer was only called once due to idempotency + mock_provider.remove_consumer.assert_called_once() def test_agent_cleanup_early_return_avoids_thread_spawn(agent): @@ -1033,14 +1034,15 @@ def test_agent__del__emits_warning_for_automatic_cleanup(): mock_provider = unittest.mock.MagicMock() agent.tool_registry.tool_providers = [mock_provider] - with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: - with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: + with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") agent.__del__() - # Verify warning was logged - mock_logger.warning.assert_called_once() - warning_call = mock_logger.warning.call_args[0] - assert "Agent cleanup called via __del__" in warning_call[0] + # Verify warning was emitted + assert len(w) == 1 + assert issubclass(w[0].category, ResourceWarning) + assert "Agent cleanup called via __del__" in str(w[0].message) # Verify cleanup was called mock_cleanup.assert_called_once() @@ -1056,11 +1058,12 @@ def test_agent__del__no_warning_after_manual_cleanup(): with unittest.mock.patch.object(agent, "cleanup_async"): agent.cleanup() - with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") agent.__del__() - # Verify no warning was logged - mock_logger.warning.assert_not_called() + # Verify no warning was emitted + assert len(w) == 0 def test_agent__del__no_warning_when_no_tool_providers(): @@ -1073,12 +1076,13 @@ def test_agent__del__no_warning_when_no_tool_providers(): # Ensure no tool providers agent.tool_registry.tool_providers = [] - with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: - with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: + with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") agent.__del__() - # Verify no warning was logged and cleanup wasn't called - mock_logger.warning.assert_not_called() + # Verify no warning was emitted and cleanup wasn't called + assert len(w) == 0 mock_cleanup.assert_not_called() diff --git a/tests/strands/tools/mcp/test_mcp_client_tool_provider.py b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py index 59a9eb81a..094cc05b1 100644 --- a/tests/strands/tools/mcp/test_mcp_client_tool_provider.py +++ b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py @@ -264,25 +264,25 @@ async def test_prefix_renames_tools(mock_transport): @pytest.mark.asyncio -async def test_add_provider_consumer(mock_transport): +async def test_add_consumer(mock_transport): """Test adding a provider consumer.""" client = MCPClient(mock_transport) - await client.add_provider_consumer("consumer1") + await client.add_consumer("consumer1") assert "consumer1" in client._consumers assert len(client._consumers) == 1 @pytest.mark.asyncio -async def test_remove_provider_consumer_without_cleanup(mock_transport): +async def test_remove_consumer_without_cleanup(mock_transport): """Test removing a provider consumer without triggering cleanup.""" client = MCPClient(mock_transport) client._consumers.add("consumer1") client._consumers.add("consumer2") client._tool_provider_started = True - await client.remove_provider_consumer("consumer1") + await client.remove_consumer("consumer1") assert "consumer1" not in client._consumers assert "consumer2" in client._consumers @@ -290,7 +290,7 @@ async def test_remove_provider_consumer_without_cleanup(mock_transport): @pytest.mark.asyncio -async def test_remove_provider_consumer_with_cleanup(mock_transport): +async def test_remove_consumer_with_cleanup(mock_transport): """Test removing the last provider consumer triggers cleanup.""" client = MCPClient(mock_transport) client._consumers.add("consumer1") @@ -298,7 +298,7 @@ async def test_remove_provider_consumer_with_cleanup(mock_transport): client._loaded_tools = [MagicMock()] with patch.object(client, "stop") as mock_stop: - await client.remove_provider_consumer("consumer1") + await client.remove_consumer("consumer1") assert len(client._consumers) == 0 assert client._tool_provider_started is False @@ -307,8 +307,8 @@ async def test_remove_provider_consumer_with_cleanup(mock_transport): @pytest.mark.asyncio -async def test_remove_provider_consumer_cleanup_failure(mock_transport): - """Test that remove_provider_consumer raises ToolProviderException when cleanup fails.""" +async def test_remove_consumer_cleanup_failure(mock_transport): + """Test that remove_consumer raises ToolProviderException when cleanup fails.""" client = MCPClient(mock_transport) client._consumers.add("consumer1") client._tool_provider_started = True @@ -317,4 +317,4 @@ async def test_remove_provider_consumer_cleanup_failure(mock_transport): mock_stop.side_effect = Exception("Cleanup failed") with pytest.raises(ToolProviderException, match="Failed to cleanup MCP client: Cleanup failed"): - await client.remove_provider_consumer("consumer1") + await client.remove_consumer("consumer1") diff --git a/tests/strands/tools/test_registry_tool_provider.py b/tests/strands/tools/test_registry_tool_provider.py index c9794326f..ca10862dc 100644 --- a/tests/strands/tools/test_registry_tool_provider.py +++ b/tests/strands/tools/test_registry_tool_provider.py @@ -29,11 +29,11 @@ async def cleanup(self): if self._cleanup_error: raise self._cleanup_error - async def add_provider_consumer(self, consumer_id): + async def add_consumer(self, consumer_id): self.add_consumer_called = True self.add_consumer_id = consumer_id - async def remove_provider_consumer(self, consumer_id): + async def remove_consumer(self, consumer_id): self.remove_consumer_called = True self.remove_consumer_id = consumer_id @@ -214,7 +214,7 @@ def test_tool_provider_tracking_persistence(self): assert provider2 in registry.tool_providers def test_process_tools_provider_async_optimization(self): - """Test that load_tools and add_provider_consumer are called in same async context.""" + """Test that load_tools and add_consumer are called in same async context.""" mock_tool = MagicMock(spec=AgentTool) mock_tool.tool_name = "test_tool" @@ -228,11 +228,11 @@ async def load_tools(self): self.load_tools_called = True return [mock_tool] - async def add_provider_consumer(self, consumer_id): + async def add_consumer(self, consumer_id): self.add_consumer_called = True self.add_consumer_id = consumer_id - async def remove_provider_consumer(self, consumer_id): + async def remove_consumer(self, consumer_id): pass provider = TestProvider() @@ -252,7 +252,7 @@ async def remove_provider_consumer(self, consumer_id): @pytest.mark.asyncio async def test_registry_cleanup(self): - """Test that registry cleanup calls remove_provider_consumer on all providers.""" + """Test that registry cleanup calls remove_consumer on all providers.""" provider1 = MockToolProvider() provider2 = MockToolProvider() @@ -261,7 +261,7 @@ async def test_registry_cleanup(self): await registry.cleanup_async() - # Verify both providers had remove_provider_consumer called + # Verify both providers had remove_consumer called assert provider1.remove_consumer_called assert provider2.remove_consumer_called @@ -277,10 +277,10 @@ def __init__(self): async def load_tools(self): return [] - async def add_provider_consumer(self, consumer_id): + async def add_consumer(self, consumer_id): pass - async def remove_provider_consumer(self, consumer_id): + async def remove_consumer(self, consumer_id): self.remove_consumer_called = True self.remove_consumer_id = consumer_id @@ -291,6 +291,6 @@ async def remove_provider_consumer(self, consumer_id): # Call cleanup await registry.cleanup_async() - # Verify remove_provider_consumer was called with correct ID + # Verify remove_consumer was called with correct ID assert provider.remove_consumer_called assert provider.remove_consumer_id == registry._registry_id From 8bc6dcb554f96f2a7536aa532a3933785371643b Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 9 Oct 2025 14:30:47 -0400 Subject: [PATCH 6/7] comments --- src/strands/agent/agent.py | 8 ++--- tests/strands/agent/test_agent.py | 56 ++++++++++++++----------------- 2 files changed, 29 insertions(+), 35 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index c4cb83fb5..e428ee0e5 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -12,7 +12,6 @@ import json import logging import random -import warnings from typing import ( Any, AsyncGenerator, @@ -559,11 +558,10 @@ def __del__(self) -> None: if self._cleanup_called or not self.tool_registry.tool_providers: return - warnings.warn( - f"agent_id={self.agent_id} | Agent cleanup called via __del__. " + logger.warning( + "agent_id=<%s> | Agent cleanup called via __del__. " "Consider calling agent.cleanup() explicitly for better resource management.", - ResourceWarning, - stacklevel=2, + self.agent_id, ) self.cleanup() except Exception as e: diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 37dee266c..5e6cee810 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -4,7 +4,6 @@ import os import textwrap import unittest.mock -import warnings from uuid import uuid4 import pytest @@ -912,15 +911,15 @@ async def test_agent_cleanup_async(agent): """Test that agent cleanup_async method works correctly.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.remove_consumer = unittest.mock.AsyncMock() + mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] await agent.cleanup_async() - # Verify provider remove_consumer was called - mock_provider.remove_consumer.assert_called_once_with(agent.tool_registry._registry_id) + # Verify provider remove_provider_consumer was called + mock_provider.remove_provider_consumer.assert_called_once_with(agent.tool_registry._registry_id) # Verify cleanup was marked as called assert agent._cleanup_called is True @@ -930,9 +929,9 @@ async def test_agent_cleanup_async_handles_exceptions(agent): """Test that agent cleanup_async handles exceptions gracefully.""" # Create mock tool providers, one that raises an exception mock_provider1 = unittest.mock.MagicMock() - mock_provider1.remove_consumer = unittest.mock.AsyncMock() + mock_provider1.remove_provider_consumer = unittest.mock.AsyncMock() mock_provider2 = unittest.mock.MagicMock() - mock_provider2.remove_consumer = unittest.mock.AsyncMock(side_effect=Exception("Cleanup failed")) + mock_provider2.remove_provider_consumer = unittest.mock.AsyncMock(side_effect=Exception("Cleanup failed")) # Add providers to agent's tool registry agent.tool_registry.tool_providers = [mock_provider1, mock_provider2] @@ -941,8 +940,8 @@ async def test_agent_cleanup_async_handles_exceptions(agent): await agent.cleanup_async() # Verify both providers were attempted - mock_provider1.remove_consumer.assert_called_once() - mock_provider2.remove_consumer.assert_called_once() + mock_provider1.remove_provider_consumer.assert_called_once() + mock_provider2.remove_provider_consumer.assert_called_once() # Verify cleanup was marked as called assert agent._cleanup_called is True @@ -952,7 +951,7 @@ async def test_agent_cleanup_async_idempotent(agent): """Test that calling cleanup_async multiple times is safe.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.remove_consumer = unittest.mock.AsyncMock() + mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] @@ -961,8 +960,8 @@ async def test_agent_cleanup_async_idempotent(agent): await agent.cleanup_async() await agent.cleanup_async() - # Verify provider remove_consumer was only called once due to idempotency - mock_provider.remove_consumer.assert_called_once() + # Verify provider remove_provider_consumer was only called once due to idempotency + mock_provider.remove_provider_consumer.assert_called_once() @pytest.mark.asyncio @@ -1000,7 +999,7 @@ def test_agent_cleanup_idempotent(agent): """Test that calling cleanup multiple times is safe.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.remove_consumer = unittest.mock.AsyncMock() + mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] @@ -1009,8 +1008,8 @@ def test_agent_cleanup_idempotent(agent): agent.cleanup() agent.cleanup() - # Verify provider remove_consumer was only called once due to idempotency - mock_provider.remove_consumer.assert_called_once() + # Verify provider remove_provider_consumer was only called once due to idempotency + mock_provider.remove_provider_consumer.assert_called_once() def test_agent_cleanup_early_return_avoids_thread_spawn(agent): @@ -1034,15 +1033,14 @@ def test_agent__del__emits_warning_for_automatic_cleanup(): mock_provider = unittest.mock.MagicMock() agent.tool_registry.tool_providers = [mock_provider] - with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") + with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: + with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: agent.__del__() - # Verify warning was emitted - assert len(w) == 1 - assert issubclass(w[0].category, ResourceWarning) - assert "Agent cleanup called via __del__" in str(w[0].message) + # Verify warning was logged + mock_logger.warning.assert_called_once() + warning_call = mock_logger.warning.call_args[0] + assert "Agent cleanup called via __del__" in warning_call[0] # Verify cleanup was called mock_cleanup.assert_called_once() @@ -1058,12 +1056,11 @@ def test_agent__del__no_warning_after_manual_cleanup(): with unittest.mock.patch.object(agent, "cleanup_async"): agent.cleanup() - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") + with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: agent.__del__() - # Verify no warning was emitted - assert len(w) == 0 + # Verify no warning was logged + mock_logger.warning.assert_not_called() def test_agent__del__no_warning_when_no_tool_providers(): @@ -1076,13 +1073,12 @@ def test_agent__del__no_warning_when_no_tool_providers(): # Ensure no tool providers agent.tool_registry.tool_providers = [] - with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") + with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: + with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: agent.__del__() - # Verify no warning was emitted and cleanup wasn't called - assert len(w) == 0 + # Verify no warning was logged and cleanup wasn't called + mock_logger.warning.assert_not_called() mock_cleanup.assert_not_called() From 1f29e5e7e0d4c9d26f9b3f8d9ea2b3256020e75e Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 9 Oct 2025 14:46:28 -0400 Subject: [PATCH 7/7] comments --- tests/strands/agent/test_agent.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 5e6cee810..83ed4d324 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -911,15 +911,15 @@ async def test_agent_cleanup_async(agent): """Test that agent cleanup_async method works correctly.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() + mock_provider.remove_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] await agent.cleanup_async() - # Verify provider remove_provider_consumer was called - mock_provider.remove_provider_consumer.assert_called_once_with(agent.tool_registry._registry_id) + # Verify provider remove_consumer was called + mock_provider.remove_consumer.assert_called_once_with(agent.tool_registry._registry_id) # Verify cleanup was marked as called assert agent._cleanup_called is True @@ -929,9 +929,9 @@ async def test_agent_cleanup_async_handles_exceptions(agent): """Test that agent cleanup_async handles exceptions gracefully.""" # Create mock tool providers, one that raises an exception mock_provider1 = unittest.mock.MagicMock() - mock_provider1.remove_provider_consumer = unittest.mock.AsyncMock() + mock_provider1.remove_consumer = unittest.mock.AsyncMock() mock_provider2 = unittest.mock.MagicMock() - mock_provider2.remove_provider_consumer = unittest.mock.AsyncMock(side_effect=Exception("Cleanup failed")) + mock_provider2.remove_consumer = unittest.mock.AsyncMock(side_effect=Exception("Cleanup failed")) # Add providers to agent's tool registry agent.tool_registry.tool_providers = [mock_provider1, mock_provider2] @@ -940,8 +940,8 @@ async def test_agent_cleanup_async_handles_exceptions(agent): await agent.cleanup_async() # Verify both providers were attempted - mock_provider1.remove_provider_consumer.assert_called_once() - mock_provider2.remove_provider_consumer.assert_called_once() + mock_provider1.remove_consumer.assert_called_once() + mock_provider2.remove_consumer.assert_called_once() # Verify cleanup was marked as called assert agent._cleanup_called is True @@ -951,7 +951,7 @@ async def test_agent_cleanup_async_idempotent(agent): """Test that calling cleanup_async multiple times is safe.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() + mock_provider.remove_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] @@ -960,8 +960,8 @@ async def test_agent_cleanup_async_idempotent(agent): await agent.cleanup_async() await agent.cleanup_async() - # Verify provider remove_provider_consumer was only called once due to idempotency - mock_provider.remove_provider_consumer.assert_called_once() + # Verify provider remove_consumer was only called once due to idempotency + mock_provider.remove_consumer.assert_called_once() @pytest.mark.asyncio @@ -999,7 +999,7 @@ def test_agent_cleanup_idempotent(agent): """Test that calling cleanup multiple times is safe.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() + mock_provider.remove_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] @@ -1008,8 +1008,8 @@ def test_agent_cleanup_idempotent(agent): agent.cleanup() agent.cleanup() - # Verify provider remove_provider_consumer was only called once due to idempotency - mock_provider.remove_provider_consumer.assert_called_once() + # Verify provider remove_consumer was only called once due to idempotency + mock_provider.remove_consumer.assert_called_once() def test_agent_cleanup_early_return_avoids_thread_spawn(agent):