Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 40 additions & 12 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import dataclasses
import hashlib
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
from contextlib import asynccontextmanager, contextmanager
from contextvars import ContextVar
Expand Down Expand Up @@ -92,6 +93,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):

function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
default_retries: int

tracer: Tracer

Expand Down Expand Up @@ -546,6 +548,13 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
)


def multi_modal_content_identifier(identifier: str | bytes) -> str:
"""Generate stable identifier for multi-modal content to help LLM in finding a specific file in tool call responses."""
if isinstance(identifier, str):
identifier = identifier.encode('utf-8')
return hashlib.sha1(identifier).hexdigest()[:6]


async def process_function_tools( # noqa C901
tool_calls: list[_messages.ToolCallPart],
output_tool_name: str | None,
Expand Down Expand Up @@ -648,8 +657,6 @@ async def process_function_tools( # noqa C901
for tool, call in calls_to_run
]

file_index = 1

pending = tasks
while pending:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
Expand All @@ -661,17 +668,38 @@ async def process_function_tools( # noqa C901
if isinstance(result, _messages.RetryPromptPart):
results_by_index[index] = result
elif isinstance(result, _messages.ToolReturnPart):
if isinstance(result.content, _messages.MultiModalContentTypes):
user_parts.append(
_messages.UserPromptPart(
content=[f'This is file {file_index}:', result.content],
timestamp=result.timestamp,
part_kind='user-prompt',
contents: list[Any]
single_content: bool
if isinstance(result.content, list):
contents = result.content # type: ignore
single_content = False
else:
contents = [result.content]
single_content = True

processed_contents: list[Any] = []
for content in contents:
if isinstance(content, _messages.MultiModalContentTypes):
if isinstance(content, _messages.BinaryContent):
identifier = multi_modal_content_identifier(content.data)
else:
identifier = multi_modal_content_identifier(content.url)

user_parts.append(
_messages.UserPromptPart(
content=[f'This is file {identifier}:', content],
timestamp=result.timestamp,
part_kind='user-prompt',
)
)
)
processed_contents.append(f'See file {identifier}')
else:
processed_contents.append(content)

result.content = f'See file {file_index}.'
file_index += 1
if single_content:
result.content = processed_contents[0]
else:
result.content = processed_contents

results_by_index[index] = result
else:
Expand Down Expand Up @@ -710,7 +738,7 @@ async def run_tool(ctx: RunContext[DepsT], **args: Any) -> Any:
for server in ctx.deps.mcp_servers:
tools = await server.list_tools()
if tool_name in {tool.name for tool in tools}:
return Tool(name=tool_name, function=run_tool, takes_ctx=True)
return Tool(name=tool_name, function=run_tool, takes_ctx=True, max_retries=ctx.deps.default_retries)
return None


Expand Down
1 change: 1 addition & 0 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
output_validators=output_validators,
function_tools=self._function_tools,
mcp_servers=self._mcp_servers,
default_retries=self._default_retries,
tracer=tracer,
get_instructions=get_instructions,
)
Expand Down
67 changes: 61 additions & 6 deletions pydantic_ai_slim/pydantic_ai/mcp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import base64
import json
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Sequence
from contextlib import AsyncExitStack, asynccontextmanager
Expand All @@ -9,16 +11,25 @@
from typing import Any

from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp.types import JSONRPCMessage, LoggingLevel
from typing_extensions import Self

from mcp.types import (
BlobResourceContents,
EmbeddedResource,
ImageContent,
JSONRPCMessage,
LoggingLevel,
TextContent,
TextResourceContents,
)
from typing_extensions import Self, assert_never

from pydantic_ai.exceptions import ModelRetry
from pydantic_ai.messages import BinaryContent
from pydantic_ai.tools import ToolDefinition

try:
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.types import CallToolResult
except ImportError as _import_error:
raise ImportError(
'Please install the `mcp` package to use the MCP server, '
Expand Down Expand Up @@ -74,7 +85,9 @@ async def list_tools(self) -> list[ToolDefinition]:
for tool in tools.tools
]

async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> CallToolResult:
async def call_tool(
self, tool_name: str, arguments: dict[str, Any]
) -> str | BinaryContent | dict[str, Any] | list[Any] | Sequence[str | BinaryContent | dict[str, Any] | list[Any]]:
"""Call a tool on the server.

Args:
Expand All @@ -83,8 +96,21 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> CallTool

Returns:
The result of the tool call.

Raises:
ModelRetry: If the tool call fails.
"""
return await self._client.call_tool(tool_name, arguments)
result = await self._client.call_tool(tool_name, arguments)

content = [self._map_tool_result_part(part) for part in result.content]

if result.isError:
text = '\n'.join(str(part) for part in content)
raise ModelRetry(text)

if len(content) == 1:
return content[0]
return content

async def __aenter__(self) -> Self:
self._exit_stack = AsyncExitStack()
Expand All @@ -105,6 +131,35 @@ async def __aexit__(
await self._exit_stack.aclose()
self.is_running = False

def _map_tool_result_part(
self, part: TextContent | ImageContent | EmbeddedResource
) -> str | BinaryContent | dict[str, Any] | list[Any]:
# See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values

if isinstance(part, TextContent):
text = part.text
if text.startswith(('[', '{')):
try:
return json.loads(text)
except ValueError:
pass
return text
elif isinstance(part, ImageContent):
return BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType)
elif isinstance(part, EmbeddedResource):
resource = part.resource
if isinstance(resource, TextResourceContents):
return resource.text
elif isinstance(resource, BlobResourceContents):
return BinaryContent(
data=base64.b64decode(resource.blob),
media_type=resource.mimeType or 'application/octet-stream',
)
else:
assert_never(resource)
else:
assert_never(part)


@dataclass
class MCPServerStdio(MCPServer):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ snap = ["create"]

[tool.codespell]
# Ref: https://github.com/codespell-project/codespell#using-a-config-file
skip = '.git*,*.svg,*.lock,*.css'
skip = '.git*,*.svg,*.lock,*.css,*.yaml'
check-hidden = true
# Ignore "formatting" like **L**anguage
ignore-regex = '\*\*[A-Z]\*\*[a-z]+\b'
Expand Down
Loading