Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/strands/tools/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) -
self.type_hints = get_type_hints(func)
self._context_param = context_param

self._validate_signature()

# Parse the docstring with docstring_parser
doc_str = inspect.getdoc(func) or ""
self.doc = docstring_parser.parse(doc_str)
Expand All @@ -111,6 +113,20 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) -
# Create a Pydantic model for validation
self.input_model = self._create_input_model()

def _validate_signature(self) -> None:
"""Verify that ToolContext is used correctly in the function signature."""
for param in self.signature.parameters.values():
if param.annotation is ToolContext:
if self._context_param is None:
raise ValueError("@tool(context) must be set if passing in ToolContext param")

if param.name != self._context_param:
raise ValueError(
f"param_name=<{param.name}> | ToolContext param must be named '{self._context_param}'"
)
# Found the parameter, no need to check further
break

def _create_input_model(self) -> Type[BaseModel]:
"""Create a Pydantic model from function signature for input validation.
Expand Down
24 changes: 24 additions & 0 deletions tests/strands/tools/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,3 +1363,27 @@ async def async_generator() -> AsyncGenerator:
]

assert act_results == exp_results


def test_function_tool_metadata_validate_signature_default_context_name_mismatch():
with pytest.raises(ValueError, match=r"param_name=<context> | ToolContext param must be named 'tool_context'"):

@strands.tool(context=True)
def my_tool(context: ToolContext):
pass


def test_function_tool_metadata_validate_signature_custom_context_name_mismatch():
with pytest.raises(ValueError, match=r"param_name=<tool_context> | ToolContext param must be named 'my_context'"):

@strands.tool(context="my_context")
def my_tool(tool_context: ToolContext):
pass


def test_function_tool_metadata_validate_signature_missing_context_config():
with pytest.raises(ValueError, match=r"@tool\(context\) must be set if passing in ToolContext param"):

@strands.tool
def my_tool(tool_context: ToolContext):
pass
Loading