Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@


class AnthropicModelSettings(ModelSettings):
"""Settings used for an Anthropic model request."""
"""Settings used for an Anthropic model request.

ALL FIELDS MUST BE `anthropic_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
"""

anthropic_metadata: MetadataParam
"""An object describing metadata about the request.
Expand Down
7 changes: 7 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@
T = typing.TypeVar('T')


class BedrockModelSettings(ModelSettings):
"""Settings for Bedrock models.

ALL FIELDS MUST BE `bedrock_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
"""


@dataclass(init=False)
class BedrockConverseModel(Model):
"""A model that uses the Bedrock Converse API."""
Expand Down
5 changes: 4 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@


class CohereModelSettings(ModelSettings):
"""Settings used for a Cohere model request."""
"""Settings used for a Cohere model request.

ALL FIELDS MUST BE `cohere_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
"""

# This class is a placeholder for any future cohere-specific settings

Expand Down
5 changes: 4 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@


class GeminiModelSettings(ModelSettings):
"""Settings used for a Gemini model request."""
"""Settings used for a Gemini model request.

ALL FIELDS MUST BE `gemini_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
"""

gemini_safety_settings: list[GeminiSafetySettings]

Expand Down
5 changes: 4 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@


class GroqModelSettings(ModelSettings):
"""Settings used for a Groq model request."""
"""Settings used for a Groq model request.

ALL FIELDS MUST BE `groq_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
"""

# This class is a placeholder for any future groq-specific settings

Expand Down
5 changes: 4 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@


class MistralModelSettings(ModelSettings):
"""Settings used for a Mistral model request."""
"""Settings used for a Mistral model request.

ALL FIELDS MUST BE `mistral_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
"""

# This class is a placeholder for any future mistral-specific settings

Expand Down
9 changes: 6 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@


class OpenAIModelSettings(ModelSettings, total=False):
"""Settings used for an OpenAI model request."""
"""Settings used for an OpenAI model request.

ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
"""

openai_reasoning_effort: chat.ChatCompletionReasoningEffort
"""
Expand All @@ -83,7 +86,7 @@ class OpenAIModelSettings(ModelSettings, total=False):
result in faster responses and fewer tokens used on reasoning in a response.
"""

user: str
openai_user: str
"""A unique identifier representing the end-user, which can help OpenAI monitor and detect abuse.

See [OpenAI's safety best practices](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids) for more details.
Expand Down Expand Up @@ -229,7 +232,7 @@ async def _completions_create(
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
user=model_settings.get('user', NOT_GIVEN),
user=model_settings.get('openai_user', NOT_GIVEN),
)
except APIStatusError as e:
if (status_code := e.status_code) >= 400:
Expand Down
8 changes: 4 additions & 4 deletions tests/models/cassettes/test_openai/test_user_id.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ interactions:
openai-organization:
- pydantic-28gund
openai-processing-ms:
- '235'
- '584'
openai-version:
- '2020-10-01'
strict-transport-security:
Expand All @@ -55,12 +55,12 @@ interactions:
content: Hello! How can I assist you today?
refusal: null
role: assistant
created: 1742906323
id: chatcmpl-BExqF10T3Bt833bEyp7OAqjVznPeS
created: 1743073438
id: chatcmpl-BFfJeRdAVFPUVWxV3OYH1tSR5KvrI
model: gpt-4o-2024-08-06
object: chat.completion
service_tier: default
system_fingerprint: fp_90d33c15d4
system_fingerprint: fp_898ac29719
usage:
completion_tokens: 10
completion_tokens_details:
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,5 +707,5 @@ async def test_user_id(allow_model_requests: None, openai_api_key: str):
# This test doesn't do anything, it's just here to ensure that calls with `user` don't cause errors, including type.
# Since we use VCR, creating tests with an `httpx.Transport` is not possible.
m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key))
agent = Agent(m, model_settings=OpenAIModelSettings(user='user_id'))
agent = Agent(m, model_settings=OpenAIModelSettings(openai_user='user_id'))
await agent.run('hello')
26 changes: 26 additions & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import importlib

import pytest

from pydantic_ai.settings import ModelSettings


@pytest.fixture(params=['openai_', 'anthropic_', 'bedrock_', 'groq_', 'gemini_', 'mistral_', 'cohere_'])
def settings(request: pytest.FixtureRequest) -> tuple[type[ModelSettings], str]:
prefix_cls_name = request.param.replace('_', '')
try:
module = importlib.import_module(f'pydantic_ai.models.{prefix_cls_name}')
except ImportError:
pytest.skip(f'{prefix_cls_name} is not installed')
capitalized_prefix = prefix_cls_name.capitalize().replace('Openai', 'OpenAI')
cls = getattr(module, capitalized_prefix + 'ModelSettings')
return cls, request.param


def test_specific_prefix_settings(settings: tuple[type[ModelSettings], str]):
settings_cls, prefix = settings
global_settings = set(ModelSettings.__annotations__.keys())
specific_settings = set(settings_cls.__annotations__.keys()) - global_settings
assert all(setting.startswith(prefix) for setting in specific_settings), (
f'{prefix} is not a prefix for {specific_settings}'
)