diff --git a/.vscode/settings.json b/.vscode/settings.json index f4be6f2a..989a29da 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,5 +7,12 @@ }, "editor.defaultFormatter": "charliermarsh.ruff", }, - "flake8.args": ["--max-line-length=120"] -} \ No newline at end of file + "flake8.args": [ + "--max-line-length=120" + ], + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} diff --git a/CHANGELOG.md b/CHANGELOG.md index 1ca5d9a7..39482c76 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,20 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). --> +## [UNRELEASED] + +### Breaking changes + +* `ChatAuto()`'s first (optional) positional parameter has changed from `system_prompt` to `provider_model`, and `system_prompt` is now a keyword parameter. As a result, you may need to change `ChatAuto("[system prompt]")` -> `ChatAuto(system_prompt="[system prompt]")`. In addition, the `provider` and `model` keyword arguments are now deprecated, but continue to work with a warning, as are the previous `CHATLAS_CHAT_PROVIDER` and `CHATLAS_CHAT_MODEL` environment variables. (#159) + +### New features + +* `ChatAuto()`'s new `provider_model` takes both provider and model in a single string in the format `"{provider}/{model}"`, e.g. `"openai/gpt-5"`. If not provided, `ChatAuto()` looks for the `CHATLAS_CHAT_PROVIDER_MODEL` environment variable, defaulting to `"openai"` if neither are provided. Unlike previous versions of `ChatAuto()`, the environment variables are now used *only if function arguments are not provided*. In other words, if `provider_model` is given, the `CHATLAS_CHAT_PROVIDER_MODEL` environment variable is ignored. Similarly, `CHATLAS_CHAT_ARGS` are only used if no `kwargs` are provided. This improves interactive use cases, makes it easier to introduce application-specific environment variables, and puts more control in the hands of the developer. (#159) + +### Bug fixes + +* `ChatAuto()` now supports recently added providers such as `ChatCloudflare()`, `ChatDeepseek()`, `ChatHuggingFace()`, etc. (#159) + ## [0.11.1] - 2025-08-29 ### New features @@ -22,7 +36,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * `.register_tool(annotations=annotations)` drops support for `mcp.types.ToolAnnotations()` and instead expects a dictionary of the same info. (#164) - ## [0.11.0] - 2025-08-26 ### New features @@ -42,7 +55,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### New features -* Added `ChatCloudflare()` for chatting via [Cloudflare AI](https://developers.cloudflare.com/workers-ai/get-started/rest-api/). (#150) +* Added `ChatCloudflare()` for chatting via [Cloudflare AI](https://developers.cloudflare.com/workers-ai/get-started/rest-api/). (#150) * Added `ChatDeepSeek()` for chatting via [DeepSeek](https://www.deepseek.com/). (#147) * Added `ChatOpenRouter()` for chatting via [Open Router](https://openrouter.ai/). (#148) * Added `ChatHuggingFace()` for chatting via [Hugging Face](https://huggingface.co/). (#144) @@ -78,7 +91,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### New features -* `Chat` gains a handful of new methods: +* `Chat` gains a handful of new methods: * `.register_mcp_tools_http_stream_async()` and `.register_mcp_tools_stdio_async()`: for registering tools from a [MCP server](https://modelcontextprotocol.io/). (#39) * `.get_tools()` and `.set_tools()`: for fine-grained control over registered tools. (#39) * `.set_model_params()`: for setting common LLM parameters in a model-agnostic fashion. (#127) @@ -87,7 +100,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Tool functions passed to `.register_tool()` can now `yield` numerous results. (#39) * A `ContentToolResultImage` content class was added for returning images from tools. It is currently only works with `ChatAnthropic`. (#39) * A `Tool` can now be constructed from a pre-existing tool schema (via a new `__init__` method). (#39) -* The `Chat.app()` method gains a `host` parameter. (#122) +* The `Chat.app()` method gains a `host` parameter. (#122) * `ChatGithub()` now supports the more standard `GITHUB_TOKEN` environment variable for storing the API key. (#123) ### Changes @@ -149,7 +162,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.7.1] - 2025-05-10 -* Added `openai` as a hard dependency, making installation easier for a wide range of use cases. (#91) +* Added `openai` as a hard dependency, making installation easier for a wide range of use cases. (#91) ## [0.7.0] - 2025-04-22 @@ -159,7 +172,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * `.stream()` and `.stream_async()` gain a `content` argument. Set this to `"all"` to include `ContentToolResult`/`ContentToolRequest` objects in the stream. (#75) * `ContentToolResult`/`ContentToolRequest` are now exported to `chatlas` namespace. (#75) * `ContentToolResult`/`ContentToolRequest` gain a `.tagify()` method so they render sensibly in a Shiny app. (#75) -* A tool can now return a `ContentToolResult`. This is useful for: +* A tool can now return a `ContentToolResult`. This is useful for: * Specifying the format used for sending the tool result to the chat model (`model_format`). (#87) * Custom rendering of the tool result (by overriding relevant methods in a subclass). (#75) * `Chat` gains a new `.current_display` property. When a `.chat()` or `.stream()` is currently active, this property returns an object with a `.echo()` method (to echo new content to the display). This is primarily useful for displaying custom content during a tool call. (#79) diff --git a/chatlas/_auto.py b/chatlas/_auto.py index 524f3735..84a95970 100644 --- a/chatlas/_auto.py +++ b/chatlas/_auto.py @@ -1,32 +1,46 @@ from __future__ import annotations import os +import warnings from typing import Callable, Literal, Optional import orjson from ._chat import Chat from ._provider_anthropic import ChatAnthropic, ChatBedrockAnthropic +from ._provider_cloudflare import ChatCloudflare from ._provider_databricks import ChatDatabricks +from ._provider_deepseek import ChatDeepSeek from ._provider_github import ChatGithub from ._provider_google import ChatGoogle, ChatVertex from ._provider_groq import ChatGroq +from ._provider_huggingface import ChatHuggingFace +from ._provider_mistral import ChatMistral from ._provider_ollama import ChatOllama from ._provider_openai import ChatAzureOpenAI, ChatOpenAI +from ._provider_openrouter import ChatOpenRouter from ._provider_perplexity import ChatPerplexity +from ._provider_portkey import ChatPortkey from ._provider_snowflake import ChatSnowflake +from ._utils import MISSING_TYPE as DEPRECATED_TYPE AutoProviders = Literal[ "anthropic", "bedrock-anthropic", + "cloudflare", "databricks", + "deep-seek", "github", "google", "groq", + "hugging-face", + "mistral", "ollama", "openai", "azure-openai", + "open-router", "perplexity", + "portkey", "snowflake", "vertex", ] @@ -34,41 +48,40 @@ _provider_chat_model_map: dict[AutoProviders, Callable[..., Chat]] = { "anthropic": ChatAnthropic, "bedrock-anthropic": ChatBedrockAnthropic, + "cloudflare": ChatCloudflare, "databricks": ChatDatabricks, + "deep-seek": ChatDeepSeek, "github": ChatGithub, "google": ChatGoogle, "groq": ChatGroq, + "hugging-face": ChatHuggingFace, + "mistral": ChatMistral, "ollama": ChatOllama, "openai": ChatOpenAI, "azure-openai": ChatAzureOpenAI, + "open-router": ChatOpenRouter, "perplexity": ChatPerplexity, + "portkey": ChatPortkey, "snowflake": ChatSnowflake, "vertex": ChatVertex, } +DEPRECATED = DEPRECATED_TYPE() + def ChatAuto( - system_prompt: Optional[str] = None, + provider_model: Optional[str] = None, *, - provider: Optional[AutoProviders] = None, - model: Optional[str] = None, + system_prompt: Optional[str] = None, + provider: AutoProviders | DEPRECATED_TYPE = DEPRECATED, + model: str | DEPRECATED_TYPE = DEPRECATED, **kwargs, ) -> Chat: """ - Use environment variables (env vars) to configure the Chat provider and model. + Chat with any provider. - Creates a :class:`~chatlas.Chat` instance based on the specified provider. - The provider may be specified through the `provider` parameter and/or the - `CHATLAS_CHAT_PROVIDER` env var. If both are set, the env var takes - precedence. Similarly, the provider's model may be specified through the - `model` parameter and/or the `CHATLAS_CHAT_MODEL` env var. Also, additional - configuration may be provided through the `kwargs` parameter and/or the - `CHATLAS_CHAT_ARGS` env var (as a JSON string). In this case, when both are - set, they are merged, with the env var arguments taking precedence. - - As a result, `ChatAuto()` provides a convenient way to set a default - provider and model in your Python code, while allowing you to override - these settings through env vars (i.e., without modifying your code). + This is a generic interface to all the other `Chat*()` functions, allowing + you to pick the provider (and model) with a simple string. Prerequisites ------------- @@ -86,55 +99,101 @@ def ChatAuto( Python packages. ::: - Examples -------- - First, set the environment variables for the provider, arguments, and API key: - ```bash - export CHATLAS_CHAT_PROVIDER=anthropic - export CHATLAS_CHAT_MODEL=claude-3-haiku-20240229 - export CHATLAS_CHAT_ARGS='{"kwargs": {"max_retries": 3}}' - export ANTHROPIC_API_KEY=your_api_key + `ChatAuto()` makes it easy to switch between different chat providers and models. + + ```python + import pandas as pd + from chatlas import ChatAuto + + # Default provider (OpenAI) & model + chat = ChatAuto() + print(chat.provider.name) + print(chat.provider.model) + + # Different provider (Anthropic) & default model + chat = ChatAuto("anthropic") + + # List models available through the provider + models = chat.list_models() + print(pd.DataFrame(models)) + + # Choose specific provider/model (Claude Sonnet 4) + chat = ChatAuto("anthropic/claude-sonnet-4-0") ``` - Then, you can use the `ChatAuto` function to create a Chat instance: + The default provider/model can also be controlled through an environment variable: + + ```bash + export CHATLAS_CHAT_PROVIDER_MODEL="anthropic/claude-sonnet-4-0" + ``` ```python from chatlas import ChatAuto chat = ChatAuto() - chat.chat("What is the capital of France?") + print(chat.provider.name) # anthropic + print(chat.provider.model) # claude-sonnet-4-0 + ``` + + For application-specific configurations, consider defining your own environment variables: + + ```bash + export MYAPP_PROVIDER_MODEL="google/gemini-2.5-flash" + ``` + + And passing them to `ChatAuto()` as an alternative way to configure the provider/model: + + ```python + import os + from chatlas import ChatAuto + + chat = ChatAuto(os.getenv("MYAPP_PROVIDER_MODEL")) + print(chat.provider.name) # google + print(chat.provider.model) # gemini-2.5-flash ``` Parameters ---------- + provider_model + The name of the provider and model to use in the format + `"{provider}/{model}"`. Providers are strings formatted in kebab-case, + e.g. to use `ChatBedrockAnthropic` set `provider="bedrock-anthropic"`, + and models are the provider-specific model names, e.g. + `"claude-3-7-sonnet-20250219"`. The `/{model}` portion may also be + omitted, in which case, the default model for that provider will be + used. + + If no value is provided, the `CHATLAS_CHAT_PROVIDER_MODEL` environment + variable will be consulted for a fallback value. If this variable is also + not set, a default value of `"openai"` is used. system_prompt A system prompt to set the behavior of the assistant. provider - The name of the default chat provider to use. Providers are strings - formatted in kebab-case, e.g. to use `ChatBedrockAnthropic` set - `provider="bedrock-anthropic"`. - - This value can also be provided via the `CHATLAS_CHAT_PROVIDER` - environment variable, which takes precedence over `provider` - when set. + Deprecated; use `provider_model` instead. model - The name of the default model to use. This value can also be provided - via the `CHATLAS_CHAT_MODEL` environment variable, which takes - precedence over `model` when set. + Deprecated; use `provider_model` instead. **kwargs - Additional keyword arguments to pass to the Chat constructor. See the + Additional keyword arguments to pass to the `Chat` constructor. See the documentation for each provider for more details on the available options. These arguments can also be provided via the `CHATLAS_CHAT_ARGS` - environment variable as a JSON string. When provided, the options - in the `CHATLAS_CHAT_ARGS` envvar take precedence over the options - passed to `kwargs`. + environment variable as a JSON string. When any additional arguments are + provided to `ChatAuto()`, the env var is ignored. + + Note that `system_prompt` and `turns` can't be set via environment variables. + They must be provided/set directly to/on `ChatAuto()`. - Note that `system_prompt` and `turns` in `kwargs` or in - `CHATLAS_CHAT_ARGS` are ignored. + Note + ---- + If you want to work with a specific provider, but don't know what models are + available (or the exact model name), use + `ChatAuto('provider_name').list_models()` to list available models. Another + option is to use the provider more directly (e.g., `ChatAnthropic()`). There, + the `model` parameter may have type hints for available models. Returns ------- @@ -147,32 +206,85 @@ def ChatAuto( If no valid provider is specified either through parameters or environment variables. """ - the_provider = os.environ.get("CHATLAS_CHAT_PROVIDER", provider) + if provider is not DEPRECATED: + warn_deprecated_param("provider") + + if model is not DEPRECATED: + if provider is DEPRECATED: + raise ValueError( + "The `model` parameter is deprecated and cannot be used without the `provider` parameter. " + "Use `provider_model` instead." + ) + warn_deprecated_param("model") + + if provider_model is None: + provider_model = os.environ.get("CHATLAS_CHAT_PROVIDER_MODEL") + + # Backwards compatibility: construct from old env vars as a fallback + if provider_model is None: + env_provider = get_legacy_env_var("CHATLAS_CHAT_PROVIDER", provider) + env_model = get_legacy_env_var("CHATLAS_CHAT_MODEL", model) + + if env_provider: + provider_model = env_provider + if env_model: + provider_model += f"/{env_model}" + + # Fall back to OpenAI if nothing is specified + if provider_model is None: + provider_model = "openai" + + if "/" in provider_model: + the_provider, the_model = provider_model.split("/", 1) + else: + the_provider, the_model = provider_model, None - if the_provider is None: - raise ValueError( - "Provider name is required as parameter or `CHATLAS_CHAT_PROVIDER` must be set." - ) if the_provider not in _provider_chat_model_map: raise ValueError( f"Provider name '{the_provider}' is not a known chatlas provider: " f"{', '.join(_provider_chat_model_map.keys())}" ) - # `system_prompt` and `turns` always come from `ChatAuto()` - base_args = {"system_prompt": system_prompt} - - if env_model := os.environ.get("CHATLAS_CHAT_MODEL"): - model = env_model - - if model: - base_args["model"] = model + # `system_prompt`, `turns` and `model` always come from `ChatAuto()` + base_args = { + "system_prompt": system_prompt, + "turns": None, + "model": the_model, + } + # Environment kwargs, used only if no kwargs provided env_kwargs = {} - if env_kwargs_str := os.environ.get("CHATLAS_CHAT_ARGS"): - env_kwargs = orjson.loads(env_kwargs_str) - - kwargs = {**kwargs, **env_kwargs, **base_args} - kwargs = {k: v for k, v in kwargs.items() if v is not None} - - return _provider_chat_model_map[the_provider](**kwargs) + if not kwargs: + env_kwargs = orjson.loads(os.environ.get("CHATLAS_CHAT_ARGS", "{}")) + + final_kwargs = {**env_kwargs, **kwargs, **base_args} + final_kwargs = {k: v for k, v in final_kwargs.items() if v is not None} + + return _provider_chat_model_map[the_provider](**final_kwargs) + + +def get_legacy_env_var( + env_var_name: str, + default: str | DEPRECATED_TYPE, +) -> str | None: + env_value = os.environ.get(env_var_name) + if env_value: + warnings.warn( + f"The '{env_var_name}' environment variable is deprecated. " + "Use 'CHATLAS_CHAT_PROVIDER_MODEL' instead.", + DeprecationWarning, + stacklevel=3, + ) + return env_value + elif isinstance(default, DEPRECATED_TYPE): + return None + else: + return default + + +def warn_deprecated_param(param_name: str, stacklevel: int = 3) -> None: + warnings.warn( + f"The '{param_name}' parameter is deprecated. Use 'provider_model' instead.", + DeprecationWarning, + stacklevel=stacklevel, + ) diff --git a/chatlas/types/anthropic/_submit.py b/chatlas/types/anthropic/_submit.py index 5dc7bcc5..3dbbe7e7 100644 --- a/chatlas/types/anthropic/_submit.py +++ b/chatlas/types/anthropic/_submit.py @@ -3,7 +3,7 @@ # --------------------------------------------------------- -from typing import Iterable, Literal, Mapping, Optional, TypedDict, Union +from typing import Iterable, Literal, Mapping, Optional, Sequence, TypedDict, Union import anthropic import anthropic.types.message_param @@ -48,7 +48,7 @@ class SubmitInputArgs(TypedDict, total=False): str, ] service_tier: Union[Literal["auto", "standard_only"], anthropic.NotGiven] - stop_sequences: Union[list[str], anthropic.NotGiven] + stop_sequences: Union[Sequence[str], anthropic.NotGiven] stream: Union[Literal[False], Literal[True], anthropic.NotGiven] system: Union[ str, diff --git a/chatlas/types/openai/_client.py b/chatlas/types/openai/_client.py index c81e5b99..4cb900a6 100644 --- a/chatlas/types/openai/_client.py +++ b/chatlas/types/openai/_client.py @@ -3,14 +3,14 @@ # --------------------------------------------------------- -from typing import Mapping, Optional, TypedDict, Union +from typing import Awaitable, Callable, Mapping, Optional, TypedDict, Union import httpx import openai class ChatClientArgs(TypedDict, total=False): - api_key: str | None + api_key: Union[str, Callable[[], Awaitable[str]], None] organization: str | None project: str | None webhook_secret: str | None diff --git a/chatlas/types/openai/_client_azure.py b/chatlas/types/openai/_client_azure.py index e2a2696d..bef42a63 100644 --- a/chatlas/types/openai/_client_azure.py +++ b/chatlas/types/openai/_client_azure.py @@ -2,7 +2,7 @@ # Do not modify this file. It was generated by `scripts/generate_typed_dicts.py`. # --------------------------------------------------------- -from typing import Mapping, Optional, TypedDict +from typing import Awaitable, Callable, Mapping, Optional, TypedDict, Union import httpx import openai @@ -12,7 +12,7 @@ class ChatAzureClientArgs(TypedDict, total=False): azure_endpoint: str | None azure_deployment: str | None api_version: str | None - api_key: str | None + api_key: Union[str, Callable[[], Awaitable[str]], None] azure_ad_token: str | None organization: str | None project: str | None diff --git a/chatlas/types/openai/_submit.py b/chatlas/types/openai/_submit.py index 13654f64..8ade8aaf 100644 --- a/chatlas/types/openai/_submit.py +++ b/chatlas/types/openai/_submit.py @@ -3,7 +3,7 @@ # --------------------------------------------------------- -from typing import Iterable, Literal, Mapping, Optional, TypedDict, Union +from typing import Iterable, Literal, Mapping, Optional, Sequence, TypedDict, Union import openai import openai.types.chat.chat_completion_allowed_tool_choice_param @@ -148,7 +148,7 @@ class SubmitInputArgs(TypedDict, total=False): service_tier: Union[ Literal["auto", "default", "flex", "scale", "priority"], None, openai.NotGiven ] - stop: Union[str, None, list[str], openai.NotGiven] + stop: Union[str, None, Sequence[str], openai.NotGiven] store: Union[bool, None, openai.NotGiven] stream: Union[Literal[False], None, Literal[True], openai.NotGiven] stream_options: Union[ diff --git a/docs/get-started/models.qmd b/docs/get-started/models.qmd index c1e3f086..abd6288b 100644 --- a/docs/get-started/models.qmd +++ b/docs/get-started/models.qmd @@ -64,12 +64,38 @@ If you're using `chatlas` inside your organisation, you'll be limited to what yo - `ChatOllama()`, which uses [Ollama](https://ollama.com), allows you to run models on your own computer. The biggest models you can run locally aren't as good as the state of the art hosted models, but they also don't share your data and and are effectively free. -### Auto complete +### Model type hints -Some providers like `ChatOpenAI()` and `ChatAnthropic()` provide autocompletion for the `model` parameter. This makes it quick and easy to find the right model id -- just enter `model=""` and you'll get a list of available models to choose from (assuming your IDE supports type hints). +Some providers like `ChatOpenAI()` and `ChatAnthropic()` provide type hints for the `model` parameter. This makes it quick and easy to find the right model id -- just enter `model=""` and you'll get a list of available models to choose from (assuming your IDE supports type hints). ![Screenshot of model autocompletion](/images/model-type-hints.png){class="shadow rounded mb-3" width="67%" } +::: callout-tip +If the provider doesn't provide these type hints, try using the `.list_models()` method (mentioned below) to find available models. +::: + + +### Auto provider + +[`ChatAuto()`](../reference/ChatAuto.qmd) provides access to any provider/model combination through one simple string. +This makes for a nice interactive/prototyping experience, where you can quickly switch between different models and providers, and leverage `chatlas`' smart defaults: + +```python +from chatlas import ChatAuto + +# Default provider (OpenAI) & model +chat = ChatAuto() +print(chat.provider.name) +print(chat.provider.model) + +# Different provider (Anthropic) & default model +chat = ChatAuto("anthropic") + +# Choose specific provider/model (Claude Sonnet 4) +chat = ChatAuto("anthropic/claude-sonnet-4-0") +``` + + ### Listing model info Most providers support the `.list_models()` method, which returns detailed information about all available models, including model IDs, pricing, and metadata. This is particularly useful for: @@ -107,8 +133,3 @@ Different providers may include different metadata fields in the model informati - **`id`**: Model identifier to use with the `Chat` constructor - **`input`/`output`/`cached_input`**: Token pricing in USD per million tokens - - -### Auto provider - -[`ChatAuto()`](../reference/ChatAuto.qmd) is a special model provider that allows one to configure the model provider through environment variables. This is useful for having a single, simple, script that can run on any model provider, without having to change the code. \ No newline at end of file diff --git a/scripts/_generate_openai_types.py b/scripts/_generate_openai_types.py index 990dde70..2fe41620 100644 --- a/scripts/_generate_openai_types.py +++ b/scripts/_generate_openai_types.py @@ -1,10 +1,9 @@ from pathlib import Path +from _utils import generate_typeddict_code, write_code_to_file from openai import AsyncAzureOpenAI, AsyncOpenAI from openai.resources.chat import Completions -from _utils import generate_typeddict_code, write_code_to_file - types_dir = Path(__file__).parent.parent / "chatlas" / "types" provider_dir = types_dir / "openai" @@ -28,12 +27,22 @@ excluded_fields={"self"}, ) + +# Temporary workaround for an issue where a type like +# Callable[[], Awaitable[str]] +# is getting incorrectly transpiled as +# Callable[Awaitable[str]] +def fix_callable_types(text: str): + return text.replace("Callable[Awaitable[str]]", "Callable[[], Awaitable[str]]") + + +init_args = fix_callable_types(init_args) + write_code_to_file( init_args, provider_dir / "_client.py", ) - init_args = generate_typeddict_code( AsyncAzureOpenAI.__init__, "ChatAzureClientArgs", @@ -44,6 +53,8 @@ }, ) +init_args = fix_callable_types(init_args) + write_code_to_file( init_args, provider_dir / "_client_azure.py", diff --git a/tests/test_auto.py b/tests/test_auto.py index 1f5ce5c2..460f27b7 100644 --- a/tests/test_auto.py +++ b/tests/test_auto.py @@ -1,3 +1,6 @@ +import os +import warnings + import pytest import chatlas @@ -10,12 +13,25 @@ from .conftest import assert_turns_existing, assert_turns_system +@pytest.fixture(autouse=True) +def mock_api_keys(monkeypatch): + """Set mock API keys for providers to avoid missing key errors.""" + api_keys = { + "OPENAI_API_KEY": "api-key", + "ANTHROPIC_API_KEY": "api-key", + "GOOGLE_API_KEY": "api-key", + } + + for key, value in api_keys.items(): + monkeypatch.setenv(key, value) + + def test_auto_settings_from_env(monkeypatch): - monkeypatch.setenv("CHATLAS_CHAT_PROVIDER", "openai") + """Test the new CHATLAS_CHAT_PROVIDER_MODEL environment variable.""" + monkeypatch.setenv("CHATLAS_CHAT_PROVIDER_MODEL", "openai/gpt-4o") monkeypatch.setenv( "CHATLAS_CHAT_ARGS", """{ - "model": "gpt-4o", "system_prompt": "Be as terse as possible; no punctuation", "kwargs": {"max_retries": 2} }""", @@ -27,35 +43,134 @@ def test_auto_settings_from_env(monkeypatch): assert isinstance(chat.provider, OpenAIProvider) -def test_auto_settings_from_env_unknown_arg_fails(monkeypatch): +def test_auto_settings_from_old_env_backwards_compatibility(monkeypatch): + """Test backwards compatibility with old environment variables.""" monkeypatch.setenv("CHATLAS_CHAT_PROVIDER", "openai") + monkeypatch.setenv("CHATLAS_CHAT_MODEL", "gpt-4o") monkeypatch.setenv( - "CHATLAS_CHAT_ARGS", '{"model": "gpt-4o", "aws_region": "us-east-1"}' + "CHATLAS_CHAT_ARGS", + """{ + "system_prompt": "Be as terse as possible; no punctuation", + "kwargs": {"max_retries": 2} +}""", ) + with pytest.warns(DeprecationWarning, match="CHATLAS_CHAT_PROVIDER"): + with pytest.warns(DeprecationWarning, match="CHATLAS_CHAT_MODEL"): + chat = ChatAuto() + + assert isinstance(chat, Chat) + assert isinstance(chat.provider, OpenAIProvider) + + +def test_auto_provider_model_parameter(): + """Test using provider_model parameter directly.""" + chat = ChatAuto(provider_model="openai/gpt-4o") + assert isinstance(chat, Chat) + assert isinstance(chat.provider, OpenAIProvider) + + +def test_auto_provider_only_parameter(): + """Test using provider_model with just provider (no model).""" + chat = ChatAuto(provider_model="openai") + assert isinstance(chat, Chat) + assert isinstance(chat.provider, OpenAIProvider) + + +def test_auto_settings_from_env_unknown_arg_fails(monkeypatch): + monkeypatch.setenv("CHATLAS_CHAT_PROVIDER_MODEL", "openai/gpt-4o") + monkeypatch.setenv("CHATLAS_CHAT_ARGS", '{"aws_region": "us-east-1"}') + with pytest.raises(TypeError): ChatAuto() -def test_auto_override_provider_with_env(monkeypatch): - monkeypatch.setenv("CHATLAS_CHAT_PROVIDER", "openai") - chat = ChatAuto(provider="anthropic") +def test_auto_parameter_overrides_env(monkeypatch): + """Test that direct parameters override environment variables.""" + monkeypatch.setenv("CHATLAS_CHAT_PROVIDER_MODEL", "anthropic") + chat = ChatAuto(provider_model="openai") assert isinstance(chat.provider, OpenAIProvider) -def test_auto_missing_provider_raises_exception(): - with pytest.raises(ValueError): - ChatAuto() +def test_auto_falls_back_to_openai_default(): + """Test that ChatAuto falls back to OpenAI when no provider is specified.""" + chat = ChatAuto() + assert isinstance(chat, Chat) + assert isinstance(chat.provider, OpenAIProvider) + + +def test_auto_unknown_provider_raises_exception(): + """Test that unknown provider raises ValueError.""" + with pytest.raises( + ValueError, match="Provider name 'unknown' is not a known chatlas provider" + ): + ChatAuto(provider_model="unknown") def test_auto_respects_turns_interface(monkeypatch): - monkeypatch.setenv("CHATLAS_CHAT_PROVIDER", "openai") - monkeypatch.setenv("CHATLAS_CHAT_ARGS", '{"model": "gpt-4o"}') + monkeypatch.delenv("OPENAI_API_KEY") + monkeypatch.setenv("CHATLAS_CHAT_PROVIDER_MODEL", "openai/gpt-4o") + + if not os.getenv("OPENAI_API_KEY"): + pytest.skip("Skipping test because OPENAI_API_KEY is not set.") + chat_fun = ChatAuto assert_turns_system(chat_fun) assert_turns_existing(chat_fun) +def test_deprecated_provider_parameter_warning(): + """Test that using deprecated provider parameter raises warning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + ChatAuto(provider="openai") + + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "provider" in str(w[0].message) + assert "provider_model" in str(w[0].message) + + +def test_deprecated_model_parameter_warning(): + """Test that using deprecated model parameter raises warning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + ChatAuto(provider="openai", model="gpt-4o") + + assert len(w) == 2 # Both provider and model warnings + assert all(issubclass(warning.category, DeprecationWarning) for warning in w) + + +def test_deprecated_model_without_provider_error(): + """Test that using model parameter without provider raises ValueError.""" + with pytest.raises( + ValueError, + match="The `model` parameter is deprecated and cannot be used without the `provider` parameter", + ): + ChatAuto(model="gpt-4o") + + +def test_parse_provider_model_with_model(): + """Test _parse_provider_model with provider/model format.""" + provider = ChatAuto("openai/gpt-4o").provider + assert provider.name.lower() == "openai" + assert provider.model == "gpt-4o" + + +def test_parse_provider_model_without_model(): + """Test _parse_provider_model with just provider.""" + provider = ChatAuto("openai").provider + assert provider.name.lower() == "openai" + assert provider.model is not None + + +def test_parse_provider_model_with_multiple_slashes(): + """Test _parse_provider_model handles multiple slashes correctly.""" + provider = ChatAuto("open-router/model/with/slashes").provider + assert provider.name.lower() == "openrouter" + assert provider.model == "model/with/slashes" + + def chat_to_kebab_case(s): if s == "ChatOpenAI": return "openai" @@ -78,14 +193,15 @@ def chat_to_kebab_case(s): def test_auto_includes_all_providers(): - providers = [ - chat_to_kebab_case(x) - for x in dir(chatlas) - if x.startswith("Chat") and x != "Chat" - ] - providers = set(providers) + providers = set( + [ + chat_to_kebab_case(x) + for x in dir(chatlas) + if x.startswith("Chat") and x not in ["Chat", "ChatAuto"] + ] + ) - missing = set(_provider_chat_model_map.keys()).difference(providers) + missing = providers.difference(_provider_chat_model_map.keys()) assert len(missing) == 0, ( f"Missing chat providers from ChatAuto: {', '.join(missing)}" @@ -93,17 +209,58 @@ def test_auto_includes_all_providers(): def test_provider_instances(monkeypatch): - monkeypatch.setenv("CHATLAS_CHAT_PROVIDER", "anthropic") + monkeypatch.setenv("CHATLAS_CHAT_PROVIDER_MODEL", "anthropic") chat = ChatAuto() assert isinstance(chat, Chat) assert isinstance(chat.provider, AnthropicProvider) - monkeypatch.setenv("CHATLAS_CHAT_PROVIDER", "bedrock-anthropic") + monkeypatch.setenv("CHATLAS_CHAT_PROVIDER_MODEL", "bedrock-anthropic") chat = ChatAuto() assert isinstance(chat, Chat) assert isinstance(chat.provider, AnthropicBedrockProvider) - monkeypatch.setenv("CHATLAS_CHAT_PROVIDER", "google") + monkeypatch.setenv("CHATLAS_CHAT_PROVIDER_MODEL", "google") chat = ChatAuto() assert isinstance(chat, Chat) assert isinstance(chat.provider, GoogleProvider) + + +def test_kwargs_priority_over_env_args(monkeypatch): + """Test that direct kwargs override CHATLAS_CHAT_ARGS.""" + monkeypatch.setenv("CHATLAS_CHAT_PROVIDER_MODEL", "openai") + monkeypatch.setenv("CHATLAS_CHAT_ARGS", '{"seed": 12}') + + chatlas.ChatOpenAI() + + chat = ChatAuto(seed=42) + assert isinstance(chat.provider, OpenAIProvider) + assert chat.provider._seed == 42 + + +def test_env_args_ignored_when_kwargs_provided(monkeypatch): + """Test that CHATLAS_CHAT_ARGS is ignored when any kwargs are provided.""" + monkeypatch.setenv("CHATLAS_CHAT_PROVIDER_MODEL", "openai") + monkeypatch.setenv("CHATLAS_CHAT_ARGS", '{"seed": -1}') + + # Even providing one kwarg should ignore the entire env args + chat = ChatAuto(base_url="https://api.example.com") + assert isinstance(chat, Chat) + assert isinstance(chat.provider, OpenAIProvider) + assert str(chat.provider._client.base_url).startswith("https://api.example.com") + assert chat.provider._seed != -1 + + +def test_system_prompt_parameter_priority(): + """Test that system_prompt parameter is always respected.""" + chat = ChatAuto(provider_model="openai", system_prompt="Test prompt") + assert isinstance(chat, Chat) + # The system_prompt should be set - this would need verification based on Chat implementation + + +def test_new_env_var_priority_over_old(monkeypatch): + """Test that new env var takes priority over old ones.""" + monkeypatch.setenv("CHATLAS_CHAT_PROVIDER_MODEL", "anthropic") + monkeypatch.setenv("CHATLAS_CHAT_PROVIDER", "openai") # Should be ignored + + chat = ChatAuto() + assert isinstance(chat.provider, AnthropicProvider)