Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/api/agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
- override_model
- last_run_messages
- system_prompt
- retriever_plain
- retriever
- retriever_plain
- result_validator
204 changes: 182 additions & 22 deletions pydantic_ai/agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations as _annotations

import asyncio
from collections.abc import AsyncIterator, Iterator, Sequence
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass, field
from typing import Any, Callable, Generic, cast, final, overload
Expand All @@ -19,7 +19,7 @@
models,
result,
)
from .dependencies import AgentDeps, RetrieverContextFunc, RetrieverParams, RetrieverPlainFunc
from .dependencies import AgentDeps, CallContext, RetrieverContextFunc, RetrieverParams, RetrieverPlainFunc
from .result import ResultData

__all__ = ('Agent',)
Expand Down Expand Up @@ -323,29 +323,121 @@ def override_model(self, overriding_model: models.Model | models.KnownModelName)
finally:
self._override_model = override_model_before

@overload
def system_prompt(
self, func: Callable[[CallContext[AgentDeps]], str], /
) -> Callable[[CallContext[AgentDeps]], str]: ...

@overload
def system_prompt(
self, func: _system_prompt.SystemPromptFunc[AgentDeps]
self, func: Callable[[CallContext[AgentDeps]], Awaitable[str]], /
) -> Callable[[CallContext[AgentDeps]], Awaitable[str]]: ...

@overload
def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ...

@overload
def system_prompt(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ...

def system_prompt(
self, func: _system_prompt.SystemPromptFunc[AgentDeps], /
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
"""Decorator to register a system prompt function that optionally takes `CallContext` as it's only argument."""
"""Decorator to register a system prompt function.

Optionally takes [`CallContext`][pydantic_ai.dependencies.CallContext] as it's only argument.
Can decorate a sync or async functions.

Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
the type of the function, see `tests/typed_agent.py` for tests.

Example:
```py
from pydantic_ai import Agent, CallContext

agent = Agent('test', deps_type=str)

@agent.system_prompt
def simple_system_prompt() -> str:
return 'foobar'

@agent.system_prompt
async def async_system_prompt(ctx: CallContext[str]) -> str:
return f'{ctx.deps} is the best'

result = agent.run_sync('foobar', deps='spam')
print(result.data)
#> success (no retriever calls)
```
"""
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func))
return func

@overload
def result_validator(
self, func: Callable[[CallContext[AgentDeps], ResultData], ResultData], /
) -> Callable[[CallContext[AgentDeps], ResultData], ResultData]: ...

@overload
def result_validator(
self, func: _result.ResultValidatorFunc[AgentDeps, ResultData]
self, func: Callable[[CallContext[AgentDeps], ResultData], Awaitable[ResultData]], /
) -> Callable[[CallContext[AgentDeps], ResultData], Awaitable[ResultData]]: ...

@overload
def result_validator(self, func: Callable[[ResultData], ResultData], /) -> Callable[[ResultData], ResultData]: ...

@overload
def result_validator(
self, func: Callable[[ResultData], Awaitable[ResultData]], /
) -> Callable[[ResultData], Awaitable[ResultData]]: ...

def result_validator(
self, func: _result.ResultValidatorFunc[AgentDeps, ResultData], /
) -> _result.ResultValidatorFunc[AgentDeps, ResultData]:
"""Decorator to register a result validator function."""
"""Decorator to register a result validator function.

Optionally takes [`CallContext`][pydantic_ai.dependencies.CallContext] as it's first argument.
Can decorate a sync or async functions.

Overloads for every possible signature of `result_validator` are included so the decorator doesn't obscure
the type of the function, see `tests/typed_agent.py` for tests.

Example:
```py
from pydantic_ai import Agent, CallContext, ModelRetry

agent = Agent('test', deps_type=str)

@agent.result_validator
def result_validator_simple(data: str) -> str:
if 'wrong' in data:
raise ModelRetry('wrong response')
return data

@agent.result_validator
async def result_validator_deps(ctx: CallContext[str], data: str) -> str:
if ctx.deps in data:
raise ModelRetry('wrong response')
return data

result = agent.run_sync('foobar', deps='spam')
print(result.data)
#> success (no retriever calls)
```
"""
self._result_validators.append(_result.ResultValidator(func))
return func

@overload
def retriever(
self, func: RetrieverContextFunc[AgentDeps, RetrieverParams], /
) -> _r.Retriever[AgentDeps, RetrieverParams]: ...
) -> RetrieverContextFunc[AgentDeps, RetrieverParams]: ...

@overload
def retriever(
self, /, *, retries: int | None = None
) -> Callable[[RetrieverContextFunc[AgentDeps, RetrieverParams]], _r.Retriever[AgentDeps, RetrieverParams]]: ...
) -> Callable[
[RetrieverContextFunc[AgentDeps, RetrieverParams]], RetrieverContextFunc[AgentDeps, RetrieverParams]
]: ...

def retriever(
self,
Expand All @@ -354,49 +446,118 @@ def retriever(
*,
retries: int | None = None,
) -> Any:
"""Decorator to register a retriever function."""
"""Decorator to register a retriever function which takes
[`CallContext`][pydantic_ai.dependencies.CallContext] as its first argument.

Can decorate a sync or async functions.

The docstring is inspected to extract both the tool description and description of each parameter,
[learn more](../agents.md#retrievers-tools-and-schema).

We can't add overloads for every possible signature of retriever, since the return type is a recursive union
so the signature of functions decorated with `@agent.retriever` is obscured.

Example:
```py
from pydantic_ai import Agent, CallContext

agent = Agent('test', deps_type=int)

@agent.retriever
def foobar(ctx: CallContext[int], x: int) -> int:
return ctx.deps + x

@agent.retriever(retries=2)
async def spam(ctx: CallContext[str], y: float) -> float:
return ctx.deps + y

result = agent.run_sync('foobar', deps=1)
print(result.data)
#> {"foobar":1,"spam":1.0}
```

Args:
func: The retriever function to register.
retries: The number of retries to allow for this retriever, defaults to the agent's default retries,
which defaults to 1.
""" # noqa: D205
if func is None:

def retriever_decorator(
func_: RetrieverContextFunc[AgentDeps, RetrieverParams],
) -> _r.Retriever[AgentDeps, RetrieverParams]:
) -> RetrieverContextFunc[AgentDeps, RetrieverParams]:
# noinspection PyTypeChecker
return self._register_retriever(_utils.Either(left=func_), retries)
self._register_retriever(_utils.Either(left=func_), retries)
return func_

return retriever_decorator
else:
# noinspection PyTypeChecker
return self._register_retriever(_utils.Either(left=func), retries)
self._register_retriever(_utils.Either(left=func), retries)
return func

@overload
def retriever_plain(
self, func: RetrieverPlainFunc[RetrieverParams], /
) -> _r.Retriever[AgentDeps, RetrieverParams]: ...
def retriever_plain(self, func: RetrieverPlainFunc[RetrieverParams], /) -> RetrieverPlainFunc[RetrieverParams]: ...

@overload
def retriever_plain(
self, /, *, retries: int | None = None
) -> Callable[[RetrieverPlainFunc[RetrieverParams]], _r.Retriever[AgentDeps, RetrieverParams]]: ...
) -> Callable[[RetrieverPlainFunc[RetrieverParams]], RetrieverPlainFunc[RetrieverParams]]: ...

def retriever_plain(
self, func: RetrieverPlainFunc[RetrieverParams] | None = None, /, *, retries: int | None = None
) -> Any:
"""Decorator to register a retriever function."""
"""Decorator to register a retriever function which DOES NOT take `CallContext` as an argument.

Can decorate a sync or async functions.

The docstring is inspected to extract both the tool description and description of each parameter,
[learn more](../agents.md#retrievers-tools-and-schema).

We can't add overloads for every possible signature of retriever, since the return type is a recursive union
so the signature of functions decorated with `@agent.retriever` is obscured.

Example:
```py
from pydantic_ai import Agent, CallContext

agent = Agent('test')

@agent.retriever
def foobar(ctx: CallContext[int]) -> int:
return 123

@agent.retriever(retries=2)
async def spam(ctx: CallContext[str]) -> float:
return 3.14

result = agent.run_sync('foobar', deps=1)
print(result.data)
#> {"foobar":123,"spam":3.14}
```

Args:
func: The retriever function to register.
retries: The number of retries to allow for this retriever, defaults to the agent's default retries,
which defaults to 1.
"""
if func is None:

def retriever_decorator(
func_: RetrieverPlainFunc[RetrieverParams],
) -> _r.Retriever[AgentDeps, RetrieverParams]:
) -> RetrieverPlainFunc[RetrieverParams]:
# noinspection PyTypeChecker
return self._register_retriever(_utils.Either(right=func_), retries)
self._register_retriever(_utils.Either(right=func_), retries)
return func_

return retriever_decorator
else:
return self._register_retriever(_utils.Either(right=func), retries)
self._register_retriever(_utils.Either(right=func), retries)
return func

def _register_retriever(
self, func: _r.RetrieverEitherFunc[AgentDeps, RetrieverParams], retries: int | None
) -> _r.Retriever[AgentDeps, RetrieverParams]:
) -> None:
"""Private utility to register a retriever function."""
retries_ = retries if retries is not None else self._default_retries
retriever = _r.Retriever[AgentDeps, RetrieverParams](func, retries_)
Expand All @@ -408,7 +569,6 @@ def _register_retriever(
raise ValueError(f'Retriever name conflicts with existing retriever: {retriever.name!r}')

self._retrievers[retriever.name] = retriever
return retriever

async def _get_agent_model(
self, model: models.Model | models.KnownModelName | None
Expand Down
5 changes: 4 additions & 1 deletion pydantic_ai/models/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,10 @@ def _request(self, messages: list[Message]) -> ModelAnyResponse:
for message in messages:
if isinstance(message, ToolReturn):
output[message.tool_name] = message.content
return ModelTextResponse(content=pydantic_core.to_json(output).decode())
if output:
return ModelTextResponse(content=pydantic_core.to_json(output).decode())
else:
return ModelTextResponse(content='success (no retriever calls)')
else:
return ModelTextResponse(content=response_text.value)
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def validate_result(ctx: CallContext[None], r: Any) -> Any:
assert agent._result_schema.allow_text_result is True # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess]

result = agent.run_sync('Hello')
assert result.data == snapshot('{}')
assert result.data == snapshot('success (no retriever calls)')
assert got_tool_call_name == snapshot(None)

assert m.agent_model_retrievers == snapshot({})
Expand Down
12 changes: 6 additions & 6 deletions tests/test_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,26 @@ class MyDeps:


@agent.retriever
async def test_retriever(ctx: CallContext[MyDeps]) -> str:
async def example_retriever(ctx: CallContext[MyDeps]) -> str:
return f'{ctx.deps}'


def test_deps_used():
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
assert result.data == '{"test_retriever":"MyDeps(foo=1, bar=2)"}'
assert result.data == '{"example_retriever":"MyDeps(foo=1, bar=2)"}'


def test_deps_override():
with agent.override_deps(MyDeps(foo=3, bar=4)):
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
assert result.data == '{"test_retriever":"MyDeps(foo=3, bar=4)"}'
assert result.data == '{"example_retriever":"MyDeps(foo=3, bar=4)"}'

with agent.override_deps(MyDeps(foo=5, bar=6)):
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
assert result.data == '{"test_retriever":"MyDeps(foo=5, bar=6)"}'
assert result.data == '{"example_retriever":"MyDeps(foo=5, bar=6)"}'

result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
assert result.data == '{"test_retriever":"MyDeps(foo=3, bar=4)"}'
assert result.data == '{"example_retriever":"MyDeps(foo=3, bar=4)"}'

result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
assert result.data == '{"test_retriever":"MyDeps(foo=1, bar=2)"}'
assert result.data == '{"example_retriever":"MyDeps(foo=1, bar=2)"}'
2 changes: 2 additions & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,5 +235,7 @@ async def stream_model_logic(messages: list[Message], info: AgentInfo) -> AsyncI
def mock_infer_model(model: Model | KnownModelName) -> Model:
if isinstance(model, (FunctionModel, TestModel)):
return model
elif model == 'test':
return TestModel()
else:
return FunctionModel(model_logic, stream_function=stream_model_logic)
Loading
Loading