Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/test-lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ jobs:
id: tests
run: hatch test tests --cover
continue-on-error: false

- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
lint:
name: Lint
runs-on: ubuntu-latest
Expand Down
152 changes: 148 additions & 4 deletions src/strands/tools/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import os
import sys
import warnings
from importlib.machinery import ModuleSpec
from pathlib import Path
from posixpath import expanduser
from types import ModuleType
from typing import List, cast

from ..types.tools import AgentTool
Expand All @@ -15,16 +18,151 @@
logger = logging.getLogger(__name__)


def load_tool_from_string(tool_string: str) -> List[AgentTool]:
"""Load tools follows strands supported input string formats.

This function can load a tool based on a string in the following ways:
1. Local file path to a module based tool: `./path/to/module/tool.py`
2. Module import path
2.1. Path to a module based tool: `strands_tools.file_read`
2.2. Path to a module with multiple AgentTool instances (@tool decorated): `tests.fixtures.say_tool`
2.3. Path to a module and a specific function: `tests.fixtures.say_tool:say`
"""
# Case 1: Local file path to a tool
# Ex: ./path/to/my_cool_tool.py
tool_path = expanduser(tool_string)
if os.path.exists(tool_path):
return load_tools_from_file_path(tool_path)

# Case 2: Module import path
# Ex: test.fixtures.say_tool:say (Load specific @tool decorated function)
# Ex: strands_tools.file_read (Load all @tool decorated functions, or module tool)
return load_tools_from_module_path(tool_string)


def load_tools_from_file_path(tool_path: str) -> List[AgentTool]:
"""Load module from specified path, and then load tools from that module.

This function attempts to load the passed in path as a python module, and if it succeeds,
then it tries to import strands tool(s) from that module.
"""
abs_path = str(Path(tool_path).resolve())
logger.debug("tool_path=<%s> | loading python tool from path", abs_path)

# Load the module by spec

# Using this to determine the module name
# ./path/to/my_cool_tool.py -> my_cool_tool
module_name = os.path.basename(tool_path).split(".")[0]

# This function imports a module based on its path, and gives it the provided name

spec: ModuleSpec = cast(ModuleSpec, importlib.util.spec_from_file_location(module_name, abs_path))
if not spec:
raise ImportError(f"Could not create spec for {module_name}")
if not spec.loader:
raise ImportError(f"No loader available for {module_name}")

module = importlib.util.module_from_spec(spec)
# Load, or re-load, the module
sys.modules[module_name] = module
# Execute the module to run any top level code
spec.loader.exec_module(module)

return load_tools_from_module(module, module_name)


def load_tools_from_module_path(module_tool_path: str) -> list[AgentTool]:
"""Load strands tool from a module path.

Example module paths:
my.module.path
my.module.path:tool_name
"""
if ":" in module_tool_path:
module_path, tool_func_name = module_tool_path.split(":")
else:
module_path, tool_func_name = (module_tool_path, None)

try:
module = importlib.import_module(module_path)
except ModuleNotFoundError as e:
raise AttributeError(f'Tool string: "{module_tool_path}" is not a valid tool string.') from e

# If a ':' is present in the string, then its a targeted function in a module
if tool_func_name:
if hasattr(module, tool_func_name):
target_tool = getattr(module, tool_func_name)
if isinstance(target_tool, DecoratedFunctionTool):
return [target_tool]

raise AttributeError(f"Tool {tool_func_name} not found in module {module_path}")

# Else, try to import all of the @tool decorated tools, or the module based tool
module_name = module_path.split(".")[-1]
return load_tools_from_module(module, module_name)


def load_tools_from_module(module: ModuleType, module_name: str) -> list[AgentTool]:
"""Load tools from a module.

First checks if the passed in module has instances of DecoratedToolFunction classes as atributes to the module.
If so, then it returns them as a list of tools. If not, then it attempts to load the module as a module based tool.
"""
logger.debug("tool_name=<%s>, module=<%s> | loading tools from module", module_name, module_name)

# Try and see if any of the attributes in the module are function-based tools decorated with @tool
# This means that there may be more than one tool available in this module, so we load them all

function_tools: List[AgentTool] = []
# Function tools will appear as attributes in the module
for attr_name in dir(module):
attr = getattr(module, attr_name)
# Check if the module attribute is a DecoratedFunctiontool
if isinstance(attr, DecoratedFunctionTool):
logger.debug("tool_name=<%s>, module=<%s> | found function-based tool in module", attr_name, module_name)
function_tools.append(cast(AgentTool, attr))

if function_tools:
return function_tools

# Finally, if no DecoratedFunctionTools are found in the module, fall back
# to module based tools, and search for TOOL_SPEC + function
module_tool_name = module_name
tool_spec = getattr(module, "TOOL_SPEC", None)
if not tool_spec:
raise AttributeError(
f"The module {module_tool_name} is not a valid module for loading tools."
"This module must contain @tool decorated function(s), or must be a module based tool."
)

# If this is a module based tool, the module should have a function with the same name as the module itself
if not hasattr(module, module_tool_name):
raise AttributeError(f"Module-based tool {module_tool_name} missing function {module_tool_name}")

tool_func = getattr(module, module_tool_name)
if not callable(tool_func):
raise TypeError(f"Tool {module_tool_name} function is not callable")

return [PythonAgentTool(module_tool_name, tool_spec, tool_func)]


class ToolLoader:
"""Handles loading of tools from different sources."""

@staticmethod
def load_python_tools(tool_path: str, tool_name: str) -> List[AgentTool]:
"""Load a Python tool module and return all discovered function-based tools as a list.
"""DEPRECATED: Load a Python tool module and return all discovered function-based tools as a list.

This method always returns a list of AgentTool (possibly length 1). It is the
canonical API for retrieving multiple tools from a single Python file.
"""
warnings.warn(
"ToolLoader.load_python_tool is deprecated and will be removed in Strands SDK 2.0. "
"Use the `load_tools_from_string` or `load_tools_from_module` methods instead.",
DeprecationWarning,
stacklevel=2,
)
try:
# Support module:function style (e.g. package.module:function)
if not os.path.exists(tool_path) and ":" in tool_path:
Expand Down Expand Up @@ -108,7 +246,7 @@ def load_python_tool(tool_path: str, tool_name: str) -> AgentTool:
"""
warnings.warn(
"ToolLoader.load_python_tool is deprecated and will be removed in Strands SDK 2.0. "
"Use ToolLoader.load_python_tools(...) which always returns a list of AgentTool.",
"Use the `load_tools_from_string` or `load_tools_from_module` methods instead.",
DeprecationWarning,
stacklevel=2,
)
Expand All @@ -127,7 +265,7 @@ def load_tool(cls, tool_path: str, tool_name: str) -> AgentTool:
"""
warnings.warn(
"ToolLoader.load_tool is deprecated and will be removed in Strands SDK 2.0. "
"Use ToolLoader.load_tools(...) which always returns a list of AgentTool.",
"Use the `load_tools_from_string` or `load_tools_from_module` methods instead.",
DeprecationWarning,
stacklevel=2,
)
Expand All @@ -140,7 +278,7 @@ def load_tool(cls, tool_path: str, tool_name: str) -> AgentTool:

@classmethod
def load_tools(cls, tool_path: str, tool_name: str) -> list[AgentTool]:
"""Load tools from a file based on its file extension.
"""DEPRECATED: Load tools from a file based on its file extension.

Args:
tool_path: Path to the tool file.
Expand All @@ -154,6 +292,12 @@ def load_tools(cls, tool_path: str, tool_name: str) -> list[AgentTool]:
ValueError: If the tool file has an unsupported extension.
Exception: For other errors during tool loading.
"""
warnings.warn(
"ToolLoader.load_tools is deprecated and will be removed in Strands SDK 2.0. "
"Use the `load_tools_from_string` or `load_tools_from_module` methods instead.",
DeprecationWarning,
stacklevel=2,
)
ext = Path(tool_path).suffix.lower()
abs_path = str(Path(tool_path).resolve())

Expand Down
Loading
Loading