diff --git a/haystack/tools/from_function.py b/haystack/tools/from_function.py index 67fd476207..c185fab47c 100644 --- a/haystack/tools/from_function.py +++ b/haystack/tools/from_function.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 import inspect +import re +import textwrap from typing import Any, Callable, Dict, Optional from pydantic import create_model @@ -10,6 +12,9 @@ from haystack.tools.errors import SchemaGenerationError from haystack.tools.tool import Tool +# Define constants for ReST directives +REST_DIRECTIVES = r"param|return|returns|raise|raises" + def create_tool_from_function( function: Callable, name: Optional[str] = None, description: Optional[str] = None @@ -72,7 +77,53 @@ def get_weather( If there is an error generating the JSON schema for the Tool. """ - tool_description = description if description is not None else (function.__doc__ or "") + tool_description = "" + param_descriptions_from_rest: Dict[str, str] = {} + return_description = "" + raises_descriptions = [] + + if description is not None: + tool_description = description + else: + # Process docstring if available + if function.__doc__: + docstring = textwrap.dedent(function.__doc__).strip() + + # Check if this is a ReST-style docstring + if re.search(rf":({REST_DIRECTIVES})\s+", docstring): + # Extract main description (everything before first directive) + main_parts = re.split(rf":({REST_DIRECTIVES})\s+", docstring, 1) + tool_description = main_parts[0].strip() + + # Parse parameter descriptions (handling both :param name: and :param type name: formats) + param_pattern = re.compile(rf":param\s+(\w+)\s*:(.*?)(?=:(?:{REST_DIRECTIVES})|$)", re.DOTALL) + param_descriptions_from_rest = {name: desc.strip() for name, desc in param_pattern.findall(docstring)} + + # Parse return descriptions + return_pattern = re.compile(rf":return:\s*(.*?)(?=:(?:{REST_DIRECTIVES})|$)", re.DOTALL) + return_matches = return_pattern.findall(docstring) + if return_matches: + return_description = return_matches[0].strip() + + # Parse raises descriptions + raises_pattern = re.compile( + rf":raises?\s+(\w+(?:,\s*\w+)*)\s*:\s*(.*?)(?=:(?:{REST_DIRECTIVES})|$)", re.DOTALL + ) + for exc_types, desc in raises_pattern.findall(docstring): + for exc_type in re.split(r",\s*", exc_types): + raises_descriptions.append(f"{exc_type}: {desc.strip()}") + else: + # Not a ReST-style docstring, use the whole thing + tool_description = docstring.strip() + + # Build a comprehensive description including return values and exceptions + full_description = tool_description + + if return_description: + full_description += f"\n\nReturns: {return_description}" + + if raises_descriptions: + full_description += "\n\nRaises:\n" + "\n".join(f"- {r}" for r in raises_descriptions) signature = inspect.signature(function) @@ -89,8 +140,12 @@ def get_weather( default = param.default if param.default is not param.empty else ... fields[param_name] = (param.annotation, default) + # Priority 1: Get descriptions from Annotated type hints if hasattr(param.annotation, "__metadata__"): descriptions[param_name] = param.annotation.__metadata__[0] + # Priority 2: Get descriptions from ReST docstring + elif param_name in param_descriptions_from_rest: + descriptions[param_name] = param_descriptions_from_rest[param_name] # create Pydantic model and generate JSON schema try: @@ -109,7 +164,7 @@ def get_weather( if param_name in schema["properties"]: schema["properties"][param_name]["description"] = param_description - return Tool(name=name or function.__name__, description=tool_description, parameters=schema, function=function) + return Tool(name=name or function.__name__, description=full_description, parameters=schema, function=function) def tool(function: Callable) -> Tool: diff --git a/releasenotes/notes/tool-from-function-with-rest-docstring-fa2728ec718db524.yaml b/releasenotes/notes/tool-from-function-with-rest-docstring-fa2728ec718db524.yaml new file mode 100644 index 0000000000..f9d302d386 --- /dev/null +++ b/releasenotes/notes/tool-from-function-with-rest-docstring-fa2728ec718db524.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Supports ReST-style docstrings for functions when generating a tool from a function. + The ReST-style docstring will be automatically parsed to infer the tool's description and + argument descriptions. diff --git a/test/tools/test_from_function.py b/test/tools/test_from_function.py index 5834b737d4..da75bbe2df 100644 --- a/test/tools/test_from_function.py +++ b/test/tools/test_from_function.py @@ -10,6 +10,19 @@ def function_with_docstring(city: str) -> str: return f"Weather report for {city}: 20°C, sunny" +def function_with_rest_docstring(city: str) -> str: + """ + Get weather report for a city. + + :param city: The city for which to get the weather. + :return: The weather report for the city. + :raises ValueError: If the city is not found. + """ + if city == "": + raise ValueError("City not found.") + return f"Weather report for {city}: 20°C, sunny" + + def test_from_function_description_from_docstring(): tool = create_tool_from_function(function=function_with_docstring) @@ -19,6 +32,23 @@ def test_from_function_description_from_docstring(): assert tool.function == function_with_docstring +def test_from_function_description_from_rest_docstring(): + tool = create_tool_from_function(function=function_with_rest_docstring) + + assert tool.name == "function_with_rest_docstring" + assert tool.description == ( + "Get weather report for a city.\n\n" + "Returns: The weather report for the city.\n\n" + "Raises:\n- ValueError: If the city is not found." + ) + assert tool.parameters == { + "type": "object", + "properties": {"city": {"type": "string", "description": "The city for which to get the weather."}}, + "required": ["city"], + } + assert tool.function == function_with_rest_docstring + + def test_from_function_with_empty_description(): tool = create_tool_from_function(function=function_with_docstring, description="")