Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.prism.log
.vscode
_dev

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ It is generated with [Stainless](https://www.stainlessapi.com/).

## Documentation

The REST API documentation can be found [on docs.together.ai](https://docs.together.ai/). The full API of this library can be found in [api.md](api.md).
The REST API documentation can be found [on docs.together.ai](https://docs.together.ai). The full API of this library can be found in [api.md](api.md).

## Installation

Expand Down
16 changes: 16 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ dev-dependencies = [
"nox",
"dirty-equals>=0.6.0",
"importlib-metadata>=6.7.0",
"rich>=13.7.1",

]

Expand Down Expand Up @@ -99,6 +100,21 @@ include = [
[tool.hatch.build.targets.wheel]
packages = ["src/together"]

[tool.hatch.build.targets.sdist]
# Basically everything except hidden files/directories (such as .github, .devcontainers, .python-version, etc)
include = [
"/*.toml",
"/*.json",
"/*.lock",
"/*.md",
"/mypy.ini",
"/noxfile.py",
"bin/*",
"examples/*",
"src/*",
"tests/*",
]

[tool.hatch.metadata.hooks.fancy-pypi-readme]
content-type = "text/markdown"

Expand Down
10 changes: 9 additions & 1 deletion requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
-e file:.
annotated-types==0.6.0
# via pydantic
anyio==4.1.0
anyio==4.4.0
# via httpx
# via together
argcomplete==3.1.2
Expand Down Expand Up @@ -44,6 +44,10 @@ idna==3.4
importlib-metadata==7.0.0
iniconfig==2.0.0
# via pytest
markdown-it-py==3.0.0
# via rich
mdurl==0.1.2
# via markdown-it-py
mypy==1.7.1
mypy-extensions==1.0.0
# via mypy
Expand All @@ -63,6 +67,8 @@ pydantic==2.7.1
# via together
pydantic-core==2.18.2
# via pydantic
pygments==2.18.0
# via rich
pyright==1.1.364
pytest==7.1.1
# via pytest-asyncio
Expand All @@ -72,6 +78,7 @@ python-dateutil==2.8.2
pytz==2023.3.post1
# via dirty-equals
respx==0.20.2
rich==13.7.1
ruff==0.1.9
setuptools==68.2.2
# via nodeenv
Expand All @@ -86,6 +93,7 @@ tomli==2.0.1
# via mypy
# via pytest
typing-extensions==4.8.0
# via anyio
# via mypy
# via pydantic
# via pydantic-core
Expand Down
3 changes: 2 additions & 1 deletion requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
-e file:.
annotated-types==0.6.0
# via pydantic
anyio==4.1.0
anyio==4.4.0
# via httpx
# via together
certifi==2023.7.22
Expand Down Expand Up @@ -38,6 +38,7 @@ sniffio==1.3.0
# via httpx
# via together
typing-extensions==4.8.0
# via anyio
# via pydantic
# via pydantic-core
# via together
4 changes: 2 additions & 2 deletions src/together/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def qs(self) -> Querystring:
@override
def auth_headers(self) -> dict[str, str]:
api_key = self.api_key
return {"Authorization": api_key}
return {"Authorization": f"Bearer {api_key}"}

@property
@override
Expand Down Expand Up @@ -313,7 +313,7 @@ def qs(self) -> Querystring:
@override
def auth_headers(self) -> dict[str, str]:
api_key = self.api_key
return {"Authorization": api_key}
return {"Authorization": f"Bearer {api_key}"}

@property
@override
Expand Down
27 changes: 27 additions & 0 deletions src/together/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ClassVar,
Protocol,
Required,
ParamSpec,
TypedDict,
TypeGuard,
final,
Expand Down Expand Up @@ -67,6 +68,9 @@
__all__ = ["BaseModel", "GenericModel"]

_T = TypeVar("_T")
_BaseModelT = TypeVar("_BaseModelT", bound="BaseModel")

P = ParamSpec("P")


@runtime_checkable
Expand Down Expand Up @@ -379,6 +383,29 @@ def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericMo
return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)


def build(
base_model_cls: Callable[P, _BaseModelT],
*args: P.args,
**kwargs: P.kwargs,
) -> _BaseModelT:
"""Construct a BaseModel class without validation.
This is useful for cases where you need to instantiate a `BaseModel`
from an API response as this provides type-safe params which isn't supported
by helpers like `construct_type()`.
```py
build(MyModel, my_field_a="foo", my_field_b=123)
```
"""
if args:
raise TypeError(
"Received positional arguments which are not supported; Keyword arguments must be used instead",
)

return cast(_BaseModelT, construct_type(type_=base_model_cls, value=kwargs))


def construct_type(*, value: object, type_: object) -> object:
"""Loose coercion to the expected type with construction of nested values.
Expand Down
5 changes: 4 additions & 1 deletion src/together/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,7 @@
maybe_transform as maybe_transform,
async_maybe_transform as async_maybe_transform,
)
from ._reflection import function_has_argument as function_has_argument
from ._reflection import (
function_has_argument as function_has_argument,
assert_signatures_in_sync as assert_signatures_in_sync,
)
34 changes: 34 additions & 0 deletions src/together/_utils/_reflection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import inspect
from typing import Any, Callable

Expand All @@ -6,3 +8,35 @@ def function_has_argument(func: Callable[..., Any], arg_name: str) -> bool:
"""Returns whether or not the given function has a specific parameter"""
sig = inspect.signature(func)
return arg_name in sig.parameters


def assert_signatures_in_sync(
source_func: Callable[..., Any],
check_func: Callable[..., Any],
*,
exclude_params: set[str] = set(),
) -> None:
"""Ensure that the signature of the second function matches the first."""

check_sig = inspect.signature(check_func)
source_sig = inspect.signature(source_func)

errors: list[str] = []

for name, source_param in source_sig.parameters.items():
if name in exclude_params:
continue

custom_param = check_sig.parameters.get(name)
if not custom_param:
errors.append(f"the `{name}` param is missing")
continue

if custom_param.annotation != source_param.annotation:
errors.append(
f"types for the `{name}` param are do not match; source={repr(source_param.annotation)} checking={repr(source_param.annotation)}"
)
continue

if errors:
raise AssertionError(f"{len(errors)} errors encountered when comparing signatures:\n\n" + "\n\n".join(errors))
4 changes: 2 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def test_default_headers_option(self) -> None:
def test_validate_headers(self) -> None:
client = Together(base_url=base_url, api_key=api_key, _strict_response_validation=True)
request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
assert request.headers.get("Authorization") == api_key
assert request.headers.get("Authorization") == f"Bearer {api_key}"

with pytest.raises(TogetherError):
client2 = Together(base_url=base_url, api_key=None, _strict_response_validation=True)
Expand Down Expand Up @@ -1048,7 +1048,7 @@ def test_default_headers_option(self) -> None:
def test_validate_headers(self) -> None:
client = AsyncTogether(base_url=base_url, api_key=api_key, _strict_response_validation=True)
request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
assert request.headers.get("Authorization") == api_key
assert request.headers.get("Authorization") == f"Bearer {api_key}"

with pytest.raises(TogetherError):
client2 = AsyncTogether(base_url=base_url, api_key=None, _strict_response_validation=True)
Expand Down