From e260c6ebbddb50d548633a53194dc1c76a329a68 Mon Sep 17 00:00:00 2001 From: Carson Date: Mon, 8 Sep 2025 11:08:56 -0500 Subject: [PATCH 1/2] feat: tool function parameters may now include a leading underscore --- chatlas/_tools.py | 11 +++++++++-- tests/conftest.py | 4 ++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/chatlas/_tools.py b/chatlas/_tools.py index dd125eb3..d8c70c30 100644 --- a/chatlas/_tools.py +++ b/chatlas/_tools.py @@ -326,10 +326,17 @@ def func_to_basemodel(func: Callable) -> type[BaseModel]: ) annotation = Any + # create_model() will error if the field name starts with `_` (since Pydantic + # uses this to indicate private fields). We can work around this by using an alias. + alias = None + if name.startswith("_"): + alias = name + name = name.lstrip("_") + if param.default != inspect.Parameter.empty: - field = Field(default=param.default) + field = Field(default=param.default, alias=alias) else: - field = Field() + field = Field(alias=alias) # Add the field to our fields dict fields[name] = (annotation, field) diff --git a/tests/conftest.py b/tests/conftest.py index 984bb4a8..fa9443ce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -162,9 +162,9 @@ def assert_tools_parallel( ): chat = chat_fun(system_prompt="Be very terse, not even punctuation.") - def favorite_color(person: str): + def favorite_color(_person: str): """Returns a person's favourite colour""" - return "sage green" if person == "Joe" else "red" + return "sage green" if _person == "Joe" else "red" chat.register_tool(favorite_color) From 0bdc650445f650f213b0847d269594915ceadac3 Mon Sep 17 00:00:00 2001 From: Carson Date: Mon, 8 Sep 2025 11:19:09 -0500 Subject: [PATCH 2/2] Fix lint; update changelog --- CHANGELOG.md | 4 +++- chatlas/_tools.py | 7 ++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 222389b3..8186f4fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### New features * `ChatAuto()`'s new `provider_model` takes both provider and model in a single string in the format `"{provider}/{model}"`, e.g. `"openai/gpt-5"`. If not provided, `ChatAuto()` looks for the `CHATLAS_CHAT_PROVIDER_MODEL` environment variable, defaulting to `"openai"` if neither are provided. Unlike previous versions of `ChatAuto()`, the environment variables are now used *only if function arguments are not provided*. In other words, if `provider_model` is given, the `CHATLAS_CHAT_PROVIDER_MODEL` environment variable is ignored. Similarly, `CHATLAS_CHAT_ARGS` are only used if no `kwargs` are provided. This improves interactive use cases, makes it easier to introduce application-specific environment variables, and puts more control in the hands of the developer. (#159) -* The `.register_tool()` method now accepts a `Tool` instance as input. This is primarily useful for binding things like `annotations` to the `Tool` in one place, and registering it in another. (#172) +* The `.register_tool()` method now: + * Accepts a `Tool` instance as input. This is primarily useful for binding things like `annotations` to the `Tool` in one place, and registering it in another. (#172) + * Supports function parameter names that start with an underscore. (#174) * The `ToolAnnotations` type gains an `extra` key field -- providing a place for providing additional information that other consumers of tool annotations (e.g., [shinychat](https://posit-dev.github.io/shinychat/)) may make use of. ### Bug fixes diff --git a/chatlas/_tools.py b/chatlas/_tools.py index d8c70c30..b71ddf09 100644 --- a/chatlas/_tools.py +++ b/chatlas/_tools.py @@ -330,8 +330,9 @@ def func_to_basemodel(func: Callable) -> type[BaseModel]: # uses this to indicate private fields). We can work around this by using an alias. alias = None if name.startswith("_"): - alias = name - name = name.lstrip("_") + field_name, alias = (name.lstrip("_"), name) + else: + field_name, alias = (name, None) if param.default != inspect.Parameter.empty: field = Field(default=param.default, alias=alias) @@ -339,7 +340,7 @@ def func_to_basemodel(func: Callable) -> type[BaseModel]: field = Field(alias=alias) # Add the field to our fields dict - fields[name] = (annotation, field) + fields[field_name] = (annotation, field) return create_model(func.__name__, **fields)