Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Added `ChatDeepSeek()` for chatting via [DeepSeek](https://www.deepseek.com/). (#147)
* Added `ChatOpenRouter()` for chatting via [Open Router](https://openrouter.ai/). (#148)
* Added `ChatHuggingFace()` for chatting via [Hugging Face](https://huggingface.co/). (#144)
* Added `ChatMistral()` for chatting via [Mistral AI](https://mistral.ai/). (#145)
* Added `ChatPortkey()` for chatting via [Portkey AI](https://portkey.ai/). (#143)

### Bug fixes
Expand Down
2 changes: 2 additions & 0 deletions chatlas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ._provider_google import ChatGoogle, ChatVertex
from ._provider_groq import ChatGroq
from ._provider_huggingface import ChatHuggingFace
from ._provider_mistral import ChatMistral
from ._provider_ollama import ChatOllama
from ._provider_openai import ChatAzureOpenAI, ChatOpenAI
from ._provider_openrouter import ChatOpenRouter
Expand All @@ -40,6 +41,7 @@
"ChatGoogle",
"ChatGroq",
"ChatHuggingFace",
"ChatMistral",
"ChatOllama",
"ChatOpenAI",
"ChatOpenRouter",
Expand Down
11 changes: 11 additions & 0 deletions chatlas/_provider_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,14 @@ def __init__(
api_key="no-token", # A placeholder to pass validations, this will not be used
http_client=httpx.AsyncClient(auth=client._client.auth),
)

# Databricks doesn't support stream_options
def _chat_perform_args(
self, stream, turns, tools, data_model=None, kwargs=None
) -> "SubmitInputArgs":
kwargs2 = super()._chat_perform_args(stream, turns, tools, data_model, kwargs)

if "stream_options" in kwargs2:
del kwargs2["stream_options"]

return kwargs2
181 changes: 181 additions & 0 deletions chatlas/_provider_mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Optional

from ._chat import Chat
from ._logging import log_model_default
from ._provider_openai import OpenAIProvider
from ._utils import MISSING, MISSING_TYPE, is_testing

if TYPE_CHECKING:
from openai.types.chat import ChatCompletion

from .types.openai import ChatClientArgs, SubmitInputArgs


def ChatMistral(
*,
system_prompt: Optional[str] = None,
model: Optional[str] = None,
api_key: Optional[str] = None,
base_url: str = "https://api.mistral.ai/v1/",
seed: int | None | MISSING_TYPE = MISSING,
kwargs: Optional["ChatClientArgs"] = None,
) -> Chat["SubmitInputArgs", ChatCompletion]:
"""
Chat with a model hosted on Mistral's La Plateforme.

Mistral AI provides high-performance language models through their API platform.

Prerequisites
-------------

::: {.callout-note}
## API credentials

Get your API key from https://console.mistral.ai/api-keys.
:::

Examples
--------
```python
import os
from chatlas import ChatMistral

chat = ChatMistral(api_key=os.getenv("MISTRAL_API_KEY"))
chat.chat("Tell me three jokes about statisticians")
```

Known limitations
-----------------

* Tool calling may be unstable.
* Images require a model that supports vision.

Parameters
----------
system_prompt
A system prompt to set the behavior of the assistant.
model
The model to use for the chat. The default, None, will pick a reasonable
default, and warn you about it. We strongly recommend explicitly
choosing a model for all but the most casual use.
api_key
The API key to use for authentication. You generally should not supply
this directly, but instead set the `MISTRAL_API_KEY` environment
variable.
base_url
The base URL to the endpoint; the default uses Mistral AI.
seed
Optional integer seed that Mistral uses to try and make output more
reproducible.
kwargs
Additional arguments to pass to the `openai.OpenAI()` client
constructor (Mistral uses OpenAI-compatible API).

Returns
-------
Chat
A chat object that retains the state of the conversation.

Note
----
Pasting an API key into a chat constructor (e.g., `ChatMistral(api_key="...")`)
is the simplest way to get started, and is fine for interactive use, but is
problematic for code that may be shared with others.

Instead, consider using environment variables or a configuration file to manage
your credentials. One popular way to manage credentials is to use a `.env` file
to store your credentials, and then use the `python-dotenv` package to load them
into your environment.

```shell
pip install python-dotenv
```

```shell
# .env
MISTRAL_API_KEY=...
```

```python
from chatlas import ChatMistral
from dotenv import load_dotenv

load_dotenv()
chat = ChatMistral()
chat.console()
```

Another, more general, solution is to load your environment variables into the shell
before starting Python (maybe in a `.bashrc`, `.zshrc`, etc. file):

```shell
export MISTRAL_API_KEY=...
```
"""
if isinstance(seed, MISSING_TYPE):
seed = 1014 if is_testing() else None

if model is None:
model = log_model_default("mistral-large-latest")

if api_key is None:
api_key = os.getenv("MISTRAL_API_KEY")

return Chat(
provider=MistralProvider(
api_key=api_key,
model=model,
base_url=base_url,
seed=seed,
kwargs=kwargs,
),
system_prompt=system_prompt,
)


class MistralProvider(OpenAIProvider):
def __init__(
self,
*,
api_key: Optional[str] = None,
model: str,
base_url: str = "https://api.mistral.ai/v1/",
seed: Optional[int] = None,
name: str = "Mistral",
kwargs: Optional["ChatClientArgs"] = None,
):
super().__init__(
api_key=api_key,
model=model,
base_url=base_url,
seed=seed,
name=name,
kwargs=kwargs,
)

# Mistral is essentially OpenAI-compatible, with a couple small differences.
# We _could_ bring in the Mistral SDK and use it directly for more precise typing,
# etc., but for now that doesn't seem worth it.
def _chat_perform_args(
self, stream, turns, tools, data_model=None, kwargs=None
) -> "SubmitInputArgs":
# Get the base arguments from OpenAI provider
kwargs2 = super()._chat_perform_args(stream, turns, tools, data_model, kwargs)

# Mistral doesn't support stream_options
if "stream_options" in kwargs2:
del kwargs2["stream_options"]

# Mistral wants random_seed, not seed
if seed := kwargs2.pop("seed", None):
if isinstance(seed, int):
kwargs2["extra_body"] = {"random_seed": seed}
elif seed is not None:
raise ValueError(
"MistralProvider only accepts an integer seed, or None."
)

return kwargs2
9 changes: 5 additions & 4 deletions chatlas/_provider_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,7 @@ def _chat_perform_args(
del kwargs_full["tools"]

if stream and "stream_options" not in kwargs_full:
if self.__class__.__name__ != "DatabricksProvider":
kwargs_full["stream_options"] = {"include_usage": True}
kwargs_full["stream_options"] = {"include_usage": True}

return kwargs_full

Expand Down Expand Up @@ -411,7 +410,9 @@ def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]:
if isinstance(x, ContentText):
content_parts.append({"type": "text", "text": x.text})
elif isinstance(x, ContentJson):
content_parts.append({"type": "text", "text": ""})
content_parts.append(
{"type": "text", "text": "<structured data/>"}
)
elif isinstance(x, ContentToolRequest):
tool_calls.append(
{
Expand Down Expand Up @@ -450,7 +451,7 @@ def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]:
if isinstance(x, ContentText):
contents.append({"type": "text", "text": x.text})
elif isinstance(x, ContentJson):
contents.append({"type": "text", "text": ""})
contents.append({"type": "text", "text": "<structured data/>"})
elif isinstance(x, ContentPDF):
contents.append(
{
Expand Down
1 change: 1 addition & 0 deletions docs/_quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ quartodoc:
- ChatGoogle
- ChatGroq
- ChatHuggingFace
- ChatMistral
- ChatOllama
- ChatOpenAI
- ChatOpenRouter
Expand Down
13 changes: 9 additions & 4 deletions docs/get-started/models.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,34 @@ To see the pre-requisites for a given provider, visit the relevant usage page in
| OpenAI | [`ChatOpenAI()`](../reference/ChatOpenAI.qmd) | |
| Azure OpenAI | [`ChatAzureOpenAI()`](../reference/ChatAzureOpenAI.qmd) | ✅ |
| Google (Gemini) | [`ChatGoogle()`](../reference/ChatGoogle.qmd) | |
| Vertex AI | [`ChatVertex()`](../reference/ChatVertex.qmd) | ✅ |
| Google (Vertex) | [`ChatVertex()`](../reference/ChatVertex.qmd) | ✅ |
| GitHub model marketplace | [`ChatGithub()`](../reference/ChatGithub.qmd) | |
| Ollama (local models) | [`ChatOllama()`](../reference/ChatOllama.qmd) | |
| Open Router | [`ChatOpenRouter()`](../reference/ChatOpenRouter.qmd) | |
| DeepSeek | [`ChatDeepSeek()`](../reference/ChatDeepSeek.qmd) | |
| Hugging Face | [`ChatHuggingFace()`](../reference/ChatHuggingFace.qmd) | |
| Databricks | [`ChatDatabricks()`](../reference/ChatDatabricks.qmd) | ✅ |
| Snowflake Cortex | [`ChatSnowflake()`](../reference/ChatSnowflake.qmd) | ✅ |
| Mistral | [`ChatMistral()`](../reference/ChatMistral.qmd) | ✅ |
| Groq | [`ChatGroq()`](../reference/ChatGroq.qmd) | |
| perplexity.ai | [`ChatPerplexity()`](../reference/ChatPerplexity.qmd) | |
| Cloudflare | [`ChatCloudflare()`](../reference/ChatCloudflare.qmd) | |
| Portkey | [`ChatPortkey()`](../reference/ChatPortkey.qmd) | ✅ |


::: callout-note

::: callout-tip
### Other providers

If you want to use a model provider that isn't listed in the table above, you have two options:
To use chatlas with a provider not listed in the table above, you have two options:

1. If the model provider is OpenAI compatible (i.e., it can be used with the [`openai` Python SDK](https://github.com/openai/openai-python?tab=readme-ov-file#configuring-the-http-client)), use `ChatOpenAI()` with the appropriate `base_url` and `api_key`.
2. If you're motivated, implement a new provider by subclassing [`Provider`](https://github.com/posit-dev/chatlas/blob/main/chatlas/_provider.py) and implementing the required methods.
:::

::: callout-warning
### Known limitations

Some providers may have a limited amount of support for things like tool calling, structured data extraction, images, etc. In this case, the provider's reference page should include a known limitations section describing the limitations.
:::

### Model choice
Expand Down
73 changes: 73 additions & 0 deletions tests/test_provider_mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os

import pytest

from chatlas import ChatMistral

from .conftest import (
assert_data_extraction,
assert_images_inline,
assert_images_remote,
assert_turns_existing,
assert_turns_system,
)

api_key = os.getenv("MISTRAL_API_KEY")
if api_key is None:
pytest.skip("MISTRAL_API_KEY is not set; skipping tests", allow_module_level=True)


def test_mistral_simple_request():
chat = ChatMistral(
system_prompt="Be as terse as possible; no punctuation",
)
chat.chat("What is 1 + 1?")
turn = chat.get_last_turn()
assert turn is not None
assert turn.tokens is not None
assert len(turn.tokens) == 3
assert turn.tokens[0] > 0 # prompt tokens
assert turn.tokens[1] > 0 # completion tokens
assert turn.finish_reason == "stop"


@pytest.mark.asyncio
async def test_mistral_simple_streaming_request():
chat = ChatMistral(
system_prompt="Be as terse as possible; no punctuation",
)
res = []
async for x in await chat.stream_async("What is 1 + 1?"):
res.append(x)
assert "2" in "".join(res)
turn = chat.get_last_turn()
assert turn is not None
assert turn.finish_reason == "stop"


def test_mistral_respects_turns_interface():
chat_fun = ChatMistral
assert_turns_system(chat_fun)
assert_turns_existing(chat_fun)


# Tool calling is poorly supported
# def test_mistral_tool_variations():
# chat_fun = ChatMistral
# assert_tools_simple(chat_fun)
# assert_tools_simple_stream_content(chat_fun)

# Tool calling is poorly supported
# @pytest.mark.asyncio
# async def test_mistral_tool_variations_async():
# await assert_tools_async(ChatMistral)


def test_data_extraction():
assert_data_extraction(ChatMistral)


def test_mistral_images():

assert_images_inline(ChatMistral)
assert_images_remote(ChatMistral)