diff --git a/CHANGELOG.md b/CHANGELOG.md index 6aeb4ba2..6859b09c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### New features +* Added `ChatDeepSeek()` for chatting via [DeepSeek](https://www.deepseek.com/). * Added `ChatHuggingFace()` for chatting via [Hugging Face](https://huggingface.co/). (#144) * Added `ChatPortkey()` for chatting via [Portkey AI](https://portkey.ai/). (#143) diff --git a/chatlas/__init__.py b/chatlas/__init__.py index 133fd13c..9266cfa5 100644 --- a/chatlas/__init__.py +++ b/chatlas/__init__.py @@ -8,6 +8,7 @@ from ._provider import Provider from ._provider_anthropic import ChatAnthropic, ChatBedrockAnthropic from ._provider_databricks import ChatDatabricks +from ._provider_deepseek import ChatDeepSeek from ._provider_github import ChatGithub from ._provider_google import ChatGoogle, ChatVertex from ._provider_groq import ChatGroq @@ -31,6 +32,7 @@ "ChatAuto", "ChatBedrockAnthropic", "ChatDatabricks", + "ChatDeepSeek", "ChatGithub", "ChatGoogle", "ChatGroq", diff --git a/chatlas/_provider_deepseek.py b/chatlas/_provider_deepseek.py new file mode 100644 index 00000000..b60f34d8 --- /dev/null +++ b/chatlas/_provider_deepseek.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Optional, cast + +from ._chat import Chat +from ._logging import log_model_default +from ._provider_openai import OpenAIProvider +from ._turn import Turn +from ._utils import MISSING, MISSING_TYPE, is_testing + +if TYPE_CHECKING: + from openai.types.chat import ChatCompletion, ChatCompletionMessageParam + + from .types.openai import ChatClientArgs, SubmitInputArgs + + +def ChatDeepSeek( + *, + system_prompt: Optional[str] = None, + model: Optional[str] = None, + api_key: Optional[str] = None, + base_url: str = "https://api.deepseek.com", + seed: Optional[int] | MISSING_TYPE = MISSING, + kwargs: Optional["ChatClientArgs"] = None, +) -> Chat["SubmitInputArgs", ChatCompletion]: + """ + Chat with a model hosted on DeepSeek. + + DeepSeek is a platform for AI inference with competitive pricing + and performance. + + Prerequisites + ------------- + + ::: {.callout-note} + ## API key + + Sign up at to get an API key. + ::: + + Examples + -------- + + ```python + import os + from chatlas import ChatDeepSeek + + chat = ChatDeepSeek(api_key=os.getenv("DEEPSEEK_API_KEY")) + chat.chat("What is the capital of France?") + ``` + + Known limitations + -------------- + + * Structured data extraction is not supported. + * Images are not supported. + + Parameters + ---------- + system_prompt + A system prompt to set the behavior of the assistant. + model + The model to use for the chat. The default, None, will pick a reasonable + default, and warn you about it. We strongly recommend explicitly choosing + a model for all but the most casual use. + api_key + The API key to use for authentication. You generally should not supply + this directly, but instead set the `DEEPSEEK_API_KEY` environment variable. + base_url + The base URL to the endpoint; the default uses DeepSeek's API. + seed + Optional integer seed that DeepSeek uses to try and make output more + reproducible. + kwargs + Additional arguments to pass to the `openai.OpenAI()` client constructor. + + Returns + ------- + Chat + A chat object that retains the state of the conversation. + + Note + ---- + This function is a lightweight wrapper around [](`~chatlas.ChatOpenAI`) with + the defaults tweaked for DeepSeek. + + Note + ---- + Pasting an API key into a chat constructor (e.g., `ChatDeepSeek(api_key="...")`) + is the simplest way to get started, and is fine for interactive use, but is + problematic for code that may be shared with others. + + Instead, consider using environment variables or a configuration file to manage + your credentials. One popular way to manage credentials is to use a `.env` file + to store your credentials, and then use the `python-dotenv` package to load them + into your environment. + + ```shell + pip install python-dotenv + ``` + + ```shell + # .env + DEEPSEEK_API_KEY=... + ``` + + ```python + from chatlas import ChatDeepSeek + from dotenv import load_dotenv + + load_dotenv() + chat = ChatDeepSeek() + chat.console() + ``` + + Another, more general, solution is to load your environment variables into the shell + before starting Python (maybe in a `.bashrc`, `.zshrc`, etc. file): + + ```shell + export DEEPSEEK_API_KEY=... + ``` + """ + if model is None: + model = log_model_default("deepseek-chat") + + if api_key is None: + api_key = os.getenv("DEEPSEEK_API_KEY") + + if isinstance(seed, MISSING_TYPE): + seed = 1014 if is_testing() else None + + return Chat( + provider=DeepSeekProvider( + api_key=api_key, + model=model, + base_url=base_url, + seed=seed, + name="DeepSeek", + kwargs=kwargs, + ), + system_prompt=system_prompt, + ) + + +class DeepSeekProvider(OpenAIProvider): + @staticmethod + def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]: + from openai.types.chat import ( + ChatCompletionAssistantMessageParam, + ChatCompletionUserMessageParam, + ) + + params = OpenAIProvider._as_message_param(turns) + + # Content must be a string + for i, param in enumerate(params): + if param["role"] in ["assistant", "user"]: + param = cast( + ChatCompletionAssistantMessageParam + | ChatCompletionUserMessageParam, + param, + ) + contents = param.get("content", None) + if not isinstance(contents, list): + continue + params[i]["content"] = "".join( + content.get("text", "") for content in contents + ) + + return params diff --git a/docs/_quarto.yml b/docs/_quarto.yml index 2c7bc351..6a0656ae 100644 --- a/docs/_quarto.yml +++ b/docs/_quarto.yml @@ -117,6 +117,7 @@ quartodoc: - ChatAzureOpenAI - ChatBedrockAnthropic - ChatDatabricks + - ChatDeepSeek - ChatGithub - ChatGoogle - ChatGroq diff --git a/docs/get-started/models.qmd b/docs/get-started/models.qmd index 9c011d1f..f0df7652 100644 --- a/docs/get-started/models.qmd +++ b/docs/get-started/models.qmd @@ -17,6 +17,7 @@ To see the pre-requisites for a given provider, visit the relevant usage page in | Name | Usage | Enterprise? | |--------------------------|----------------------------------------------------------|------------| | Anthropic (Claude) | [`ChatAnthropic()`](../reference/ChatAnthropic.qmd) | | +| DeepSeek | [`ChatDeepSeek()`](../reference/ChatDeepSeek.qmd) | | | GitHub model marketplace | [`ChatGithub()`](../reference/ChatGithub.qmd) | | | Google (Gemini) | [`ChatGoogle()`](../reference/ChatGoogle.qmd) | | | Groq | [`ChatGroq()`](../reference/ChatGroq.qmd) | | diff --git a/tests/test_provider_deepseek.py b/tests/test_provider_deepseek.py new file mode 100644 index 00000000..8d617df1 --- /dev/null +++ b/tests/test_provider_deepseek.py @@ -0,0 +1,62 @@ +import os + +import pytest +from chatlas import ChatDeepSeek + +from .conftest import ( + assert_tools_async, + assert_tools_simple, + assert_turns_existing, + assert_turns_system, +) + +api_key = os.getenv("DEEPSEEK_API_KEY") +if api_key is None: + pytest.skip("DEEPSEEK_API_KEY is not set; skipping tests", allow_module_level=True) + + +def test_deepseek_simple_request(): + chat = ChatDeepSeek( + system_prompt="Be as terse as possible; no punctuation", + ) + chat.chat("What is 1 + 1?") + turn = chat.get_last_turn() + assert turn is not None + assert turn.tokens is not None + assert len(turn.tokens) == 3 + assert turn.tokens[0] >= 10 # More lenient assertion for DeepSeek + assert turn.finish_reason == "stop" + + +@pytest.mark.asyncio +async def test_deepseek_simple_streaming_request(): + chat = ChatDeepSeek( + system_prompt="Be as terse as possible; no punctuation", + ) + res = [] + async for x in await chat.stream_async("What is 1 + 1?"): + res.append(x) + assert "2" in "".join(res) + turn = chat.get_last_turn() + assert turn is not None + assert turn.finish_reason == "stop" + + +def test_deepseek_respects_turns_interface(): + chat_fun = ChatDeepSeek + assert_turns_system(chat_fun) + assert_turns_existing(chat_fun) + + +def test_deepseek_tool_variations(): + chat_fun = ChatDeepSeek + assert_tools_simple(chat_fun) + + +@pytest.mark.asyncio +async def test_deepseek_tool_variations_async(): + chat_fun = ChatDeepSeek + await assert_tools_async(chat_fun) + + +# Doesn't seem to support data extraction or images