From 652496a0d5b431719e59ee4c7c36175e7f731754 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 11 Nov 2024 18:38:31 -0700 Subject: [PATCH 1/2] Add dependency injection support --- pydantic_ai/__init__.py | 13 +- pydantic_ai/_depends/__init__.py | 15 ++ pydantic_ai/_depends/build.py | 118 +++++++++ pydantic_ai/_depends/compat.py | 13 + pydantic_ai/_depends/depends.py | 207 +++++++++++++++ pydantic_ai/_depends/model.py | 0 pydantic_ai/_depends/models.py | 427 +++++++++++++++++++++++++++++++ pydantic_ai/_depends/provider.py | 40 +++ pydantic_ai/_depends/utils.py | 179 +++++++++++++ pydantic_ai/_system_prompt.py | 31 ++- pydantic_ai/agent.py | 130 +++++++++- pydantic_ai/call_context.py | 26 ++ 12 files changed, 1170 insertions(+), 29 deletions(-) create mode 100644 pydantic_ai/_depends/__init__.py create mode 100644 pydantic_ai/_depends/build.py create mode 100644 pydantic_ai/_depends/compat.py create mode 100644 pydantic_ai/_depends/depends.py create mode 100644 pydantic_ai/_depends/model.py create mode 100644 pydantic_ai/_depends/models.py create mode 100644 pydantic_ai/_depends/provider.py create mode 100644 pydantic_ai/_depends/utils.py create mode 100644 pydantic_ai/call_context.py diff --git a/pydantic_ai/__init__.py b/pydantic_ai/__init__.py index 132b11da08..964f93a228 100644 --- a/pydantic_ai/__init__.py +++ b/pydantic_ai/__init__.py @@ -1,8 +1,19 @@ from importlib.metadata import version +from ._depends import Depends, DependsType, inject from .agent import Agent from .dependencies import CallContext from .exceptions import ModelRetry, UnexpectedModelBehaviour, UserError -__all__ = 'Agent', 'CallContext', 'ModelRetry', 'UnexpectedModelBehaviour', 'UserError', '__version__' +__all__ = ( + 'Agent', + 'CallContext', + 'ModelRetry', + 'UnexpectedModelBehaviour', + 'UserError', + '__version__', + 'Depends', + 'DependsType', + 'inject', +) __version__ = version('pydantic_ai') diff --git a/pydantic_ai/_depends/__init__.py b/pydantic_ai/_depends/__init__.py new file mode 100644 index 0000000000..43b734225f --- /dev/null +++ b/pydantic_ai/_depends/__init__.py @@ -0,0 +1,15 @@ +"""Draws heavily from the fast_depends library, but with some important differences in behavior. + +In particular: +* No pydantic validation is performed on inputs to/outputs from function calls +* No support for extra_dependencies +* No support for custom field types +* You can call injected functions and pass values for arguments that would have been injected. + When this happens, the dependency function is not called and the passed value is used instead. + In fast_depends, the dependency function is always called and provided arguments are ignored. +""" + +from .depends import Depends, inject +from .models import Depends as DependsType + +__all__ = ('Depends', 'inject', 'DependsType') diff --git a/pydantic_ai/_depends/build.py b/pydantic_ai/_depends/build.py new file mode 100644 index 0000000000..8fa3eb5481 --- /dev/null +++ b/pydantic_ai/_depends/build.py @@ -0,0 +1,118 @@ +import inspect +from collections.abc import Awaitable +from typing import ( + Annotated, + Any, + Callable, + TypeVar, + Union, +) + +from typing_extensions import ( + ParamSpec, + get_args, + get_origin, +) + +from .models import CallModel, Depends as DependsType +from .utils import ( + get_evaluated_signature, + is_async_gen_callable, + is_coroutine_callable, + is_gen_callable, +) + +P = ParamSpec('P') +T = TypeVar('T') + + +def build_call_model( # noqa C901 + call: Union[ + Callable[P, T], + Callable[P, Awaitable[T]], + ], + *, + use_cache: bool = True, + is_sync: bool | None = None, +) -> CallModel[P, T]: + name = getattr(call, '__name__', type(call).__name__) + + is_call_async = is_coroutine_callable(call) or is_async_gen_callable(call) + if is_sync is None: + is_sync = not is_call_async + else: + assert not (is_sync and is_call_async), f'You cannot use async dependency `{name}` at sync main' + is_call_generator = is_gen_callable(call) + + signature = get_evaluated_signature(call) + + class_fields: dict[str, tuple[Any, Any]] = {} + dependencies: dict[str, CallModel[..., Any]] = {} + positional_args: list[str] = [] + keyword_args: list[str] = [] + var_positional_arg: str | None = None + var_keyword_arg: str | None = None + + for param_name, param in signature.parameters.items(): + dep: DependsType | None = None + + if param.annotation is inspect.Parameter.empty: + annotation = Any + else: + annotation = param.annotation + + if get_origin(param.annotation) is Annotated: + annotated_args = get_args(param.annotation) + for arg in annotated_args[1:]: + if isinstance(arg, DependsType): + if dep: + raise ValueError(f'Cannot specify multiple `Depends` arguments for `{param_name}`!') + dep = arg + + default: Any + if param.kind == inspect.Parameter.VAR_POSITIONAL: + default = () + var_positional_arg = param_name + elif param.kind == inspect.Parameter.VAR_KEYWORD: + default = {} + var_keyword_arg = param_name + elif param.default is inspect.Parameter.empty: + default = inspect.Parameter.empty + else: + default = param.default + + if isinstance(default, DependsType): + if dep: + raise ValueError(f'Cannot use `Depends` with `Annotated` and a default value for `{param_name}`!') + dep, default = default, inspect.Parameter.empty + + else: + class_fields[param_name] = (annotation, default) + + if dep: + dependencies[param_name] = build_call_model( + dep.dependency, + use_cache=dep.use_cache, + is_sync=is_sync, + ) + + keyword_args.append(param_name) + + else: + if param.kind is param.KEYWORD_ONLY: + keyword_args.append(param_name) + elif param.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): + positional_args.append(param_name) + + return CallModel( + call=call, + params=class_fields, + use_cache=use_cache, + is_async=is_call_async, + is_generator=is_call_generator, + dependencies=dependencies, + positional_args=positional_args, + keyword_args=keyword_args, + var_positional_arg=var_positional_arg, + var_keyword_arg=var_keyword_arg, + ) diff --git a/pydantic_ai/_depends/compat.py b/pydantic_ai/_depends/compat.py new file mode 100644 index 0000000000..a92476b365 --- /dev/null +++ b/pydantic_ai/_depends/compat.py @@ -0,0 +1,13 @@ +import sys +from importlib.metadata import version as get_version + +__all__ = ('ExceptionGroup',) +ANYIO_V3 = get_version('anyio').startswith('3.') + +if ANYIO_V3: + from anyio import ExceptionGroup as ExceptionGroup # type: ignore +else: + if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup as ExceptionGroup + else: + ExceptionGroup = ExceptionGroup diff --git a/pydantic_ai/_depends/depends.py b/pydantic_ai/_depends/depends.py new file mode 100644 index 0000000000..3a448ed822 --- /dev/null +++ b/pydantic_ai/_depends/depends.py @@ -0,0 +1,207 @@ +from collections.abc import AsyncIterator, Iterator +from contextlib import AsyncExitStack, ExitStack +from functools import partial, wraps +from typing import ( + Any, + Callable, + TypeVar, + cast, + overload, +) + +from typing_extensions import ParamSpec + +from .build import build_call_model +from .models import CallModel, Depends as DependsType +from .provider import Provider, dependency_provider + +P = ParamSpec('P') +T = TypeVar('T') + + +def Depends( + dependency: Callable[P, T], + *, + use_cache: bool = True, +) -> T: + result = DependsType(dependency=dependency, use_cache=use_cache) + # We lie about the return type here to get better type-checking + return result # type: ignore + + +@overload +def inject( + *, + dependency_overrides_provider: Provider | None = dependency_provider, +) -> Callable[[Callable[P, T]], Callable[P, T]]: ... + + +@overload +def inject( + func: Callable[P, T], +) -> Callable[P, T]: ... + + +def inject( + func: Callable[P, T] | None = None, + dependency_overrides_provider: Provider | None = dependency_provider, +) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[P, T]: + if func is None: + + def decorator(func: Callable[P, T]) -> Callable[P, T]: + return _inject_decorator(func, dependency_overrides_provider) + + return decorator + + return _inject_decorator(func, dependency_overrides_provider) + + +def _inject_decorator( + func: Callable[P, T], dependency_overrides_provider: Provider | None = dependency_provider +) -> Callable[P, T]: + overrides: dict[Callable[..., Any], Callable[..., Any]] | None = ( + dependency_overrides_provider.dependency_overrides if dependency_overrides_provider else None + ) + + def func_wrapper(func: Callable[P, T]) -> Callable[P, T]: + call_model = build_call_model(call=func) + + if call_model.is_async: + if call_model.is_generator: + return partial(solve_async_gen, call_model, overrides) # type: ignore[assignment] + + else: + + @wraps(func) + async def async_injected_wrapper(*args: P.args, **kwargs: P.kwargs): + async with AsyncExitStack() as stack: + r = await call_model.asolve( + args=args, + kwargs=kwargs, + stack=stack, + dependency_overrides=overrides, + cache_dependencies={}, + nested=False, + ) + return r + raise AssertionError('unreachable') + + return async_injected_wrapper # type: ignore # + + else: + if call_model.is_generator: + return partial(solve_gen, call_model, overrides) # type: ignore[assignment] + + else: + + @wraps(func) + def sync_injected_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + with ExitStack() as stack: + r = call_model.solve( + args=args, + kwargs=kwargs, + stack=stack, + dependency_overrides=overrides, + cache_dependencies={}, + nested=False, + ) + return r + raise AssertionError('unreachable') + + return sync_injected_wrapper + + return func_wrapper(func) + + +class solve_async_gen: + _iter: AsyncIterator[Any] | None = None + + def __init__( + self, + model: 'CallModel[..., Any]', + overrides: dict[Any, Any] | None, + *args: Any, + **kwargs: Any, + ): + self.call = model + self.args = args + self.kwargs = kwargs + self.overrides = overrides + + def __aiter__(self) -> 'solve_async_gen': + self._iter = None + self.stack = AsyncExitStack() + return self + + async def __anext__(self) -> Any: + if self._iter is None: + stack = self.stack = AsyncExitStack() + await self.stack.__aenter__() + self._iter = cast( + AsyncIterator[Any], + ( + await self.call.asolve( + *self.args, + stack=stack, + dependency_overrides=self.overrides, + cache_dependencies={}, + nested=False, + **self.kwargs, + ) + ).__aiter__(), + ) + + try: + r = await self._iter.__anext__() + except StopAsyncIteration as e: + await self.stack.__aexit__(None, None, None) + raise e + else: + return r + + +class solve_gen: + _iter: Iterator[Any] | None = None + + def __init__( + self, + model: 'CallModel[..., Any]', + overrides: dict[Any, Any] | None, + *args: Any, + **kwargs: Any, + ): + self.call = model + self.args = args + self.kwargs = kwargs + self.overrides = overrides + + def __iter__(self) -> 'solve_gen': + self._iter = None + self.stack = ExitStack() + return self + + def __next__(self) -> Any: + if self._iter is None: + stack = self.stack = ExitStack() + self.stack.__enter__() + self._iter = cast( + Iterator[Any], + iter( + self.call.solve( + args=self.args, + kwargs=self.kwargs, + stack=stack, + dependency_overrides=self.overrides, + cache_dependencies={}, + nested=False, + ) + ), + ) + + try: + r = next(self._iter) + except StopIteration as e: + self.stack.__exit__(None, None, None) + raise e + else: + return r diff --git a/pydantic_ai/_depends/model.py b/pydantic_ai/_depends/model.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pydantic_ai/_depends/models.py b/pydantic_ai/_depends/models.py new file mode 100644 index 0000000000..860942692e --- /dev/null +++ b/pydantic_ai/_depends/models.py @@ -0,0 +1,427 @@ +from collections.abc import Awaitable, Generator, Iterable, Sequence +from contextlib import AsyncExitStack, ExitStack +from dataclasses import dataclass +from inspect import Parameter, unwrap +from typing import ( + Any, + Callable, + Generic, + TypeVar, + Union, +) + +from typing_extensions import ParamSpec + +from .utils import ( + is_async_gen_callable, + is_coroutine_callable, + is_gen_callable, + run_async, + solve_generator_async, + solve_generator_sync, +) + +P = ParamSpec('P') +T = TypeVar('T') + + +@dataclass +class Depends: + dependency: Callable[..., Any] + use_cache: bool = True + + +class CallModel(Generic[P, T]): + call: Union[ + Callable[P, T], + Callable[P, Awaitable[T]], + ] + is_async: bool + is_generator: bool + + params: dict[str, tuple[Any, Any]] + + dependencies: dict[str, 'CallModel[..., Any]'] + sorted_dependencies: tuple[tuple['CallModel[..., Any]', int], ...] + keyword_args: tuple[str, ...] + positional_args: tuple[str, ...] + var_positional_arg: str | None + var_keyword_arg: str | None + + use_cache: bool + + __slots__ = ( + 'call', + 'is_async', + 'is_generator', + 'params', + 'dependencies', + 'sorted_dependencies', + 'keyword_args', + 'positional_args', + 'var_positional_arg', + 'var_keyword_arg', + 'use_cache', + ) + + @property + def call_name(self) -> str: + call = unwrap(self.call) + return getattr(call, '__name__', type(call).__name__) + + @property + def flat_params(self) -> dict[str, tuple[Any, Any]]: + params = self.params + for d in self.dependencies.values(): + params.update(d.flat_params) + return params + + @property + def flat_dependencies( + self, + ) -> dict[ + Callable[..., Any], + tuple[ + 'CallModel[..., Any]', + tuple[Callable[..., Any], ...], + ], + ]: + flat: dict[ + Callable[..., Any], + tuple[ + CallModel[..., Any], + tuple[Callable[..., Any], ...], + ], + ] = {} + + for i in self.dependencies.values(): + flat.update( + { + i.call: ( + i, + tuple(j.call for j in i.dependencies.values()), + ) + } + ) + + flat.update(i.flat_dependencies) + + return flat + + def __init__( + self, + *, + call: Union[ + Callable[P, T], + Callable[P, Awaitable[T]], + ], + params: dict[str, tuple[Any, Any]], + use_cache: bool = True, + is_async: bool = False, + is_generator: bool = False, + dependencies: dict[str, 'CallModel[..., Any]'] | None = None, + keyword_args: list[str] | None = None, + positional_args: list[str] | None = None, + var_positional_arg: str | None = None, + var_keyword_arg: str | None = None, + ): + self.call = call + + self.keyword_args = tuple(keyword_args or ()) + self.positional_args = tuple(positional_args or ()) + self.var_positional_arg = var_positional_arg + self.var_keyword_arg = var_keyword_arg + self.use_cache = use_cache + self.is_async = is_async or is_coroutine_callable(call) or is_async_gen_callable(call) + self.is_generator = is_generator or is_gen_callable(call) or is_async_gen_callable(call) + + self.dependencies = dependencies or {} + + sorted_dep: list[CallModel[..., Any]] = [] + flat = self.flat_dependencies + for calls in flat.values(): + _sort_dep(sorted_dep, calls, flat) + + for name in self.dependencies.keys(): + params.pop(name, None) + self.params = params + + def _solve( # noqa C901 + self, + *, + args: tuple[Any, ...], + kwargs: dict[str, Any], + cache_dependencies: dict[ + Union[ + Callable[P, T], + Callable[P, Awaitable[T]], + ], + T, + ], + dependency_overrides: dict[ + Union[ + Callable[P, T], + Callable[P, Awaitable[T]], + ], + Union[ + Callable[P, T], + Callable[P, Awaitable[T]], + ], + ] + | None = None, + ) -> Generator[ + tuple[ + tuple[Any, ...], + dict[str, Any], + Callable[..., Any], + ], + Any, + T, + ]: + if dependency_overrides: + call = dependency_overrides.get(self.call, self.call) + assert self.is_async or not is_coroutine_callable( + call + ), f'You cannot use async dependency `{self.call_name}` at sync main' + + else: + call = self.call + + if self.use_cache and call in cache_dependencies: + return cache_dependencies[call] + + kw: dict[str, Any] = {} + + for arg in self.keyword_args: + if (v := kwargs.pop(arg, Parameter.empty)) is not Parameter.empty: + kw[arg] = v + + if self.var_keyword_arg is not None: + kw[self.var_keyword_arg] = kwargs + else: + kw.update(kwargs) + + positional_arg_index = 0 + for arg in self.positional_args: + if args: + kw[arg], args = args[0], args[1:] + positional_arg_index += 1 + else: + break + + keyword_args: Iterable[str] + if self.var_positional_arg is not None: + kw[self.var_positional_arg] = args + keyword_args = self.keyword_args + else: + if args: + remaining_args = (self.positional_args + self.keyword_args)[positional_arg_index:] + for name, arg in zip(remaining_args, args): + kw[name] = arg + + keyword_args = self.keyword_args + self.positional_args + for arg in keyword_args: + if arg in self.params: + default = self.params[arg][1] + if default is not Parameter.empty: + kw[arg] = self.params[arg][1] + + if not args: + break + + if arg not in self.dependencies: + kw[arg], args = args[0], args[1:] + + solved_kw: dict[str, Any] + solved_kw = yield args, kw, call + + args_: Sequence[Any] + + kwargs_ = {arg: solved_kw[arg] for arg in keyword_args if arg in solved_kw} + + if self.var_positional_arg is not None: + args_ = tuple(map(solved_kw.get, self.positional_args)) + else: + args_ = () + + response: T + response = yield args_, kwargs_, call + + if self.use_cache: # pragma: no branch + cache_dependencies[call] = response + + return response + + def solve( + self, + *, + args: tuple[Any, ...], + kwargs: dict[str, Any], + stack: ExitStack, + cache_dependencies: dict[ + Union[ + Callable[P, T], + Callable[P, Awaitable[T]], + ], + T, + ], + dependency_overrides: dict[ + Union[ + Callable[P, T], + Callable[P, Awaitable[T]], + ], + Union[ + Callable[P, T], + Callable[P, Awaitable[T]], + ], + ] + | None = None, + nested: bool = False, + ) -> T: + cast_gen = self._solve( + args=args, + kwargs=kwargs, + cache_dependencies=cache_dependencies, + dependency_overrides=dependency_overrides, + ) + try: + args, kwargs, _ = next(cast_gen) + except StopIteration as e: + cached_value: T = e.value + return cached_value + + for dep_arg, dep in self.dependencies.items(): + if not nested and dep_arg in kwargs: + continue + kwargs[dep_arg] = dep.solve( + args=(), + kwargs=kwargs, + stack=stack, + cache_dependencies=cache_dependencies, + dependency_overrides=dependency_overrides, + nested=True, + ) + + final_args, final_kwargs, call = cast_gen.send(kwargs) + + if self.is_generator and nested: + response = solve_generator_sync( + sub_args=final_args, + sub_values=final_kwargs, + call=call, + stack=stack, + ) + + else: + response = call(*final_args, **final_kwargs) + + try: + cast_gen.send(response) + except StopIteration as e: + value: T = e.value + return value + + raise AssertionError('unreachable') + + async def asolve( + self, + *, + args: tuple[Any, ...], + kwargs: dict[str, Any], + stack: AsyncExitStack, + cache_dependencies: dict[ + Union[ + Callable[P, T], + Callable[P, Awaitable[T]], + ], + T, + ], + dependency_overrides: dict[ + Union[ + Callable[P, T], + Callable[P, Awaitable[T]], + ], + Union[ + Callable[P, T], + Callable[P, Awaitable[T]], + ], + ] + | None = None, + nested: bool = False, + ) -> T: + cast_gen = self._solve( + args=args, + kwargs=kwargs, + cache_dependencies=cache_dependencies, + dependency_overrides=dependency_overrides, + ) + try: + args, kwargs, _ = next(cast_gen) + except StopIteration as e: + cached_value: T = e.value + return cached_value + + for dep_arg, dep in self.dependencies.items(): + if not nested and dep_arg in kwargs: + continue + kwargs[dep_arg] = await dep.asolve( + args=args, + kwargs=kwargs, + stack=stack, + cache_dependencies=cache_dependencies, + dependency_overrides=dependency_overrides, + nested=True, + ) + + final_args, final_kwargs, call = cast_gen.send(kwargs) + + if self.is_generator and nested: + response = await solve_generator_async( + final_args, + final_kwargs, + call=call, + stack=stack, + ) + else: + response = await run_async(call, *final_args, **final_kwargs) + + try: + cast_gen.send(response) + except StopIteration as e: + value: T = e.value + return value + + raise AssertionError('unreachable') + + +def _sort_dep( + collector: list['CallModel[..., Any]'], + items: tuple[ + 'CallModel[..., Any]', + tuple[Callable[..., Any], ...], + ], + flat: dict[ + Callable[..., Any], + tuple[ + 'CallModel[..., Any]', + tuple[Callable[..., Any], ...], + ], + ], +) -> None: + model, calls = items + + if model in collector: + return + + if not calls: + position = -1 + + else: + for i in calls: + sub_model, _ = flat[i] + if sub_model not in collector: # pragma: no branch + _sort_dep(collector, flat[i], flat) + + position = max(collector.index(flat[i][0]) for i in calls) + + collector.insert(position + 1, model) diff --git a/pydantic_ai/_depends/provider.py b/pydantic_ai/_depends/provider.py new file mode 100644 index 0000000000..8d4d466256 --- /dev/null +++ b/pydantic_ai/_depends/provider.py @@ -0,0 +1,40 @@ +from collections.abc import Iterator +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Any, Callable, cast + +_sentinel = object() + + +@dataclass +class Provider: + dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] = field(default_factory=dict) + + def clear(self) -> None: + self.dependency_overrides = {} + + def override( + self, + original: Callable[..., Any], + override: Callable[..., Any], + ) -> None: + self.dependency_overrides[original] = override + + @contextmanager + def scope( + self, + original: Callable[..., Any], + override: Callable[..., Any], + ) -> Iterator[None]: + before_scope = self.dependency_overrides.pop(original, _sentinel) + self.dependency_overrides[original] = override + try: + yield + finally: + if before_scope is _sentinel: + self.dependency_overrides.pop(original, None) + else: + self.dependency_overrides[original] = cast(Callable[..., Any], before_scope) + + +dependency_provider = Provider() # default provider diff --git a/pydantic_ai/_depends/utils.py b/pydantic_ai/_depends/utils.py new file mode 100644 index 0000000000..cc7ade82c0 --- /dev/null +++ b/pydantic_ai/_depends/utils.py @@ -0,0 +1,179 @@ +import asyncio +import functools +import inspect +from collections.abc import AsyncGenerator, AsyncIterable, Awaitable +from contextlib import AbstractContextManager, AsyncExitStack, ExitStack, asynccontextmanager, contextmanager +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Callable, + ForwardRef, + TypeVar, + Union, + cast, +) + +import anyio +import anyio.to_thread +from pydantic._internal._typing_extra import eval_type +from typing_extensions import ( + ParamSpec, + get_args, + get_origin, +) + +if TYPE_CHECKING: + from types import FrameType + +P = ParamSpec('P') +T = TypeVar('T') + + +async def run_async( + func: Union[ + Callable[P, T], + Callable[P, Awaitable[T]], + ], + *args: P.args, + **kwargs: P.kwargs, +) -> T: + if is_coroutine_callable(func): + return await cast(Callable[P, Awaitable[T]], func)(*args, **kwargs) + else: + return await run_in_threadpool(cast(Callable[P, T], func), *args, **kwargs) + + +async def run_in_threadpool(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + if kwargs: + func = functools.partial(func, **kwargs) + return await anyio.to_thread.run_sync(func, *args) # type: ignore + + +async def solve_generator_async( + sub_args: tuple[Any, ...], sub_values: dict[str, Any], call: Callable[..., Any], stack: AsyncExitStack +) -> Any: + if is_gen_callable(call): + cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) + elif is_async_gen_callable(call): # pragma: no branch + cm = asynccontextmanager(call)(*sub_args, **sub_values) + else: + raise AssertionError(f'Unknown generator type {call}') + return await stack.enter_async_context(cm) + + +def solve_generator_sync( + sub_args: tuple[Any, ...], sub_values: dict[str, Any], call: Callable[..., Any], stack: ExitStack +) -> Any: + cm = contextmanager(call)(*sub_args, **sub_values) + return stack.enter_context(cm) + + +def get_evaluated_signature(call: Callable[..., Any]) -> inspect.Signature: + signature = inspect.signature(call) + + locals = collect_outer_stack_locals() + + # We unwrap call to get the original unwrapped function + call = inspect.unwrap(call) + + globalns = getattr(call, '__globals__', {}) + typed_params = [ + inspect.Parameter( + name=param.name, + kind=param.kind, + default=param.default, + annotation=get_typed_annotation( + param.annotation, + globalns, + locals, + ), + ) + for param in signature.parameters.values() + ] + typed_return_annotation = get_typed_annotation( + signature.return_annotation, + globalns, + locals, + ) + + return signature.replace(parameters=typed_params, return_annotation=typed_return_annotation) + + +def collect_outer_stack_locals() -> dict[str, Any]: + frame = inspect.currentframe() + + frames: list[FrameType] = [] + while frame is not None: + if 'fast_depends' not in frame.f_code.co_filename: + frames.append(frame) + frame = frame.f_back + + locals: dict[str, Any] = {} + for f in frames[::-1]: + locals.update(f.f_locals) + + return locals + + +def get_typed_annotation( + annotation: Any, + globalns: dict[str, Any], + locals: dict[str, Any], +) -> Any: + if isinstance(annotation, str): + annotation = ForwardRef(annotation) + + if isinstance(annotation, ForwardRef): + annotation = eval_type(annotation, globalns, locals, lenient=True) + + if get_origin(annotation) is Annotated and (args := get_args(annotation)): + solved_args = [get_typed_annotation(x, globalns, locals) for x in args] + annotation.__origin__, annotation.__metadata__ = solved_args[0], tuple(solved_args[1:]) + + return annotation + + +@asynccontextmanager +async def contextmanager_in_threadpool( + cm: AbstractContextManager[T], +) -> AsyncGenerator[T, None]: + exit_limiter = anyio.CapacityLimiter(1) + try: + yield await run_in_threadpool(cm.__enter__) + except Exception as e: + ok = bool(await anyio.to_thread.run_sync(cm.__exit__, type(e), e, None, limiter=exit_limiter)) + if not ok: # pragma: no branch + raise e + else: + await anyio.to_thread.run_sync(cm.__exit__, None, None, None, limiter=exit_limiter) + + +def is_gen_callable(call: Callable[..., Any]) -> bool: + if inspect.isgeneratorfunction(call): + return True + dunder_call = getattr(call, '__call__', None) + return inspect.isgeneratorfunction(dunder_call) + + +def is_async_gen_callable(call: Callable[..., Any]) -> bool: + if inspect.isasyncgenfunction(call): + return True + dunder_call = getattr(call, '__call__', None) + return inspect.isasyncgenfunction(dunder_call) + + +def is_coroutine_callable(call: Callable[..., Any]) -> bool: + if inspect.isclass(call): + return False + + if asyncio.iscoroutinefunction(call): + return True + + dunder_call = getattr(call, '__call__', None) + return asyncio.iscoroutinefunction(dunder_call) + + +async def async_map(func: Callable[..., T], async_iterable: AsyncIterable[Any]) -> AsyncIterable[T]: + async for i in async_iterable: + yield func(i) diff --git a/pydantic_ai/_system_prompt.py b/pydantic_ai/_system_prompt.py index e44588d722..b0a36693f2 100644 --- a/pydantic_ai/_system_prompt.py +++ b/pydantic_ai/_system_prompt.py @@ -3,31 +3,30 @@ import inspect from collections.abc import Awaitable from dataclasses import dataclass, field -from typing import Any, Callable, Generic, cast +from typing import Callable, Generic, cast from . import _utils -from .dependencies import AgentDeps, CallContext, SystemPromptFunc +from .call_context import call_context +from .dependencies import AgentDeps, CallContext @dataclass class SystemPromptRunner(Generic[AgentDeps]): - function: SystemPromptFunc[AgentDeps] - _takes_ctx: bool = field(init=False) + function: Callable[[], str | Awaitable[str]] _is_async: bool = field(init=False) def __post_init__(self): - self._takes_ctx = len(inspect.signature(self.function).parameters) > 0 self._is_async = inspect.iscoroutinefunction(self.function) async def run(self, deps: AgentDeps) -> str: - if self._takes_ctx: - args = (CallContext(deps, 0, None),) - else: - args = () - - if self._is_async: - function = cast(Callable[[Any], Awaitable[str]], self.function) - return await function(*args) - else: - function = cast(Callable[[Any], str], self.function) - return await _utils.run_in_executor(function, *args) + # TODO: Need to set the agent call context appropriately when running retries, etc.; not sure where that happens + with call_context(CallContext(deps, 0, None)): + # Thanks to dependency injection, we can assume self.function accepts zero arguments + # If that's wrong, the user will get a pydantic error; we should eventually the error messages + # though. + if self._is_async: + function = cast(Callable[[], Awaitable[str]], self.function) + return await function() + else: + function = cast(Callable[[], str], self.function) + return await _utils.run_in_executor(function) diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index 0c97abe413..48f7a0bdff 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -1,20 +1,47 @@ from __future__ import annotations as _annotations import asyncio -from collections.abc import AsyncIterator, Sequence +import inspect +from collections.abc import AsyncIterator, Awaitable, Sequence from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import Any, Callable, Generic, Literal, cast, final, overload +from typing import ( + Annotated, + Any, + Callable, + Concatenate, + Generic, + Literal, + ParamSpec, + TypeVar, + cast, + final, + get_args, + get_origin, + overload, +) import logfire_api from typing_extensions import assert_never -from . import _result, _retriever as _r, _system_prompt, _utils, exceptions, messages as _messages, models, result -from .dependencies import AgentDeps, RetrieverContextFunc, RetrieverParams, RetrieverPlainFunc +from . import ( + _result, + _retriever as _r, + _system_prompt, + _utils, + exceptions, + messages as _messages, + models, + result, +) +from ._depends import Depends, DependsType, inject +from .call_context import get_call_context +from .dependencies import AgentDeps, CallContext, RetrieverContextFunc, RetrieverParams, RetrieverPlainFunc from .result import ResultData __all__ = 'Agent', 'KnownModelName' + KnownModelName = Literal[ 'openai:gpt-4o', 'openai:gpt-4o-mini', @@ -30,6 +57,9 @@ """ _logfire = logfire_api.Logfire(otel_scope='pydantic-ai') +SystemPromptReturnT = TypeVar('SystemPromptReturnT', bound=str | Awaitable[str]) +T = TypeVar('T') +P = ParamSpec('P') @final @@ -272,20 +302,65 @@ async def run_stream( # the model_response should have been fully streamed by now, we can add it's cost cost += model_response.cost() - def system_prompt( - self, func: _system_prompt.SystemPromptFunc[AgentDeps] - ) -> _system_prompt.SystemPromptFunc[AgentDeps]: - """Decorator to register a system prompt function that takes `CallContext` as it's only argument.""" - self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func)) + def system_prompt(self, func: Callable[P, SystemPromptReturnT]) -> Callable[P, SystemPromptReturnT]: + """Decorator to register a system prompt function that takes any number of dependency-injected arguments.""" + # TODO: Do runtime validation that all arguments are injected and, if present, the CallContext type matches + # Note: this may require changes to how dependency_injection works + func = inject_for_agent(func) + validated_func = cast(Callable[[], SystemPromptReturnT], func) # note: no validation has been done yet + self._system_prompt_functions.append(_system_prompt.SystemPromptRunner[AgentDeps](validated_func)) return func + def system_prompt_checked( + self, func: Callable[Concatenate[CallContext[AgentDeps], P], SystemPromptReturnT] + ) -> Callable[Concatenate[CallContext[AgentDeps], P], SystemPromptReturnT]: + """Decorator to register a system prompt function that takes `CallContext` as its first argument. + + The function may also accept any number of dependency-injected arguments. The upshot of this API is the ability + to type-check agreement of the AgentDeps parameter between the agent and the system prompt function. + """ + return self.system_prompt(func) + + # Note: we need to use overloads here rather than the typevar bound trick used for system_prompt because we would + # need to reference the ResultData parameter in the bound of the typevar, and Python's generics don't support this. + @overload + def result_validator(self, func: Callable[P, ResultData]) -> Callable[P, ResultData]: ... + + @overload + def result_validator(self, func: Callable[P, Awaitable[ResultData]]) -> Callable[P, Awaitable[ResultData]]: ... + + @overload def result_validator( - self, func: _result.ResultValidatorFunc[AgentDeps, ResultData] - ) -> _result.ResultValidatorFunc[AgentDeps, ResultData]: + self, func: Callable[P, ResultData | Awaitable[ResultData]] + ) -> Callable[P, ResultData | Awaitable[ResultData]]: ... + + def result_validator( + self, func: Callable[P, ResultData | Awaitable[ResultData]] + ) -> Callable[P, ResultData | Awaitable[ResultData]]: """Decorator to register a result validator function.""" - self._result_validators.append(_result.ResultValidator(func)) + # TODO: Do runtime validation that all arguments except one are injected and, + # if present, the CallContext type matches. + # Note: this may require changes to how dependency_injection works + validated_func = cast(Callable[[ResultData], ResultData], func) # note: no validation has been done yet + self._result_validators.append(_result.ResultValidator(validated_func)) return func + @overload + def result_validator_checked( + self, func: Callable[Concatenate[CallContext[AgentDeps], P], ResultData] + ) -> Callable[Concatenate[CallContext[AgentDeps], P], ResultData]: ... + + @overload + def result_validator_checked( + self, func: Callable[Concatenate[CallContext[AgentDeps], P], Awaitable[ResultData]] + ) -> Callable[Concatenate[CallContext[AgentDeps], P], Awaitable[ResultData]]: ... + + def result_validator_checked( + self, func: Callable[Concatenate[CallContext[AgentDeps], P], ResultData | Awaitable[ResultData]] + ) -> Callable[Concatenate[CallContext[AgentDeps], P], ResultData | Awaitable[ResultData]]: + """Decorator to register a result validator function.""" + return self.result_validator(func) + @overload def retriever_context( self, func: RetrieverContextFunc[AgentDeps, RetrieverParams], / @@ -546,3 +621,34 @@ def _unknown_tool(self, tool_name: str) -> _messages.RetryPrompt: else: msg = 'No tools available.' return _messages.RetryPrompt(content=f'Unknown tool name: {tool_name!r}. {msg}') + + +def inject_for_agent(func: Callable[..., T]) -> Callable[..., T]: + signature = inspect.signature(func) + new_params: list[inspect.Parameter] = [] + for param in signature.parameters.values(): + if isinstance(param.default, DependsType): + # Explicitly has an = Depends(...) + new_params.append(param) + continue + + if get_origin(param.annotation) is Annotated: + annotated_args = get_args(param.annotation) + if any(isinstance(arg, DependsType) for arg in annotated_args): + # Annotated with Depends(...) + new_params.append(param) + continue + type_hint = annotated_args[0] + else: + type_hint = param.annotation + + if get_origin(type_hint) is CallContext: + # Set a default depends context + param = param.replace(annotation=Annotated[param.annotation, Depends(get_call_context)]) + + new_params.append(param) + + func.__signature__ = signature.replace(parameters=new_params) # type: ignore + func = inject(func) + + return func diff --git a/pydantic_ai/call_context.py b/pydantic_ai/call_context.py new file mode 100644 index 0000000000..4b01ae13c6 --- /dev/null +++ b/pydantic_ai/call_context.py @@ -0,0 +1,26 @@ +from __future__ import annotations as _annotations + +from contextlib import contextmanager +from contextvars import ContextVar +from typing import ( + Any, +) + +from .dependencies import CallContext + +_CALL_CONTEXT_VAR: ContextVar[CallContext[Any]] = ContextVar('AgentCallContext') + + +def get_call_context() -> CallContext[Any]: + """A helper to get the current agent call context. Can be used with Depends(...).""" + return _CALL_CONTEXT_VAR.get() + + +@contextmanager +def call_context(ctx: CallContext[Any]): + """Temporarily set the agent call context for the duration of the context manager.""" + token = _CALL_CONTEXT_VAR.set(ctx) + try: + yield + finally: + _CALL_CONTEXT_VAR.reset(token) From 2b353d54098572280ca4c5d6b0b375abb5045de6 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 11 Nov 2024 23:02:46 -0700 Subject: [PATCH 2/2] Add an example --- pydantic_ai/__init__.py | 5 +- pydantic_ai/_depends/__init__.py | 3 +- .../dependency_injection_evals.py | 298 ++++++++++++++++++ 3 files changed, 303 insertions(+), 3 deletions(-) create mode 100644 pydantic_ai_examples/dependency_injection_evals.py diff --git a/pydantic_ai/__init__.py b/pydantic_ai/__init__.py index 964f93a228..cf1103b396 100644 --- a/pydantic_ai/__init__.py +++ b/pydantic_ai/__init__.py @@ -1,6 +1,6 @@ from importlib.metadata import version -from ._depends import Depends, DependsType, inject +from ._depends import Depends, DependsType, Provider, dependency_provider from .agent import Agent from .dependencies import CallContext from .exceptions import ModelRetry, UnexpectedModelBehaviour, UserError @@ -14,6 +14,7 @@ '__version__', 'Depends', 'DependsType', - 'inject', + 'Provider', + 'dependency_provider', ) __version__ = version('pydantic_ai') diff --git a/pydantic_ai/_depends/__init__.py b/pydantic_ai/_depends/__init__.py index 43b734225f..77bafbd0df 100644 --- a/pydantic_ai/_depends/__init__.py +++ b/pydantic_ai/_depends/__init__.py @@ -11,5 +11,6 @@ from .depends import Depends, inject from .models import Depends as DependsType +from .provider import Provider, dependency_provider -__all__ = ('Depends', 'inject', 'DependsType') +__all__ = ('Depends', 'inject', 'DependsType', 'Provider', 'dependency_provider') diff --git a/pydantic_ai_examples/dependency_injection_evals.py b/pydantic_ai_examples/dependency_injection_evals.py new file mode 100644 index 0000000000..744afb0fdc --- /dev/null +++ b/pydantic_ai_examples/dependency_injection_evals.py @@ -0,0 +1,298 @@ +from collections.abc import Sequence +from dataclasses import dataclass +from functools import cache +from typing import Literal, get_args + +from pydantic_ai import Agent, Depends, dependency_provider +from pydantic_ai.models.gemini import GeminiModel + +# Define possible tags +ReviewTag = Literal[ + 'good food quality', + 'average food quality', + 'bad food quality', + 'good customer service', + 'bad customer service', + 'long waits', + 'order mix-ups', + 'good atmosphere', + 'other complaint', + 'other compliment', +] + + +# Define the response type +@dataclass +class ReviewAnalysisResponse: + score: int # Between -5 (most negative) and +5 (most positive) + tags: list[ReviewTag] # Extracted tags from the review + + +Response = ReviewAnalysisResponse + +# Create the agent +agent = Agent[None, Response]( + result_type=Response, + retries=2, + model=GeminiModel('gemini-1.5-flash', api_key=''), +) + + +# Define training examples +@dataclass +class TrainingExample: + text: str + score: int + tags: list[ReviewTag] + + +@cache +def get_examples() -> list[TrainingExample]: + """Get training examples to use when building the system prompt. + + Note 1: we could use an external file or database to store examples in a real-world scenario. + In this case, it may be useful to use CallContext and/or dependency injection to obtain resources used to obtain + the examples, especially if the examples or other input we want to provide to the prompt might be user-managed or + otherwise updating in real time. + + Note 2: Returning examples in a function allows us to use dependency injection to easily override the set of + examples used while building the prompt. This allows us to evaluate our agent's performance using cross-validation. + """ + return [ + TrainingExample( + text='The food was amazing and the service was excellent.', + score=5, + tags=['good food quality', 'good customer service'], + ), + TrainingExample( + text='I waited an hour for my food, and it was cold when it arrived.', + score=-4, + tags=['long waits', 'bad food quality'], + ), + TrainingExample( + text='The waiter mixed up our orders and was very rude.', + score=-3, + tags=['order mix-ups', 'bad customer service'], + ), + TrainingExample( + text='Great atmosphere, but the food was just okay.', + score=1, + tags=['good atmosphere', 'average food quality'], + ), + TrainingExample( + text='The food was terrible, but the staff were friendly.', + score=-2, + tags=['bad food quality', 'good customer service'], + ), + TrainingExample( + text='Amazing dishes and wonderful service!', + score=5, + tags=['good food quality', 'good customer service'], + ), + ] + + +@dataclass +class SystemPromptBuilder: + """This class builds the system prompt for the agent. + + We use a class to build the system prompt so that we can easily override different aspects of the prompt building, + making it easier to compare different approaches. While this might feel like overkill in this example, this can + be useful in more complex scenarios where the prompt-building might make use of lots of different context about the + user or other resources. + + We'll see below how we can leverage dependency injection to easily change the behavior of the system prompt builder. + """ + + examples: list[TrainingExample] + + def build_prompt(self) -> str: + return f'{self._get_task_description()}\n\n{self._get_examples_text()}\n\n{self._final_notes()}' + + @staticmethod + def _get_task_description(): + tags_list = get_args(ReviewTag) + return ( + 'You are a restaurant review analysis assistant.\n' + 'For each review, provide a score between -5 (most negative) and +5 (most positive), ' + 'and extract relevant tags from the following list:\n' + f'{tags_list}.' + ) + + def _get_examples_text(self): + return 'Here are some examples:\n' + '\n\n'.join( + [f"Review: '{ex.text}'\nScore: {ex.score}\nLabels: {', '.join(ex.tags)}" for ex in self.examples] + ) + + def _final_notes(self) -> str: + return 'Now, analyze the following review.' + + +def build_prompt(examples: list[TrainingExample] = Depends(get_examples)) -> str: + """Build the system prompt from the examples. + + Obtaining the examples through dependency injection will make it easier to override the examples used + during evaluation. + """ + return SystemPromptBuilder(examples).build_prompt() + + +# Set the system prompt for the agent +@agent.system_prompt +def get_system_prompt(prompt: str = Depends(build_prompt)) -> str: + """Get the system prompt. + + While this function may seem simple enough to be unnecessary, obtaining the prompt through dependency injection + will make it easier to override the logic used to produce the prompt. In particular, it ends up being + straightforward to replace `build_prompt` with a function that uses a different prompt builder. + + That said, it might be nice if we could override this get_system_prompt function more directly. + (It should be straightforward to automatically wrap the @agent.system_prompt decorator in a way + that you can override the function it decorates directly, if we want to go down that route.) + """ + return prompt + + +async def handle_user_request(text: str) -> Response: + """Handle a user request by running the agent on the provided text. + + While this function is simple in this example, it could be more complex in a real-world scenario. + For example, it might make requests to various services to obtain additional context for the agent. + + In our evaluation code below, we make direct use of this function, and this is important, because in + real-world examples it may not always be as practical to make calls with the agent in isolation. + (E.g., if the agent requires a great deal of data for a single run, and that data is easily queried from a database, + but painful to move into a fully-isolated evaluation harness.) + """ + result = await agent.run(text) + return result.data + + +@dataclass +class EvaluationResult: + """A simple container for the results of evaluating the agent on a single example.""" + + text: str + expected_score: int + actual_score: int + expected_tags: list[ReviewTag] + actual_tags: list[ReviewTag] + + +async def get_agent_evaluation_results() -> list[EvaluationResult]: + """Get the evaluation results for the agent. + + This function evaluates the agent on each example in the training set, comparing the agent's output to the + expected output. It returns a list of EvaluationResult objects, which can be used to compute metrics. + """ + all_examples = get_examples() + results: list[EvaluationResult] = [] + for i, validation_example in enumerate(all_examples): + # Remove the validation example from the training examples: + training_examples = all_examples.copy() + training_examples.pop(i) + + # Run the agent on the validation example, using only the training examples to build the prompt: + with dependency_provider.scope(get_examples, lambda: training_examples): + actual_response = await handle_user_request(validation_example.text) + + # Store the results for evaluation: + results.append( + EvaluationResult( + text=validation_example.text, + expected_score=validation_example.score, + actual_score=actual_response.score, + expected_tags=validation_example.tags, + actual_tags=actual_response.tags, + ) + ) + return results + + +# Now we translate the evaluation results into quantitative metrics +def f1_score(expected: Sequence[str], actual: Sequence[str]) -> float: + """Compute the F1 score for the tags extracted by the agent.""" + expected_set = set(expected) + actual_set = set(actual) + true_positives = len(expected_set & actual_set) + if not expected_set and not actual_set: + return 1.0 # Perfect match if both are empty + if not actual_set: + return 0.0 + precision = true_positives / len(actual_set) + recall = true_positives / len(expected_set) + if precision + recall == 0: + return 0.0 + return 2 * (precision * recall) / (precision + recall) + + +def compute_metrics(results: list[EvaluationResult]): + """Compute the average score error and average F1 score for the tags extracted by the agent.""" + total_score_error = 0.0 + total_f1 = 0.0 + for result in results: + score_error = abs(result.expected_score - result.actual_score) + total_score_error += score_error + f1 = f1_score(result.expected_tags, result.actual_tags) + total_f1 += f1 + average_score_error = total_score_error / len(results) + average_f1 = total_f1 / len(results) + return average_score_error, average_f1 + + +async def evaluate_agent(): + """Evaluate the agent and print the results.""" + results = await get_agent_evaluation_results() + score_error, f1 = compute_metrics(results) + print('Average Score Error:', score_error) + print('Average F1 Score for Labels:', f1) + + +# Okay, now let's leverage the modular design of our SystemPromptBuilder and our dependency injection system to +# evaluate the impact on our agent of making changes deeper in the system prompt building process: + + +class AlternateSystemPromptBuilder(SystemPromptBuilder): + """Add some additional notes to the system prompt to emphasize the importance of food quality in the reviews.""" + + def _final_notes(self) -> str: + return ( + 'Note: We are primarily concerned with the quality of the food. ' + 'Place the highest emphasis on comments about food quality when determining the score, ' + 'with comparatively lower emphasis on complaints about service, wait times, etc.\n\n' + + super()._final_notes() + ) + + +async def compare_system_prompt_builders(): + """Evaluate the agent using two different system prompt builders. + + We first use the base SystemPromptBuilder, then override the build_prompt dependency to use the + AlternateSystemPromptBuilder and evaluate the agent again to see the impact on the agent's performance. + """ + # Evaluate with the base prompt builder + print('Evaluating base SystemPromptBuilder:') + await evaluate_agent() + + # Override the build_prompt dependency to use the alternate prompt builder + def build_alternate_prompt(examples: list[TrainingExample] = Depends(get_examples)) -> str: + return AlternateSystemPromptBuilder(examples).build_prompt() + + with dependency_provider.scope(build_prompt, build_alternate_prompt): + print('\nEvaluating AlternateSystemPromptBuilder:') + await evaluate_agent() # This will print the results using the AlternateSystemPromptBuilder + + +if __name__ == '__main__': + import anyio + + anyio.run(compare_system_prompt_builders) + """ + Evaluating base SystemPromptBuilder: + Average Score Error: 0.3333333333333333 + Average F1 Score for Labels: 1.0 + + Evaluating AlternateSystemPromptBuilder: + Average Score Error: 0.5 + Average F1 Score for Labels: 1.0 + """