Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions chatlas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from . import types
from ._api_headers import ApiHeaders
from ._auto import ChatAuto
from ._batch_chat import (
batch_chat,
Expand Down Expand Up @@ -46,6 +47,7 @@
__version__ = "0.0.0" # stub value for docs

__all__ = (
"ApiHeaders",
"batch_chat",
"batch_chat_completed",
"batch_chat_structured",
Expand Down
36 changes: 36 additions & 0 deletions chatlas/_api_headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations

from typing import Callable, Union

ApiHeadersValue = dict[str, str]
"""A dict of HTTP header name-value pairs."""

ApiHeaders = Union[ApiHeadersValue, Callable[[], ApiHeadersValue]]
"""
Extra HTTP headers to include with every API request.

Can be:

* A **dict** of ``{header_name: header_value}`` — sent as-is on every request.
* A **zero-argument callable** returning such a dict — called on every
request, enabling token refresh and other dynamic auth patterns.
"""


def resolve_api_headers(api_headers: ApiHeaders | None) -> dict[str, str] | None:
"""
Resolve api_headers into a dict of HTTP headers (or None).

Called at request time so that callables can return fresh values.
"""
if api_headers is None:
return None

value = api_headers() if callable(api_headers) else api_headers

if isinstance(value, dict):
return value

raise TypeError(
f"api_headers must be (or return) a dict, got {type(value).__name__}"
)
12 changes: 10 additions & 2 deletions chatlas/_provider_cloudflare.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Callable, Optional

from ._api_headers import ApiHeaders
from ._chat import Chat
from ._logging import log_model_default
from ._provider_openai_completions import OpenAICompletionsProvider
Expand All @@ -18,7 +19,8 @@ def ChatCloudflare(
account: Optional[str] = None,
system_prompt: Optional[str] = None,
model: Optional[str] = None,
api_key: Optional[str] = None,
api_key: Optional[str | Callable[[], str]] = None,
api_headers: Optional[ApiHeaders] = None,
seed: Optional[int] | MISSING_TYPE = MISSING,
kwargs: Optional["ChatClientArgs"] = None,
) -> Chat["SubmitInputArgs", ChatCompletion]:
Expand Down Expand Up @@ -73,6 +75,11 @@ def ChatCloudflare(
The API key to use for authentication. You generally should not supply
this directly, but instead set the `CLOUDFLARE_API_KEY` environment
variable.
api_headers
Extra HTTP headers to include with every chat API request. Can be a dict
of ``{header_name: header_value}`` pairs, or a zero-argument callable
returning such a dict. A callable is invoked on every request,
enabling dynamic auth patterns like token refresh.
seed
Optional integer seed that ChatGPT uses to try and make output more
reproducible.
Expand Down Expand Up @@ -159,6 +166,7 @@ def ChatCloudflare(
base_url=base_url,
seed=seed,
name="Cloudflare",
api_headers=api_headers,
kwargs=kwargs,
),
system_prompt=system_prompt,
Expand Down
12 changes: 10 additions & 2 deletions chatlas/_provider_deepseek.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Callable, Optional, cast

from ._api_headers import ApiHeaders
from ._chat import Chat
from ._logging import log_model_default
from ._provider_openai_completions import OpenAICompletionsProvider
Expand All @@ -19,7 +20,8 @@ def ChatDeepSeek(
*,
system_prompt: Optional[str] = None,
model: Optional[str] = None,
api_key: Optional[str] = None,
api_key: Optional[str | Callable[[], str]] = None,
api_headers: Optional[ApiHeaders] = None,
base_url: str = "https://api.deepseek.com",
seed: Optional[int] | MISSING_TYPE = MISSING,
kwargs: Optional["ChatClientArgs"] = None,
Expand Down Expand Up @@ -67,6 +69,11 @@ def ChatDeepSeek(
api_key
The API key to use for authentication. You generally should not supply
this directly, but instead set the `DEEPSEEK_API_KEY` environment variable.
api_headers
Extra HTTP headers to include with every chat API request. Can be a dict
of ``{header_name: header_value}`` pairs, or a zero-argument callable
returning such a dict. A callable is invoked on every request,
enabling dynamic auth patterns like token refresh.
base_url
The base URL to the endpoint; the default uses DeepSeek's API.
seed
Expand Down Expand Up @@ -138,6 +145,7 @@ def ChatDeepSeek(
seed=seed,
preserve_thinking=True,
name="DeepSeek",
api_headers=api_headers,
kwargs=kwargs,
),
system_prompt=system_prompt,
Expand Down
16 changes: 12 additions & 4 deletions chatlas/_provider_github.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Callable, Optional

import requests

from ._api_headers import ApiHeaders
from ._chat import Chat
from ._logging import log_model_default
from ._provider import ModelInfo
Expand All @@ -20,7 +21,8 @@ def ChatGithub(
*,
system_prompt: Optional[str] = None,
model: Optional[str] = None,
api_key: Optional[str] = None,
api_key: Optional[str | Callable[[], str]] = None,
api_headers: Optional[ApiHeaders] = None,
base_url: str = "https://models.github.ai/inference/",
seed: Optional[int] | MISSING_TYPE = MISSING,
kwargs: Optional["ChatClientArgs"] = None,
Expand Down Expand Up @@ -64,6 +66,11 @@ def ChatGithub(
api_key
The API key to use for authentication. You generally should not supply
this directly, but instead set the `GITHUB_TOKEN` environment variable.
api_headers
Extra HTTP headers to include with every chat API request. Can be a dict
of ``{header_name: header_value}`` pairs, or a zero-argument callable
returning such a dict. A callable is invoked on every request,
enabling dynamic auth patterns like token refresh.
base_url
The base URL to the endpoint; the default uses Github's API.
seed
Expand Down Expand Up @@ -134,15 +141,16 @@ def ChatGithub(
base_url=base_url,
seed=seed,
name="GitHub",
api_headers=api_headers,
kwargs=kwargs,
),
system_prompt=system_prompt,
)


class GitHubProvider(OpenAICompletionsProvider):
def __init__(self, base_url: str, **kwargs):
super().__init__(base_url=base_url, **kwargs)
def __init__(self, base_url: str, api_headers: Optional[ApiHeaders] = None, **kwargs):
super().__init__(base_url=base_url, api_headers=api_headers, **kwargs)
self._base_url = base_url

def list_models(self) -> list[ModelInfo]:
Expand Down
12 changes: 10 additions & 2 deletions chatlas/_provider_groq.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Callable, Optional

from ._api_headers import ApiHeaders
from ._chat import Chat
from ._logging import log_model_default
from ._provider_openai_completions import OpenAICompletionsProvider
Expand All @@ -17,7 +18,8 @@ def ChatGroq(
*,
system_prompt: Optional[str] = None,
model: Optional[str] = None,
api_key: Optional[str] = None,
api_key: Optional[str | Callable[[], str]] = None,
api_headers: Optional[ApiHeaders] = None,
base_url: str = "https://api.groq.com/openai/v1",
seed: Optional[int] | MISSING_TYPE = MISSING,
kwargs: Optional["ChatClientArgs"] = None,
Expand Down Expand Up @@ -58,6 +60,11 @@ def ChatGroq(
api_key
The API key to use for authentication. You generally should not supply
this directly, but instead set the `GROQ_API_KEY` environment variable.
api_headers
Extra HTTP headers to include with every chat API request. Can be a dict
of ``{header_name: header_value}`` pairs, or a zero-argument callable
returning such a dict. A callable is invoked on every request,
enabling dynamic auth patterns like token refresh.
base_url
The base URL to the endpoint; the default uses Groq's API.
seed
Expand Down Expand Up @@ -128,6 +135,7 @@ def ChatGroq(
base_url=base_url,
seed=seed,
name="Groq",
api_headers=api_headers,
kwargs=kwargs,
),
system_prompt=system_prompt,
Expand Down
16 changes: 13 additions & 3 deletions chatlas/_provider_huggingface.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Callable, Optional

from ._api_headers import ApiHeaders
from ._chat import Chat
from ._logging import log_model_default
from ._provider_openai_completions import OpenAICompletionsProvider
Expand All @@ -17,7 +18,8 @@ def ChatHuggingFace(
*,
system_prompt: Optional[str] = None,
model: Optional[str] = None,
api_key: Optional[str] = None,
api_key: Optional[str | Callable[[], str]] = None,
api_headers: Optional[ApiHeaders] = None,
kwargs: Optional["ChatClientArgs"] = None,
) -> Chat["SubmitInputArgs", ChatCompletion]:
"""
Expand Down Expand Up @@ -63,6 +65,11 @@ def ChatHuggingFace(
The API key to use for authentication. You generally should not supply
this directly, but instead set the `HUGGINGFACE_API_KEY` environment
variable.
api_headers
Extra HTTP headers to include with every chat API request. Can be a dict
of ``{header_name: header_value}`` pairs, or a zero-argument callable
returning such a dict. A callable is invoked on every request,
enabling dynamic auth patterns like token refresh.
kwargs
Additional arguments to pass to the underlying OpenAI client
constructor.
Expand Down Expand Up @@ -131,6 +138,7 @@ def ChatHuggingFace(
provider=HuggingFaceProvider(
api_key=api_key,
model=model,
api_headers=api_headers,
kwargs=kwargs,
),
system_prompt=system_prompt,
Expand All @@ -141,8 +149,9 @@ class HuggingFaceProvider(OpenAICompletionsProvider):
def __init__(
self,
*,
api_key: Optional[str] = None,
api_key: Optional[str | Callable[[], str]] = None,
model: str,
api_headers: Optional[ApiHeaders] = None,
kwargs: Optional["ChatClientArgs"] = None,
):
# https://huggingface.co/docs/inference-providers/en/index?python-clients=requests#http--curl
Expand All @@ -151,5 +160,6 @@ def __init__(
model=model,
api_key=api_key,
base_url="https://router.huggingface.co/v1",
api_headers=api_headers,
kwargs=kwargs,
)
21 changes: 16 additions & 5 deletions chatlas/_provider_lmstudio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import os
import re
import urllib.request
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Callable, Optional

import orjson

from ._api_headers import ApiHeaders
from ._chat import Chat
from ._provider import ModelInfo
from ._provider_openai_completions import OpenAICompletionsProvider
Expand All @@ -22,7 +23,8 @@ def ChatLMStudio(
*,
system_prompt: Optional[str] = None,
base_url: str = "http://localhost:1234",
api_key: Optional[str] = None,
api_key: Optional[str | Callable[[], str]] = None,
api_headers: Optional[ApiHeaders] = None,
seed: int | None | MISSING_TYPE = MISSING,
kwargs: Optional["ChatClientArgs"] = None,
) -> "Chat[SubmitInputArgs, ChatCompletion]":
Expand Down Expand Up @@ -79,6 +81,11 @@ def ChatLMStudio(
usage. If you're accessing an LM Studio instance behind a reverse proxy
or secured endpoint that enforces bearer-token authentication, you can
set the `LMSTUDIO_API_KEY` environment variable or provide a value here.
api_headers
Extra HTTP headers to include with every chat API request. Can be a dict
of ``{header_name: header_value}`` pairs, or a zero-argument callable
returning such a dict. A callable is invoked on every request,
enabling dynamic auth patterns like token refresh.
seed
Optional integer seed that helps to make output more reproducible.
kwargs
Expand All @@ -95,10 +102,12 @@ def ChatLMStudio(
if api_key is None:
api_key = os.getenv("LMSTUDIO_API_KEY", "")

if not has_lmstudio(base_url, api_key=api_key):
resolved_key = api_key() if callable(api_key) else api_key

if not has_lmstudio(base_url, api_key=resolved_key):
raise RuntimeError("Can't find locally running LM Studio.")

models = lmstudio_model_info(base_url, api_key=api_key)
models = lmstudio_model_info(base_url, api_key=resolved_key)
model_ids = [m["id"] for m in models]

if model is None:
Expand All @@ -123,20 +132,22 @@ def ChatLMStudio(
base_url=base_url,
seed=seed,
name="LM Studio",
api_headers=api_headers,
kwargs=kwargs,
),
system_prompt=system_prompt,
)


class LMStudioProvider(OpenAICompletionsProvider):
def __init__(self, *, api_key, model, base_url, seed, name, kwargs):
def __init__(self, *, api_key: str | Callable[[], str], model, base_url, seed, name, api_headers=None, kwargs):
super().__init__(
api_key=api_key,
model=model,
base_url=f"{base_url}/v1",
seed=seed,
name=name,
api_headers=api_headers,
kwargs=kwargs,
)
self.base_url = base_url
Expand Down
Loading
Loading