diff --git a/docs/api/agent.md b/docs/api/agent.md index 2152da3808..2dd569c96f 100644 --- a/docs/api/agent.md +++ b/docs/api/agent.md @@ -9,6 +9,7 @@ - run_stream - model - override_deps + - override_model - last_run_messages - system_prompt - retriever_plain diff --git a/docs/api/models/base.md b/docs/api/models/base.md index a788db1cd9..48672830f6 100644 --- a/docs/api/models/base.md +++ b/docs/api/models/base.md @@ -1,3 +1,13 @@ # `pydantic_ai.models` ::: pydantic_ai.models + options: + members: + - Model + - AgentModel + - AbstractToolDefinition + - StreamTextResponse + - StreamStructuredResponse + - ALLOW_MODEL_REQUESTS + - check_allow_model_requests + - override_allow_model_requests diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index 93fe5fd170..b5792ba258 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -9,7 +9,16 @@ import logfire_api from typing_extensions import assert_never -from . import _result, _retriever as _r, _system_prompt, _utils, exceptions, messages as _messages, models, result +from . import ( + _result, + _retriever as _r, + _system_prompt, + _utils, + exceptions, + messages as _messages, + models, + result, +) from .dependencies import AgentDeps, RetrieverContextFunc, RetrieverParams, RetrieverPlainFunc from .result import ResultData @@ -23,6 +32,7 @@ 'openai:gpt-3.5-turbo', 'gemini-1.5-flash', 'gemini-1.5-pro', + 'test', ] """Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent]. @@ -40,7 +50,7 @@ class Agent(Generic[AgentDeps, ResultData]): """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.""" # dataclass fields mostly for my sanity — knowing what attributes are available - model: models.Model | None + model: models.Model | KnownModelName | None """The default model configured for this agent.""" _result_schema: _result.ResultSchema[ResultData] | None _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] @@ -52,7 +62,8 @@ class Agent(Generic[AgentDeps, ResultData]): _deps_type: type[AgentDeps] _max_result_retries: int _current_result_retry: int - _override_deps_stack: list[AgentDeps] + _override_deps: _utils.Option[AgentDeps] = None + _override_model: _utils.Option[models.Model] = None last_run_messages: list[_messages.Message] | None = None """The messages from the last run, useful when a run raised an exception. @@ -70,6 +81,7 @@ def __init__( result_tool_name: str = 'final_result', result_tool_description: str | None = None, result_retries: int | None = None, + defer_model_check: bool = False, ): """Create an agent. @@ -87,8 +99,16 @@ def __init__( result_tool_name: The name of the tool to use for the final result. result_tool_description: The description of the final result tool. result_retries: The maximum number of retries to allow for result validation, defaults to `retries`. + defer_model_check: by default, if you provide a [named][pydantic_ai.agent.KnownModelName] model, + it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately, + which checks for the necessary environment variables. Set this to `false` + to defer the evaluation until the first run. Useful if you want to + [override the model][pydantic_ai.Agent.override_model] for testing. """ - self.model = models.infer_model(model) if model is not None else None + if model is None or defer_model_check: + self.model = model + else: + self.model = models.infer_model(model) self._result_schema = _result.ResultSchema[result_type].build( result_type, result_tool_name, result_tool_description @@ -104,7 +124,6 @@ def __init__( self._max_result_retries = result_retries if result_retries is not None else retries self._current_result_retry = 0 self._result_validators = [] - self._override_deps_stack = [] async def run( self, @@ -281,11 +300,26 @@ def override_deps(self, overriding_deps: AgentDeps) -> Iterator[None]: Args: overriding_deps: The dependencies to use instead of the dependencies passed to the agent run. """ - self._override_deps_stack.append(overriding_deps) + override_deps_before = self._override_deps + self._override_deps = _utils.Some(overriding_deps) try: yield finally: - self._override_deps_stack.pop() + self._override_deps = override_deps_before + + @contextmanager + def override_model(self, overriding_model: models.Model | KnownModelName) -> Iterator[None]: + """Context manager to temporarily override the model used by the agent. + + Args: + overriding_model: The model to use instead of the model passed to the agent run. + """ + override_model_before = self._override_model + self._override_model = _utils.Some(models.infer_model(overriding_model)) + try: + yield + finally: + self._override_model = override_model_before def system_prompt( self, func: _system_prompt.SystemPromptFunc[AgentDeps] @@ -386,11 +420,20 @@ async def _get_agent_model( a tuple of `(model used, custom_model if any, agent_model)` """ model_: models.Model - if model is not None: + if some_model := self._override_model: + # we don't want `override_model()` to cover up errors from the model not being defined, hence this check + if model is None and self.model is None: + raise exceptions.UserError( + '`model` must be set either when creating the agent or when calling it. ' + '(Even when `override_model()` is customizing the model that will actually be called)' + ) + model_ = some_model.value + custom_model = None + elif model is not None: custom_model = model_ = models.infer_model(model) elif self.model is not None: # noinspection PyTypeChecker - model_ = self.model + model_ = self.model = models.infer_model(self.model) custom_model = None else: raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.') @@ -573,9 +616,9 @@ def _get_deps(self, deps: AgentDeps) -> AgentDeps: We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope. """ - try: - return self._override_deps_stack[-1] - except IndexError: + if some_deps := self._override_deps: + return some_deps.value + else: return deps diff --git a/pydantic_ai/models/__init__.py b/pydantic_ai/models/__init__.py index eca7bf84a3..b35452a46a 100644 --- a/pydantic_ai/models/__init__.py +++ b/pydantic_ai/models/__init__.py @@ -7,8 +7,8 @@ from __future__ import annotations as _annotations from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Iterable, Mapping, Sequence -from contextlib import asynccontextmanager +from collections.abc import AsyncIterator, Iterable, Iterator, Mapping, Sequence +from contextlib import asynccontextmanager, contextmanager from datetime import datetime from functools import cache from typing import TYPE_CHECKING, Protocol, Union @@ -151,10 +151,54 @@ def timestamp(self) -> datetime: EitherStreamedResponse = Union[StreamTextResponse, StreamStructuredResponse] +ALLOW_MODEL_REQUESTS = False +"""Whether to allow requests to models. + +This global setting allows you to disable request to most models, e.g. to make sure you don't accidentally +make costly requests to a model during tests. + +The testing models [`TestModel`][pydantic_ai.models.test.TestModel] and +[`FunctionModel`][pydantic_ai.models.function.FunctionModel] are no affected by this setting. +""" + + +def check_allow_model_requests() -> None: + """Check if model requests are allowed. + + If you're defining your own models that have cost or latency associated with their use, you should call this in + [`Model.agent_model`][pydantic_ai.models.Model.agent_model]. + + Raises: + RuntimeError: If model requests are not allowed. + """ + if not ALLOW_MODEL_REQUESTS: + raise RuntimeError('Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False') + + +@contextmanager +def override_allow_model_requests(allow_model_requests: bool) -> Iterator[None]: + """Context manager to temporarily override [`ALLOW_MODEL_REQUESTS`][pydantic_ai.models.ALLOW_MODEL_REQUESTS]. + + Args: + allow_model_requests: Whether to allow model requests within the context. + """ + global ALLOW_MODEL_REQUESTS + old_value = ALLOW_MODEL_REQUESTS + ALLOW_MODEL_REQUESTS = allow_model_requests # pyright: ignore[reportConstantRedefinition] + try: + yield + finally: + ALLOW_MODEL_REQUESTS = old_value # pyright: ignore[reportConstantRedefinition] + + def infer_model(model: Model | KnownModelName) -> Model: """Infer the model from the name.""" if isinstance(model, Model): return model + elif model == 'test': + from .test import TestModel + + return TestModel() elif model.startswith('openai:'): from .openai import OpenAIModel diff --git a/pydantic_ai/models/gemini.py b/pydantic_ai/models/gemini.py index 9d19b4a136..d944f8ee05 100644 --- a/pydantic_ai/models/gemini.py +++ b/pydantic_ai/models/gemini.py @@ -55,6 +55,7 @@ StreamStructuredResponse, StreamTextResponse, cached_async_http_client, + check_allow_model_requests, ) GeminiModelName = Literal['gemini-1.5-flash', 'gemini-1.5-flash-8b', 'gemini-1.5-pro', 'gemini-1.0-pro'] @@ -113,6 +114,7 @@ def agent_model( allow_text_result: bool, result_tools: Sequence[AbstractToolDefinition] | None, ) -> GeminiAgentModel: + check_allow_model_requests() tools = [_function_from_abstract_tool(t) for t in retrievers.values()] if result_tools is not None: tools += [_function_from_abstract_tool(t) for t in result_tools] diff --git a/pydantic_ai/models/openai.py b/pydantic_ai/models/openai.py index a4322ecc20..b3d54e42e1 100644 --- a/pydantic_ai/models/openai.py +++ b/pydantic_ai/models/openai.py @@ -33,6 +33,7 @@ StreamStructuredResponse, StreamTextResponse, cached_async_http_client, + check_allow_model_requests, ) @@ -85,6 +86,7 @@ def agent_model( allow_text_result: bool, result_tools: Sequence[AbstractToolDefinition] | None, ) -> AgentModel: + check_allow_model_requests() tools = [self._map_tool_definition(r) for r in retrievers.values()] if result_tools is not None: tools += [self._map_tool_definition(r) for r in result_tools] diff --git a/tests/conftest.py b/tests/conftest.py index ff86e5af17..75156fae02 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ import re import secrets import sys +from collections.abc import Iterator from datetime import datetime from pathlib import Path from types import ModuleType @@ -15,8 +16,13 @@ from _pytest.assertion.rewrite import AssertionRewritingHook from typing_extensions import TypeAlias +import pydantic_ai.models + __all__ = 'IsNow', 'TestEnv', 'ClientWithHandler' + +pydantic_ai.models.ALLOW_MODEL_REQUESTS = False + if TYPE_CHECKING: def IsNow(*args: Any, **kwargs: Any) -> datetime: ... @@ -38,28 +44,30 @@ class TestEnv: __test__ = False def __init__(self): - self.envars: set[str] = set() + self.envars: dict[str, str | None] = {} def set(self, name: str, value: str) -> None: - self.envars.add(name) + self.envars[name] = os.getenv(name) os.environ[name] = value - def pop(self, name: str) -> None: # pragma: no cover - self.envars.remove(name) - os.environ.pop(name) + def remove(self, name: str) -> None: + self.envars[name] = os.environ.pop(name, None) - def clear(self) -> None: - for n in self.envars: - os.environ.pop(n) + def reset(self) -> None: + for name, value in self.envars.items(): + if value is None: + os.environ.pop(name, None) + else: + os.environ[name] = value @pytest.fixture -def env(): +def env() -> Iterator[TestEnv]: test_env = TestEnv() yield test_env - test_env.clear() + test_env.reset() @pytest.fixture @@ -67,6 +75,12 @@ def anyio_backend(): return 'asyncio' +@pytest.fixture +def allow_model_requests(): + with pydantic_ai.models.override_allow_model_requests(True): + yield + + @pytest.fixture async def client_with_handler(): client: httpx.AsyncClient | None = None diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 9f30b1c91e..860b27ac18 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -59,6 +59,7 @@ def test_api_key_env_var(env: TestEnv): def test_api_key_not_set(env: TestEnv): + env.remove('GEMINI_API_KEY') with pytest.raises(UserError, match='API key must be provided or set in the GEMINI_API_KEY environment variable'): GeminiModel('gemini-1.5-flash') @@ -69,7 +70,7 @@ def test_api_key_empty(env: TestEnv): GeminiModel('gemini-1.5-flash') -def test_agent_model_simple(): +def test_agent_model_simple(allow_model_requests: None): m = GeminiModel('gemini-1.5-flash', api_key='via-arg') agent_model = m.agent_model({}, True, None) assert isinstance(agent_model.http_client, httpx.AsyncClient) @@ -88,7 +89,7 @@ class TestToolDefinition: outer_typed_dict_key: str | None = None -def test_agent_model_tools(): +def test_agent_model_tools(allow_model_requests: None): m = GeminiModel('gemini-1.5-flash', api_key='via-arg') retrievers = { 'foo': TestToolDefinition( @@ -144,7 +145,7 @@ def test_agent_model_tools(): assert agent_model.tool_config is None -def test_require_response_tool(): +def test_require_response_tool(allow_model_requests: None): m = GeminiModel('gemini-1.5-flash', api_key='via-arg') result_tool = TestToolDefinition( 'result', @@ -173,7 +174,7 @@ def test_require_response_tool(): ) -def test_json_def_replaced(): +def test_json_def_replaced(allow_model_requests: None): class Location(BaseModel): lat: float lng: float = 1.1 @@ -238,7 +239,7 @@ class Locations(BaseModel): ) -def test_json_def_replaced_any_of(): +def test_json_def_replaced_any_of(allow_model_requests: None): class Location(BaseModel): lat: float lng: float @@ -282,7 +283,7 @@ class Locations(BaseModel): ) -def test_json_def_recursive(): +def test_json_def_recursive(allow_model_requests: None): class Location(BaseModel): lat: float lng: float @@ -335,7 +336,9 @@ async def __aiter__(self) -> AsyncIterator[bytes]: @pytest.fixture -async def get_gemini_client(client_with_handler: ClientWithHandler, env: TestEnv): +async def get_gemini_client( + client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None +) -> GetGeminiClient: env.set('GEMINI_API_KEY', 'via-env-var') def create_client(response_or_list: ResOrList) -> httpx.AsyncClient: @@ -495,7 +498,7 @@ async def get_location(loc_name: str) -> str: assert result.cost() == snapshot(Cost(request_tokens=3, response_tokens=6, total_tokens=9)) -async def test_unexpected_response(client_with_handler: ClientWithHandler, env: TestEnv): +async def test_unexpected_response(client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None): env.set('GEMINI_API_KEY', 'via-env-var') def handler(_: httpx.Request): diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 196d98e518..cf74b0d40a 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -113,7 +113,7 @@ def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage ) -async def test_request_simple_success(): +async def test_request_simple_success(allow_model_requests: None): c = completion_message(ChatCompletionMessage(content='world', role='assistant')) mock_client = MockOpenAI.create_mock(c) m = OpenAIModel('gpt-4', openai_client=mock_client) @@ -139,7 +139,7 @@ async def test_request_simple_success(): ) -async def test_request_simple_usage(): +async def test_request_simple_usage(allow_model_requests: None): c = completion_message( ChatCompletionMessage(content='world', role='assistant'), usage=CompletionUsage(completion_tokens=1, prompt_tokens=2, total_tokens=3), @@ -153,7 +153,7 @@ async def test_request_simple_usage(): assert result.cost() == snapshot(Cost(request_tokens=2, response_tokens=1, total_tokens=3)) -async def test_request_structured_response(): +async def test_request_structured_response(allow_model_requests: None): c = completion_message( ChatCompletionMessage( content=None, @@ -190,7 +190,7 @@ async def test_request_structured_response(): ) -async def test_request_tool_call(): +async def test_request_tool_call(allow_model_requests: None): responses = [ completion_message( ChatCompletionMessage( @@ -309,7 +309,7 @@ def text_chunk(text: str, finish_reason: FinishReason | None = None) -> chat.Cha return chunk([ChoiceDelta(content=text, role='assistant')], finish_reason=finish_reason) -async def test_stream_text(): +async def test_stream_text(allow_model_requests: None): stream = text_chunk('hello '), text_chunk('world'), chunk([]) mock_client = MockOpenAI.create_mock_stream(stream) m = OpenAIModel('gpt-4', openai_client=mock_client) @@ -323,7 +323,7 @@ async def test_stream_text(): assert result.cost() == snapshot(Cost(request_tokens=6, response_tokens=3, total_tokens=9)) -async def test_stream_text_finish_reason(): +async def test_stream_text_finish_reason(allow_model_requests: None): stream = text_chunk('hello '), text_chunk('world'), text_chunk('.', finish_reason='stop') mock_client = MockOpenAI.create_mock_stream(stream) m = OpenAIModel('gpt-4', openai_client=mock_client) @@ -358,7 +358,7 @@ class MyTypedDict(TypedDict, total=False): second: str -async def test_stream_structured(): +async def test_stream_structured(allow_model_requests: None): stream = ( chunk([ChoiceDelta()]), chunk([ChoiceDelta(tool_calls=[])]), @@ -392,7 +392,7 @@ async def test_stream_structured(): assert result.cost().response_tokens == len(stream) -async def test_stream_structured_finish_reason(): +async def test_stream_structured_finish_reason(allow_model_requests: None): stream = ( struc_chunk('final_result', None), struc_chunk(None, '{"first": "One'), @@ -418,7 +418,7 @@ async def test_stream_structured_finish_reason(): assert result.is_complete -async def test_no_content(): +async def test_no_content(allow_model_requests: None): stream = chunk([ChoiceDelta()]), chunk([ChoiceDelta()]) mock_client = MockOpenAI.create_mock_stream(stream) m = OpenAIModel('gpt-4', openai_client=mock_client) diff --git a/tests/test_agent.py b/tests/test_agent.py index 3fd98280ca..aac21b9b0e 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -6,7 +6,7 @@ from inline_snapshot import snapshot from pydantic import BaseModel -from pydantic_ai import Agent, CallContext, ModelRetry, UnexpectedModelBehaviour +from pydantic_ai import Agent, CallContext, ModelRetry, UnexpectedModelBehaviour, UserError from pydantic_ai.messages import ( ArgsJson, ArgsObject, @@ -23,7 +23,7 @@ from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel from pydantic_ai.result import Cost, RunResult -from tests.conftest import IsNow +from tests.conftest import IsNow, TestEnv def test_result_tuple(): @@ -492,3 +492,28 @@ def empty(m: list[Message], _info: AgentInfo) -> ModelAnyResponse: ModelTextResponse(content='success', timestamp=IsNow(tz=timezone.utc)), ] ) + + +def test_model_requests_blocked(env: TestEnv): + env.set('GEMINI_API_KEY', 'foobar') + agent = Agent('gemini-1.5-flash', result_type=tuple[str, str], defer_model_check=True) + + with pytest.raises(RuntimeError, match='Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False'): + agent.run_sync('Hello') + + +def test_override_model(env: TestEnv): + env.set('GEMINI_API_KEY', 'foobar') + agent = Agent('gemini-1.5-flash', result_type=tuple[int, str], defer_model_check=True) + + with agent.override_model('test'): + result = agent.run_sync('Hello') + assert result.data == snapshot((0, 'a')) + + +def test_override_model_no_model(): + agent = Agent() + + with pytest.raises(UserError, match=r'`model` must be set either.+Even when `override_model\(\)` is customizing'): + with agent.override_model('test'): + agent.run_sync('Hello') diff --git a/tests/test_live.py b/tests/test_live.py index 6f00ff9d21..836defce2c 100644 --- a/tests/test_live.py +++ b/tests/test_live.py @@ -4,6 +4,7 @@ """ import os +from collections.abc import AsyncIterator import httpx import pytest @@ -20,7 +21,7 @@ @pytest.fixture -async def http_client(): +async def http_client(allow_model_requests: None) -> AsyncIterator[httpx.AsyncClient]: async with httpx.AsyncClient(timeout=30) as client: yield client