diff --git a/CHANGELOG.md b/CHANGELOG.md index 5787db62..a5b01ef4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Added `ChatDeepSeek()` for chatting via [DeepSeek](https://www.deepseek.com/). (#147) * Added `ChatOpenRouter()` for chatting via [Open Router](https://openrouter.ai/). (#148) * Added `ChatHuggingFace()` for chatting via [Hugging Face](https://huggingface.co/). (#144) +* Added `ChatMistral()` for chatting via [Mistral AI](https://mistral.ai/). (#145) * Added `ChatPortkey()` for chatting via [Portkey AI](https://portkey.ai/). (#143) ### Bug fixes diff --git a/chatlas/__init__.py b/chatlas/__init__.py index a04d3bb8..a4f83a2f 100644 --- a/chatlas/__init__.py +++ b/chatlas/__init__.py @@ -14,6 +14,7 @@ from ._provider_google import ChatGoogle, ChatVertex from ._provider_groq import ChatGroq from ._provider_huggingface import ChatHuggingFace +from ._provider_mistral import ChatMistral from ._provider_ollama import ChatOllama from ._provider_openai import ChatAzureOpenAI, ChatOpenAI from ._provider_openrouter import ChatOpenRouter @@ -40,6 +41,7 @@ "ChatGoogle", "ChatGroq", "ChatHuggingFace", + "ChatMistral", "ChatOllama", "ChatOpenAI", "ChatOpenRouter", diff --git a/chatlas/_provider_databricks.py b/chatlas/_provider_databricks.py index 5f46f2e0..1ea4a5af 100644 --- a/chatlas/_provider_databricks.py +++ b/chatlas/_provider_databricks.py @@ -127,3 +127,14 @@ def __init__( api_key="no-token", # A placeholder to pass validations, this will not be used http_client=httpx.AsyncClient(auth=client._client.auth), ) + + # Databricks doesn't support stream_options + def _chat_perform_args( + self, stream, turns, tools, data_model=None, kwargs=None + ) -> "SubmitInputArgs": + kwargs2 = super()._chat_perform_args(stream, turns, tools, data_model, kwargs) + + if "stream_options" in kwargs2: + del kwargs2["stream_options"] + + return kwargs2 diff --git a/chatlas/_provider_mistral.py b/chatlas/_provider_mistral.py new file mode 100644 index 00000000..e7341e3c --- /dev/null +++ b/chatlas/_provider_mistral.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Optional + +from ._chat import Chat +from ._logging import log_model_default +from ._provider_openai import OpenAIProvider +from ._utils import MISSING, MISSING_TYPE, is_testing + +if TYPE_CHECKING: + from openai.types.chat import ChatCompletion + + from .types.openai import ChatClientArgs, SubmitInputArgs + + +def ChatMistral( + *, + system_prompt: Optional[str] = None, + model: Optional[str] = None, + api_key: Optional[str] = None, + base_url: str = "https://api.mistral.ai/v1/", + seed: int | None | MISSING_TYPE = MISSING, + kwargs: Optional["ChatClientArgs"] = None, +) -> Chat["SubmitInputArgs", ChatCompletion]: + """ + Chat with a model hosted on Mistral's La Plateforme. + + Mistral AI provides high-performance language models through their API platform. + + Prerequisites + ------------- + + ::: {.callout-note} + ## API credentials + + Get your API key from https://console.mistral.ai/api-keys. + ::: + + Examples + -------- + ```python + import os + from chatlas import ChatMistral + + chat = ChatMistral(api_key=os.getenv("MISTRAL_API_KEY")) + chat.chat("Tell me three jokes about statisticians") + ``` + + Known limitations + ----------------- + + * Tool calling may be unstable. + * Images require a model that supports vision. + + Parameters + ---------- + system_prompt + A system prompt to set the behavior of the assistant. + model + The model to use for the chat. The default, None, will pick a reasonable + default, and warn you about it. We strongly recommend explicitly + choosing a model for all but the most casual use. + api_key + The API key to use for authentication. You generally should not supply + this directly, but instead set the `MISTRAL_API_KEY` environment + variable. + base_url + The base URL to the endpoint; the default uses Mistral AI. + seed + Optional integer seed that Mistral uses to try and make output more + reproducible. + kwargs + Additional arguments to pass to the `openai.OpenAI()` client + constructor (Mistral uses OpenAI-compatible API). + + Returns + ------- + Chat + A chat object that retains the state of the conversation. + + Note + ---- + Pasting an API key into a chat constructor (e.g., `ChatMistral(api_key="...")`) + is the simplest way to get started, and is fine for interactive use, but is + problematic for code that may be shared with others. + + Instead, consider using environment variables or a configuration file to manage + your credentials. One popular way to manage credentials is to use a `.env` file + to store your credentials, and then use the `python-dotenv` package to load them + into your environment. + + ```shell + pip install python-dotenv + ``` + + ```shell + # .env + MISTRAL_API_KEY=... + ``` + + ```python + from chatlas import ChatMistral + from dotenv import load_dotenv + + load_dotenv() + chat = ChatMistral() + chat.console() + ``` + + Another, more general, solution is to load your environment variables into the shell + before starting Python (maybe in a `.bashrc`, `.zshrc`, etc. file): + + ```shell + export MISTRAL_API_KEY=... + ``` + """ + if isinstance(seed, MISSING_TYPE): + seed = 1014 if is_testing() else None + + if model is None: + model = log_model_default("mistral-large-latest") + + if api_key is None: + api_key = os.getenv("MISTRAL_API_KEY") + + return Chat( + provider=MistralProvider( + api_key=api_key, + model=model, + base_url=base_url, + seed=seed, + kwargs=kwargs, + ), + system_prompt=system_prompt, + ) + + +class MistralProvider(OpenAIProvider): + def __init__( + self, + *, + api_key: Optional[str] = None, + model: str, + base_url: str = "https://api.mistral.ai/v1/", + seed: Optional[int] = None, + name: str = "Mistral", + kwargs: Optional["ChatClientArgs"] = None, + ): + super().__init__( + api_key=api_key, + model=model, + base_url=base_url, + seed=seed, + name=name, + kwargs=kwargs, + ) + + # Mistral is essentially OpenAI-compatible, with a couple small differences. + # We _could_ bring in the Mistral SDK and use it directly for more precise typing, + # etc., but for now that doesn't seem worth it. + def _chat_perform_args( + self, stream, turns, tools, data_model=None, kwargs=None + ) -> "SubmitInputArgs": + # Get the base arguments from OpenAI provider + kwargs2 = super()._chat_perform_args(stream, turns, tools, data_model, kwargs) + + # Mistral doesn't support stream_options + if "stream_options" in kwargs2: + del kwargs2["stream_options"] + + # Mistral wants random_seed, not seed + if seed := kwargs2.pop("seed", None): + if isinstance(seed, int): + kwargs2["extra_body"] = {"random_seed": seed} + elif seed is not None: + raise ValueError( + "MistralProvider only accepts an integer seed, or None." + ) + + return kwargs2 diff --git a/chatlas/_provider_openai.py b/chatlas/_provider_openai.py index 284cb84a..8e8e0ff9 100644 --- a/chatlas/_provider_openai.py +++ b/chatlas/_provider_openai.py @@ -310,8 +310,7 @@ def _chat_perform_args( del kwargs_full["tools"] if stream and "stream_options" not in kwargs_full: - if self.__class__.__name__ != "DatabricksProvider": - kwargs_full["stream_options"] = {"include_usage": True} + kwargs_full["stream_options"] = {"include_usage": True} return kwargs_full @@ -411,7 +410,9 @@ def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]: if isinstance(x, ContentText): content_parts.append({"type": "text", "text": x.text}) elif isinstance(x, ContentJson): - content_parts.append({"type": "text", "text": ""}) + content_parts.append( + {"type": "text", "text": ""} + ) elif isinstance(x, ContentToolRequest): tool_calls.append( { @@ -450,7 +451,7 @@ def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]: if isinstance(x, ContentText): contents.append({"type": "text", "text": x.text}) elif isinstance(x, ContentJson): - contents.append({"type": "text", "text": ""}) + contents.append({"type": "text", "text": ""}) elif isinstance(x, ContentPDF): contents.append( { diff --git a/docs/_quarto.yml b/docs/_quarto.yml index 867be28b..7979f3e9 100644 --- a/docs/_quarto.yml +++ b/docs/_quarto.yml @@ -124,6 +124,7 @@ quartodoc: - ChatGoogle - ChatGroq - ChatHuggingFace + - ChatMistral - ChatOllama - ChatOpenAI - ChatOpenRouter diff --git a/docs/get-started/models.qmd b/docs/get-started/models.qmd index 8d0bc619..deaa662c 100644 --- a/docs/get-started/models.qmd +++ b/docs/get-started/models.qmd @@ -21,7 +21,7 @@ To see the pre-requisites for a given provider, visit the relevant usage page in | OpenAI | [`ChatOpenAI()`](../reference/ChatOpenAI.qmd) | | | Azure OpenAI | [`ChatAzureOpenAI()`](../reference/ChatAzureOpenAI.qmd) | ✅ | | Google (Gemini) | [`ChatGoogle()`](../reference/ChatGoogle.qmd) | | -| Vertex AI | [`ChatVertex()`](../reference/ChatVertex.qmd) | ✅ | +| Google (Vertex) | [`ChatVertex()`](../reference/ChatVertex.qmd) | ✅ | | GitHub model marketplace | [`ChatGithub()`](../reference/ChatGithub.qmd) | | | Ollama (local models) | [`ChatOllama()`](../reference/ChatOllama.qmd) | | | Open Router | [`ChatOpenRouter()`](../reference/ChatOpenRouter.qmd) | | @@ -29,21 +29,26 @@ To see the pre-requisites for a given provider, visit the relevant usage page in | Hugging Face | [`ChatHuggingFace()`](../reference/ChatHuggingFace.qmd) | | | Databricks | [`ChatDatabricks()`](../reference/ChatDatabricks.qmd) | ✅ | | Snowflake Cortex | [`ChatSnowflake()`](../reference/ChatSnowflake.qmd) | ✅ | +| Mistral | [`ChatMistral()`](../reference/ChatMistral.qmd) | ✅ | | Groq | [`ChatGroq()`](../reference/ChatGroq.qmd) | | | perplexity.ai | [`ChatPerplexity()`](../reference/ChatPerplexity.qmd) | | | Cloudflare | [`ChatCloudflare()`](../reference/ChatCloudflare.qmd) | | | Portkey | [`ChatPortkey()`](../reference/ChatPortkey.qmd) | ✅ | -::: callout-note - +::: callout-tip ### Other providers -If you want to use a model provider that isn't listed in the table above, you have two options: +To use chatlas with a provider not listed in the table above, you have two options: 1. If the model provider is OpenAI compatible (i.e., it can be used with the [`openai` Python SDK](https://github.com/openai/openai-python?tab=readme-ov-file#configuring-the-http-client)), use `ChatOpenAI()` with the appropriate `base_url` and `api_key`. 2. If you're motivated, implement a new provider by subclassing [`Provider`](https://github.com/posit-dev/chatlas/blob/main/chatlas/_provider.py) and implementing the required methods. +::: + +::: callout-warning +### Known limitations +Some providers may have a limited amount of support for things like tool calling, structured data extraction, images, etc. In this case, the provider's reference page should include a known limitations section describing the limitations. ::: ### Model choice diff --git a/tests/test_provider_mistral.py b/tests/test_provider_mistral.py new file mode 100644 index 00000000..e1396506 --- /dev/null +++ b/tests/test_provider_mistral.py @@ -0,0 +1,73 @@ +import os + +import pytest + +from chatlas import ChatMistral + +from .conftest import ( + assert_data_extraction, + assert_images_inline, + assert_images_remote, + assert_turns_existing, + assert_turns_system, +) + +api_key = os.getenv("MISTRAL_API_KEY") +if api_key is None: + pytest.skip("MISTRAL_API_KEY is not set; skipping tests", allow_module_level=True) + + +def test_mistral_simple_request(): + chat = ChatMistral( + system_prompt="Be as terse as possible; no punctuation", + ) + chat.chat("What is 1 + 1?") + turn = chat.get_last_turn() + assert turn is not None + assert turn.tokens is not None + assert len(turn.tokens) == 3 + assert turn.tokens[0] > 0 # prompt tokens + assert turn.tokens[1] > 0 # completion tokens + assert turn.finish_reason == "stop" + + +@pytest.mark.asyncio +async def test_mistral_simple_streaming_request(): + chat = ChatMistral( + system_prompt="Be as terse as possible; no punctuation", + ) + res = [] + async for x in await chat.stream_async("What is 1 + 1?"): + res.append(x) + assert "2" in "".join(res) + turn = chat.get_last_turn() + assert turn is not None + assert turn.finish_reason == "stop" + + +def test_mistral_respects_turns_interface(): + chat_fun = ChatMistral + assert_turns_system(chat_fun) + assert_turns_existing(chat_fun) + + +# Tool calling is poorly supported +# def test_mistral_tool_variations(): +# chat_fun = ChatMistral +# assert_tools_simple(chat_fun) +# assert_tools_simple_stream_content(chat_fun) + +# Tool calling is poorly supported +# @pytest.mark.asyncio +# async def test_mistral_tool_variations_async(): +# await assert_tools_async(ChatMistral) + + +def test_data_extraction(): + assert_data_extraction(ChatMistral) + + +def test_mistral_images(): + + assert_images_inline(ChatMistral) + assert_images_remote(ChatMistral)