Skip to content

Commit 80d5c07

Browse files
authored
Make it possible to override JSON schema generation for tools (#1108)
1 parent 444f5d0 commit 80d5c07

File tree

7 files changed

+128
-42
lines changed

7 files changed

+128
-42
lines changed

docs/tools.md

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -265,19 +265,18 @@ def print_schema(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse
265265
print(tool.parameters_json_schema)
266266
"""
267267
{
268+
'additionalProperties': False,
268269
'properties': {
269-
'a': {'description': 'apple pie', 'title': 'A', 'type': 'integer'},
270-
'b': {'description': 'banana cake', 'title': 'B', 'type': 'string'},
270+
'a': {'description': 'apple pie', 'type': 'integer'},
271+
'b': {'description': 'banana cake', 'type': 'string'},
271272
'c': {
272273
'additionalProperties': {'items': {'type': 'number'}, 'type': 'array'},
273274
'description': 'carrot smoothie',
274-
'title': 'C',
275275
'type': 'object',
276276
},
277277
},
278278
'required': ['a', 'b', 'c'],
279279
'type': 'object',
280-
'additionalProperties': False,
281280
}
282281
"""
283282
return ModelResponse(parts=[TextPart('foobar')])
@@ -328,9 +327,9 @@ print(test_model.last_model_request_parameters.function_tools)
328327
description='This is a Foobar',
329328
parameters_json_schema={
330329
'properties': {
331-
'x': {'title': 'X', 'type': 'integer'},
332-
'y': {'title': 'Y', 'type': 'string'},
333-
'z': {'default': 3.14, 'title': 'Z', 'type': 'number'},
330+
'x': {'type': 'integer'},
331+
'y': {'type': 'string'},
332+
'z': {'default': 3.14, 'type': 'number'},
334333
},
335334
'required': ['x', 'y'],
336335
'title': 'Foobar',
@@ -432,16 +431,12 @@ print(test_model.last_model_request_parameters.function_tools)
432431
name='greet',
433432
description='',
434433
parameters_json_schema={
434+
'additionalProperties': False,
435435
'properties': {
436-
'name': {
437-
'title': 'Name',
438-
'type': 'string',
439-
'description': 'Name of the human to greet.',
440-
}
436+
'name': {'type': 'string', 'description': 'Name of the human to greet.'}
441437
},
442438
'required': ['name'],
443439
'type': 'object',
444-
'additionalProperties': False,
445440
},
446441
outer_typed_dict_key=None,
447442
)

pydantic_ai_slim/pydantic_ai/_pydantic.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def function_schema( # noqa: C901
4444
takes_ctx: bool,
4545
docstring_format: DocstringFormat,
4646
require_parameter_descriptions: bool,
47+
schema_generator: type[GenerateJsonSchema],
4748
) -> FunctionSchema:
4849
"""Build a Pydantic validator and JSON schema from a tool function.
4950
@@ -52,6 +53,7 @@ def function_schema( # noqa: C901
5253
takes_ctx: Whether the function takes a `RunContext` first argument.
5354
docstring_format: The docstring format to use.
5455
require_parameter_descriptions: Whether to require descriptions for all tool function parameters.
56+
schema_generator: The JSON schema generator class to use.
5557
5658
Returns:
5759
A `FunctionSchema` instance.
@@ -150,14 +152,12 @@ def function_schema( # noqa: C901
150152
)
151153
# PluggableSchemaValidator is api compatible with SchemaValidator
152154
schema_validator = cast(SchemaValidator, schema_validator)
153-
json_schema = GenerateJsonSchema().generate(schema)
155+
json_schema = schema_generator().generate(schema)
154156

155157
# workaround for https://github.com/pydantic/pydantic/issues/10785
156-
# if we build a custom TypeDict schema (matches when `single_arg_name is None`), we manually set
158+
# if we build a custom TypedDict schema (matches when `single_arg_name is None`), we manually set
157159
# `additionalProperties` in the JSON Schema
158-
if single_arg_name is None:
159-
json_schema['additionalProperties'] = bool(var_kwargs_schema)
160-
elif not description:
160+
if single_arg_name is not None and not description:
161161
# if the tool description is not set, and we have a single parameter, take the description from that
162162
# and set it on the tool
163163
description = json_schema.pop('description', None)
@@ -218,6 +218,7 @@ def _build_schema(
218218
td_schema = core_schema.typed_dict_schema(
219219
fields,
220220
config=core_config,
221+
total=var_kwargs_schema is None,
221222
extras_schema=gen_schema.generate_schema(var_kwargs_schema) if var_kwargs_schema else None,
222223
)
223224
return td_schema, None

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload
1010

1111
from opentelemetry.trace import NoOpTracer, use_span
12+
from pydantic.json_schema import GenerateJsonSchema
1213
from typing_extensions import TypeGuard, TypeVar, deprecated
1314

1415
from pydantic_graph import End, Graph, GraphRun, GraphRunContext
@@ -31,6 +32,7 @@
3132
from .tools import (
3233
AgentDepsT,
3334
DocstringFormat,
35+
GenerateToolJsonSchema,
3436
RunContext,
3537
Tool,
3638
ToolFuncContext,
@@ -936,6 +938,7 @@ def tool(
936938
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
937939
docstring_format: DocstringFormat = 'auto',
938940
require_parameter_descriptions: bool = False,
941+
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
939942
) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ...
940943

941944
def tool(
@@ -948,6 +951,7 @@ def tool(
948951
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
949952
docstring_format: DocstringFormat = 'auto',
950953
require_parameter_descriptions: bool = False,
954+
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
951955
) -> Any:
952956
"""Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
953957
@@ -989,6 +993,7 @@ async def spam(ctx: RunContext[str], y: float) -> float:
989993
docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
990994
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
991995
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
996+
schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
992997
"""
993998
if func is None:
994999

@@ -997,15 +1002,22 @@ def tool_decorator(
9971002
) -> ToolFuncContext[AgentDepsT, ToolParams]:
9981003
# noinspection PyTypeChecker
9991004
self._register_function(
1000-
func_, True, name, retries, prepare, docstring_format, require_parameter_descriptions
1005+
func_,
1006+
True,
1007+
name,
1008+
retries,
1009+
prepare,
1010+
docstring_format,
1011+
require_parameter_descriptions,
1012+
schema_generator,
10011013
)
10021014
return func_
10031015

10041016
return tool_decorator
10051017
else:
10061018
# noinspection PyTypeChecker
10071019
self._register_function(
1008-
func, True, name, retries, prepare, docstring_format, require_parameter_descriptions
1020+
func, True, name, retries, prepare, docstring_format, require_parameter_descriptions, schema_generator
10091021
)
10101022
return func
10111023

@@ -1022,6 +1034,7 @@ def tool_plain(
10221034
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
10231035
docstring_format: DocstringFormat = 'auto',
10241036
require_parameter_descriptions: bool = False,
1037+
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
10251038
) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
10261039

10271040
def tool_plain(
@@ -1034,6 +1047,7 @@ def tool_plain(
10341047
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
10351048
docstring_format: DocstringFormat = 'auto',
10361049
require_parameter_descriptions: bool = False,
1050+
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
10371051
) -> Any:
10381052
"""Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
10391053
@@ -1075,20 +1089,28 @@ async def spam(ctx: RunContext[str]) -> float:
10751089
docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
10761090
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
10771091
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
1092+
schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
10781093
"""
10791094
if func is None:
10801095

10811096
def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
10821097
# noinspection PyTypeChecker
10831098
self._register_function(
1084-
func_, False, name, retries, prepare, docstring_format, require_parameter_descriptions
1099+
func_,
1100+
False,
1101+
name,
1102+
retries,
1103+
prepare,
1104+
docstring_format,
1105+
require_parameter_descriptions,
1106+
schema_generator,
10851107
)
10861108
return func_
10871109

10881110
return tool_decorator
10891111
else:
10901112
self._register_function(
1091-
func, False, name, retries, prepare, docstring_format, require_parameter_descriptions
1113+
func, False, name, retries, prepare, docstring_format, require_parameter_descriptions, schema_generator
10921114
)
10931115
return func
10941116

@@ -1101,6 +1123,7 @@ def _register_function(
11011123
prepare: ToolPrepareFunc[AgentDepsT] | None,
11021124
docstring_format: DocstringFormat,
11031125
require_parameter_descriptions: bool,
1126+
schema_generator: type[GenerateJsonSchema],
11041127
) -> None:
11051128
"""Private utility to register a function as a tool."""
11061129
retries_ = retries if retries is not None else self._default_retries
@@ -1112,6 +1135,7 @@ def _register_function(
11121135
prepare=prepare,
11131136
docstring_format=docstring_format,
11141137
require_parameter_descriptions=require_parameter_descriptions,
1138+
schema_generator=schema_generator,
11151139
)
11161140
self._register_tool(tool)
11171141

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
88

99
from pydantic import ValidationError
10-
from pydantic_core import SchemaValidator
10+
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
11+
from pydantic_core import SchemaValidator, core_schema
1112
from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar
1213

1314
from . import _pydantic, _utils, messages as _messages, models
@@ -142,6 +143,22 @@ def hitchhiker(ctx: RunContext[int], answer: str) -> str:
142143
A = TypeVar('A')
143144

144145

146+
class GenerateToolJsonSchema(GenerateJsonSchema):
147+
def typed_dict_schema(self, schema: core_schema.TypedDictSchema) -> JsonSchemaValue:
148+
s = super().typed_dict_schema(schema)
149+
total = schema.get('total')
150+
if total is not None:
151+
s['additionalProperties'] = not total
152+
return s
153+
154+
def _named_required_fields_schema(self, named_required_fields: Sequence[tuple[str, bool, Any]]) -> JsonSchemaValue:
155+
# Remove largely-useless property titles
156+
s = super()._named_required_fields_schema(named_required_fields)
157+
for p in s.get('properties', {}):
158+
s['properties'][p].pop('title', None)
159+
return s
160+
161+
145162
@dataclass(init=False)
146163
class Tool(Generic[AgentDepsT]):
147164
"""A tool function for an agent."""
@@ -176,6 +193,7 @@ def __init__(
176193
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
177194
docstring_format: DocstringFormat = 'auto',
178195
require_parameter_descriptions: bool = False,
196+
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
179197
):
180198
"""Create a new tool instance.
181199
@@ -225,11 +243,14 @@ async def prep_my_tool(
225243
docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
226244
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
227245
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
246+
schema_generator: The JSON schema generator class to use. Defaults to `GenerateToolJsonSchema`.
228247
"""
229248
if takes_ctx is None:
230249
takes_ctx = _pydantic.takes_ctx(function)
231250

232-
f = _pydantic.function_schema(function, takes_ctx, docstring_format, require_parameter_descriptions)
251+
f = _pydantic.function_schema(
252+
function, takes_ctx, docstring_format, require_parameter_descriptions, schema_generator
253+
)
233254
self.function = function
234255
self.takes_ctx = takes_ctx
235256
self.max_retries = max_retries

tests/graph/test_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ async def run(self, ctx: GraphRunContext) -> End[None]:
393393

394394
assert isinstance(n, BaseNode)
395395
n = await run.next()
396-
assert n == snapshot(End(None))
396+
assert n == snapshot(End(data=None))
397397

398398
with pytest.raises(TypeError, match=r'`next` must be called with a `BaseNode` instance, got End\(data=None\).'):
399399
await run.next()

tests/graph/test_persistence.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ async def run(self, ctx: GraphRunContext) -> End[int]:
287287
node = Foo()
288288
async with graph.iter(node, persistence=sp) as run:
289289
end = await run.next()
290-
assert end == snapshot(End(123))
290+
assert end == snapshot(End(data=123))
291291

292292
msg = "Incorrect snapshot status 'success', must be 'created' or 'pending'."
293293
with pytest.raises(GraphNodeStatusError, match=msg):

0 commit comments

Comments
 (0)