Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### New features

* `ChatAuto()`'s new `provider_model` takes both provider and model in a single string in the format `"{provider}/{model}"`, e.g. `"openai/gpt-5"`. If not provided, `ChatAuto()` looks for the `CHATLAS_CHAT_PROVIDER_MODEL` environment variable, defaulting to `"openai"` if neither are provided. Unlike previous versions of `ChatAuto()`, the environment variables are now used *only if function arguments are not provided*. In other words, if `provider_model` is given, the `CHATLAS_CHAT_PROVIDER_MODEL` environment variable is ignored. Similarly, `CHATLAS_CHAT_ARGS` are only used if no `kwargs` are provided. This improves interactive use cases, makes it easier to introduce application-specific environment variables, and puts more control in the hands of the developer. (#159)
* The `.register_tool()` method now accepts a `Tool` instance as input. This is primarily useful for binding things like `annotations` to the `Tool` in one place, and registering it in another. (#172)

### Bug fixes

Expand Down
11 changes: 10 additions & 1 deletion chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,7 +1535,7 @@ async def cleanup_mcp_tools(self, names: Optional[Sequence[str]] = None):

def register_tool(
self,
func: Callable[..., Any] | Callable[..., Awaitable[Any]],
func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool,
*,
force: bool = False,
name: Optional[str] = None,
Expand Down Expand Up @@ -1629,6 +1629,15 @@ def add(a: int, b: int) -> int:
ValueError
If a tool with the same name already exists and `force` is `False`.
"""
if isinstance(func, Tool):
name = name or func.name
annotations = annotations or func.annotations
if model is not None:
func = Tool.from_func(
func.func, name=name, model=model, annotations=annotations
)
func = func.func

tool = Tool.from_func(func, name=name, model=model, annotations=annotations)
if tool.name in self._tools and not force:
raise ValueError(
Expand Down
131 changes: 131 additions & 0 deletions tests/test_tools_enhanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,137 @@ def another_func(x: int) -> int:
assert chat._tools["test_tool"].func == another_func


class TestRegisterToolInstance:
"""Test register_tool() with Tool instances."""

def test_register_tool_instance_basic(self):
"""Test registering a Tool instance directly."""
chat = ChatOpenAI()

def add(x: int, y: int) -> int:
"""Add two numbers."""
return x + y

# Create a Tool instance
tool = Tool.from_func(add)

# Register the Tool instance
chat.register_tool(tool)

# Verify it was registered correctly
tools = chat.get_tools()
assert len(tools) == 1
registered_tool = tools[0]
assert registered_tool.name == "add"
assert registered_tool.func == add

# Check the schema
func_schema = registered_tool.schema["function"]
assert func_schema["name"] == "add"
assert func_schema.get("description") == "Add two numbers."

def test_register_tool_instance_with_custom_name(self):
"""Test registering a Tool instance with a custom name override."""
chat = ChatOpenAI()

def multiply(x: int, y: int) -> int:
"""Multiply two numbers."""
return x * y

# Create a Tool instance
tool = Tool.from_func(multiply)

# Register with custom name
chat.register_tool(tool, name="custom_multiply")

# Verify it was registered with the custom name
tools = chat.get_tools()
assert len(tools) == 1
registered_tool = tools[0]
assert registered_tool.name == "custom_multiply"
assert registered_tool.func == multiply

def test_register_tool_instance_with_model_override(self):
"""Test registering a Tool instance with a model override."""
from pydantic import BaseModel, Field

chat = ChatOpenAI()

def divide(x: int, y: int) -> float:
"""Divide two numbers."""
return x / y

class DivideParams(BaseModel):
"""Parameters for division with detailed descriptions."""

x: int = Field(description="The dividend")
y: int = Field(description="The divisor (must not be zero)")

# Create a Tool instance
tool = Tool.from_func(divide)

# Register with model override
chat.register_tool(tool, model=DivideParams)

# Verify it was registered with the new model
tools = chat.get_tools()
assert len(tools) == 1
registered_tool = tools[0]
assert registered_tool.name == "divide"
assert registered_tool.func == divide

# Check that Field descriptions are preserved
func_schema = registered_tool.schema["function"]
params: dict = func_schema.get("parameters", {})
props = params["properties"]
assert props["x"]["description"] == "The dividend"
assert props["y"]["description"] == "The divisor (must not be zero)"

def test_register_tool_instance_force_overwrite(self):
"""Test force overwriting an existing tool with a Tool instance."""
chat = ChatOpenAI()

def original_func(x: int) -> int:
"""Original function."""
return x

def new_func(x: int) -> int:
"""New function."""
return x * 2

# Register original function
chat.register_tool(original_func)

# Create Tool instance with same name
new_tool = Tool.from_func(new_func)
new_tool = Tool(
func=new_func,
name="original_func", # Use same name as original
description="New function.",
parameters={
"type": "object",
"properties": {"x": {"type": "integer"}},
"required": ["x"],
"additionalProperties": False,
},
)

# Should fail without force
with pytest.raises(
ValueError, match="Tool with name 'original_func' is already registered"
):
chat.register_tool(new_tool)

# Should succeed with force=True
chat.register_tool(new_tool, force=True)

tools = chat.get_tools()
assert len(tools) == 1
registered_tool = tools[0]
assert registered_tool.name == "original_func"
assert registered_tool.func == new_func


class TestToolYielding:
"""Test tool functions that yield multiple results."""

Expand Down