From 3e6d53018ad634cd0751436c5a909bdd8ce7d86c Mon Sep 17 00:00:00 2001 From: Ratish1 Date: Tue, 14 Oct 2025 17:07:56 +0400 Subject: [PATCH 1/3] fix(tool/decorator): validate ToolContext parameter name to avoid opaque Pydantic error --- src/strands/tools/decorator.py | 31 ++++++++++++++++++++ tests/strands/tools/test_decorator.py | 41 +++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 99aa7e372..54596960d 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -54,6 +54,8 @@ def my_tool(param1: str, param2: int = 42) -> dict: TypeVar, Union, cast, + get_args, + get_origin, get_type_hints, overload, ) @@ -103,6 +105,35 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) - doc_str = inspect.getdoc(func) or "" self.doc = docstring_parser.parse(doc_str) + def _contains_tool_context(tp: Any) -> bool: + """Return True if the annotation `tp` (possibly Union/Optional) includes ToolContext.""" + if tp is None: + return False + origin = get_origin(tp) + if origin is Union: + return any(_contains_tool_context(a) for a in get_args(tp)) + # Handle direct ToolContext type + return tp is ToolContext + + for param in self.signature.parameters.values(): + # Prefer resolved type hints (handles forward refs); fall back to annotation + ann = self.type_hints.get(param.name, param.annotation) + if ann is inspect._empty: + continue + + if _contains_tool_context(ann): + # If decorator didn't opt-in to context injection, complain + if self._context_param is None: + raise TypeError( + f"Parameter '{param.name}' is of type 'ToolContext' but '@tool(context=True)' is missing." + ) + # If decorator specified a different param name, complain + if param.name != self._context_param: + raise TypeError( + f"Parameter '{param.name}' is of type 'ToolContext' but has the wrong name. " + f"It should be named '{self._context_param}'." + ) + # Get parameter descriptions from parsed docstring self.param_descriptions = { param.arg_name: param.description or f"Parameter {param.arg_name}" for param in self.doc.params diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 5b4b5cdda..38db783d7 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -1363,3 +1363,44 @@ async def async_generator() -> AsyncGenerator: ] assert act_results == exp_results + + +def test_tool_with_mismatched_tool_context_param_name_raises_error(): + """Verify that a TypeError is raised for a mismatched tool_context parameter name.""" + with pytest.raises(TypeError) as excinfo: + + @strands.tool(context=True) + def my_tool(context: ToolContext): + pass + + assert ( + "Parameter 'context' is of type 'ToolContext' but has the wrong name. It should be named 'tool_context'." + in str(excinfo.value) + ) + + +def test_tool_with_tool_context_but_no_context_flag_raises_error(): + """Verify that a TypeError is raised if ToolContext is used without context=True.""" + with pytest.raises(TypeError) as excinfo: + + @strands.tool + def my_tool(tool_context: ToolContext): + pass + + assert "Parameter 'tool_context' is of type 'ToolContext' but '@tool(context=True)' is missing." in str( + excinfo.value + ) + + +def test_tool_with_tool_context_named_custom_context_raises_error_if_mismatched(): + """Verify that a TypeError is raised when context param name doesn't match the decorator value.""" + with pytest.raises(TypeError) as excinfo: + + @strands.tool(context="my_context") + def my_tool(tool_context: ToolContext): + pass + + assert ( + "Parameter 'tool_context' is of type 'ToolContext' but has the wrong name. It should be named 'my_context'." + in str(excinfo.value) + ) From 1fa7640e0f517aa64561474124c3bdd779d61880 Mon Sep 17 00:00:00 2001 From: Ratish1 Date: Tue, 14 Oct 2025 19:10:02 +0400 Subject: [PATCH 2/3] fix(tool/decorator): simplify validation logic --- src/strands/tools/decorator.py | 48 ++++++++++----------------- tests/strands/tools/test_decorator.py | 28 ++++++---------- 2 files changed, 28 insertions(+), 48 deletions(-) diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 54596960d..669c60138 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -54,8 +54,6 @@ def my_tool(param1: str, param2: int = 42) -> dict: TypeVar, Union, cast, - get_args, - get_origin, get_type_hints, overload, ) @@ -101,39 +99,12 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) - self.type_hints = get_type_hints(func) self._context_param = context_param + self._validate_signature() + # Parse the docstring with docstring_parser doc_str = inspect.getdoc(func) or "" self.doc = docstring_parser.parse(doc_str) - def _contains_tool_context(tp: Any) -> bool: - """Return True if the annotation `tp` (possibly Union/Optional) includes ToolContext.""" - if tp is None: - return False - origin = get_origin(tp) - if origin is Union: - return any(_contains_tool_context(a) for a in get_args(tp)) - # Handle direct ToolContext type - return tp is ToolContext - - for param in self.signature.parameters.values(): - # Prefer resolved type hints (handles forward refs); fall back to annotation - ann = self.type_hints.get(param.name, param.annotation) - if ann is inspect._empty: - continue - - if _contains_tool_context(ann): - # If decorator didn't opt-in to context injection, complain - if self._context_param is None: - raise TypeError( - f"Parameter '{param.name}' is of type 'ToolContext' but '@tool(context=True)' is missing." - ) - # If decorator specified a different param name, complain - if param.name != self._context_param: - raise TypeError( - f"Parameter '{param.name}' is of type 'ToolContext' but has the wrong name. " - f"It should be named '{self._context_param}'." - ) - # Get parameter descriptions from parsed docstring self.param_descriptions = { param.arg_name: param.description or f"Parameter {param.arg_name}" for param in self.doc.params @@ -142,6 +113,21 @@ def _contains_tool_context(tp: Any) -> bool: # Create a Pydantic model for validation self.input_model = self._create_input_model() + def _validate_signature(self) -> None: + """Verify that ToolContext is used correctly in the function signature.""" + # Find and validate the ToolContext parameter + for param in self.signature.parameters.values(): + if param.annotation is ToolContext: + if self._context_param is None: + raise ValueError("@tool(context=True) must be set if passing in ToolContext param") + + if param.name != self._context_param: + raise ValueError( + f"param_name=<{param.name}> | ToolContext param must be named '{self._context_param}'" + ) + # Found the parameter, no need to check further + break + def _create_input_model(self) -> Type[BaseModel]: """Create a Pydantic model from function signature for input validation. diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 38db783d7..bbdcb06a6 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -1366,41 +1366,35 @@ async def async_generator() -> AsyncGenerator: def test_tool_with_mismatched_tool_context_param_name_raises_error(): - """Verify that a TypeError is raised for a mismatched tool_context parameter name.""" - with pytest.raises(TypeError) as excinfo: + """Verify that a ValueError is raised for a mismatched tool_context parameter name.""" + with pytest.raises(ValueError) as excinfo: @strands.tool(context=True) def my_tool(context: ToolContext): pass - assert ( - "Parameter 'context' is of type 'ToolContext' but has the wrong name. It should be named 'tool_context'." - in str(excinfo.value) - ) + assert "ToolContext param must be named 'tool_context'" in str(excinfo.value) + assert "param_name=" in str(excinfo.value) def test_tool_with_tool_context_but_no_context_flag_raises_error(): - """Verify that a TypeError is raised if ToolContext is used without context=True.""" - with pytest.raises(TypeError) as excinfo: + """Verify that a ValueError is raised if ToolContext is used without context=True.""" + with pytest.raises(ValueError) as excinfo: @strands.tool def my_tool(tool_context: ToolContext): pass - assert "Parameter 'tool_context' is of type 'ToolContext' but '@tool(context=True)' is missing." in str( - excinfo.value - ) + assert "@tool(context=True) must be set" in str(excinfo.value) def test_tool_with_tool_context_named_custom_context_raises_error_if_mismatched(): - """Verify that a TypeError is raised when context param name doesn't match the decorator value.""" - with pytest.raises(TypeError) as excinfo: + """Verify that a ValueError is raised when context param name doesn't match the decorator value.""" + with pytest.raises(ValueError) as excinfo: @strands.tool(context="my_context") def my_tool(tool_context: ToolContext): pass - assert ( - "Parameter 'tool_context' is of type 'ToolContext' but has the wrong name. It should be named 'my_context'." - in str(excinfo.value) - ) + assert "ToolContext param must be named 'my_context'" in str(excinfo.value) + assert "param_name=" in str(excinfo.value) From c85bb0ed96c8dec92b5f5624379492231be016aa Mon Sep 17 00:00:00 2001 From: Ratish1 Date: Tue, 14 Oct 2025 20:01:33 +0400 Subject: [PATCH 3/3] fix(tool/decorator): comments --- src/strands/tools/decorator.py | 3 +-- tests/strands/tools/test_decorator.py | 27 ++++++++------------------- 2 files changed, 9 insertions(+), 21 deletions(-) diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 669c60138..72109dbef 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -115,11 +115,10 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) - def _validate_signature(self) -> None: """Verify that ToolContext is used correctly in the function signature.""" - # Find and validate the ToolContext parameter for param in self.signature.parameters.values(): if param.annotation is ToolContext: if self._context_param is None: - raise ValueError("@tool(context=True) must be set if passing in ToolContext param") + raise ValueError("@tool(context) must be set if passing in ToolContext param") if param.name != self._context_param: raise ValueError( diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index bbdcb06a6..658a34052 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -1365,36 +1365,25 @@ async def async_generator() -> AsyncGenerator: assert act_results == exp_results -def test_tool_with_mismatched_tool_context_param_name_raises_error(): - """Verify that a ValueError is raised for a mismatched tool_context parameter name.""" - with pytest.raises(ValueError) as excinfo: +def test_function_tool_metadata_validate_signature_default_context_name_mismatch(): + with pytest.raises(ValueError, match=r"param_name= | ToolContext param must be named 'tool_context'"): @strands.tool(context=True) def my_tool(context: ToolContext): pass - assert "ToolContext param must be named 'tool_context'" in str(excinfo.value) - assert "param_name=" in str(excinfo.value) +def test_function_tool_metadata_validate_signature_custom_context_name_mismatch(): + with pytest.raises(ValueError, match=r"param_name= | ToolContext param must be named 'my_context'"): -def test_tool_with_tool_context_but_no_context_flag_raises_error(): - """Verify that a ValueError is raised if ToolContext is used without context=True.""" - with pytest.raises(ValueError) as excinfo: - - @strands.tool + @strands.tool(context="my_context") def my_tool(tool_context: ToolContext): pass - assert "@tool(context=True) must be set" in str(excinfo.value) - -def test_tool_with_tool_context_named_custom_context_raises_error_if_mismatched(): - """Verify that a ValueError is raised when context param name doesn't match the decorator value.""" - with pytest.raises(ValueError) as excinfo: +def test_function_tool_metadata_validate_signature_missing_context_config(): + with pytest.raises(ValueError, match=r"@tool\(context\) must be set if passing in ToolContext param"): - @strands.tool(context="my_context") + @strands.tool def my_tool(tool_context: ToolContext): pass - - assert "ToolContext param must be named 'my_context'" in str(excinfo.value) - assert "param_name=" in str(excinfo.value)