Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions pydantic_ai_slim/pydantic_ai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ async def main():
"""
if infer_name and self.name is None:
self._infer_name(inspect.currentframe())

model_used = self._get_model(model)
del model

Expand Down Expand Up @@ -607,16 +608,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
else:
instrumentation_settings = None
tracer = NoOpTracer()
if builtin_tools:
# Deduplicate builtin tools passed to the agent and the run based on type
builtin_tools = list(
{
**({type(tool): tool for tool in self._builtin_tools or []}),
**({type(tool): tool for tool in builtin_tools}),
}.values()
)
else:
builtin_tools = list(self._builtin_tools)

graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
user_deps=deps,
prompt=user_prompt,
Expand All @@ -629,7 +621,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
output_schema=output_schema,
output_validators=output_validators,
history_processors=self.history_processors,
builtin_tools=builtin_tools,
builtin_tools=[*self._builtin_tools, *(builtin_tools or [])],
tool_manager=tool_manager,
tracer=tracer,
get_instructions=get_instructions,
Expand Down
14 changes: 13 additions & 1 deletion pydantic_ai_slim/pydantic_ai/builtin_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ class AbstractBuiltinTool(ABC):
kind: str = 'unknown_builtin_tool'
"""Built-in tool identifier, this should be available on all built-in tools as a discriminator."""

@property
def unique_id(self) -> str:
"""A unique identifier for the builtin tool.

If multiple instances of the same builtin tool can be passed to the model, subclasses should override this property to allow them to be distinguished.
"""
return self.kind

def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
_BUILTIN_TOOL_TYPES[cls.kind] = cls
Expand Down Expand Up @@ -275,7 +283,7 @@ class MCPServerTool(AbstractBuiltinTool):
"""

id: str
"""The ID of the MCP server."""
"""A unique identifier for the MCP server."""

url: str
"""The URL of the MCP server to use.
Expand Down Expand Up @@ -321,6 +329,10 @@ class MCPServerTool(AbstractBuiltinTool):

kind: str = 'mcp_server'

@property
def unique_id(self) -> str:
return ':'.join([self.kind, self.id])


def _tool_discriminator(tool_data: dict[str, Any] | AbstractBuiltinTool) -> str:
if isinstance(tool_data, dict):
Expand Down
14 changes: 11 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,17 @@ def prepare_request(
they need to customize the preparation flow further, but most implementations should simply call
``self.prepare_request(...)`` at the start of their ``request`` (and related) methods.
"""
merged_settings = merge_model_settings(self.settings, model_settings)
customized_parameters = self.customize_request_parameters(model_request_parameters)
return merged_settings, customized_parameters
model_settings = merge_model_settings(self.settings, model_settings)

if builtin_tools := model_request_parameters.builtin_tools:
# Deduplicate builtin tools
model_request_parameters = replace(
model_request_parameters,
builtin_tools=list({tool.unique_id: tool for tool in builtin_tools}.values()),
)

model_request_parameters = self.customize_request_parameters(model_request_parameters)
return model_settings, model_request_parameters

@property
@abstractmethod
Expand Down
36 changes: 24 additions & 12 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,14 @@ async def _messages_create(
def _process_response(self, response: BetaMessage) -> ModelResponse:
"""Process a non-streamed response, and prepare a message to return."""
items: list[ModelResponsePart] = []
builtin_tool_calls: dict[str, BuiltinToolCallPart] = {}
for item in response.content:
if isinstance(item, BetaTextBlock):
items.append(TextPart(content=item.text))
elif isinstance(item, BetaServerToolUseBlock):
items.append(_map_server_tool_use_block(item, self.system))
call_part = _map_server_tool_use_block(item, self.system)
builtin_tool_calls[call_part.tool_call_id] = call_part
items.append(call_part)
elif isinstance(item, BetaWebSearchToolResultBlock):
items.append(_map_web_search_tool_result_block(item, self.system))
elif isinstance(item, BetaCodeExecutionToolResultBlock):
Expand All @@ -340,9 +343,12 @@ def _process_response(self, response: BetaMessage) -> ModelResponse:
elif isinstance(item, BetaThinkingBlock):
items.append(ThinkingPart(content=item.thinking, signature=item.signature, provider_name=self.system))
elif isinstance(item, BetaMCPToolUseBlock):
items.append(_map_mcp_server_use_block(item, self.system))
call_part = _map_mcp_server_use_block(item, self.system)
builtin_tool_calls[call_part.tool_call_id] = call_part
items.append(call_part)
elif isinstance(item, BetaMCPToolResultBlock):
items.append(_map_mcp_server_result_block(item, self.system))
call_part = builtin_tool_calls.get(item.tool_use_id)
items.append(_map_mcp_server_result_block(item, call_part, self.system))
else:
assert isinstance(item, BetaToolUseBlock), f'unexpected item type {type(item)}'
items.append(
Expand Down Expand Up @@ -545,9 +551,9 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Be
)
assistant_content_params.append(server_tool_use_block_param)
elif (
response_part.tool_name == MCPServerTool.kind
response_part.tool_name.startswith(MCPServerTool.kind)
and (server_id := response_part.tool_name.split(':', 1)[1])
and (args := response_part.args_as_dict())
and (server_id := args.get('server_id'))
and (tool_name := args.get('tool_name'))
and (tool_args := args.get('tool_args'))
): # pragma: no branch
Expand Down Expand Up @@ -590,7 +596,7 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Be
),
)
)
elif response_part.tool_name == MCPServerTool.kind and isinstance(
elif response_part.tool_name.startswith(MCPServerTool.kind) and isinstance(
response_part.content, dict
): # pragma: no branch
assistant_content_params.append(
Expand Down Expand Up @@ -714,6 +720,7 @@ class AnthropicStreamedResponse(StreamedResponse):
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
current_block: BetaContentBlock | None = None

builtin_tool_calls: dict[str, BuiltinToolCallPart] = {}
async for event in self._response:
if isinstance(event, BetaRawMessageStartEvent):
self._usage = _map_usage(event, self._provider_name, self._provider_url, self._model_name)
Expand Down Expand Up @@ -751,9 +758,11 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
if maybe_event is not None: # pragma: no branch
yield maybe_event
elif isinstance(current_block, BetaServerToolUseBlock):
call_part = _map_server_tool_use_block(current_block, self.provider_name)
builtin_tool_calls[call_part.tool_call_id] = call_part
yield self._parts_manager.handle_part(
vendor_part_id=event.index,
part=_map_server_tool_use_block(current_block, self.provider_name),
part=call_part,
)
elif isinstance(current_block, BetaWebSearchToolResultBlock):
yield self._parts_manager.handle_part(
Expand All @@ -767,6 +776,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
)
elif isinstance(current_block, BetaMCPToolUseBlock):
call_part = _map_mcp_server_use_block(current_block, self.provider_name)
builtin_tool_calls[call_part.tool_call_id] = call_part

args_json = call_part.args_as_json_str()
# Drop the final `{}}` so that we can add tool args deltas
Expand All @@ -785,9 +795,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
if maybe_event is not None: # pragma: no branch
yield maybe_event
elif isinstance(current_block, BetaMCPToolResultBlock):
call_part = builtin_tool_calls.get(current_block.tool_use_id)
yield self._parts_manager.handle_part(
vendor_part_id=event.index,
part=_map_mcp_server_result_block(current_block, self.provider_name),
part=_map_mcp_server_result_block(current_block, call_part, self.provider_name),
)

elif isinstance(event, BetaRawContentBlockDeltaEvent):
Expand Down Expand Up @@ -908,21 +919,22 @@ def _map_code_execution_tool_result_block(
def _map_mcp_server_use_block(item: BetaMCPToolUseBlock, provider_name: str) -> BuiltinToolCallPart:
return BuiltinToolCallPart(
provider_name=provider_name,
tool_name=MCPServerTool.kind,
tool_name=':'.join([MCPServerTool.kind, item.server_name]),
args={
'action': 'call_tool',
'server_id': item.server_name,
'tool_name': item.name,
'tool_args': cast(dict[str, Any], item.input),
},
tool_call_id=item.id,
)


def _map_mcp_server_result_block(item: BetaMCPToolResultBlock, provider_name: str) -> BuiltinToolReturnPart:
def _map_mcp_server_result_block(
item: BetaMCPToolResultBlock, call_part: BuiltinToolCallPart | None, provider_name: str
) -> BuiltinToolReturnPart:
return BuiltinToolReturnPart(
provider_name=provider_name,
tool_name=MCPServerTool.kind,
tool_name=call_part.tool_name if call_part else MCPServerTool.kind,
content=item.model_dump(mode='json', include={'content', 'is_error'}),
tool_call_id=item.tool_use_id,
)
19 changes: 10 additions & 9 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,11 +1474,11 @@ async def _map_messages( # noqa: C901
)
openai_messages.append(image_generation_item)
elif ( # pragma: no branch
item.tool_name == MCPServerTool.kind
item.tool_name.startswith(MCPServerTool.kind)
and item.tool_call_id
and (server_id := item.tool_name.split(':', 1)[1])
and (args := item.args_as_dict())
and (action := args.get('action'))
and (server_id := args.get('server_id'))
):
if action == 'list_tools':
mcp_list_tools_item = responses.response_input_item_param.McpListTools(
Expand Down Expand Up @@ -1525,7 +1525,7 @@ async def _map_messages( # noqa: C901
elif item.tool_name == ImageGenerationTool.kind:
# Image generation result does not need to be sent back, just the `id` off of `BuiltinToolCallPart`.
pass
elif item.tool_name == MCPServerTool.kind: # pragma: no branch
elif item.tool_name.startswith(MCPServerTool.kind): # pragma: no branch
# MCP call result does not need to be sent back, just the fields off of `BuiltinToolCallPart`.
pass
elif isinstance(item, FilePart):
Expand Down Expand Up @@ -2257,15 +2257,16 @@ def _map_image_generation_tool_call(
def _map_mcp_list_tools(
item: responses.response_output_item.McpListTools, provider_name: str
) -> tuple[BuiltinToolCallPart, BuiltinToolReturnPart]:
tool_name = ':'.join([MCPServerTool.kind, item.server_label])
return (
BuiltinToolCallPart(
tool_name=MCPServerTool.kind,
tool_name=tool_name,
tool_call_id=item.id,
provider_name=provider_name,
args={'action': 'list_tools', 'server_id': item.server_label},
args={'action': 'list_tools'},
),
BuiltinToolReturnPart(
tool_name=MCPServerTool.kind,
tool_name=tool_name,
tool_call_id=item.id,
content=item.model_dump(mode='json', include={'tools', 'error'}),
provider_name=provider_name,
Expand All @@ -2276,20 +2277,20 @@ def _map_mcp_list_tools(
def _map_mcp_call(
item: responses.response_output_item.McpCall, provider_name: str
) -> tuple[BuiltinToolCallPart, BuiltinToolReturnPart]:
tool_name = ':'.join([MCPServerTool.kind, item.server_label])
return (
BuiltinToolCallPart(
tool_name=MCPServerTool.kind,
tool_name=tool_name,
tool_call_id=item.id,
args={
'action': 'call_tool',
'server_id': item.server_label,
'tool_name': item.name,
'tool_args': json.loads(item.arguments) if item.arguments else {},
},
provider_name=provider_name,
),
BuiltinToolReturnPart(
tool_name=MCPServerTool.kind,
tool_name=tool_name,
tool_call_id=item.id,
content={
'output': item.output,
Expand Down
22 changes: 10 additions & 12 deletions tests/models/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3098,10 +3098,9 @@ async def test_anthropic_mcp_servers(allow_model_requests: None, anthropic_api_k
provider_name='anthropic',
),
BuiltinToolCallPart(
tool_name='mcp_server',
tool_name='mcp_server:deepwiki',
args={
'action': 'call_tool',
'server_id': 'deepwiki',
'tool_name': 'ask_question',
'tool_args': {
'repoName': 'pydantic/pydantic-ai',
Expand All @@ -3112,7 +3111,7 @@ async def test_anthropic_mcp_servers(allow_model_requests: None, anthropic_api_k
provider_name='anthropic',
),
BuiltinToolReturnPart(
tool_name='mcp_server',
tool_name='mcp_server:deepwiki',
content={
'content': [
{
Expand Down Expand Up @@ -3181,10 +3180,9 @@ async def test_anthropic_mcp_servers(allow_model_requests: None, anthropic_api_k
provider_name='anthropic',
),
BuiltinToolCallPart(
tool_name='mcp_server',
tool_name='mcp_server:deepwiki',
args={
'action': 'call_tool',
'server_id': 'deepwiki',
'tool_name': 'ask_question',
'tool_args': {
'repoName': 'pydantic/pydantic',
Expand All @@ -3195,7 +3193,7 @@ async def test_anthropic_mcp_servers(allow_model_requests: None, anthropic_api_k
provider_name='anthropic',
),
BuiltinToolReturnPart(
tool_name='mcp_server',
tool_name='mcp_server:deepwiki',
content={
'content': [
{
Expand Down Expand Up @@ -3345,13 +3343,13 @@ async def test_anthropic_mcp_servers_stream(allow_model_requests: None, anthropi
provider_name='anthropic',
),
BuiltinToolCallPart(
tool_name='mcp_server',
args='{"action":"call_tool","server_id":"deepwiki","tool_name":"ask_question","tool_args":{"repoName": "pydantic/pydantic-ai", "question": "What is this repository about? What are its main features and purpose?"}}',
tool_name='mcp_server:deepwiki',
args='{"action":"call_tool","tool_name":"ask_question","tool_args":{"repoName": "pydantic/pydantic-ai", "question": "What is this repository about? What are its main features and purpose?"}}',
tool_call_id='mcptoolu_01FZmJ5UspaX5BB9uU339UT1',
provider_name='anthropic',
),
BuiltinToolReturnPart(
tool_name='mcp_server',
tool_name='mcp_server:deepwiki',
content={
'content': [
{
Expand Down Expand Up @@ -3407,15 +3405,15 @@ async def test_anthropic_mcp_servers_stream(allow_model_requests: None, anthropi
PartStartEvent(
index=1,
part=BuiltinToolCallPart(
tool_name='mcp_server',
tool_name='mcp_server:deepwiki',
tool_call_id='mcptoolu_01FZmJ5UspaX5BB9uU339UT1',
provider_name='anthropic',
),
),
PartDeltaEvent(
index=1,
delta=ToolCallPartDelta(
args_delta='{"action":"call_tool","server_id":"deepwiki","tool_name":"ask_question","tool_args":',
args_delta='{"action":"call_tool","tool_name":"ask_question","tool_args":',
tool_call_id='mcptoolu_01FZmJ5UspaX5BB9uU339UT1',
),
),
Expand Down Expand Up @@ -3489,7 +3487,7 @@ async def test_anthropic_mcp_servers_stream(allow_model_requests: None, anthropi
PartStartEvent(
index=2,
part=BuiltinToolReturnPart(
tool_name='mcp_server',
tool_name='mcp_server:deepwiki',
content={
'content': [
{
Expand Down
Loading