Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- run_stream
- model
- override_deps
- override_model
- last_run_messages
- system_prompt
- retriever_plain
Expand Down
10 changes: 10 additions & 0 deletions docs/api/models/base.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
# `pydantic_ai.models`

::: pydantic_ai.models
options:
members:
- Model
- AgentModel
- AbstractToolDefinition
- StreamTextResponse
- StreamStructuredResponse
- ALLOW_MODEL_REQUESTS
- check_allow_model_requests
- override_allow_model_requests
67 changes: 55 additions & 12 deletions pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,16 @@
import logfire_api
from typing_extensions import assert_never

from . import _result, _retriever as _r, _system_prompt, _utils, exceptions, messages as _messages, models, result
from . import (
_result,
_retriever as _r,
_system_prompt,
_utils,
exceptions,
messages as _messages,
models,
result,
)
from .dependencies import AgentDeps, RetrieverContextFunc, RetrieverParams, RetrieverPlainFunc
from .result import ResultData

Expand All @@ -23,6 +32,7 @@
'openai:gpt-3.5-turbo',
'gemini-1.5-flash',
'gemini-1.5-pro',
'test',
]
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].

Expand All @@ -40,7 +50,7 @@ class Agent(Generic[AgentDeps, ResultData]):
"""Class for defining "agents" - a way to have a specific type of "conversation" with an LLM."""

# dataclass fields mostly for my sanity — knowing what attributes are available
model: models.Model | None
model: models.Model | KnownModelName | None
"""The default model configured for this agent."""
_result_schema: _result.ResultSchema[ResultData] | None
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]]
Expand All @@ -52,7 +62,8 @@ class Agent(Generic[AgentDeps, ResultData]):
_deps_type: type[AgentDeps]
_max_result_retries: int
_current_result_retry: int
_override_deps_stack: list[AgentDeps]
_override_deps: _utils.Option[AgentDeps] = None
_override_model: _utils.Option[models.Model] = None
last_run_messages: list[_messages.Message] | None = None
"""The messages from the last run, useful when a run raised an exception.

Expand All @@ -70,6 +81,7 @@ def __init__(
result_tool_name: str = 'final_result',
result_tool_description: str | None = None,
result_retries: int | None = None,
defer_model_check: bool = False,
):
"""Create an agent.

Expand All @@ -87,8 +99,16 @@ def __init__(
result_tool_name: The name of the tool to use for the final result.
result_tool_description: The description of the final result tool.
result_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
defer_model_check: by default, if you provide a [named][pydantic_ai.agent.KnownModelName] model,
it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
which checks for the necessary environment variables. Set this to `false`
to defer the evaluation until the first run. Useful if you want to
[override the model][pydantic_ai.Agent.override_model] for testing.
"""
self.model = models.infer_model(model) if model is not None else None
if model is None or defer_model_check:
self.model = model
else:
self.model = models.infer_model(model)

self._result_schema = _result.ResultSchema[result_type].build(
result_type, result_tool_name, result_tool_description
Expand All @@ -104,7 +124,6 @@ def __init__(
self._max_result_retries = result_retries if result_retries is not None else retries
self._current_result_retry = 0
self._result_validators = []
self._override_deps_stack = []

async def run(
self,
Expand Down Expand Up @@ -281,11 +300,26 @@ def override_deps(self, overriding_deps: AgentDeps) -> Iterator[None]:
Args:
overriding_deps: The dependencies to use instead of the dependencies passed to the agent run.
"""
self._override_deps_stack.append(overriding_deps)
override_deps_before = self._override_deps
self._override_deps = _utils.Some(overriding_deps)
try:
yield
finally:
self._override_deps_stack.pop()
self._override_deps = override_deps_before

@contextmanager
def override_model(self, overriding_model: models.Model | KnownModelName) -> Iterator[None]:
"""Context manager to temporarily override the model used by the agent.

Args:
overriding_model: The model to use instead of the model passed to the agent run.
"""
override_model_before = self._override_model
self._override_model = _utils.Some(models.infer_model(overriding_model))
try:
yield
finally:
self._override_model = override_model_before

def system_prompt(
self, func: _system_prompt.SystemPromptFunc[AgentDeps]
Expand Down Expand Up @@ -386,11 +420,20 @@ async def _get_agent_model(
a tuple of `(model used, custom_model if any, agent_model)`
"""
model_: models.Model
if model is not None:
if some_model := self._override_model:
# we don't want `override_model()` to cover up errors from the model not being defined, hence this check
if model is None and self.model is None:
raise exceptions.UserError(
'`model` must be set either when creating the agent or when calling it. '
'(Even when `override_model()` is customizing the model that will actually be called)'
)
model_ = some_model.value
custom_model = None
elif model is not None:
custom_model = model_ = models.infer_model(model)
elif self.model is not None:
# noinspection PyTypeChecker
model_ = self.model
model_ = self.model = models.infer_model(self.model)
custom_model = None
else:
raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
Expand Down Expand Up @@ -573,9 +616,9 @@ def _get_deps(self, deps: AgentDeps) -> AgentDeps:

We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope.
"""
try:
return self._override_deps_stack[-1]
except IndexError:
if some_deps := self._override_deps:
return some_deps.value
else:
return deps


Expand Down
48 changes: 46 additions & 2 deletions pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from __future__ import annotations as _annotations

from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterable, Mapping, Sequence
from contextlib import asynccontextmanager
from collections.abc import AsyncIterator, Iterable, Iterator, Mapping, Sequence
from contextlib import asynccontextmanager, contextmanager
from datetime import datetime
from functools import cache
from typing import TYPE_CHECKING, Protocol, Union
Expand Down Expand Up @@ -151,10 +151,54 @@ def timestamp(self) -> datetime:
EitherStreamedResponse = Union[StreamTextResponse, StreamStructuredResponse]


ALLOW_MODEL_REQUESTS = False
"""Whether to allow requests to models.

This global setting allows you to disable request to most models, e.g. to make sure you don't accidentally
make costly requests to a model during tests.

The testing models [`TestModel`][pydantic_ai.models.test.TestModel] and
[`FunctionModel`][pydantic_ai.models.function.FunctionModel] are no affected by this setting.
"""


def check_allow_model_requests() -> None:
"""Check if model requests are allowed.

If you're defining your own models that have cost or latency associated with their use, you should call this in
[`Model.agent_model`][pydantic_ai.models.Model.agent_model].

Raises:
RuntimeError: If model requests are not allowed.
"""
if not ALLOW_MODEL_REQUESTS:
raise RuntimeError('Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False')


@contextmanager
def override_allow_model_requests(allow_model_requests: bool) -> Iterator[None]:
"""Context manager to temporarily override [`ALLOW_MODEL_REQUESTS`][pydantic_ai.models.ALLOW_MODEL_REQUESTS].

Args:
allow_model_requests: Whether to allow model requests within the context.
"""
global ALLOW_MODEL_REQUESTS
old_value = ALLOW_MODEL_REQUESTS
ALLOW_MODEL_REQUESTS = allow_model_requests # pyright: ignore[reportConstantRedefinition]
try:
yield
finally:
ALLOW_MODEL_REQUESTS = old_value # pyright: ignore[reportConstantRedefinition]


def infer_model(model: Model | KnownModelName) -> Model:
"""Infer the model from the name."""
if isinstance(model, Model):
return model
elif model == 'test':
from .test import TestModel

return TestModel()
elif model.startswith('openai:'):
from .openai import OpenAIModel

Expand Down
2 changes: 2 additions & 0 deletions pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
StreamStructuredResponse,
StreamTextResponse,
cached_async_http_client,
check_allow_model_requests,
)

GeminiModelName = Literal['gemini-1.5-flash', 'gemini-1.5-flash-8b', 'gemini-1.5-pro', 'gemini-1.0-pro']
Expand Down Expand Up @@ -113,6 +114,7 @@ def agent_model(
allow_text_result: bool,
result_tools: Sequence[AbstractToolDefinition] | None,
) -> GeminiAgentModel:
check_allow_model_requests()
tools = [_function_from_abstract_tool(t) for t in retrievers.values()]
if result_tools is not None:
tools += [_function_from_abstract_tool(t) for t in result_tools]
Expand Down
2 changes: 2 additions & 0 deletions pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
StreamStructuredResponse,
StreamTextResponse,
cached_async_http_client,
check_allow_model_requests,
)


Expand Down Expand Up @@ -85,6 +86,7 @@ def agent_model(
allow_text_result: bool,
result_tools: Sequence[AbstractToolDefinition] | None,
) -> AgentModel:
check_allow_model_requests()
tools = [self._map_tool_definition(r) for r in retrievers.values()]
if result_tools is not None:
tools += [self._map_tool_definition(r) for r in result_tools]
Expand Down
34 changes: 24 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import secrets
import sys
from collections.abc import Iterator
from datetime import datetime
from pathlib import Path
from types import ModuleType
Expand All @@ -15,8 +16,13 @@
from _pytest.assertion.rewrite import AssertionRewritingHook
from typing_extensions import TypeAlias

import pydantic_ai.models

__all__ = 'IsNow', 'TestEnv', 'ClientWithHandler'


pydantic_ai.models.ALLOW_MODEL_REQUESTS = False

if TYPE_CHECKING:

def IsNow(*args: Any, **kwargs: Any) -> datetime: ...
Expand All @@ -38,35 +44,43 @@ class TestEnv:
__test__ = False

def __init__(self):
self.envars: set[str] = set()
self.envars: dict[str, str | None] = {}

def set(self, name: str, value: str) -> None:
self.envars.add(name)
self.envars[name] = os.getenv(name)
os.environ[name] = value

def pop(self, name: str) -> None: # pragma: no cover
self.envars.remove(name)
os.environ.pop(name)
def remove(self, name: str) -> None:
self.envars[name] = os.environ.pop(name, None)

def clear(self) -> None:
for n in self.envars:
os.environ.pop(n)
def reset(self) -> None:
for name, value in self.envars.items():
if value is None:
os.environ.pop(name, None)
else:
os.environ[name] = value


@pytest.fixture
def env():
def env() -> Iterator[TestEnv]:
test_env = TestEnv()

yield test_env

test_env.clear()
test_env.reset()


@pytest.fixture
def anyio_backend():
return 'asyncio'


@pytest.fixture
def allow_model_requests():
with pydantic_ai.models.override_allow_model_requests(True):
yield


@pytest.fixture
async def client_with_handler():
client: httpx.AsyncClient | None = None
Expand Down
Loading
Loading