diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index ac3cfeae5c..6e713f6a11 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -226,12 +226,21 @@ class ServerCapabilities: prompts: bool = False """Whether the server offers any prompt templates.""" + prompts_list_changed: bool = False + """Whether the server will emit notifications when the list of prompts changes.""" + resources: bool = False """Whether the server offers any resources to read.""" + resources_list_changed: bool = False + """Whether the server will emit notifications when the list of resources changes.""" + tools: bool = False """Whether the server offers any tools to call.""" + tools_list_changed: bool = False + """Whether the server will emit notifications when the list of tools changes.""" + completions: bool = False """Whether the server offers autocompletion suggestions for prompts and resources.""" @@ -244,12 +253,18 @@ def from_mcp_sdk(cls, mcp_capabilities: mcp_types.ServerCapabilities) -> ServerC Args: mcp_capabilities: The MCP SDK ServerCapabilities object. """ + prompts_cap = mcp_capabilities.prompts + resources_cap = mcp_capabilities.resources + tools_cap = mcp_capabilities.tools return cls( experimental=list(mcp_capabilities.experimental.keys()) if mcp_capabilities.experimental else None, logging=mcp_capabilities.logging is not None, - prompts=mcp_capabilities.prompts is not None, - resources=mcp_capabilities.resources is not None, - tools=mcp_capabilities.tools is not None, + prompts=prompts_cap is not None, + prompts_list_changed=bool(prompts_cap.listChanged) if prompts_cap else False, + resources=resources_cap is not None, + resources_list_changed=bool(resources_cap.listChanged) if resources_cap else False, + tools=tools_cap is not None, + tools_list_changed=bool(tools_cap.listChanged) if tools_cap else False, completions=mcp_capabilities.completions is not None, ) @@ -319,6 +334,26 @@ class MCPServer(AbstractToolset[Any], ABC): elicitation_callback: ElicitationFnT | None = None """Callback function to handle elicitation requests from the server.""" + cache_tools: bool + """Whether to cache the list of tools. + + When enabled (default), tools are fetched once and cached until either: + - The server sends a `notifications/tools/list_changed` notification + - The connection is closed + + Set to `False` for servers that change tools dynamically without sending notifications. + """ + + cache_resources: bool + """Whether to cache the list of resources. + + When enabled (default), resources are fetched once and cached until either: + - The server sends a `notifications/resources/list_changed` notification + - The connection is closed + + Set to `False` for servers that change resources dynamically without sending notifications. + """ + _id: str | None _enter_lock: Lock = field(compare=False) @@ -332,6 +367,9 @@ class MCPServer(AbstractToolset[Any], ABC): _server_capabilities: ServerCapabilities _instructions: str | None + _cached_tools: list[mcp_types.Tool] | None + _cached_resources: list[Resource] | None + def __init__( self, tool_prefix: str | None = None, @@ -344,6 +382,8 @@ def __init__( sampling_model: models.Model | None = None, max_retries: int = 1, elicitation_callback: ElicitationFnT | None = None, + cache_tools: bool = True, + cache_resources: bool = True, *, id: str | None = None, ): @@ -357,6 +397,8 @@ def __init__( self.sampling_model = sampling_model self.max_retries = max_retries self.elicitation_callback = elicitation_callback + self.cache_tools = cache_tools + self.cache_resources = cache_resources self._id = id or tool_prefix @@ -366,6 +408,8 @@ def __post_init__(self): self._enter_lock = Lock() self._running_count = 0 self._exit_stack = None + self._cached_tools = None + self._cached_resources = None @abstractmethod @asynccontextmanager @@ -430,13 +474,22 @@ def instructions(self) -> str | None: async def list_tools(self) -> list[mcp_types.Tool]: """Retrieve tools that are currently active on the server. - Note: - - We don't cache tools as they might change. - - We also don't subscribe to the server to avoid complexity. + Tools are cached by default, with cache invalidation on: + - `notifications/tools/list_changed` notifications from the server + - Connection close (cache is cleared in `__aexit__`) + + Set `cache_tools=False` for servers that change tools without sending notifications. """ - async with self: # Ensure server is running - result = await self._client.list_tools() - return result.tools + async with self: + if self.cache_tools: + if self._cached_tools is not None: + return self._cached_tools + result = await self._client.list_tools() + self._cached_tools = result.tools + return result.tools + else: + result = await self._client.list_tools() + return result.tools async def direct_call_tool( self, @@ -542,21 +595,31 @@ def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[Any]: async def list_resources(self) -> list[Resource]: """Retrieve resources that are currently present on the server. - Note: - - We don't cache resources as they might change. - - We also don't subscribe to resource changes to avoid complexity. + Resources are cached by default, with cache invalidation on: + - `notifications/resources/list_changed` notifications from the server + - Connection close (cache is cleared in `__aexit__`) + + Set `cache_resources=False` for servers that change resources without sending notifications. Raises: MCPError: If the server returns an error. """ - async with self: # Ensure server is running + async with self: if not self.capabilities.resources: return [] try: - result = await self._client.list_resources() + if self.cache_resources: + if self._cached_resources is not None: + return self._cached_resources + result = await self._client.list_resources() + resources = [Resource.from_mcp_sdk(r) for r in result.resources] + self._cached_resources = resources + return resources + else: + result = await self._client.list_resources() + return [Resource.from_mcp_sdk(r) for r in result.resources] except mcp_exceptions.McpError as e: raise MCPError.from_mcp_sdk(e) from e - return [Resource.from_mcp_sdk(r) for r in result.resources] async def list_resource_templates(self) -> list[ResourceTemplate]: """Retrieve resource templates that are currently present on the server. @@ -628,6 +691,7 @@ async def __aenter__(self) -> Self: elicitation_callback=self.elicitation_callback, logging_callback=self.log_handler, read_timeout_seconds=timedelta(seconds=self.read_timeout), + message_handler=self._handle_notification, ) self._client = await exit_stack.enter_async_context(client) @@ -651,6 +715,8 @@ async def __aexit__(self, *args: Any) -> bool | None: if self._running_count == 0 and self._exit_stack is not None: await self._exit_stack.aclose() self._exit_stack = None + self._cached_tools = None + self._cached_resources = None @property def is_running(self) -> bool: @@ -680,6 +746,13 @@ async def _sampling_callback( model=self.sampling_model.model_name, ) + async def _handle_notification(self, message: Any) -> None: + """Handle notifications from the MCP server, invalidating caches as needed.""" + if isinstance(message, mcp_types.ToolListChangedNotification): + self._cached_tools = None + elif isinstance(message, mcp_types.ResourceListChangedNotification): + self._cached_resources = None + async def _map_tool_result_part( self, part: mcp_types.ContentBlock ) -> str | messages.BinaryContent | dict[str, Any] | list[Any]: @@ -776,6 +849,8 @@ class MCPServerStdio(MCPServer): sampling_model: models.Model | None max_retries: int elicitation_callback: ElicitationFnT | None = None + cache_tools: bool + cache_resources: bool def __init__( self, @@ -794,6 +869,8 @@ def __init__( sampling_model: models.Model | None = None, max_retries: int = 1, elicitation_callback: ElicitationFnT | None = None, + cache_tools: bool = True, + cache_resources: bool = True, id: str | None = None, ): """Build a new MCP server. @@ -813,6 +890,10 @@ def __init__( sampling_model: The model to use for sampling. max_retries: The maximum number of times to retry a tool call. elicitation_callback: Callback function to handle elicitation requests from the server. + cache_tools: Whether to cache the list of tools. + See [`MCPServer.cache_tools`][pydantic_ai.mcp.MCPServer.cache_tools]. + cache_resources: Whether to cache the list of resources. + See [`MCPServer.cache_resources`][pydantic_ai.mcp.MCPServer.cache_resources]. id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow. """ self.command = command @@ -831,6 +912,8 @@ def __init__( sampling_model, max_retries, elicitation_callback, + cache_tools, + cache_resources, id=id, ) @@ -930,6 +1013,8 @@ class _MCPServerHTTP(MCPServer): sampling_model: models.Model | None max_retries: int elicitation_callback: ElicitationFnT | None = None + cache_tools: bool + cache_resources: bool def __init__( self, @@ -948,6 +1033,8 @@ def __init__( sampling_model: models.Model | None = None, max_retries: int = 1, elicitation_callback: ElicitationFnT | None = None, + cache_tools: bool = True, + cache_resources: bool = True, **_deprecated_kwargs: Any, ): """Build a new MCP server. @@ -967,6 +1054,10 @@ def __init__( sampling_model: The model to use for sampling. max_retries: The maximum number of times to retry a tool call. elicitation_callback: Callback function to handle elicitation requests from the server. + cache_tools: Whether to cache the list of tools. + See [`MCPServer.cache_tools`][pydantic_ai.mcp.MCPServer.cache_tools]. + cache_resources: Whether to cache the list of resources. + See [`MCPServer.cache_resources`][pydantic_ai.mcp.MCPServer.cache_resources]. """ if 'sse_read_timeout' in _deprecated_kwargs: if read_timeout is not None: @@ -997,6 +1088,8 @@ def __init__( sampling_model, max_retries, elicitation_callback, + cache_tools, + cache_resources, id=id, ) diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 221ad37548..1130c54186 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -2007,3 +2007,138 @@ async def test_custom_http_client_not_closed(): assert len(tools) > 0 assert not custom_http_client.is_closed + + +# ============================================================================ +# Tool and Resource Caching Tests +# ============================================================================ + + +async def test_tools_caching_enabled_by_default() -> None: + """Test that list_tools() caches results by default.""" + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + async with server: + # First call - should fetch from server and cache + tools1 = await server.list_tools() + assert len(tools1) > 0 + assert server._cached_tools is not None # pyright: ignore[reportPrivateUsage] + + # Second call - should return cached value (cache is still populated) + tools2 = await server.list_tools() + assert tools2 == tools1 + assert server._cached_tools is not None # pyright: ignore[reportPrivateUsage] + + +async def test_tools_no_caching_when_disabled() -> None: + """Test that list_tools() does not cache when cache_tools=False.""" + server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], cache_tools=False) + async with server: + # First call - should not populate cache + tools1 = await server.list_tools() + assert len(tools1) > 0 + assert server._cached_tools is None # pyright: ignore[reportPrivateUsage] + + # Second call - cache should still be None + tools2 = await server.list_tools() + assert tools2 == tools1 + assert server._cached_tools is None # pyright: ignore[reportPrivateUsage] + + +async def test_tools_cache_invalidation_on_notification() -> None: + """Test that tools cache is invalidated when ToolListChangedNotification is received.""" + from mcp.types import ToolListChangedNotification + + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + async with server: + # Populate cache + await server.list_tools() + assert server._cached_tools is not None # pyright: ignore[reportPrivateUsage] + + # Simulate receiving a tool list changed notification + notification = ToolListChangedNotification() + await server._handle_notification(notification) # pyright: ignore[reportPrivateUsage] + + # Cache should be invalidated (set to None) + assert server._cached_tools is None # pyright: ignore[reportPrivateUsage] + + +async def test_resources_caching_enabled_by_default() -> None: + """Test that list_resources() caches results by default.""" + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + async with server: + assert server.capabilities.resources + + # First call - should fetch from server and cache + resources1 = await server.list_resources() + assert server._cached_resources is not None # pyright: ignore[reportPrivateUsage] + + # Second call - should return cached value (cache is still populated) + resources2 = await server.list_resources() + assert resources2 == resources1 + assert server._cached_resources is not None # pyright: ignore[reportPrivateUsage] + + +async def test_resources_no_caching_when_disabled() -> None: + """Test that list_resources() does not cache when cache_resources=False.""" + server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], cache_resources=False) + async with server: + assert server.capabilities.resources + + # First call - should not populate cache + resources1 = await server.list_resources() + assert server._cached_resources is None # pyright: ignore[reportPrivateUsage] + + # Second call - cache should still be None + resources2 = await server.list_resources() + assert resources2 == resources1 + assert server._cached_resources is None # pyright: ignore[reportPrivateUsage] + + +async def test_resources_cache_invalidation_on_notification() -> None: + """Test that resources cache is invalidated when ResourceListChangedNotification is received.""" + from mcp.types import ResourceListChangedNotification + + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + async with server: + assert server.capabilities.resources + + # Populate cache + await server.list_resources() + assert server._cached_resources is not None # pyright: ignore[reportPrivateUsage] + + # Simulate receiving a resource list changed notification + notification = ResourceListChangedNotification() + await server._handle_notification(notification) # pyright: ignore[reportPrivateUsage] + + # Cache should be invalidated + assert server._cached_resources is None # pyright: ignore[reportPrivateUsage] + + +async def test_cache_cleared_on_connection_close() -> None: + """Test that caches are cleared when the connection is closed.""" + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + + # First connection + async with server: + await server.list_tools() + assert server._cached_tools is not None # pyright: ignore[reportPrivateUsage] + + # After exiting, cache should be cleared by __aexit__ + assert server._cached_tools is None # pyright: ignore[reportPrivateUsage] + + # Reconnect and verify cache starts empty + async with server: + assert server._cached_tools is None # pyright: ignore[reportPrivateUsage] + # Fetch again to populate + await server.list_tools() + assert server._cached_tools is not None # pyright: ignore[reportPrivateUsage] + + +async def test_server_capabilities_list_changed_fields() -> None: + """Test that ServerCapabilities correctly parses listChanged fields.""" + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + async with server: + caps = server.capabilities + assert isinstance(caps.prompts_list_changed, bool) + assert isinstance(caps.tools_list_changed, bool) + assert isinstance(caps.resources_list_changed, bool)