diff --git a/CHANGELOG.md b/CHANGELOG.md index 222389b3..8186f4fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### New features * `ChatAuto()`'s new `provider_model` takes both provider and model in a single string in the format `"{provider}/{model}"`, e.g. `"openai/gpt-5"`. If not provided, `ChatAuto()` looks for the `CHATLAS_CHAT_PROVIDER_MODEL` environment variable, defaulting to `"openai"` if neither are provided. Unlike previous versions of `ChatAuto()`, the environment variables are now used *only if function arguments are not provided*. In other words, if `provider_model` is given, the `CHATLAS_CHAT_PROVIDER_MODEL` environment variable is ignored. Similarly, `CHATLAS_CHAT_ARGS` are only used if no `kwargs` are provided. This improves interactive use cases, makes it easier to introduce application-specific environment variables, and puts more control in the hands of the developer. (#159) -* The `.register_tool()` method now accepts a `Tool` instance as input. This is primarily useful for binding things like `annotations` to the `Tool` in one place, and registering it in another. (#172) +* The `.register_tool()` method now: + * Accepts a `Tool` instance as input. This is primarily useful for binding things like `annotations` to the `Tool` in one place, and registering it in another. (#172) + * Supports function parameter names that start with an underscore. (#174) * The `ToolAnnotations` type gains an `extra` key field -- providing a place for providing additional information that other consumers of tool annotations (e.g., [shinychat](https://posit-dev.github.io/shinychat/)) may make use of. ### Bug fixes diff --git a/chatlas/_tools.py b/chatlas/_tools.py index dd125eb3..b71ddf09 100644 --- a/chatlas/_tools.py +++ b/chatlas/_tools.py @@ -326,13 +326,21 @@ def func_to_basemodel(func: Callable) -> type[BaseModel]: ) annotation = Any + # create_model() will error if the field name starts with `_` (since Pydantic + # uses this to indicate private fields). We can work around this by using an alias. + alias = None + if name.startswith("_"): + field_name, alias = (name.lstrip("_"), name) + else: + field_name, alias = (name, None) + if param.default != inspect.Parameter.empty: - field = Field(default=param.default) + field = Field(default=param.default, alias=alias) else: - field = Field() + field = Field(alias=alias) # Add the field to our fields dict - fields[name] = (annotation, field) + fields[field_name] = (annotation, field) return create_model(func.__name__, **fields) diff --git a/tests/conftest.py b/tests/conftest.py index 984bb4a8..fa9443ce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -162,9 +162,9 @@ def assert_tools_parallel( ): chat = chat_fun(system_prompt="Be very terse, not even punctuation.") - def favorite_color(person: str): + def favorite_color(_person: str): """Returns a person's favourite colour""" - return "sage green" if person == "Joe" else "red" + return "sage green" if _person == "Joe" else "red" chat.register_tool(favorite_color)