Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [UNRELEASED]

* Added a new `ChatPortkey()` which integrates with [Portkey AI](https://portkey.ai/). (#)
### New features

* Added `ChatHuggingFace()` for chatting via [Hugging Face](https://huggingface.co/). (#144)
* Added `ChatPortkey()` for chatting via [Portkey AI](https://portkey.ai/). (#143)


## [0.9.2] - 2025-08-08
Expand Down
3 changes: 2 additions & 1 deletion chatlas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ._provider_github import ChatGithub
from ._provider_google import ChatGoogle, ChatVertex
from ._provider_groq import ChatGroq
from ._provider_huggingface import ChatHuggingFace
from ._provider_ollama import ChatOllama
from ._provider_openai import ChatAzureOpenAI, ChatOpenAI
from ._provider_perplexity import ChatPerplexity
Expand All @@ -33,6 +34,7 @@
"ChatGithub",
"ChatGoogle",
"ChatGroq",
"ChatHuggingFace",
"ChatOllama",
"ChatOpenAI",
"ChatAzureOpenAI",
Expand All @@ -58,4 +60,3 @@
"Turn",
"types",
)

155 changes: 155 additions & 0 deletions chatlas/_provider_huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Optional

from ._chat import Chat
from ._logging import log_model_default
from ._provider_openai import OpenAIProvider

if TYPE_CHECKING:
from openai.types.chat import ChatCompletion

from .types.openai import ChatClientArgs, SubmitInputArgs


def ChatHuggingFace(
*,
system_prompt: Optional[str] = None,
model: Optional[str] = None,
api_key: Optional[str] = None,
kwargs: Optional["ChatClientArgs"] = None,
) -> Chat["SubmitInputArgs", ChatCompletion]:
"""
Chat with a model hosted on Hugging Face Inference API.

[Hugging Face](https://huggingface.co/) hosts a variety of open-source
and proprietary AI models available via their Inference API.
To use the Hugging Face API, you must have an Access Token, which you can obtain
from your [Hugging Face account](https://huggingface.co/settings/tokens).
Ensure that at least "Make calls to Inference Providers" and
"Make calls to your Inference Endpoints" is checked.

Prerequisites
--------------

::: {.callout-note}
## API key

You will need to create a Hugging Face account and generate an API token
from your [account settings](https://huggingface.co/settings/tokens).
Make sure to enable "Make calls to Inference Providers" permission.
:::

Examples
--------
```python
import os
from chatlas import ChatHuggingFace

chat = ChatHuggingFace(api_key=os.getenv("HUGGINGFACE_API_KEY"))
chat.chat("What is the capital of France?")
```

Parameters
----------
system_prompt
A system prompt to set the behavior of the assistant.
model
The model to use for the chat. The default, None, will pick a reasonable
default, and warn you about it. We strongly recommend explicitly
choosing a model for all but the most casual use.
api_key
The API key to use for authentication. You generally should not supply
this directly, but instead set the `HUGGINGFACE_API_KEY` environment
variable.
kwargs
Additional arguments to pass to the underlying OpenAI client
constructor.

Returns
-------
Chat
A chat object that retains the state of the conversation.

Known limitations
-----------------

* Some models do not support the chat interface or parts of it, for example
`google/gemma-2-2b-it` does not support a system prompt. You will need to
carefully choose the model.
* Tool calling support varies by model - many models do not support it.

Note
----
This function is a lightweight wrapper around [](`~chatlas.ChatOpenAI`), with
the defaults tweaked for Hugging Face.

Note
----
Pasting an API key into a chat constructor (e.g., `ChatHuggingFace(api_key="...")`)
is the simplest way to get started, and is fine for interactive use, but is
problematic for code that may be shared with others.

Instead, consider using environment variables or a configuration file to manage
your credentials. One popular way to manage credentials is to use a `.env` file
to store your credentials, and then use the `python-dotenv` package to load them
into your environment.

```shell
pip install python-dotenv
```

```shell
# .env
HUGGINGFACE_API_KEY=...
```

```python
from chatlas import ChatHuggingFace
from dotenv import load_dotenv

load_dotenv()
chat = ChatHuggingFace()
chat.console()
```

Another, more general, solution is to load your environment variables into the shell
before starting Python (maybe in a `.bashrc`, `.zshrc`, etc. file):

```shell
export HUGGINGFACE_API_KEY=...
```
"""
if api_key is None:
api_key = os.getenv("HUGGINGFACE_API_KEY")

if model is None:
model = log_model_default("meta-llama/Llama-3.1-8B-Instruct")

return Chat(
provider=HuggingFaceProvider(
api_key=api_key,
model=model,
kwargs=kwargs,
),
system_prompt=system_prompt,
)


class HuggingFaceProvider(OpenAIProvider):
def __init__(
self,
*,
api_key: Optional[str] = None,
model: str,
kwargs: Optional["ChatClientArgs"] = None,
):
# https://huggingface.co/docs/inference-providers/en/index?python-clients=requests#http--curl
super().__init__(
name="HuggingFace",
model=model,
api_key=api_key,
base_url="https://router.huggingface.co/v1",
kwargs=kwargs,
)
1 change: 1 addition & 0 deletions docs/_quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ quartodoc:
- ChatGithub
- ChatGoogle
- ChatGroq
- ChatHuggingFace
- ChatOllama
- ChatOpenAI
- ChatPerplexity
Expand Down
1 change: 1 addition & 0 deletions docs/get-started/models.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ To see the pre-requisites for a given provider, visit the relevant usage page in
| GitHub model marketplace | [`ChatGithub()`](../reference/ChatGithub.qmd) | |
| Google (Gemini) | [`ChatGoogle()`](../reference/ChatGoogle.qmd) | |
| Groq | [`ChatGroq()`](../reference/ChatGroq.qmd) | |
| Hugging Face | [`ChatHuggingFace()`](../reference/ChatHuggingFace.qmd) | |
| Ollama local models | [`ChatOllama()`](../reference/ChatOllama.qmd) | |
| OpenAI | [`ChatOpenAI()`](../reference/ChatOpenAI.qmd) | |
| perplexity.ai | [`ChatPerplexity()`](../reference/ChatPerplexity.qmd) | |
Expand Down
103 changes: 103 additions & 0 deletions tests/test_provider_huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import os

import pytest
from chatlas import ChatHuggingFace

from .conftest import (
assert_data_extraction,
assert_images_inline,
assert_images_remote,
assert_tools_async,
assert_tools_simple,
assert_turns_existing,
assert_turns_system,
)

# I think we would need to pay Hugging Face to actually run these tests?
api_key = os.getenv("HUGGINGFACE_API_KEY")
if api_key is None:
pytest.skip(
"HUGGINGFACE_API_KEY is not set; skipping tests", allow_module_level=True
)


def test_huggingface_simple_request():
chat = ChatHuggingFace(
system_prompt="Be as terse as possible; no punctuation",
model="meta-llama/Llama-3.1-8B-Instruct",
)
chat.chat("What is 1 + 1?")
turn = chat.get_last_turn()
assert turn is not None
assert turn.tokens is not None
assert len(turn.tokens) == 3
assert turn.tokens[0] > 0 # input tokens
assert turn.tokens[1] > 0 # output tokens
assert turn.finish_reason == "stop"


@pytest.mark.asyncio
async def test_huggingface_simple_streaming_request():
chat = ChatHuggingFace(
system_prompt="Be as terse as possible; no punctuation",
model="meta-llama/Llama-3.1-8B-Instruct",
)
res = []
async for x in await chat.stream_async("What is 1 + 1?"):
res.append(x)
assert "2" in "".join(res)
turn = chat.get_last_turn()
assert turn is not None
assert turn.finish_reason == "stop"


def test_huggingface_respects_turns_interface():
chat_fun = ChatHuggingFace
assert_turns_system(chat_fun)
assert_turns_existing(chat_fun)


def test_huggingface_tools():
def chat_fun(**kwargs):
return ChatHuggingFace(model="meta-llama/Llama-3.1-8B-Instruct", **kwargs)

assert_tools_simple(chat_fun)


@pytest.mark.asyncio
async def test_huggingface_tools_async():
def chat_fun(**kwargs):
return ChatHuggingFace(model="meta-llama/Llama-3.1-8B-Instruct", **kwargs)

await assert_tools_async(chat_fun)


def test_huggingface_data_extraction():
def chat_fun(**kwargs):
return ChatHuggingFace(model="meta-llama/Llama-3.1-8B-Instruct", **kwargs)

assert_data_extraction(chat_fun)


def test_huggingface_images():
# Use a vision model that supports images
def chat_fun(**kwargs):
return ChatHuggingFace(model="Qwen/Qwen2.5-VL-7B-Instruct", **kwargs)

assert_images_inline(chat_fun)
assert_images_remote(chat_fun)


def test_huggingface_custom_model():
chat = ChatHuggingFace(model="microsoft/DialoGPT-medium")
assert chat.provider.model == "microsoft/DialoGPT-medium"


def test_huggingface_base_url():
chat = ChatHuggingFace()
assert "huggingface.co" in str(chat.provider._client.base_url)


def test_huggingface_provider_name():
chat = ChatHuggingFace()
assert chat.provider.name == "HuggingFace"
6 changes: 6 additions & 0 deletions tests/test_provider_portkey.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import pytest
from chatlas import ChatPortkey

Expand All @@ -14,6 +16,10 @@
assert_turns_system,
)

api_key = os.getenv("PORTKEY_API_KEY")
if api_key is None:
pytest.skip("PORTKEY_API_KEY is not set; skipping tests", allow_module_level=True)


def _chat_portkey_test(**kwargs):
model = kwargs.pop("model", "@openai/gpt-4o-mini")
Expand Down
Loading