Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ print(result.output)
4. `result.output` will be a boolean indicating if the square is a winner. Pydantic performs the output validation, and it'll be typed as a `bool` since its type is derived from the `output_type` generic parameter of the agent.

!!! tip "Agents are designed for reuse, like FastAPI Apps"
Agents are intended to be instantiated once (frequently as module globals) and reused throughout your application, similar to a small [FastAPI][fastapi.FastAPI] app or an [APIRouter][fastapi.APIRouter].
Agents can be instantiated once as a module global and reused throughout your application, similar to a small [FastAPI][fastapi.FastAPI] app or an [APIRouter][fastapi.APIRouter], or be created dynamically by a factory function like `get_agent('agent-type')`, whichever you prefer.

## Running Agents

Expand Down
82 changes: 82 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,88 @@ def prompted_output_instructions(self) -> str | None:
__repr__ = _utils.dataclasses_no_defaults_repr


@dataclass
class ResolvedToolChoice:
"""Provider-agnostic resolved tool choice.

This is the result of validating and resolving the user's `tool_choice` setting.
Providers should map this to their API-specific format.
"""

mode: Literal['none', 'auto', 'required', 'specific']
"""The resolved tool choice mode."""

tool_names: list[str] | None = None
"""For 'specific' mode, the list of tool names to force."""

output_tools_fallback: bool = False
"""True if we need to fall back to output tools only (when 'none' was requested but output tools exist)."""


_TOOL_CHOICE_NONE_WITH_OUTPUT_TOOLS_WARNING = (
"tool_choice='none' is set but output tools are required for structured output. "
'The output tools will remain available. Consider using native or prompted output modes '
"if you need tool_choice='none' with structured output."
)


def resolve_tool_choice(
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
*,
stacklevel: int = 6,
) -> ResolvedToolChoice | None:
"""Resolve and validate tool_choice from model settings.

This centralizes the common logic for handling tool_choice across all providers:
- Validates tool names in list[str] against available function_tools
- Issues warnings for conflicting settings (tool_choice='none' with output tools)
- Returns a provider-agnostic ResolvedToolChoice for the provider to map to their API format

Args:
model_settings: The model settings containing tool_choice.
model_request_parameters: The request parameters containing tool definitions.
stacklevel: The stack level for warnings (default 6 works for most provider call stacks).

Returns:
ResolvedToolChoice if an explicit tool_choice was provided and validated,
None if tool_choice was not set (provider should use default behavior based on allow_text_output).

Raises:
UserError: If tool names in list[str] are invalid.
"""
user_tool_choice = (model_settings or {}).get('tool_choice')

if user_tool_choice is None:
return None

if user_tool_choice == 'none':
if model_request_parameters.output_tools:
warnings.warn(_TOOL_CHOICE_NONE_WITH_OUTPUT_TOOLS_WARNING, UserWarning, stacklevel=stacklevel)
return ResolvedToolChoice(mode='none', output_tools_fallback=True)
return ResolvedToolChoice(mode='none')

if user_tool_choice == 'auto':
return ResolvedToolChoice(mode='auto')

if user_tool_choice == 'required':
return ResolvedToolChoice(mode='required')

if isinstance(user_tool_choice, list):
if not user_tool_choice:
raise UserError('tool_choice cannot be an empty list. Use None for default behavior.')
function_tool_names = {t.name for t in model_request_parameters.function_tools}
invalid_names = set(user_tool_choice) - function_tool_names
if invalid_names:
raise UserError(
f'Invalid tool names in tool_choice: {invalid_names}. '
f'Available function tools: {function_tool_names or "none"}'
)
return ResolvedToolChoice(mode='specific', tool_names=list(user_tool_choice))

return None # pragma: no cover


class Model(ABC):
"""Abstract class for a model."""

Expand Down
85 changes: 77 additions & 8 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations as _annotations

import io
import warnings
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field, replace
Expand Down Expand Up @@ -42,7 +43,15 @@
from ..providers.anthropic import AsyncAnthropicClient
from ..settings import ModelSettings, merge_model_settings
from ..tools import ToolDefinition
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
from . import (
Model,
ModelRequestParameters,
StreamedResponse,
check_allow_model_requests,
download_item,
get_user_agent,
resolve_tool_choice,
)

_FINISH_REASON_MAP: dict[BetaStopReason, FinishReason] = {
'end_turn': 'stop',
Expand Down Expand Up @@ -640,18 +649,77 @@ def _infer_tool_choice(
) -> BetaToolChoiceParam | None:
if not tools:
return None
else:
tool_choice: BetaToolChoiceParam

thinking_enabled = model_settings.get('anthropic_thinking') is not None
tool_choice: BetaToolChoiceParam

resolved = resolve_tool_choice(model_settings, model_request_parameters)

if resolved is not None:
if resolved.mode == 'none':
if resolved.output_tools_fallback:
output_tool_names = [t.name for t in model_request_parameters.output_tools]
if len(output_tool_names) == 1:
tool_choice = {'type': 'tool', 'name': output_tool_names[0]}
else:
warnings.warn(
'Anthropic only supports forcing a single tool. '
"Falling back to 'auto' for multiple output tools.",
UserWarning,
stacklevel=6,
)
tool_choice = {'type': 'auto'}
else:
tool_choice = {'type': 'none'}

elif resolved.mode == 'auto':
tool_choice = {'type': 'auto'}

elif resolved.mode == 'required':
if thinking_enabled:
warnings.warn(
"tool_choice='required' is not supported with Anthropic thinking mode. Falling back to 'auto'.",
UserWarning,
stacklevel=6,
)
tool_choice = {'type': 'auto'}
else:
tool_choice = {'type': 'any'}

elif resolved.mode == 'specific':
assert resolved.tool_names # Guaranteed non-empty by resolve_tool_choice()
if thinking_enabled:
warnings.warn(
"Forcing specific tools is not supported with Anthropic thinking mode. Falling back to 'auto'.",
UserWarning,
stacklevel=6,
)
tool_choice = {'type': 'auto'}
elif len(resolved.tool_names) == 1:
tool_choice = {'type': 'tool', 'name': resolved.tool_names[0]}
else:
warnings.warn(
'Anthropic only supports forcing a single tool. '
"Falling back to 'any' (required) for multiple tools.",
UserWarning,
stacklevel=6,
)
tool_choice = {'type': 'any'}
else:
assert_never(resolved.mode)

else:
# Default behavior: infer from allow_text_output
if not model_request_parameters.allow_text_output:
tool_choice = {'type': 'any'}
else:
tool_choice = {'type': 'auto'}

if 'parallel_tool_calls' in model_settings:
tool_choice['disable_parallel_tool_use'] = not model_settings['parallel_tool_calls']
if 'parallel_tool_calls' in model_settings and tool_choice.get('type') != 'none':
# only `BetaToolChoiceNoneParam` doesn't have this field
tool_choice['disable_parallel_tool_use'] = not model_settings['parallel_tool_calls'] # pyright: ignore[reportGeneralTypeIssues]

return tool_choice
return tool_choice

async def _map_message( # noqa: C901
self,
Expand Down Expand Up @@ -856,9 +924,10 @@ async def _map_message( # noqa: C901
system_prompt_parts.insert(0, instructions)
system_prompt = '\n\n'.join(system_prompt_parts)

ttl: Literal['5m', '1h']
# Add cache_control to the last message content if anthropic_cache_messages is enabled
if anthropic_messages and (cache_messages := model_settings.get('anthropic_cache_messages')):
ttl: Literal['5m', '1h'] = '5m' if cache_messages is True else cache_messages
ttl = '5m' if cache_messages is True else cache_messages
m = anthropic_messages[-1]
content = m['content']
if isinstance(content, str):
Expand All @@ -878,7 +947,7 @@ async def _map_message( # noqa: C901
# If anthropic_cache_instructions is enabled, return system prompt as a list with cache_control
if system_prompt and (cache_instructions := model_settings.get('anthropic_cache_instructions')):
# If True, use '5m'; otherwise use the specified ttl value
ttl: Literal['5m', '1h'] = '5m' if cache_instructions is True else cache_instructions
ttl = '5m' if cache_instructions is True else cache_instructions
system_prompt_blocks = [
BetaTextBlockParam(
type='text',
Expand Down
50 changes: 44 additions & 6 deletions pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import functools
import typing
import warnings
from collections.abc import AsyncIterator, Iterable, Iterator, Mapping
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
Expand Down Expand Up @@ -41,7 +42,7 @@
)
from pydantic_ai._run_context import RunContext
from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, UserError
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, download_item
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, download_item, resolve_tool_choice
from pydantic_ai.providers import Provider, infer_provider
from pydantic_ai.providers.bedrock import BedrockModelProfile
from pydantic_ai.settings import ModelSettings
Expand Down Expand Up @@ -429,7 +430,7 @@ async def _messages_create(
'inferenceConfig': inference_config,
}

tool_config = self._map_tool_config(model_request_parameters)
tool_config = self._map_tool_config(model_request_parameters, model_settings)
if tool_config:
params['toolConfig'] = tool_config

Expand Down Expand Up @@ -485,16 +486,53 @@ def _map_inference_config(

return inference_config

def _map_tool_config(self, model_request_parameters: ModelRequestParameters) -> ToolConfigurationTypeDef | None:
def _map_tool_config(
self,
model_request_parameters: ModelRequestParameters,
model_settings: BedrockModelSettings | None,
) -> ToolConfigurationTypeDef | None:
tools = self._get_tools(model_request_parameters)
if not tools:
return None

resolved = resolve_tool_choice(model_settings, model_request_parameters)
tool_choice: ToolChoiceTypeDef
if not model_request_parameters.allow_text_output:
tool_choice = {'any': {}}

if resolved is not None:
if resolved.mode == 'none':
warnings.warn(
"Bedrock does not support tool_choice='none'. Falling back to 'auto'.",
UserWarning,
stacklevel=6,
)
tool_choice = {'auto': {}}

elif resolved.mode == 'auto':
tool_choice = {'auto': {}}

elif resolved.mode == 'required':
tool_choice = {'any': {}}

elif resolved.mode == 'specific':
assert resolved.tool_names # Guaranteed non-empty by resolve_tool_choice()
if len(resolved.tool_names) == 1:
tool_choice = {'tool': {'name': resolved.tool_names[0]}}
else:
warnings.warn(
'Bedrock only supports forcing a single tool. '
"Falling back to 'any' (required) for multiple tools.",
UserWarning,
stacklevel=6,
)
tool_choice = {'any': {}}
else:
assert_never(resolved.mode)
else:
tool_choice = {'auto': {}}
# Default behavior: infer from allow_text_output
if not model_request_parameters.allow_text_output:
tool_choice = {'any': {}}
else:
tool_choice = {'auto': {}}

tool_config: ToolConfigurationTypeDef = {'tools': tools}
if tool_choice and BedrockModelProfile.from_profile(self.profile).bedrock_supports_tool_choice:
Expand Down
61 changes: 55 additions & 6 deletions pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
check_allow_model_requests,
download_item,
get_user_agent,
resolve_tool_choice,
)

try:
Expand Down Expand Up @@ -364,17 +365,65 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[T
return tools or None

def _get_tool_config(
self, model_request_parameters: ModelRequestParameters, tools: list[ToolDict] | None
self,
model_request_parameters: ModelRequestParameters,
tools: list[ToolDict] | None,
model_settings: GoogleModelSettings,
) -> ToolConfigDict | None:
if not model_request_parameters.allow_text_output and tools:
names: list[str] = []
if not tools:
return None

resolved = resolve_tool_choice(model_settings, model_request_parameters)

if resolved is not None:
if resolved.mode == 'none':
if resolved.output_tools_fallback:
output_tool_names = [t.name for t in model_request_parameters.output_tools]
return ToolConfigDict(
function_calling_config=FunctionCallingConfigDict(
mode=FunctionCallingConfigMode.ANY,
allowed_function_names=output_tool_names,
)
)
return ToolConfigDict(
function_calling_config=FunctionCallingConfigDict(mode=FunctionCallingConfigMode.NONE)
)

if resolved.mode == 'auto':
return ToolConfigDict(
function_calling_config=FunctionCallingConfigDict(mode=FunctionCallingConfigMode.AUTO)
)

if resolved.mode == 'required':
names: list[str] = []
for tool in tools:
for function_declaration in tool.get('function_declarations') or []:
if name := function_declaration.get('name'): # pragma: no branch
names.append(name)
return ToolConfigDict(
function_calling_config=FunctionCallingConfigDict(
mode=FunctionCallingConfigMode.ANY,
allowed_function_names=names,
)
)

if resolved.mode == 'specific' and resolved.tool_names: # pragma: no branch
return ToolConfigDict(
function_calling_config=FunctionCallingConfigDict(
mode=FunctionCallingConfigMode.ANY,
allowed_function_names=resolved.tool_names,
)
)

# Default behavior: infer from allow_text_output
if not model_request_parameters.allow_text_output:
names = []
for tool in tools:
for function_declaration in tool.get('function_declarations') or []:
if name := function_declaration.get('name'): # pragma: no branch
names.append(name)
return _tool_config(names)
else:
return None
return None

@overload
async def _generate_content(
Expand Down Expand Up @@ -440,7 +489,7 @@ async def _build_content_and_config(
raise UserError('JSON output is not supported by this model.')
response_mime_type = 'application/json'

tool_config = self._get_tool_config(model_request_parameters, tools)
tool_config = self._get_tool_config(model_request_parameters, tools, model_settings)
system_instruction, contents = await self._map_messages(messages, model_request_parameters)

modalities = [Modality.TEXT.value]
Expand Down
Loading