Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 108 additions & 15 deletions pydantic_ai_slim/pydantic_ai/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,21 @@ class ServerCapabilities:
prompts: bool = False
"""Whether the server offers any prompt templates."""

prompts_list_changed: bool = False
"""Whether the server will emit notifications when the list of prompts changes."""

resources: bool = False
"""Whether the server offers any resources to read."""

resources_list_changed: bool = False
"""Whether the server will emit notifications when the list of resources changes."""

tools: bool = False
"""Whether the server offers any tools to call."""

tools_list_changed: bool = False
"""Whether the server will emit notifications when the list of tools changes."""

completions: bool = False
"""Whether the server offers autocompletion suggestions for prompts and resources."""

Expand All @@ -244,12 +253,18 @@ def from_mcp_sdk(cls, mcp_capabilities: mcp_types.ServerCapabilities) -> ServerC
Args:
mcp_capabilities: The MCP SDK ServerCapabilities object.
"""
prompts_cap = mcp_capabilities.prompts
resources_cap = mcp_capabilities.resources
tools_cap = mcp_capabilities.tools
return cls(
experimental=list(mcp_capabilities.experimental.keys()) if mcp_capabilities.experimental else None,
logging=mcp_capabilities.logging is not None,
prompts=mcp_capabilities.prompts is not None,
resources=mcp_capabilities.resources is not None,
tools=mcp_capabilities.tools is not None,
prompts=prompts_cap is not None,
prompts_list_changed=bool(prompts_cap.listChanged) if prompts_cap else False,
resources=resources_cap is not None,
resources_list_changed=bool(resources_cap.listChanged) if resources_cap else False,
tools=tools_cap is not None,
tools_list_changed=bool(tools_cap.listChanged) if tools_cap else False,
completions=mcp_capabilities.completions is not None,
)

Expand Down Expand Up @@ -319,6 +334,26 @@ class MCPServer(AbstractToolset[Any], ABC):
elicitation_callback: ElicitationFnT | None = None
"""Callback function to handle elicitation requests from the server."""

cache_tools: bool
"""Whether to cache the list of tools.

When enabled (default), tools are fetched once and cached until either:
- The server sends a `notifications/tools/list_changed` notification
- The connection is closed

Set to `False` for servers that change tools dynamically without sending notifications.
"""

cache_resources: bool
"""Whether to cache the list of resources.

When enabled (default), resources are fetched once and cached until either:
- The server sends a `notifications/resources/list_changed` notification
- The connection is closed

Set to `False` for servers that change resources dynamically without sending notifications.
"""

_id: str | None

_enter_lock: Lock = field(compare=False)
Expand All @@ -332,6 +367,9 @@ class MCPServer(AbstractToolset[Any], ABC):
_server_capabilities: ServerCapabilities
_instructions: str | None

_cached_tools: list[mcp_types.Tool] | None
_cached_resources: list[Resource] | None

def __init__(
self,
tool_prefix: str | None = None,
Expand All @@ -344,6 +382,8 @@ def __init__(
sampling_model: models.Model | None = None,
max_retries: int = 1,
elicitation_callback: ElicitationFnT | None = None,
cache_tools: bool = True,
cache_resources: bool = True,
*,
id: str | None = None,
):
Expand All @@ -357,6 +397,8 @@ def __init__(
self.sampling_model = sampling_model
self.max_retries = max_retries
self.elicitation_callback = elicitation_callback
self.cache_tools = cache_tools
self.cache_resources = cache_resources

self._id = id or tool_prefix

Expand All @@ -366,6 +408,8 @@ def __post_init__(self):
self._enter_lock = Lock()
self._running_count = 0
self._exit_stack = None
self._cached_tools = None
self._cached_resources = None

@abstractmethod
@asynccontextmanager
Expand Down Expand Up @@ -430,13 +474,22 @@ def instructions(self) -> str | None:
async def list_tools(self) -> list[mcp_types.Tool]:
"""Retrieve tools that are currently active on the server.

Note:
- We don't cache tools as they might change.
- We also don't subscribe to the server to avoid complexity.
Tools are cached by default, with cache invalidation on:
- `notifications/tools/list_changed` notifications from the server
- Connection close (cache is cleared in `__aexit__`)

Set `cache_tools=False` for servers that change tools without sending notifications.
"""
async with self: # Ensure server is running
result = await self._client.list_tools()
return result.tools
async with self:
if self.cache_tools:
if self._cached_tools is not None:
return self._cached_tools
result = await self._client.list_tools()
self._cached_tools = result.tools
return result.tools
else:
result = await self._client.list_tools()
return result.tools

async def direct_call_tool(
self,
Expand Down Expand Up @@ -542,21 +595,31 @@ def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[Any]:
async def list_resources(self) -> list[Resource]:
"""Retrieve resources that are currently present on the server.

Note:
- We don't cache resources as they might change.
- We also don't subscribe to resource changes to avoid complexity.
Resources are cached by default, with cache invalidation on:
- `notifications/resources/list_changed` notifications from the server
- Connection close (cache is cleared in `__aexit__`)

Set `cache_resources=False` for servers that change resources without sending notifications.

Raises:
MCPError: If the server returns an error.
"""
async with self: # Ensure server is running
async with self:
if not self.capabilities.resources:
return []
try:
result = await self._client.list_resources()
if self.cache_resources:
if self._cached_resources is not None:
return self._cached_resources
result = await self._client.list_resources()
resources = [Resource.from_mcp_sdk(r) for r in result.resources]
self._cached_resources = resources
return resources
else:
result = await self._client.list_resources()
return [Resource.from_mcp_sdk(r) for r in result.resources]
except mcp_exceptions.McpError as e:
raise MCPError.from_mcp_sdk(e) from e
return [Resource.from_mcp_sdk(r) for r in result.resources]

async def list_resource_templates(self) -> list[ResourceTemplate]:
"""Retrieve resource templates that are currently present on the server.
Expand Down Expand Up @@ -628,6 +691,7 @@ async def __aenter__(self) -> Self:
elicitation_callback=self.elicitation_callback,
logging_callback=self.log_handler,
read_timeout_seconds=timedelta(seconds=self.read_timeout),
message_handler=self._handle_notification,
)
self._client = await exit_stack.enter_async_context(client)

Expand All @@ -651,6 +715,8 @@ async def __aexit__(self, *args: Any) -> bool | None:
if self._running_count == 0 and self._exit_stack is not None:
await self._exit_stack.aclose()
self._exit_stack = None
self._cached_tools = None
self._cached_resources = None

@property
def is_running(self) -> bool:
Expand Down Expand Up @@ -680,6 +746,13 @@ async def _sampling_callback(
model=self.sampling_model.model_name,
)

async def _handle_notification(self, message: Any) -> None:
"""Handle notifications from the MCP server, invalidating caches as needed."""
if isinstance(message, mcp_types.ToolListChangedNotification):
self._cached_tools = None
elif isinstance(message, mcp_types.ResourceListChangedNotification):
self._cached_resources = None

async def _map_tool_result_part(
self, part: mcp_types.ContentBlock
) -> str | messages.BinaryContent | dict[str, Any] | list[Any]:
Expand Down Expand Up @@ -776,6 +849,8 @@ class MCPServerStdio(MCPServer):
sampling_model: models.Model | None
max_retries: int
elicitation_callback: ElicitationFnT | None = None
cache_tools: bool
cache_resources: bool

def __init__(
self,
Expand All @@ -794,6 +869,8 @@ def __init__(
sampling_model: models.Model | None = None,
max_retries: int = 1,
elicitation_callback: ElicitationFnT | None = None,
cache_tools: bool = True,
cache_resources: bool = True,
id: str | None = None,
):
"""Build a new MCP server.
Expand All @@ -813,6 +890,10 @@ def __init__(
sampling_model: The model to use for sampling.
max_retries: The maximum number of times to retry a tool call.
elicitation_callback: Callback function to handle elicitation requests from the server.
cache_tools: Whether to cache the list of tools.
See [`MCPServer.cache_tools`][pydantic_ai.mcp.MCPServer.cache_tools].
cache_resources: Whether to cache the list of resources.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's link to the full description as it contains very crucial information about when to use this and its behavior. You can use the [...][pydantic_ai.mcp.MCPServer.cache_tools] format

See [`MCPServer.cache_resources`][pydantic_ai.mcp.MCPServer.cache_resources].
id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow.
"""
self.command = command
Expand All @@ -831,6 +912,8 @@ def __init__(
sampling_model,
max_retries,
elicitation_callback,
cache_tools,
cache_resources,
id=id,
)

Expand Down Expand Up @@ -930,6 +1013,8 @@ class _MCPServerHTTP(MCPServer):
sampling_model: models.Model | None
max_retries: int
elicitation_callback: ElicitationFnT | None = None
cache_tools: bool
cache_resources: bool

def __init__(
self,
Expand All @@ -948,6 +1033,8 @@ def __init__(
sampling_model: models.Model | None = None,
max_retries: int = 1,
elicitation_callback: ElicitationFnT | None = None,
cache_tools: bool = True,
cache_resources: bool = True,
**_deprecated_kwargs: Any,
):
"""Build a new MCP server.
Expand All @@ -967,6 +1054,10 @@ def __init__(
sampling_model: The model to use for sampling.
max_retries: The maximum number of times to retry a tool call.
elicitation_callback: Callback function to handle elicitation requests from the server.
cache_tools: Whether to cache the list of tools.
See [`MCPServer.cache_tools`][pydantic_ai.mcp.MCPServer.cache_tools].
cache_resources: Whether to cache the list of resources.
See [`MCPServer.cache_resources`][pydantic_ai.mcp.MCPServer.cache_resources].
"""
if 'sse_read_timeout' in _deprecated_kwargs:
if read_timeout is not None:
Expand Down Expand Up @@ -997,6 +1088,8 @@ def __init__(
sampling_model,
max_retries,
elicitation_callback,
cache_tools,
cache_resources,
id=id,
)

Expand Down
Loading