diff --git a/CHANGELOG.md b/CHANGELOG.md index a467098d..6aeb4ba2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [UNRELEASED] -* Added a new `ChatPortkey()` which integrates with [Portkey AI](https://portkey.ai/). (#) +### New features + +* Added `ChatHuggingFace()` for chatting via [Hugging Face](https://huggingface.co/). (#144) +* Added `ChatPortkey()` for chatting via [Portkey AI](https://portkey.ai/). (#143) ## [0.9.2] - 2025-08-08 diff --git a/chatlas/__init__.py b/chatlas/__init__.py index a60a5636..133fd13c 100644 --- a/chatlas/__init__.py +++ b/chatlas/__init__.py @@ -11,6 +11,7 @@ from ._provider_github import ChatGithub from ._provider_google import ChatGoogle, ChatVertex from ._provider_groq import ChatGroq +from ._provider_huggingface import ChatHuggingFace from ._provider_ollama import ChatOllama from ._provider_openai import ChatAzureOpenAI, ChatOpenAI from ._provider_perplexity import ChatPerplexity @@ -33,6 +34,7 @@ "ChatGithub", "ChatGoogle", "ChatGroq", + "ChatHuggingFace", "ChatOllama", "ChatOpenAI", "ChatAzureOpenAI", @@ -58,4 +60,3 @@ "Turn", "types", ) - diff --git a/chatlas/_provider_huggingface.py b/chatlas/_provider_huggingface.py new file mode 100644 index 00000000..4bfcb007 --- /dev/null +++ b/chatlas/_provider_huggingface.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Optional + +from ._chat import Chat +from ._logging import log_model_default +from ._provider_openai import OpenAIProvider + +if TYPE_CHECKING: + from openai.types.chat import ChatCompletion + + from .types.openai import ChatClientArgs, SubmitInputArgs + + +def ChatHuggingFace( + *, + system_prompt: Optional[str] = None, + model: Optional[str] = None, + api_key: Optional[str] = None, + kwargs: Optional["ChatClientArgs"] = None, +) -> Chat["SubmitInputArgs", ChatCompletion]: + """ + Chat with a model hosted on Hugging Face Inference API. + + [Hugging Face](https://huggingface.co/) hosts a variety of open-source + and proprietary AI models available via their Inference API. + To use the Hugging Face API, you must have an Access Token, which you can obtain + from your [Hugging Face account](https://huggingface.co/settings/tokens). + Ensure that at least "Make calls to Inference Providers" and + "Make calls to your Inference Endpoints" is checked. + + Prerequisites + -------------- + + ::: {.callout-note} + ## API key + + You will need to create a Hugging Face account and generate an API token + from your [account settings](https://huggingface.co/settings/tokens). + Make sure to enable "Make calls to Inference Providers" permission. + ::: + + Examples + -------- + ```python + import os + from chatlas import ChatHuggingFace + + chat = ChatHuggingFace(api_key=os.getenv("HUGGINGFACE_API_KEY")) + chat.chat("What is the capital of France?") + ``` + + Parameters + ---------- + system_prompt + A system prompt to set the behavior of the assistant. + model + The model to use for the chat. The default, None, will pick a reasonable + default, and warn you about it. We strongly recommend explicitly + choosing a model for all but the most casual use. + api_key + The API key to use for authentication. You generally should not supply + this directly, but instead set the `HUGGINGFACE_API_KEY` environment + variable. + kwargs + Additional arguments to pass to the underlying OpenAI client + constructor. + + Returns + ------- + Chat + A chat object that retains the state of the conversation. + + Known limitations + ----------------- + + * Some models do not support the chat interface or parts of it, for example + `google/gemma-2-2b-it` does not support a system prompt. You will need to + carefully choose the model. + * Tool calling support varies by model - many models do not support it. + + Note + ---- + This function is a lightweight wrapper around [](`~chatlas.ChatOpenAI`), with + the defaults tweaked for Hugging Face. + + Note + ---- + Pasting an API key into a chat constructor (e.g., `ChatHuggingFace(api_key="...")`) + is the simplest way to get started, and is fine for interactive use, but is + problematic for code that may be shared with others. + + Instead, consider using environment variables or a configuration file to manage + your credentials. One popular way to manage credentials is to use a `.env` file + to store your credentials, and then use the `python-dotenv` package to load them + into your environment. + + ```shell + pip install python-dotenv + ``` + + ```shell + # .env + HUGGINGFACE_API_KEY=... + ``` + + ```python + from chatlas import ChatHuggingFace + from dotenv import load_dotenv + + load_dotenv() + chat = ChatHuggingFace() + chat.console() + ``` + + Another, more general, solution is to load your environment variables into the shell + before starting Python (maybe in a `.bashrc`, `.zshrc`, etc. file): + + ```shell + export HUGGINGFACE_API_KEY=... + ``` + """ + if api_key is None: + api_key = os.getenv("HUGGINGFACE_API_KEY") + + if model is None: + model = log_model_default("meta-llama/Llama-3.1-8B-Instruct") + + return Chat( + provider=HuggingFaceProvider( + api_key=api_key, + model=model, + kwargs=kwargs, + ), + system_prompt=system_prompt, + ) + + +class HuggingFaceProvider(OpenAIProvider): + def __init__( + self, + *, + api_key: Optional[str] = None, + model: str, + kwargs: Optional["ChatClientArgs"] = None, + ): + # https://huggingface.co/docs/inference-providers/en/index?python-clients=requests#http--curl + super().__init__( + name="HuggingFace", + model=model, + api_key=api_key, + base_url="https://router.huggingface.co/v1", + kwargs=kwargs, + ) diff --git a/docs/_quarto.yml b/docs/_quarto.yml index 307ea253..2c7bc351 100644 --- a/docs/_quarto.yml +++ b/docs/_quarto.yml @@ -120,6 +120,7 @@ quartodoc: - ChatGithub - ChatGoogle - ChatGroq + - ChatHuggingFace - ChatOllama - ChatOpenAI - ChatPerplexity diff --git a/docs/get-started/models.qmd b/docs/get-started/models.qmd index f6421b2d..9c011d1f 100644 --- a/docs/get-started/models.qmd +++ b/docs/get-started/models.qmd @@ -20,6 +20,7 @@ To see the pre-requisites for a given provider, visit the relevant usage page in | GitHub model marketplace | [`ChatGithub()`](../reference/ChatGithub.qmd) | | | Google (Gemini) | [`ChatGoogle()`](../reference/ChatGoogle.qmd) | | | Groq | [`ChatGroq()`](../reference/ChatGroq.qmd) | | +| Hugging Face | [`ChatHuggingFace()`](../reference/ChatHuggingFace.qmd) | | | Ollama local models | [`ChatOllama()`](../reference/ChatOllama.qmd) | | | OpenAI | [`ChatOpenAI()`](../reference/ChatOpenAI.qmd) | | | perplexity.ai | [`ChatPerplexity()`](../reference/ChatPerplexity.qmd) | | diff --git a/tests/test_provider_huggingface.py b/tests/test_provider_huggingface.py new file mode 100644 index 00000000..f229ac56 --- /dev/null +++ b/tests/test_provider_huggingface.py @@ -0,0 +1,103 @@ +import os + +import pytest +from chatlas import ChatHuggingFace + +from .conftest import ( + assert_data_extraction, + assert_images_inline, + assert_images_remote, + assert_tools_async, + assert_tools_simple, + assert_turns_existing, + assert_turns_system, +) + +# I think we would need to pay Hugging Face to actually run these tests? +api_key = os.getenv("HUGGINGFACE_API_KEY") +if api_key is None: + pytest.skip( + "HUGGINGFACE_API_KEY is not set; skipping tests", allow_module_level=True + ) + + +def test_huggingface_simple_request(): + chat = ChatHuggingFace( + system_prompt="Be as terse as possible; no punctuation", + model="meta-llama/Llama-3.1-8B-Instruct", + ) + chat.chat("What is 1 + 1?") + turn = chat.get_last_turn() + assert turn is not None + assert turn.tokens is not None + assert len(turn.tokens) == 3 + assert turn.tokens[0] > 0 # input tokens + assert turn.tokens[1] > 0 # output tokens + assert turn.finish_reason == "stop" + + +@pytest.mark.asyncio +async def test_huggingface_simple_streaming_request(): + chat = ChatHuggingFace( + system_prompt="Be as terse as possible; no punctuation", + model="meta-llama/Llama-3.1-8B-Instruct", + ) + res = [] + async for x in await chat.stream_async("What is 1 + 1?"): + res.append(x) + assert "2" in "".join(res) + turn = chat.get_last_turn() + assert turn is not None + assert turn.finish_reason == "stop" + + +def test_huggingface_respects_turns_interface(): + chat_fun = ChatHuggingFace + assert_turns_system(chat_fun) + assert_turns_existing(chat_fun) + + +def test_huggingface_tools(): + def chat_fun(**kwargs): + return ChatHuggingFace(model="meta-llama/Llama-3.1-8B-Instruct", **kwargs) + + assert_tools_simple(chat_fun) + + +@pytest.mark.asyncio +async def test_huggingface_tools_async(): + def chat_fun(**kwargs): + return ChatHuggingFace(model="meta-llama/Llama-3.1-8B-Instruct", **kwargs) + + await assert_tools_async(chat_fun) + + +def test_huggingface_data_extraction(): + def chat_fun(**kwargs): + return ChatHuggingFace(model="meta-llama/Llama-3.1-8B-Instruct", **kwargs) + + assert_data_extraction(chat_fun) + + +def test_huggingface_images(): + # Use a vision model that supports images + def chat_fun(**kwargs): + return ChatHuggingFace(model="Qwen/Qwen2.5-VL-7B-Instruct", **kwargs) + + assert_images_inline(chat_fun) + assert_images_remote(chat_fun) + + +def test_huggingface_custom_model(): + chat = ChatHuggingFace(model="microsoft/DialoGPT-medium") + assert chat.provider.model == "microsoft/DialoGPT-medium" + + +def test_huggingface_base_url(): + chat = ChatHuggingFace() + assert "huggingface.co" in str(chat.provider._client.base_url) + + +def test_huggingface_provider_name(): + chat = ChatHuggingFace() + assert chat.provider.name == "HuggingFace" diff --git a/tests/test_provider_portkey.py b/tests/test_provider_portkey.py index 96bb7bc0..02f36c89 100644 --- a/tests/test_provider_portkey.py +++ b/tests/test_provider_portkey.py @@ -1,3 +1,5 @@ +import os + import pytest from chatlas import ChatPortkey @@ -14,6 +16,10 @@ assert_turns_system, ) +api_key = os.getenv("PORTKEY_API_KEY") +if api_key is None: + pytest.skip("PORTKEY_API_KEY is not set; skipping tests", allow_module_level=True) + def _chat_portkey_test(**kwargs): model = kwargs.pop("model", "@openai/gpt-4o-mini")