Skip to content

Commit 444f5d0

Browse files
KludexdmontagusamuelcolvinViicos
authored
Add support for MCP servers (#1100)
Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com> Co-authored-by: Samuel Colvin <s@muelcolvin.com> Co-authored-by: Victorien <65306057+Viicos@users.noreply.github.com>
1 parent 1479995 commit 444f5d0

File tree

16 files changed

+827
-14
lines changed

16 files changed

+827
-14
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ jobs:
157157

158158
# this must run last as it modifies the environment!
159159
- name: test lowest versions
160+
if: matrix.python-version != '3.9'
160161
run: |
161162
unset UV_FROZEN
162163
uv run --all-extras --resolution lowest-direct coverage run -m pytest

docs/api/mcp.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
::: pydantic_ai.mcp

docs/mcp_servers.md

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# MCP Servers
2+
3+
**PydanticAI** supports integration with
4+
[MCP (Model Control Protocol) Servers](https://modelcontextprotocol.io/introduction),
5+
allowing you to extend agent capabilities through external services. This integration enables
6+
dynamic tool discovery.
7+
8+
## Install
9+
10+
To use MCP servers, you need to either install [`pydantic-ai`](install.md), or install
11+
[`pydantic-ai-slim`](install.md#slim-install) with the `mcp` optional group:
12+
13+
```bash
14+
pip/uv-add 'pydantic-ai-slim[mcp]'
15+
```
16+
17+
!!! note
18+
MCP integration requires Python 3.10 or higher.
19+
20+
## Running MCP Servers
21+
22+
Before diving into how to use MCP servers with PydanticAI, let's look at how to run MCP servers
23+
with different transports.
24+
25+
To run MCP servers, you'll need to install the MCP CLI package:
26+
27+
```bash
28+
pip/uv-add 'mcp[cli]'
29+
```
30+
31+
Here's a simple MCP server that provides a temperature conversion tool. We will later assume this is the server we connect to from our agent:
32+
33+
```python {title="temperature_mcp_server.py" py="3.10"}
34+
from mcp.server.fastmcp import FastMCP
35+
36+
mcp = FastMCP('Temperature Conversion Server')
37+
38+
39+
@mcp.tool()
40+
async def celsius_to_fahrenheit(celsius: float) -> float:
41+
"""Convert Celsius to Fahrenheit.
42+
43+
Args:
44+
celsius: Temperature in Celsius
45+
"""
46+
return (celsius * 9 / 5) + 32
47+
48+
49+
if __name__ == '__main__':
50+
mcp.run('stdio') # (1)!
51+
```
52+
53+
1. Run with stdio transport (for subprocess communication).
54+
55+
The same server can be run with [SSE transport](https://modelcontextprotocol.io/docs/concepts/transports#server-sent-events-sse)
56+
for HTTP communication:
57+
58+
```python {title="temperature_mcp_server_sse.py" py="3.10"}
59+
from temperature_mcp_server import mcp
60+
61+
if __name__ == '__main__':
62+
mcp.run('sse', port=8000) # (1)!
63+
```
64+
65+
1. Run with SSE transport on port 8000.
66+
67+
## Usage
68+
69+
PydanticAI comes with two ways to connect to MCP servers:
70+
71+
- [`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] which connects to an MCP server using the [HTTP SSE](https://modelcontextprotocol.io/docs/concepts/transports#server-sent-events-sse) transport
72+
- [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] which runs the server as a subprocess and connects to it using the [stdio](https://modelcontextprotocol.io/docs/concepts/transports#standard-input%2Foutput-stdio) transport
73+
74+
Examples of both are shown below.
75+
76+
### MCP Remote Server
77+
78+
You can have a MCP server running on a remote server. In this case, you'd use the
79+
[`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] class:
80+
81+
```python {title="mcp_remote_server.py" py="3.10"}
82+
from pydantic_ai import Agent
83+
from pydantic_ai.mcp import MCPServerSSE
84+
85+
server = MCPServerSSE(url='http://localhost:8005/sse')
86+
agent = Agent('openai:gpt-4o', mcp_servers=[server])
87+
88+
89+
async def main():
90+
async with agent.run_mcp_servers():
91+
result = await agent.run('Can you convert 30 degrees celsius to fahrenheit?')
92+
print(result.data)
93+
#> 30 degrees Celsius is equal to 86 degrees Fahrenheit.
94+
```
95+
96+
This will connect to the MCP server at the given URL and use the
97+
[SSE transport](https://modelcontextprotocol.io/docs/concepts/transports#server-sent-events-sse).
98+
99+
### MCP Subprocess Server
100+
101+
We also have a subprocess-based server that can be used to run the MCP server in a separate process.
102+
In this case, you'd use the [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] class,
103+
when using `MCPServerStdio` you need to run the server with the [`run_mcp_servers`][pydantic_ai.Agent.run_mcp_servers]
104+
context manager before running the server.
105+
106+
```python {title="mcp_subprocess_server.py" py="3.10"}
107+
from pydantic_ai.agent import Agent
108+
from pydantic_ai.mcp import MCPServerStdio
109+
110+
server = MCPServerStdio('python', ['-m', 'pydantic_ai_examples.mcp_server'])
111+
agent = Agent('openai:gpt-4o', mcp_servers=[server])
112+
113+
114+
async def main():
115+
async with agent.run_mcp_servers():
116+
result = await agent.run('Can you convert 30 degrees celsius to fahrenheit?')
117+
print(result.data)
118+
#> 30 degrees Celsius is equal to 86 degrees Fahrenheit.
119+
```
120+
121+
This will start the MCP server in a separate process and connect to it using the
122+
[stdio transport](https://modelcontextprotocol.io/docs/concepts/transports#standard-input%2Foutput-stdio).
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Simple MCP Server that can be used to test the MCP protocol.
2+
3+
Run with:
4+
5+
uv run -m pydantic_ai_examples.mcp_server --transport <TRANSPORT>
6+
7+
TRANSPORT can be either `sse` or `stdio`.
8+
"""
9+
10+
import argparse
11+
12+
from mcp.server.fastmcp import FastMCP
13+
14+
mcp = FastMCP('PydanticAI MCP Server', port=8005)
15+
16+
17+
@mcp.tool()
18+
async def celsius_to_fahrenheit(celsius: float) -> float:
19+
"""Convert Celsius to Fahrenheit.
20+
21+
Args:
22+
celsius: Temperature in Celsius
23+
24+
Returns:
25+
Temperature in Fahrenheit
26+
"""
27+
return (celsius * 9 / 5) + 32
28+
29+
30+
if __name__ == '__main__':
31+
parser = argparse.ArgumentParser()
32+
parser.add_argument(
33+
'--transport', type=str, default='stdio', choices=('sse', 'stdio')
34+
)
35+
args = parser.parse_args()
36+
37+
mcp.run(transport=args.transport)

examples/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ dependencies = [
4141
"uvicorn>=0.32.0",
4242
"devtools>=0.12.2",
4343
"gradio>=5.9.0; python_version>'3.9'",
44+
"mcp[cli]>=1.4.1; python_version >= '3.10'"
4445
]
4546

4647
[tool.hatch.build.targets.wheel]

mkdocs.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ nav:
2828
- multi-agent-applications.md
2929
- graph.md
3030
- input.md
31+
- mcp_servers.md
3132
- cli.md
3233
- Examples:
3334
- examples/index.md
@@ -64,6 +65,7 @@ nav:
6465
- api/models/function.md
6566
- api/models/fallback.md
6667
- api/providers.md
68+
- api/mcp.md
6769
- api/pydantic_graph/graph.md
6870
- api/pydantic_graph/nodes.md
6971
- api/pydantic_graph/persistence.md

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from contextlib import asynccontextmanager, contextmanager
88
from contextvars import ContextVar
99
from dataclasses import field
10-
from typing import Any, Generic, Literal, Union, cast
10+
from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast
1111

1212
from opentelemetry.trace import Span, Tracer
1313
from typing_extensions import TypeGuard, TypeVar, assert_never
@@ -27,11 +27,10 @@
2727
from .models.instrumented import InstrumentedModel
2828
from .result import ResultDataT
2929
from .settings import ModelSettings, merge_model_settings
30-
from .tools import (
31-
RunContext,
32-
Tool,
33-
ToolDefinition,
34-
)
30+
from .tools import RunContext, Tool, ToolDefinition
31+
32+
if TYPE_CHECKING:
33+
from .mcp import MCPServer
3534

3635
__all__ = (
3736
'GraphAgentState',
@@ -94,6 +93,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
9493
result_validators: list[_result.ResultValidator[DepsT, ResultDataT]]
9594

9695
function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
96+
mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
9797

9898
run_span: Span
9999
tracer: Tracer
@@ -219,7 +219,17 @@ async def add_tool(tool: Tool[DepsT]) -> None:
219219
if tool_def := await tool.prepare_tool_def(ctx):
220220
function_tool_defs.append(tool_def)
221221

222-
await asyncio.gather(*map(add_tool, ctx.deps.function_tools.values()))
222+
async def add_mcp_server_tools(server: MCPServer) -> None:
223+
if not server.is_running:
224+
raise exceptions.UserError(f'MCP server is not running: {server}')
225+
tool_defs = await server.list_tools()
226+
# TODO(Marcelo): We should check if the tool names are unique. If not, we should raise an error.
227+
function_tool_defs.extend(tool_defs)
228+
229+
await asyncio.gather(
230+
*map(add_tool, ctx.deps.function_tools.values()),
231+
*map(add_mcp_server_tools, ctx.deps.mcp_servers),
232+
)
223233

224234
result_schema = ctx.deps.result_schema
225235
return models.ModelRequestParameters(
@@ -594,6 +604,21 @@ async def process_function_tools(
594604
yield event
595605
call_index_to_event_id[len(calls_to_run)] = event.call_id
596606
calls_to_run.append((tool, call))
607+
elif mcp_tool := await _tool_from_mcp_server(call.tool_name, ctx):
608+
if stub_function_tools:
609+
# TODO(Marcelo): We should add coverage for this part of the code.
610+
output_parts.append( # pragma: no cover
611+
_messages.ToolReturnPart(
612+
tool_name=call.tool_name,
613+
content='Tool not executed - a final result was already processed.',
614+
tool_call_id=call.tool_call_id,
615+
)
616+
)
617+
else:
618+
event = _messages.FunctionToolCallEvent(call)
619+
yield event
620+
call_index_to_event_id[len(calls_to_run)] = event.call_id
621+
calls_to_run.append((mcp_tool, call))
597622
elif result_schema is not None and call.tool_name in result_schema.tools:
598623
# if tool_name is in _result_schema, it means we found a result tool but an error occurred in
599624
# validation, we don't add another part here
@@ -641,6 +666,35 @@ async def process_function_tools(
641666
output_parts.append(results_by_index[k])
642667

643668

669+
async def _tool_from_mcp_server(
670+
tool_name: str,
671+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
672+
) -> Tool[DepsT] | None:
673+
"""Call each MCP server to find the tool with the given name.
674+
675+
Args:
676+
tool_name: The name of the tool to find.
677+
ctx: The current run context.
678+
679+
Returns:
680+
The tool with the given name, or `None` if no tool with the given name is found.
681+
"""
682+
683+
async def run_tool(ctx: RunContext[DepsT], **args: Any) -> Any:
684+
# There's no normal situation where the server will not be running at this point, we check just in case
685+
# some weird edge case occurs.
686+
if not server.is_running: # pragma: no cover
687+
raise exceptions.UserError(f'MCP server is not running: {server}')
688+
result = await server.call_tool(tool_name, args)
689+
return result
690+
691+
for server in ctx.deps.mcp_servers:
692+
tools = await server.list_tools()
693+
if tool_name in {tool.name for tool in tools}:
694+
return Tool(name=tool_name, function=run_tool, takes_ctx=True)
695+
return None
696+
697+
644698
def _unknown_tool(
645699
tool_name: str,
646700
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import dataclasses
44
import inspect
55
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
6-
from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
6+
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
77
from copy import deepcopy
88
from types import FrameType
9-
from typing import Any, Callable, ClassVar, Generic, cast, final, overload
9+
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload
1010

1111
from opentelemetry.trace import NoOpTracer, use_span
1212
from typing_extensions import TypeGuard, TypeVar, deprecated
@@ -47,6 +47,9 @@
4747
ModelRequestNode = _agent_graph.ModelRequestNode
4848
UserPromptNode = _agent_graph.UserPromptNode
4949

50+
if TYPE_CHECKING:
51+
from pydantic_ai.mcp import MCPServer
52+
5053
__all__ = (
5154
'Agent',
5255
'AgentRun',
@@ -129,6 +132,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
129132
repr=False
130133
)
131134
_function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False)
135+
_mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
132136
_default_retries: int = dataclasses.field(repr=False)
133137
_max_result_retries: int = dataclasses.field(repr=False)
134138
_override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False)
@@ -148,6 +152,7 @@ def __init__(
148152
result_tool_description: str | None = None,
149153
result_retries: int | None = None,
150154
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
155+
mcp_servers: Sequence[MCPServer] = (),
151156
defer_model_check: bool = False,
152157
end_strategy: EndStrategy = 'early',
153158
instrument: InstrumentationSettings | bool | None = None,
@@ -173,6 +178,8 @@ def __init__(
173178
result_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
174179
tools: Tools to register with the agent, you can also register tools via the decorators
175180
[`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
181+
mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer]
182+
for each server you want the agent to connect to.
176183
defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model,
177184
it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
178185
which checks for the necessary environment variables. Set this to `false`
@@ -215,6 +222,7 @@ def __init__(
215222

216223
self._default_retries = retries
217224
self._max_result_retries = result_retries if result_retries is not None else retries
225+
self._mcp_servers = mcp_servers
218226
for tool in tools:
219227
if isinstance(tool, Tool):
220228
self._register_tool(tool)
@@ -461,6 +469,7 @@ async def main():
461469
result_tools=self._result_schema.tool_defs() if self._result_schema else [],
462470
result_validators=result_validators,
463471
function_tools=self._function_tools,
472+
mcp_servers=self._mcp_servers,
464473
run_span=run_span,
465474
tracer=tracer,
466475
)
@@ -1253,6 +1262,20 @@ def is_end_node(
12531262
"""
12541263
return isinstance(node, End)
12551264

1265+
@asynccontextmanager
1266+
async def run_mcp_servers(self) -> AsyncIterator[None]:
1267+
"""Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent.
1268+
1269+
Returns: a context manager to start and shutdown the servers.
1270+
"""
1271+
exit_stack = AsyncExitStack()
1272+
try:
1273+
for mcp_server in self._mcp_servers:
1274+
await exit_stack.enter_async_context(mcp_server)
1275+
yield
1276+
finally:
1277+
await exit_stack.aclose()
1278+
12561279

12571280
@dataclasses.dataclass(repr=False)
12581281
class AgentRun(Generic[AgentDepsT, ResultDataT]):

0 commit comments

Comments
 (0)