From 80d8abaa0e8900efd7c151640dee21b496214acc Mon Sep 17 00:00:00 2001 From: Carson Date: Mon, 11 Aug 2025 17:17:54 -0500 Subject: [PATCH 1/8] Quick and dirty start on ChatMistral() --- chatlas/__init__.py | 2 + chatlas/_provider_mistral.py | 172 +++++++++++++++++++++++++++++++++ docs/_quarto.yml | 1 + docs/get-started/models.qmd | 1 + tests/test_provider_mistral.py | 77 +++++++++++++++ 5 files changed, 253 insertions(+) create mode 100644 chatlas/_provider_mistral.py create mode 100644 tests/test_provider_mistral.py diff --git a/chatlas/__init__.py b/chatlas/__init__.py index 133fd13c..0ccb9076 100644 --- a/chatlas/__init__.py +++ b/chatlas/__init__.py @@ -12,6 +12,7 @@ from ._provider_google import ChatGoogle, ChatVertex from ._provider_groq import ChatGroq from ._provider_huggingface import ChatHuggingFace +from ._provider_mistral import ChatMistral from ._provider_ollama import ChatOllama from ._provider_openai import ChatAzureOpenAI, ChatOpenAI from ._provider_perplexity import ChatPerplexity @@ -35,6 +36,7 @@ "ChatGoogle", "ChatGroq", "ChatHuggingFace", + "ChatMistral", "ChatOllama", "ChatOpenAI", "ChatAzureOpenAI", diff --git a/chatlas/_provider_mistral.py b/chatlas/_provider_mistral.py new file mode 100644 index 00000000..4ffecf54 --- /dev/null +++ b/chatlas/_provider_mistral.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Optional + +from ._chat import Chat +from ._logging import log_model_default +from ._provider_openai import OpenAIProvider +from ._utils import MISSING, MISSING_TYPE, is_testing + +if TYPE_CHECKING: + from openai.types.chat import ChatCompletion + + from .types.openai import ChatClientArgs, SubmitInputArgs + + +def ChatMistral( + *, + system_prompt: Optional[str] = None, + model: Optional[str] = None, + api_key: Optional[str] = None, + base_url: str = "https://api.mistral.ai/v1/", + seed: int | None | MISSING_TYPE = MISSING, + kwargs: Optional["ChatClientArgs"] = None, +) -> Chat["SubmitInputArgs", ChatCompletion]: + """ + Chat with a model hosted on Mistral's La Plateforme. + + Mistral AI provides high-performance language models through their API platform. + + Prerequisites + -------------- + + Get your API key from https://console.mistral.ai/api-keys. + + Known limitations + ----------------- + + * Tool calling may be unstable. + * Images require a model that supports vision. + + Examples + -------- + ```python + import os + from chatlas import ChatMistral + + chat = ChatMistral(api_key=os.getenv("MISTRAL_API_KEY")) + chat.chat("Tell me three jokes about statisticians") + ``` + + Parameters + ---------- + system_prompt + A system prompt to set the behavior of the assistant. + model + The model to use for the chat. The default, None, will pick a reasonable + default, and warn you about it. We strongly recommend explicitly + choosing a model for all but the most casual use. + api_key + The API key to use for authentication. You generally should not supply + this directly, but instead set the `MISTRAL_API_KEY` environment + variable. + base_url + The base URL to the endpoint; the default uses Mistral AI. + seed + Optional integer seed that Mistral uses to try and make output more + reproducible. + kwargs + Additional arguments to pass to the `openai.OpenAI()` client + constructor (Mistral uses OpenAI-compatible API). + + Returns + ------- + Chat + A chat object that retains the state of the conversation. + + Note + ---- + Pasting an API key into a chat constructor (e.g., `ChatMistral(api_key="...")`) + is the simplest way to get started, and is fine for interactive use, but is + problematic for code that may be shared with others. + + Instead, consider using environment variables or a configuration file to manage + your credentials. One popular way to manage credentials is to use a `.env` file + to store your credentials, and then use the `python-dotenv` package to load them + into your environment. + + ```shell + pip install python-dotenv + ``` + + ```shell + # .env + MISTRAL_API_KEY=... + ``` + + ```python + from chatlas import ChatMistral + from dotenv import load_dotenv + + load_dotenv() + chat = ChatMistral() + chat.console() + ``` + + Another, more general, solution is to load your environment variables into the shell + before starting Python (maybe in a `.bashrc`, `.zshrc`, etc. file): + + ```shell + export MISTRAL_API_KEY=... + ``` + """ + if isinstance(seed, MISSING_TYPE): + seed = 1014 if is_testing() else None + + if model is None: + model = log_model_default("mistral-large-latest") + + if api_key is None: + api_key = os.getenv("MISTRAL_API_KEY") + + return Chat( + provider=MistralProvider( + api_key=api_key, + model=model, + base_url=base_url, + seed=seed, + kwargs=kwargs, + ), + system_prompt=system_prompt, + ) + + +class MistralProvider(OpenAIProvider): + def __init__( + self, + *, + api_key: Optional[str] = None, + model: str, + base_url: str = "https://api.mistral.ai/v1/", + seed: Optional[int] = None, + name: str = "Mistral", + kwargs: Optional["ChatClientArgs"] = None, + ): + super().__init__( + api_key=api_key, + model=model, + base_url=base_url, + seed=seed, + name=name, + kwargs=kwargs, + ) + + def _chat_perform_args( + self, + stream: bool, + turns, + tools, + data_model=None, + kwargs=None, + ) -> "SubmitInputArgs": + # Get the base arguments from OpenAI provider + kwargs_full = super()._chat_perform_args( + stream, turns, tools, data_model, kwargs + ) + + # Mistral doesn't support stream_options + if "stream_options" in kwargs_full: + del kwargs_full["stream_options"] + + return kwargs_full diff --git a/docs/_quarto.yml b/docs/_quarto.yml index 2c7bc351..27446e6d 100644 --- a/docs/_quarto.yml +++ b/docs/_quarto.yml @@ -121,6 +121,7 @@ quartodoc: - ChatGoogle - ChatGroq - ChatHuggingFace + - ChatMistral - ChatOllama - ChatOpenAI - ChatPerplexity diff --git a/docs/get-started/models.qmd b/docs/get-started/models.qmd index 9c011d1f..adb19a54 100644 --- a/docs/get-started/models.qmd +++ b/docs/get-started/models.qmd @@ -27,6 +27,7 @@ To see the pre-requisites for a given provider, visit the relevant usage page in | AWS Bedrock | [`ChatBedrockAnthropic()`](../reference/ChatBedrockAnthropic.qmd) | ✅ | | Azure OpenAI | [`ChatAzureOpenAI()`](../reference/ChatAzureOpenAI.qmd) | ✅ | | Databricks | [`ChatDatabricks()`](../reference/ChatDatabricks.qmd) | ✅ | +| Mistral | [`ChatMistral()`](../reference/ChatMistral.qmd) | ✅ | | Portkey | [`ChatPortkey()`](../reference/ChatPortkey.qmd) | ✅ | | Snowflake Cortex | [`ChatSnowflake()`](../reference/ChatSnowflake.qmd) | ✅ | | Vertex AI | [`ChatVertex()`](../reference/ChatVertex.qmd) | ✅ | diff --git a/tests/test_provider_mistral.py b/tests/test_provider_mistral.py new file mode 100644 index 00000000..d486031a --- /dev/null +++ b/tests/test_provider_mistral.py @@ -0,0 +1,77 @@ +import os + +import pytest +from chatlas import ChatMistral + +do_test = os.getenv("TEST_MISTRAL", "true") +if do_test.lower() == "false": + pytest.skip("Skipping Mistral tests", allow_module_level=True) + +from .conftest import ( + assert_data_extraction, + assert_images_inline, + assert_images_remote, + assert_tools_async, + assert_tools_simple, + assert_tools_simple_stream_content, + assert_turns_existing, + assert_turns_system, +) + + +def test_mistral_simple_request(): + chat = ChatMistral( + system_prompt="Be as terse as possible; no punctuation", + ) + chat.chat("What is 1 + 1?") + turn = chat.get_last_turn() + assert turn is not None + assert turn.tokens is not None + assert len(turn.tokens) == 3 + assert turn.tokens[0] > 0 # prompt tokens + assert turn.tokens[1] > 0 # completion tokens + assert turn.finish_reason == "stop" + + +@pytest.mark.asyncio +async def test_mistral_simple_streaming_request(): + chat = ChatMistral( + system_prompt="Be as terse as possible; no punctuation", + ) + res = [] + async for x in await chat.stream_async("What is 1 + 1?"): + res.append(x) + assert "2" in "".join(res) + turn = chat.get_last_turn() + assert turn is not None + assert turn.finish_reason == "stop" + + +def test_mistral_respects_turns_interface(): + chat_fun = ChatMistral + assert_turns_system(chat_fun) + assert_turns_existing(chat_fun) + + +def test_mistral_tool_variations(): + """Note: Tool calling may be unstable with Mistral.""" + chat_fun = ChatMistral + assert_tools_simple(chat_fun) + assert_tools_simple_stream_content(chat_fun) + + +@pytest.mark.asyncio +async def test_mistral_tool_variations_async(): + """Note: Tool calling may be unstable with Mistral.""" + await assert_tools_async(ChatMistral) + + +def test_data_extraction(): + assert_data_extraction(ChatMistral) + + +def test_mistral_images(): + """Note: Images require a model that supports vision.""" + chat_fun = lambda **kwargs: ChatMistral(model="pixtral-12b-latest", **kwargs) + assert_images_inline(chat_fun) + assert_images_remote(chat_fun) From dc7f87ad7642d9798e163f78ad1a1d1bc3b7ede4 Mon Sep 17 00:00:00 2001 From: Carson Date: Mon, 11 Aug 2025 17:20:51 -0500 Subject: [PATCH 2/8] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6aeb4ba2..72cd5f44 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### New features * Added `ChatHuggingFace()` for chatting via [Hugging Face](https://huggingface.co/). (#144) +* Added `ChatMistral()` for chatting via [Mistral AI](https://mistral.ai/). (#145) * Added `ChatPortkey()` for chatting via [Portkey AI](https://portkey.ai/). (#143) From 8eae5846abc1ffe266f5e1c4b517fe739a3b5a01 Mon Sep 17 00:00:00 2001 From: Carson Date: Tue, 12 Aug 2025 10:23:38 -0500 Subject: [PATCH 3/8] Better handling of kwarg differences --- chatlas/_provider_databricks.py | 11 +++++++++++ chatlas/_provider_mistral.py | 31 ++++++++++++++++++------------- chatlas/_provider_openai.py | 3 +-- 3 files changed, 30 insertions(+), 15 deletions(-) diff --git a/chatlas/_provider_databricks.py b/chatlas/_provider_databricks.py index 5f46f2e0..1ea4a5af 100644 --- a/chatlas/_provider_databricks.py +++ b/chatlas/_provider_databricks.py @@ -127,3 +127,14 @@ def __init__( api_key="no-token", # A placeholder to pass validations, this will not be used http_client=httpx.AsyncClient(auth=client._client.auth), ) + + # Databricks doesn't support stream_options + def _chat_perform_args( + self, stream, turns, tools, data_model=None, kwargs=None + ) -> "SubmitInputArgs": + kwargs2 = super()._chat_perform_args(stream, turns, tools, data_model, kwargs) + + if "stream_options" in kwargs2: + del kwargs2["stream_options"] + + return kwargs2 diff --git a/chatlas/_provider_mistral.py b/chatlas/_provider_mistral.py index 4ffecf54..c7f59aba 100644 --- a/chatlas/_provider_mistral.py +++ b/chatlas/_provider_mistral.py @@ -152,21 +152,26 @@ def __init__( kwargs=kwargs, ) + # Mistral is essentially OpenAI-compatible, with a couple small differences. + # We _could_ bring in the Mistral SDK and use it directly for more precise typing, + # etc., but for now that doesn't seem worth it. def _chat_perform_args( - self, - stream: bool, - turns, - tools, - data_model=None, - kwargs=None, + self, stream, turns, tools, data_model=None, kwargs=None ) -> "SubmitInputArgs": # Get the base arguments from OpenAI provider - kwargs_full = super()._chat_perform_args( - stream, turns, tools, data_model, kwargs - ) + kwargs2 = super()._chat_perform_args(stream, turns, tools, data_model, kwargs) # Mistral doesn't support stream_options - if "stream_options" in kwargs_full: - del kwargs_full["stream_options"] - - return kwargs_full + if "stream_options" in kwargs2: + del kwargs2["stream_options"] + + # Mistral wants random_seed, not seed + if seed := kwargs2.pop("seed", None): + if isinstance(seed, int): + kwargs2["extra_body"] = {"random_seed": seed} + elif seed is not None: + raise ValueError( + "MistralProvider only accepts an integer seed, or None." + ) + + return kwargs2 diff --git a/chatlas/_provider_openai.py b/chatlas/_provider_openai.py index 215bb28f..7cd624d3 100644 --- a/chatlas/_provider_openai.py +++ b/chatlas/_provider_openai.py @@ -310,8 +310,7 @@ def _chat_perform_args( del kwargs_full["tools"] if stream and "stream_options" not in kwargs_full: - if self.__class__.__name__ != "DatabricksProvider": - kwargs_full["stream_options"] = {"include_usage": True} + kwargs_full["stream_options"] = {"include_usage": True} return kwargs_full From 19ed22ed15a9c9215511cdf81052a5d5e61ab48c Mon Sep 17 00:00:00 2001 From: Carson Date: Tue, 12 Aug 2025 10:38:26 -0500 Subject: [PATCH 4/8] Tool calling is known to be poorly supported --- tests/test_provider_mistral.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/tests/test_provider_mistral.py b/tests/test_provider_mistral.py index d486031a..e4deb3da 100644 --- a/tests/test_provider_mistral.py +++ b/tests/test_provider_mistral.py @@ -11,9 +11,6 @@ assert_data_extraction, assert_images_inline, assert_images_remote, - assert_tools_async, - assert_tools_simple, - assert_tools_simple_stream_content, assert_turns_existing, assert_turns_system, ) @@ -53,17 +50,16 @@ def test_mistral_respects_turns_interface(): assert_turns_existing(chat_fun) -def test_mistral_tool_variations(): - """Note: Tool calling may be unstable with Mistral.""" - chat_fun = ChatMistral - assert_tools_simple(chat_fun) - assert_tools_simple_stream_content(chat_fun) - +# Tool calling is poorly supported +# def test_mistral_tool_variations(): +# chat_fun = ChatMistral +# assert_tools_simple(chat_fun) +# assert_tools_simple_stream_content(chat_fun) -@pytest.mark.asyncio -async def test_mistral_tool_variations_async(): - """Note: Tool calling may be unstable with Mistral.""" - await assert_tools_async(ChatMistral) +# Tool calling is poorly supported +# @pytest.mark.asyncio +# async def test_mistral_tool_variations_async(): +# await assert_tools_async(ChatMistral) def test_data_extraction(): @@ -71,7 +67,8 @@ def test_data_extraction(): def test_mistral_images(): - """Note: Images require a model that supports vision.""" - chat_fun = lambda **kwargs: ChatMistral(model="pixtral-12b-latest", **kwargs) + def chat_fun(**kwargs): + return ChatMistral(model="pixtral-12b-latest", **kwargs) + assert_images_inline(chat_fun) assert_images_remote(chat_fun) From 116360844cc51b4269f9e942de77785117b00ee6 Mon Sep 17 00:00:00 2001 From: Carson Date: Thu, 14 Aug 2025 14:46:08 -0500 Subject: [PATCH 5/8] Cleanup docs/tests --- chatlas/_provider_mistral.py | 18 +++++++++++------- tests/test_provider_mistral.py | 9 +++++---- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/chatlas/_provider_mistral.py b/chatlas/_provider_mistral.py index c7f59aba..e7341e3c 100644 --- a/chatlas/_provider_mistral.py +++ b/chatlas/_provider_mistral.py @@ -29,15 +29,13 @@ def ChatMistral( Mistral AI provides high-performance language models through their API platform. Prerequisites - -------------- + ------------- - Get your API key from https://console.mistral.ai/api-keys. - - Known limitations - ----------------- + ::: {.callout-note} + ## API credentials - * Tool calling may be unstable. - * Images require a model that supports vision. + Get your API key from https://console.mistral.ai/api-keys. + ::: Examples -------- @@ -49,6 +47,12 @@ def ChatMistral( chat.chat("Tell me three jokes about statisticians") ``` + Known limitations + ----------------- + + * Tool calling may be unstable. + * Images require a model that supports vision. + Parameters ---------- system_prompt diff --git a/tests/test_provider_mistral.py b/tests/test_provider_mistral.py index e4deb3da..0939ac3b 100644 --- a/tests/test_provider_mistral.py +++ b/tests/test_provider_mistral.py @@ -1,11 +1,8 @@ import os import pytest -from chatlas import ChatMistral -do_test = os.getenv("TEST_MISTRAL", "true") -if do_test.lower() == "false": - pytest.skip("Skipping Mistral tests", allow_module_level=True) +from chatlas import ChatMistral from .conftest import ( assert_data_extraction, @@ -15,6 +12,10 @@ assert_turns_system, ) +api_key = os.getenv("MISTRAL_API_KEY") +if api_key is None: + pytest.skip("MISTRAL_API_KEY is not set; skipping tests", allow_module_level=True) + def test_mistral_simple_request(): chat = ChatMistral( From 17d0f56336729efa9d7e2ac9619c82ed31a3be7f Mon Sep 17 00:00:00 2001 From: Carson Date: Thu, 14 Aug 2025 15:50:01 -0500 Subject: [PATCH 6/8] Add callout about known limitations --- docs/get-started/models.qmd | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/docs/get-started/models.qmd b/docs/get-started/models.qmd index d3bd8b71..deaa662c 100644 --- a/docs/get-started/models.qmd +++ b/docs/get-started/models.qmd @@ -20,7 +20,8 @@ To see the pre-requisites for a given provider, visit the relevant usage page in | AWS Bedrock | [`ChatBedrockAnthropic()`](../reference/ChatBedrockAnthropic.qmd) | ✅ | | OpenAI | [`ChatOpenAI()`](../reference/ChatOpenAI.qmd) | | | Azure OpenAI | [`ChatAzureOpenAI()`](../reference/ChatAzureOpenAI.qmd) | ✅ | -| Vertex AI | [`ChatVertex()`](../reference/ChatVertex.qmd) | ✅ | +| Google (Gemini) | [`ChatGoogle()`](../reference/ChatGoogle.qmd) | | +| Google (Vertex) | [`ChatVertex()`](../reference/ChatVertex.qmd) | ✅ | | GitHub model marketplace | [`ChatGithub()`](../reference/ChatGithub.qmd) | | | Ollama (local models) | [`ChatOllama()`](../reference/ChatOllama.qmd) | | | Open Router | [`ChatOpenRouter()`](../reference/ChatOpenRouter.qmd) | | @@ -35,15 +36,19 @@ To see the pre-requisites for a given provider, visit the relevant usage page in | Portkey | [`ChatPortkey()`](../reference/ChatPortkey.qmd) | ✅ | -::: callout-note - +::: callout-tip ### Other providers -If you want to use a model provider that isn't listed in the table above, you have two options: +To use chatlas with a provider not listed in the table above, you have two options: 1. If the model provider is OpenAI compatible (i.e., it can be used with the [`openai` Python SDK](https://github.com/openai/openai-python?tab=readme-ov-file#configuring-the-http-client)), use `ChatOpenAI()` with the appropriate `base_url` and `api_key`. 2. If you're motivated, implement a new provider by subclassing [`Provider`](https://github.com/posit-dev/chatlas/blob/main/chatlas/_provider.py) and implementing the required methods. +::: + +::: callout-warning +### Known limitations +Some providers may have a limited amount of support for things like tool calling, structured data extraction, images, etc. In this case, the provider's reference page should include a known limitations section describing the limitations. ::: ### Model choice From 9a65723cea2499083124fb6a5912c9dda639c3d0 Mon Sep 17 00:00:00 2001 From: Carson Date: Thu, 14 Aug 2025 15:51:10 -0500 Subject: [PATCH 7/8] fix: avoid error when structured data is in conversation history --- chatlas/_provider_openai.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/chatlas/_provider_openai.py b/chatlas/_provider_openai.py index 708e7bd7..8e8e0ff9 100644 --- a/chatlas/_provider_openai.py +++ b/chatlas/_provider_openai.py @@ -410,7 +410,9 @@ def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]: if isinstance(x, ContentText): content_parts.append({"type": "text", "text": x.text}) elif isinstance(x, ContentJson): - content_parts.append({"type": "text", "text": ""}) + content_parts.append( + {"type": "text", "text": ""} + ) elif isinstance(x, ContentToolRequest): tool_calls.append( { @@ -449,7 +451,7 @@ def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]: if isinstance(x, ContentText): contents.append({"type": "text", "text": x.text}) elif isinstance(x, ContentJson): - contents.append({"type": "text", "text": ""}) + contents.append({"type": "text", "text": ""}) elif isinstance(x, ContentPDF): contents.append( { From a66bc196fffb513f0bb6cae6bf9f5a30ba9a70d9 Mon Sep 17 00:00:00 2001 From: Carson Date: Thu, 14 Aug 2025 16:05:01 -0500 Subject: [PATCH 8/8] default model works better for the image tests --- tests/test_provider_mistral.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_provider_mistral.py b/tests/test_provider_mistral.py index 0939ac3b..e1396506 100644 --- a/tests/test_provider_mistral.py +++ b/tests/test_provider_mistral.py @@ -68,8 +68,6 @@ def test_data_extraction(): def test_mistral_images(): - def chat_fun(**kwargs): - return ChatMistral(model="pixtral-12b-latest", **kwargs) - assert_images_inline(chat_fun) - assert_images_remote(chat_fun) + assert_images_inline(ChatMistral) + assert_images_remote(ChatMistral)