diff --git a/CHANGELOG.md b/CHANGELOG.md
index 5787db62..a5b01ef4 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Added `ChatDeepSeek()` for chatting via [DeepSeek](https://www.deepseek.com/). (#147)
* Added `ChatOpenRouter()` for chatting via [Open Router](https://openrouter.ai/). (#148)
* Added `ChatHuggingFace()` for chatting via [Hugging Face](https://huggingface.co/). (#144)
+* Added `ChatMistral()` for chatting via [Mistral AI](https://mistral.ai/). (#145)
* Added `ChatPortkey()` for chatting via [Portkey AI](https://portkey.ai/). (#143)
### Bug fixes
diff --git a/chatlas/__init__.py b/chatlas/__init__.py
index a04d3bb8..a4f83a2f 100644
--- a/chatlas/__init__.py
+++ b/chatlas/__init__.py
@@ -14,6 +14,7 @@
from ._provider_google import ChatGoogle, ChatVertex
from ._provider_groq import ChatGroq
from ._provider_huggingface import ChatHuggingFace
+from ._provider_mistral import ChatMistral
from ._provider_ollama import ChatOllama
from ._provider_openai import ChatAzureOpenAI, ChatOpenAI
from ._provider_openrouter import ChatOpenRouter
@@ -40,6 +41,7 @@
"ChatGoogle",
"ChatGroq",
"ChatHuggingFace",
+ "ChatMistral",
"ChatOllama",
"ChatOpenAI",
"ChatOpenRouter",
diff --git a/chatlas/_provider_databricks.py b/chatlas/_provider_databricks.py
index 5f46f2e0..1ea4a5af 100644
--- a/chatlas/_provider_databricks.py
+++ b/chatlas/_provider_databricks.py
@@ -127,3 +127,14 @@ def __init__(
api_key="no-token", # A placeholder to pass validations, this will not be used
http_client=httpx.AsyncClient(auth=client._client.auth),
)
+
+ # Databricks doesn't support stream_options
+ def _chat_perform_args(
+ self, stream, turns, tools, data_model=None, kwargs=None
+ ) -> "SubmitInputArgs":
+ kwargs2 = super()._chat_perform_args(stream, turns, tools, data_model, kwargs)
+
+ if "stream_options" in kwargs2:
+ del kwargs2["stream_options"]
+
+ return kwargs2
diff --git a/chatlas/_provider_mistral.py b/chatlas/_provider_mistral.py
new file mode 100644
index 00000000..e7341e3c
--- /dev/null
+++ b/chatlas/_provider_mistral.py
@@ -0,0 +1,181 @@
+from __future__ import annotations
+
+import os
+from typing import TYPE_CHECKING, Optional
+
+from ._chat import Chat
+from ._logging import log_model_default
+from ._provider_openai import OpenAIProvider
+from ._utils import MISSING, MISSING_TYPE, is_testing
+
+if TYPE_CHECKING:
+ from openai.types.chat import ChatCompletion
+
+ from .types.openai import ChatClientArgs, SubmitInputArgs
+
+
+def ChatMistral(
+ *,
+ system_prompt: Optional[str] = None,
+ model: Optional[str] = None,
+ api_key: Optional[str] = None,
+ base_url: str = "https://api.mistral.ai/v1/",
+ seed: int | None | MISSING_TYPE = MISSING,
+ kwargs: Optional["ChatClientArgs"] = None,
+) -> Chat["SubmitInputArgs", ChatCompletion]:
+ """
+ Chat with a model hosted on Mistral's La Plateforme.
+
+ Mistral AI provides high-performance language models through their API platform.
+
+ Prerequisites
+ -------------
+
+ ::: {.callout-note}
+ ## API credentials
+
+ Get your API key from https://console.mistral.ai/api-keys.
+ :::
+
+ Examples
+ --------
+ ```python
+ import os
+ from chatlas import ChatMistral
+
+ chat = ChatMistral(api_key=os.getenv("MISTRAL_API_KEY"))
+ chat.chat("Tell me three jokes about statisticians")
+ ```
+
+ Known limitations
+ -----------------
+
+ * Tool calling may be unstable.
+ * Images require a model that supports vision.
+
+ Parameters
+ ----------
+ system_prompt
+ A system prompt to set the behavior of the assistant.
+ model
+ The model to use for the chat. The default, None, will pick a reasonable
+ default, and warn you about it. We strongly recommend explicitly
+ choosing a model for all but the most casual use.
+ api_key
+ The API key to use for authentication. You generally should not supply
+ this directly, but instead set the `MISTRAL_API_KEY` environment
+ variable.
+ base_url
+ The base URL to the endpoint; the default uses Mistral AI.
+ seed
+ Optional integer seed that Mistral uses to try and make output more
+ reproducible.
+ kwargs
+ Additional arguments to pass to the `openai.OpenAI()` client
+ constructor (Mistral uses OpenAI-compatible API).
+
+ Returns
+ -------
+ Chat
+ A chat object that retains the state of the conversation.
+
+ Note
+ ----
+ Pasting an API key into a chat constructor (e.g., `ChatMistral(api_key="...")`)
+ is the simplest way to get started, and is fine for interactive use, but is
+ problematic for code that may be shared with others.
+
+ Instead, consider using environment variables or a configuration file to manage
+ your credentials. One popular way to manage credentials is to use a `.env` file
+ to store your credentials, and then use the `python-dotenv` package to load them
+ into your environment.
+
+ ```shell
+ pip install python-dotenv
+ ```
+
+ ```shell
+ # .env
+ MISTRAL_API_KEY=...
+ ```
+
+ ```python
+ from chatlas import ChatMistral
+ from dotenv import load_dotenv
+
+ load_dotenv()
+ chat = ChatMistral()
+ chat.console()
+ ```
+
+ Another, more general, solution is to load your environment variables into the shell
+ before starting Python (maybe in a `.bashrc`, `.zshrc`, etc. file):
+
+ ```shell
+ export MISTRAL_API_KEY=...
+ ```
+ """
+ if isinstance(seed, MISSING_TYPE):
+ seed = 1014 if is_testing() else None
+
+ if model is None:
+ model = log_model_default("mistral-large-latest")
+
+ if api_key is None:
+ api_key = os.getenv("MISTRAL_API_KEY")
+
+ return Chat(
+ provider=MistralProvider(
+ api_key=api_key,
+ model=model,
+ base_url=base_url,
+ seed=seed,
+ kwargs=kwargs,
+ ),
+ system_prompt=system_prompt,
+ )
+
+
+class MistralProvider(OpenAIProvider):
+ def __init__(
+ self,
+ *,
+ api_key: Optional[str] = None,
+ model: str,
+ base_url: str = "https://api.mistral.ai/v1/",
+ seed: Optional[int] = None,
+ name: str = "Mistral",
+ kwargs: Optional["ChatClientArgs"] = None,
+ ):
+ super().__init__(
+ api_key=api_key,
+ model=model,
+ base_url=base_url,
+ seed=seed,
+ name=name,
+ kwargs=kwargs,
+ )
+
+ # Mistral is essentially OpenAI-compatible, with a couple small differences.
+ # We _could_ bring in the Mistral SDK and use it directly for more precise typing,
+ # etc., but for now that doesn't seem worth it.
+ def _chat_perform_args(
+ self, stream, turns, tools, data_model=None, kwargs=None
+ ) -> "SubmitInputArgs":
+ # Get the base arguments from OpenAI provider
+ kwargs2 = super()._chat_perform_args(stream, turns, tools, data_model, kwargs)
+
+ # Mistral doesn't support stream_options
+ if "stream_options" in kwargs2:
+ del kwargs2["stream_options"]
+
+ # Mistral wants random_seed, not seed
+ if seed := kwargs2.pop("seed", None):
+ if isinstance(seed, int):
+ kwargs2["extra_body"] = {"random_seed": seed}
+ elif seed is not None:
+ raise ValueError(
+ "MistralProvider only accepts an integer seed, or None."
+ )
+
+ return kwargs2
diff --git a/chatlas/_provider_openai.py b/chatlas/_provider_openai.py
index 284cb84a..8e8e0ff9 100644
--- a/chatlas/_provider_openai.py
+++ b/chatlas/_provider_openai.py
@@ -310,8 +310,7 @@ def _chat_perform_args(
del kwargs_full["tools"]
if stream and "stream_options" not in kwargs_full:
- if self.__class__.__name__ != "DatabricksProvider":
- kwargs_full["stream_options"] = {"include_usage": True}
+ kwargs_full["stream_options"] = {"include_usage": True}
return kwargs_full
@@ -411,7 +410,9 @@ def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]:
if isinstance(x, ContentText):
content_parts.append({"type": "text", "text": x.text})
elif isinstance(x, ContentJson):
- content_parts.append({"type": "text", "text": ""})
+ content_parts.append(
+ {"type": "text", "text": ""}
+ )
elif isinstance(x, ContentToolRequest):
tool_calls.append(
{
@@ -450,7 +451,7 @@ def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]:
if isinstance(x, ContentText):
contents.append({"type": "text", "text": x.text})
elif isinstance(x, ContentJson):
- contents.append({"type": "text", "text": ""})
+ contents.append({"type": "text", "text": ""})
elif isinstance(x, ContentPDF):
contents.append(
{
diff --git a/docs/_quarto.yml b/docs/_quarto.yml
index 867be28b..7979f3e9 100644
--- a/docs/_quarto.yml
+++ b/docs/_quarto.yml
@@ -124,6 +124,7 @@ quartodoc:
- ChatGoogle
- ChatGroq
- ChatHuggingFace
+ - ChatMistral
- ChatOllama
- ChatOpenAI
- ChatOpenRouter
diff --git a/docs/get-started/models.qmd b/docs/get-started/models.qmd
index 8d0bc619..deaa662c 100644
--- a/docs/get-started/models.qmd
+++ b/docs/get-started/models.qmd
@@ -21,7 +21,7 @@ To see the pre-requisites for a given provider, visit the relevant usage page in
| OpenAI | [`ChatOpenAI()`](../reference/ChatOpenAI.qmd) | |
| Azure OpenAI | [`ChatAzureOpenAI()`](../reference/ChatAzureOpenAI.qmd) | ✅ |
| Google (Gemini) | [`ChatGoogle()`](../reference/ChatGoogle.qmd) | |
-| Vertex AI | [`ChatVertex()`](../reference/ChatVertex.qmd) | ✅ |
+| Google (Vertex) | [`ChatVertex()`](../reference/ChatVertex.qmd) | ✅ |
| GitHub model marketplace | [`ChatGithub()`](../reference/ChatGithub.qmd) | |
| Ollama (local models) | [`ChatOllama()`](../reference/ChatOllama.qmd) | |
| Open Router | [`ChatOpenRouter()`](../reference/ChatOpenRouter.qmd) | |
@@ -29,21 +29,26 @@ To see the pre-requisites for a given provider, visit the relevant usage page in
| Hugging Face | [`ChatHuggingFace()`](../reference/ChatHuggingFace.qmd) | |
| Databricks | [`ChatDatabricks()`](../reference/ChatDatabricks.qmd) | ✅ |
| Snowflake Cortex | [`ChatSnowflake()`](../reference/ChatSnowflake.qmd) | ✅ |
+| Mistral | [`ChatMistral()`](../reference/ChatMistral.qmd) | ✅ |
| Groq | [`ChatGroq()`](../reference/ChatGroq.qmd) | |
| perplexity.ai | [`ChatPerplexity()`](../reference/ChatPerplexity.qmd) | |
| Cloudflare | [`ChatCloudflare()`](../reference/ChatCloudflare.qmd) | |
| Portkey | [`ChatPortkey()`](../reference/ChatPortkey.qmd) | ✅ |
-::: callout-note
-
+::: callout-tip
### Other providers
-If you want to use a model provider that isn't listed in the table above, you have two options:
+To use chatlas with a provider not listed in the table above, you have two options:
1. If the model provider is OpenAI compatible (i.e., it can be used with the [`openai` Python SDK](https://github.com/openai/openai-python?tab=readme-ov-file#configuring-the-http-client)), use `ChatOpenAI()` with the appropriate `base_url` and `api_key`.
2. If you're motivated, implement a new provider by subclassing [`Provider`](https://github.com/posit-dev/chatlas/blob/main/chatlas/_provider.py) and implementing the required methods.
+:::
+
+::: callout-warning
+### Known limitations
+Some providers may have a limited amount of support for things like tool calling, structured data extraction, images, etc. In this case, the provider's reference page should include a known limitations section describing the limitations.
:::
### Model choice
diff --git a/tests/test_provider_mistral.py b/tests/test_provider_mistral.py
new file mode 100644
index 00000000..e1396506
--- /dev/null
+++ b/tests/test_provider_mistral.py
@@ -0,0 +1,73 @@
+import os
+
+import pytest
+
+from chatlas import ChatMistral
+
+from .conftest import (
+ assert_data_extraction,
+ assert_images_inline,
+ assert_images_remote,
+ assert_turns_existing,
+ assert_turns_system,
+)
+
+api_key = os.getenv("MISTRAL_API_KEY")
+if api_key is None:
+ pytest.skip("MISTRAL_API_KEY is not set; skipping tests", allow_module_level=True)
+
+
+def test_mistral_simple_request():
+ chat = ChatMistral(
+ system_prompt="Be as terse as possible; no punctuation",
+ )
+ chat.chat("What is 1 + 1?")
+ turn = chat.get_last_turn()
+ assert turn is not None
+ assert turn.tokens is not None
+ assert len(turn.tokens) == 3
+ assert turn.tokens[0] > 0 # prompt tokens
+ assert turn.tokens[1] > 0 # completion tokens
+ assert turn.finish_reason == "stop"
+
+
+@pytest.mark.asyncio
+async def test_mistral_simple_streaming_request():
+ chat = ChatMistral(
+ system_prompt="Be as terse as possible; no punctuation",
+ )
+ res = []
+ async for x in await chat.stream_async("What is 1 + 1?"):
+ res.append(x)
+ assert "2" in "".join(res)
+ turn = chat.get_last_turn()
+ assert turn is not None
+ assert turn.finish_reason == "stop"
+
+
+def test_mistral_respects_turns_interface():
+ chat_fun = ChatMistral
+ assert_turns_system(chat_fun)
+ assert_turns_existing(chat_fun)
+
+
+# Tool calling is poorly supported
+# def test_mistral_tool_variations():
+# chat_fun = ChatMistral
+# assert_tools_simple(chat_fun)
+# assert_tools_simple_stream_content(chat_fun)
+
+# Tool calling is poorly supported
+# @pytest.mark.asyncio
+# async def test_mistral_tool_variations_async():
+# await assert_tools_async(ChatMistral)
+
+
+def test_data_extraction():
+ assert_data_extraction(ChatMistral)
+
+
+def test_mistral_images():
+
+ assert_images_inline(ChatMistral)
+ assert_images_remote(ChatMistral)