diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml index 291874dce..e38942b2c 100644 --- a/.github/workflows/test-lint.yml +++ b/.github/workflows/test-lint.yml @@ -66,6 +66,11 @@ jobs: id: tests run: hatch test tests --cover continue-on-error: false + + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} lint: name: Lint runs-on: ubuntu-latest diff --git a/src/strands/tools/loader.py b/src/strands/tools/loader.py index 5935077db..31e8dc788 100644 --- a/src/strands/tools/loader.py +++ b/src/strands/tools/loader.py @@ -5,7 +5,10 @@ import os import sys import warnings +from importlib.machinery import ModuleSpec from pathlib import Path +from posixpath import expanduser +from types import ModuleType from typing import List, cast from ..types.tools import AgentTool @@ -15,16 +18,151 @@ logger = logging.getLogger(__name__) +def load_tool_from_string(tool_string: str) -> List[AgentTool]: + """Load tools follows strands supported input string formats. + + This function can load a tool based on a string in the following ways: + 1. Local file path to a module based tool: `./path/to/module/tool.py` + 2. Module import path + 2.1. Path to a module based tool: `strands_tools.file_read` + 2.2. Path to a module with multiple AgentTool instances (@tool decorated): `tests.fixtures.say_tool` + 2.3. Path to a module and a specific function: `tests.fixtures.say_tool:say` + """ + # Case 1: Local file path to a tool + # Ex: ./path/to/my_cool_tool.py + tool_path = expanduser(tool_string) + if os.path.exists(tool_path): + return load_tools_from_file_path(tool_path) + + # Case 2: Module import path + # Ex: test.fixtures.say_tool:say (Load specific @tool decorated function) + # Ex: strands_tools.file_read (Load all @tool decorated functions, or module tool) + return load_tools_from_module_path(tool_string) + + +def load_tools_from_file_path(tool_path: str) -> List[AgentTool]: + """Load module from specified path, and then load tools from that module. + + This function attempts to load the passed in path as a python module, and if it succeeds, + then it tries to import strands tool(s) from that module. + """ + abs_path = str(Path(tool_path).resolve()) + logger.debug("tool_path=<%s> | loading python tool from path", abs_path) + + # Load the module by spec + + # Using this to determine the module name + # ./path/to/my_cool_tool.py -> my_cool_tool + module_name = os.path.basename(tool_path).split(".")[0] + + # This function imports a module based on its path, and gives it the provided name + + spec: ModuleSpec = cast(ModuleSpec, importlib.util.spec_from_file_location(module_name, abs_path)) + if not spec: + raise ImportError(f"Could not create spec for {module_name}") + if not spec.loader: + raise ImportError(f"No loader available for {module_name}") + + module = importlib.util.module_from_spec(spec) + # Load, or re-load, the module + sys.modules[module_name] = module + # Execute the module to run any top level code + spec.loader.exec_module(module) + + return load_tools_from_module(module, module_name) + + +def load_tools_from_module_path(module_tool_path: str) -> list[AgentTool]: + """Load strands tool from a module path. + + Example module paths: + my.module.path + my.module.path:tool_name + """ + if ":" in module_tool_path: + module_path, tool_func_name = module_tool_path.split(":") + else: + module_path, tool_func_name = (module_tool_path, None) + + try: + module = importlib.import_module(module_path) + except ModuleNotFoundError as e: + raise AttributeError(f'Tool string: "{module_tool_path}" is not a valid tool string.') from e + + # If a ':' is present in the string, then its a targeted function in a module + if tool_func_name: + if hasattr(module, tool_func_name): + target_tool = getattr(module, tool_func_name) + if isinstance(target_tool, DecoratedFunctionTool): + return [target_tool] + + raise AttributeError(f"Tool {tool_func_name} not found in module {module_path}") + + # Else, try to import all of the @tool decorated tools, or the module based tool + module_name = module_path.split(".")[-1] + return load_tools_from_module(module, module_name) + + +def load_tools_from_module(module: ModuleType, module_name: str) -> list[AgentTool]: + """Load tools from a module. + + First checks if the passed in module has instances of DecoratedToolFunction classes as atributes to the module. + If so, then it returns them as a list of tools. If not, then it attempts to load the module as a module based tool. + """ + logger.debug("tool_name=<%s>, module=<%s> | loading tools from module", module_name, module_name) + + # Try and see if any of the attributes in the module are function-based tools decorated with @tool + # This means that there may be more than one tool available in this module, so we load them all + + function_tools: List[AgentTool] = [] + # Function tools will appear as attributes in the module + for attr_name in dir(module): + attr = getattr(module, attr_name) + # Check if the module attribute is a DecoratedFunctiontool + if isinstance(attr, DecoratedFunctionTool): + logger.debug("tool_name=<%s>, module=<%s> | found function-based tool in module", attr_name, module_name) + function_tools.append(cast(AgentTool, attr)) + + if function_tools: + return function_tools + + # Finally, if no DecoratedFunctionTools are found in the module, fall back + # to module based tools, and search for TOOL_SPEC + function + module_tool_name = module_name + tool_spec = getattr(module, "TOOL_SPEC", None) + if not tool_spec: + raise AttributeError( + f"The module {module_tool_name} is not a valid module for loading tools." + "This module must contain @tool decorated function(s), or must be a module based tool." + ) + + # If this is a module based tool, the module should have a function with the same name as the module itself + if not hasattr(module, module_tool_name): + raise AttributeError(f"Module-based tool {module_tool_name} missing function {module_tool_name}") + + tool_func = getattr(module, module_tool_name) + if not callable(tool_func): + raise TypeError(f"Tool {module_tool_name} function is not callable") + + return [PythonAgentTool(module_tool_name, tool_spec, tool_func)] + + class ToolLoader: """Handles loading of tools from different sources.""" @staticmethod def load_python_tools(tool_path: str, tool_name: str) -> List[AgentTool]: - """Load a Python tool module and return all discovered function-based tools as a list. + """DEPRECATED: Load a Python tool module and return all discovered function-based tools as a list. This method always returns a list of AgentTool (possibly length 1). It is the canonical API for retrieving multiple tools from a single Python file. """ + warnings.warn( + "ToolLoader.load_python_tool is deprecated and will be removed in Strands SDK 2.0. " + "Use the `load_tools_from_string` or `load_tools_from_module` methods instead.", + DeprecationWarning, + stacklevel=2, + ) try: # Support module:function style (e.g. package.module:function) if not os.path.exists(tool_path) and ":" in tool_path: @@ -108,7 +246,7 @@ def load_python_tool(tool_path: str, tool_name: str) -> AgentTool: """ warnings.warn( "ToolLoader.load_python_tool is deprecated and will be removed in Strands SDK 2.0. " - "Use ToolLoader.load_python_tools(...) which always returns a list of AgentTool.", + "Use the `load_tools_from_string` or `load_tools_from_module` methods instead.", DeprecationWarning, stacklevel=2, ) @@ -127,7 +265,7 @@ def load_tool(cls, tool_path: str, tool_name: str) -> AgentTool: """ warnings.warn( "ToolLoader.load_tool is deprecated and will be removed in Strands SDK 2.0. " - "Use ToolLoader.load_tools(...) which always returns a list of AgentTool.", + "Use the `load_tools_from_string` or `load_tools_from_module` methods instead.", DeprecationWarning, stacklevel=2, ) @@ -140,7 +278,7 @@ def load_tool(cls, tool_path: str, tool_name: str) -> AgentTool: @classmethod def load_tools(cls, tool_path: str, tool_name: str) -> list[AgentTool]: - """Load tools from a file based on its file extension. + """DEPRECATED: Load tools from a file based on its file extension. Args: tool_path: Path to the tool file. @@ -154,6 +292,12 @@ def load_tools(cls, tool_path: str, tool_name: str) -> list[AgentTool]: ValueError: If the tool file has an unsupported extension. Exception: For other errors during tool loading. """ + warnings.warn( + "ToolLoader.load_tools is deprecated and will be removed in Strands SDK 2.0. " + "Use the `load_tools_from_string` or `load_tools_from_module` methods instead.", + DeprecationWarning, + stacklevel=2, + ) ext = Path(tool_path).suffix.lower() abs_path = str(Path(tool_path).resolve()) diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 0660337a2..3631c9dee 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -8,6 +8,7 @@ import logging import os import sys +import warnings from importlib import import_module, util from os.path import expanduser from pathlib import Path @@ -18,6 +19,7 @@ from strands.tools.decorator import DecoratedFunctionTool from ..types.tools import AgentTool, ToolSpec +from .loader import load_tool_from_string, load_tools_from_module from .tools import PythonAgentTool, normalize_schema, normalize_tool_spec logger = logging.getLogger(__name__) @@ -36,18 +38,23 @@ def __init__(self) -> None: self.tool_config: Optional[Dict[str, Any]] = None def process_tools(self, tools: List[Any]) -> List[str]: - """Process tools list that can contain tool names, paths, imported modules, or functions. + """Process tools list. + + Process list of tools that can contain local file path string, module import path string, + imported modules, @tool decorated functions, or instances of AgentTool. Args: tools: List of tool specifications. Can be: + 1. Local file path to a module based tool: `./path/to/module/tool.py` + 2. Module import path + 2.1. Path to a module based tool: `strands_tools.file_read` + 2.2. Path to a module with multiple AgentTool instances (@tool decorated): `tests.fixtures.say_tool` + 2.3. Path to a module and a specific function: `tests.fixtures.say_tool:say` + 3. A module for a module based tool + 4. Instances of AgentTool (@tool decorated functions) + 5. Dictionaries with name/path keys (deprecated) - - String tool names (e.g., "calculator") - - File paths (e.g., "/path/to/tool.py") - - Imported Python modules (e.g., a module object) - - Functions decorated with @tool - - Dictionaries with name/path keys - - Instance of an AgentTool Returns: List of tool names that were processed. @@ -55,62 +62,76 @@ def process_tools(self, tools: List[Any]) -> List[str]: tool_names = [] def add_tool(tool: Any) -> None: - # Case 1: String file path - if isinstance(tool, str): - # Extract tool name from path - tool_name = os.path.basename(tool).split(".")[0] - self.load_tool_from_filepath(tool_name=tool_name, tool_path=tool) - tool_names.append(tool_name) - - # Case 2: Dictionary with name and path - elif isinstance(tool, dict) and "name" in tool and "path" in tool: - self.load_tool_from_filepath(tool_name=tool["name"], tool_path=tool["path"]) - tool_names.append(tool["name"]) - - # Case 3: Dictionary with path only - elif isinstance(tool, dict) and "path" in tool: - tool_name = os.path.basename(tool["path"]).split(".")[0] - self.load_tool_from_filepath(tool_name=tool_name, tool_path=tool["path"]) - tool_names.append(tool_name) - - # Case 4: Imported Python module - elif hasattr(tool, "__file__") and inspect.ismodule(tool): - # Get the module file path - module_path = tool.__file__ - # Extract the tool name from the module name - tool_name = tool.__name__.split(".")[-1] - - # Check for TOOL_SPEC in module to validate it's a Strands tool - if hasattr(tool, "TOOL_SPEC") and hasattr(tool, tool_name) and module_path: - self.load_tool_from_filepath(tool_name=tool_name, tool_path=module_path) - tool_names.append(tool_name) + try: + # String based tool + # Can be a file path, a module path, or a module path with a targeted function. Examples: + # './path/to/tool.py' + # 'my.module.tool' + # 'my.module.tool:tool_name' + if isinstance(tool, str): + tools = load_tool_from_string(tool) + for a_tool in tools: + a_tool.mark_dynamic() + self.register_tool(a_tool) + tool_names.append(a_tool.tool_name) + + # Dictionary with name and path + elif isinstance(tool, dict) and "name" in tool and "path" in tool: + tools = load_tool_from_string(tool["path"]) + + tool_found = False + for a_tool in tools: + if a_tool.tool_name == tool["name"]: + a_tool.mark_dynamic() + self.register_tool(a_tool) + tool_names.append(a_tool.tool_name) + tool_found = True + + if not tool_found: + raise ValueError(f'Tool "{tool["name"]}" not found in "{tool["path"]}"') + + # Dictionary with path only + elif isinstance(tool, dict) and "path" in tool: + tools = load_tool_from_string(tool["path"]) + + for a_tool in tools: + a_tool.mark_dynamic() + self.register_tool(a_tool) + tool_names.append(a_tool.tool_name) + + # Imported Python module + elif hasattr(tool, "__file__") and inspect.ismodule(tool): + # Extract the tool name from the module name + module_tool_name = tool.__name__.split(".")[-1] + + tools = load_tools_from_module(tool, module_tool_name) + for a_tool in tools: + self.register_tool(a_tool) + tool_names.append(a_tool.tool_name) + + # Case 5: AgentTools (which also covers @tool) + elif isinstance(tool, AgentTool): + self.register_tool(tool) + tool_names.append(tool.tool_name) + + # Case 6: Nested iterable (list, tuple, etc.) - add each sub-tool + elif isinstance(tool, Iterable) and not isinstance(tool, (str, bytes, bytearray)): + for t in tool: + add_tool(t) else: - function_tools = self._scan_module_for_tools(tool) - for function_tool in function_tools: - self.register_tool(function_tool) - tool_names.append(function_tool.tool_name) - - if not function_tools: - logger.warning("tool_name=<%s>, module_path=<%s> | invalid agent tool", tool_name, module_path) - - # Case 5: AgentTools (which also covers @tool) - elif isinstance(tool, AgentTool): - self.register_tool(tool) - tool_names.append(tool.tool_name) - # Case 6: Nested iterable (list, tuple, etc.) - add each sub-tool - elif isinstance(tool, Iterable) and not isinstance(tool, (str, bytes, bytearray)): - for t in tool: - add_tool(t) - else: - logger.warning("tool=<%s> | unrecognized tool specification", tool) + logger.warning("tool=<%s> | unrecognized tool specification", tool) - for a_tool in tools: - add_tool(a_tool) + except Exception as e: + exception_str = str(e) + logger.exception("tool_name=<%s> | failed to load tool", tool) + raise ValueError(f"Failed to load tool {tool}: {exception_str}") from e + for tool in tools: + add_tool(tool) return tool_names def load_tool_from_filepath(self, tool_name: str, tool_path: str) -> None: - """Load a tool from a file path. + """DEPRECATED: Load a tool from a file path. Args: tool_name: Name of the tool. @@ -120,6 +141,13 @@ def load_tool_from_filepath(self, tool_name: str, tool_path: str) -> None: FileNotFoundError: If the tool file is not found. ValueError: If the tool cannot be loaded. """ + warnings.warn( + "load_tool_from_filepath is deprecated and will be removed in Strands SDK 2.0. " + "`process_tools` automatically handles loading tools from a filepath.", + DeprecationWarning, + stacklevel=2, + ) + from .loader import ToolLoader try: diff --git a/tests/fixtures/say_tool.py b/tests/fixtures/say_tool.py new file mode 100644 index 000000000..4607b2501 --- /dev/null +++ b/tests/fixtures/say_tool.py @@ -0,0 +1,17 @@ +from strands import tool + + +@tool +def say(input: str) -> str: + """Say something.""" + return f"Hello {input}!" + + +@tool +def dont_say(input: str) -> str: + """Dont say something.""" + return "Didnt say anything!" + + +def not_a_tool() -> str: + return "Not a tool!" diff --git a/tests/fixtures/tool_with_spec_but_no_function.py b/tests/fixtures/tool_with_spec_but_no_function.py new file mode 100644 index 000000000..75f8bf6f6 --- /dev/null +++ b/tests/fixtures/tool_with_spec_but_no_function.py @@ -0,0 +1 @@ +TOOL_SPEC = {"hello": "world!"} diff --git a/tests/fixtures/tool_with_spec_but_non_callable_function.py b/tests/fixtures/tool_with_spec_but_non_callable_function.py new file mode 100644 index 000000000..0ca2f092c --- /dev/null +++ b/tests/fixtures/tool_with_spec_but_non_callable_function.py @@ -0,0 +1,3 @@ +TOOL_SPEC = {"hello": "world"} + +tool_with_spec_but_non_callable_function = "not a function!" diff --git a/tests/strands/tools/test_loader.py b/tests/strands/tools/test_loader.py index 6b86d00ee..13aca90c3 100644 --- a/tests/strands/tools/test_loader.py +++ b/tests/strands/tools/test_loader.py @@ -1,11 +1,12 @@ import os import re +import tempfile import textwrap import pytest from strands.tools.decorator import DecoratedFunctionTool -from strands.tools.loader import ToolLoader +from strands.tools.loader import ToolLoader, load_tools_from_file_path from strands.tools.tools import PythonAgentTool @@ -310,3 +311,9 @@ def test_load_tool_path_returns_single_tool(tool_path): assert loaded_python_tool.tool_name == "alpha" assert loaded_tool.tool_name == "alpha" + + +def test_load_tools_from_file_path_module_spec_missing(): + with tempfile.NamedTemporaryFile() as f: + with pytest.raises(ImportError, match=f"Could not create spec for {os.path.basename(f.name)}"): + load_tools_from_file_path(f.name) diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index f0759ea07..ee0098adc 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -26,7 +26,10 @@ def test_process_tools_with_invalid_path(): tool_registry = ToolRegistry() invalid_path = "not a filepath" - with pytest.raises(ValueError, match=f"Failed to load tool {invalid_path.split('.')[0]}: Tool file not found:.*"): + with pytest.raises( + ValueError, + match=f'Failed to load tool {invalid_path}: Tool string: "{invalid_path}" is not a valid tool string', + ): tool_registry.process_tools([invalid_path]) @@ -164,3 +167,96 @@ def test_register_tool_duplicate_name_with_hot_reload(): # Verify the second tool replaced the first assert tool_registry.registry["hot_reload_tool"] == tool_2 + + +def test_register_strands_tools_from_module(): + tool_registry = ToolRegistry() + tool_registry.process_tools(["tests.fixtures.say_tool"]) + + assert len(tool_registry.registry) == 2 + assert "say" in tool_registry.registry + assert "dont_say" in tool_registry.registry + + +def test_register_strands_tools_specific_tool_from_module(): + tool_registry = ToolRegistry() + tool_registry.process_tools(["tests.fixtures.say_tool:say"]) + + assert len(tool_registry.registry) == 1 + assert "say" in tool_registry.registry + assert "dont_say" not in tool_registry.registry + + +def test_register_strands_tools_specific_tool_from_module_tool_missing(): + tool_registry = ToolRegistry() + + with pytest.raises(ValueError, match="Failed to load tool tests.fixtures.say_tool:nay: "): + tool_registry.process_tools(["tests.fixtures.say_tool:nay"]) + + +def test_register_strands_tools_specific_tool_from_module_not_a_tool(): + tool_registry = ToolRegistry() + + with pytest.raises(ValueError, match="Failed to load tool tests.fixtures.say_tool:not_a_tool: "): + tool_registry.process_tools(["tests.fixtures.say_tool:not_a_tool"]) + + +def test_register_strands_tools_with_dict(): + tool_registry = ToolRegistry() + tool_registry.process_tools([{"path": "tests.fixtures.say_tool"}]) + + assert len(tool_registry.registry) == 2 + assert "say" in tool_registry.registry + assert "dont_say" in tool_registry.registry + + +def test_register_strands_tools_specific_tool_with_dict(): + tool_registry = ToolRegistry() + tool_registry.process_tools([{"path": "tests.fixtures.say_tool", "name": "say"}]) + + assert len(tool_registry.registry) == 1 + assert "say" in tool_registry.registry + + +def test_register_strands_tools_specific_tool_with_dict_not_found(): + tool_registry = ToolRegistry() + + with pytest.raises( + ValueError, + match="Failed to load tool {'path': 'tests.fixtures.say_tool'" + ", 'name': 'nay'}: Tool \"nay\" not found in \"tests.fixtures.say_tool\"", + ): + tool_registry.process_tools([{"path": "tests.fixtures.say_tool", "name": "nay"}]) + + +def test_register_strands_tools_module_no_spec(): + tool_registry = ToolRegistry() + + with pytest.raises( + ValueError, + match="Failed to load tool tests.fixtures.mocked_model_provider: " + "The module mocked_model_provider is not a valid module", + ): + tool_registry.process_tools(["tests.fixtures.mocked_model_provider"]) + + +def test_register_strands_tools_module_no_function(): + tool_registry = ToolRegistry() + + with pytest.raises( + ValueError, + match="Failed to load tool tests.fixtures.tool_with_spec_but_no_function: " + "Module-based tool tool_with_spec_but_no_function missing function tool_with_spec_but_no_function", + ): + tool_registry.process_tools(["tests.fixtures.tool_with_spec_but_no_function"]) + + +def test_register_strands_tools_module_non_callable_function(): + tool_registry = ToolRegistry() + + with pytest.raises( + ValueError, + match="Failed to load tool tests.fixtures.tool_with_spec_but_non_callable_function:" + " Tool tool_with_spec_but_non_callable_function function is not callable", + ): + tool_registry.process_tools(["tests.fixtures.tool_with_spec_but_non_callable_function"])