Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_griffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,16 @@ def doc_descriptions(
) -> tuple[str, dict[str, str]]:
"""Extract the function description and parameter descriptions from a function's docstring.

The function parses the docstring using the specified format (or infers it if 'auto')
and extracts both the main description and parameter descriptions. If a returns section
is present in the docstring, the main description will be formatted as XML.

Returns:
A tuple of (main function description, parameter descriptions).
A tuple containing:
- str: Main description string, which may be either:
* Plain text if no returns section is present
* XML-formatted if returns section exists, including <summary> and <returns> tags
- dict[str, str]: Dictionary mapping parameter names to their descriptions
"""
doc = func.__doc__
if doc is None:
Expand All @@ -33,7 +41,14 @@ def doc_descriptions(
parent = cast(GriffeObject, sig)

docstring_style = _infer_docstring_style(doc) if docstring_format == 'auto' else docstring_format
docstring = Docstring(doc, lineno=1, parser=docstring_style, parent=parent)
docstring = Docstring(
doc,
lineno=1,
parser=docstring_style,
parent=parent,
# https://mkdocstrings.github.io/griffe/reference/docstrings/#google-options
parser_options={'returns_named_value': False, 'returns_multiple_items': False},
)
with _disable_griffe_logging():
sections = docstring.parse()

Expand All @@ -45,6 +60,18 @@ def doc_descriptions(
if main := next((p for p in sections if p.kind == DocstringSectionKind.text), None):
main_desc = main.value

if return_ := next((p for p in sections if p.kind == DocstringSectionKind.returns), None):
return_statement = return_.value[0]
return_desc = return_statement.description
return_type = return_statement.annotation
type_tag = f'<type>{return_type}</type>\n' if return_type else ''
return_xml = f'<returns>\n{type_tag}<description>{return_desc}</description>\n</returns>'

if main_desc:
main_desc = f'<summary>{main_desc}</summary>\n{return_xml}'
else:
main_desc = return_xml

return main_desc, params


Expand Down
156 changes: 148 additions & 8 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def sphinx_style_docstring(foo: int, /) -> str: # pragma: no cover
"""Sphinx style docstring.

:param foo: The foo thing.
:return: The result.
"""
return str(foo)

Expand Down Expand Up @@ -187,6 +186,152 @@ def test_docstring_numpy(docstring_format: Literal['numpy', 'auto']):
)


def test_google_style_with_returns():
agent = Agent(FunctionModel(get_json_schema))

def my_tool(x: int) -> str: # pragma: no cover
"""A function that does something.

Args:
x: The input value.

Returns:
str: The result as a string.
"""
return str(x)

agent.tool_plain(my_tool)
result = agent.run_sync('Hello')
json_schema = json.loads(result.data)
assert json_schema == snapshot(
{
'name': 'my_tool',
'description': """\
<summary>A function that does something.</summary>
<returns>
<type>str</type>
<description>The result as a string.</description>
</returns>\
""",
'parameters_json_schema': {
'additionalProperties': False,
'properties': {'x': {'description': 'The input value.', 'type': 'integer'}},
'required': ['x'],
'type': 'object',
},
'outer_typed_dict_key': None,
}
)


def test_sphinx_style_with_returns():
agent = Agent(FunctionModel(get_json_schema))

def my_tool(x: int) -> str: # pragma: no cover
"""A sphinx function with returns.

:param x: The input value.
:rtype: str
:return: The result as a string with type.
"""
return str(x)

agent.tool_plain(docstring_format='sphinx')(my_tool)
result = agent.run_sync('Hello')
json_schema = json.loads(result.data)
assert json_schema == snapshot(
{
'name': 'my_tool',
'description': """\
<summary>A sphinx function with returns.</summary>
<returns>
<type>str</type>
<description>The result as a string with type.</description>
</returns>\
""",
'parameters_json_schema': {
'additionalProperties': False,
'properties': {'x': {'description': 'The input value.', 'type': 'integer'}},
'required': ['x'],
'type': 'object',
},
'outer_typed_dict_key': None,
}
)


def test_numpy_style_with_returns():
agent = Agent(FunctionModel(get_json_schema))

def my_tool(x: int) -> str: # pragma: no cover
"""A numpy function with returns.

Parameters
----------
x : int
The input value.

Returns
-------
str
The result as a string with type.
"""
return str(x)

agent.tool_plain(docstring_format='numpy')(my_tool)
result = agent.run_sync('Hello')
json_schema = json.loads(result.data)
assert json_schema == snapshot(
{
'name': 'my_tool',
'description': """\
<summary>A numpy function with returns.</summary>
<returns>
<type>str</type>
<description>The result as a string with type.</description>
</returns>\
""",
'parameters_json_schema': {
'additionalProperties': False,
'properties': {'x': {'description': 'The input value.', 'type': 'integer'}},
'required': ['x'],
'type': 'object',
},
'outer_typed_dict_key': None,
}
)


def only_returns_type() -> str: # pragma: no cover
"""

Returns:
str: The result as a string.
"""
return 'foo'


def test_only_returns_type():
agent = Agent(FunctionModel(get_json_schema))
agent.tool_plain(only_returns_type)

result = agent.run_sync('Hello')
json_schema = json.loads(result.data)
assert json_schema == snapshot(
{
'name': 'only_returns_type',
'description': """\
<returns>
<type>str</type>
<description>The result as a string.</description>
</returns>\
""",
'parameters_json_schema': {'additionalProperties': False, 'properties': {}, 'type': 'object'},
'outer_typed_dict_key': None,
}
)


def unknown_docstring(**kwargs: int) -> str: # pragma: no cover
"""Unknown style docstring."""
return str(kwargs)
Expand Down Expand Up @@ -572,11 +717,7 @@ def ctx_tool(ctx: RunContext[int], x: int) -> int:


async def tool_without_return_annotation_in_docstring() -> str: # pragma: no cover
"""A tool that documents what it returns but doesn't have a return annotation in the docstring.

Returns:
A value.
"""
"""A tool that documents what it returns but doesn't have a return annotation in the docstring."""

return ''

Expand All @@ -591,8 +732,7 @@ def test_suppress_griffe_logging(caplog: LogCaptureFixture):
json_schema = json.loads(result.data)
assert json_schema == snapshot(
{
'description': "A tool that documents what it returns but doesn't have a "
'return annotation in the docstring.',
'description': "A tool that documents what it returns but doesn't have a return annotation in the docstring.",
'name': 'tool_without_return_annotation_in_docstring',
'outer_typed_dict_key': None,
'parameters_json_schema': {'additionalProperties': False, 'properties': {}, 'type': 'object'},
Expand Down