diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4fd5e4d1..66e067e1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,6 +7,10 @@ on: - 'integrated/**' - 'stl-preview-head/**' - 'stl-preview-base/**' + pull_request: + branches-ignore: + - 'stl-preview-head/**' + - 'stl-preview-base/**' jobs: lint: diff --git a/.release-please-manifest.json b/.release-please-manifest.json index fd0ccba9..000572ec 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "0.1.0-alpha.12" + ".": "0.1.0-alpha.13" } \ No newline at end of file diff --git a/.stats.yml b/.stats.yml index 144618cc..fcc638d2 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ -configured_endpoints: 30 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/togetherai%2Ftogetherai-653a4aa26fdd2b335d1ead9c2ea0672cbe48a7616b76bf350a2421a8def4e08d.yml -openapi_spec_hash: 1d5af8ab9d8c11d7f5225e19ebd1654a -config_hash: d15dd709dd3f87b0a8b83b00b4abc881 +configured_endpoints: 33 +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/togetherai%2Ftogetherai-9ca24c17ccf9b0b4c2d27c09881dc74bf4cb44efc7a5ccb7d54fa15caee095d1.yml +openapi_spec_hash: 306c08678a0677f1deb1d35def6f8713 +config_hash: fa59b4c1ab0a2e74aa855e7227b10c7d diff --git a/CHANGELOG.md b/CHANGELOG.md index 7dc4298a..78685d07 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,36 @@ # Changelog +## 0.1.0-alpha.13 (2025-06-20) + +Full Changelog: [v0.1.0-alpha.12...v0.1.0-alpha.13](https://github.com/togethercomputer/together-py/compare/v0.1.0-alpha.12...v0.1.0-alpha.13) + +### Features + +* **api:** add batch api to config ([07299cc](https://github.com/togethercomputer/together-py/commit/07299cc337cb356076643df7fc070b2fd8e85c54)) +* **api:** api update ([249669c](https://github.com/togethercomputer/together-py/commit/249669c03db384d71c04fe69f78a579b5235c54c)) +* **client:** add support for aiohttp ([8e4cedf](https://github.com/togethercomputer/together-py/commit/8e4cedf646520031811a97f65460f41b61894dd9)) + + +### Bug Fixes + +* **client:** correctly parse binary response | stream ([7b9486c](https://github.com/togethercomputer/together-py/commit/7b9486c29ef0eeb862460d1ee82417db9a8f801f)) +* **tests:** fix: tests which call HTTP endpoints directly with the example parameters ([82b2dcb](https://github.com/togethercomputer/together-py/commit/82b2dcb43af96a7339b2305d02486d3084850303)) + + +### Chores + +* change publish docs url ([8fac9f3](https://github.com/togethercomputer/together-py/commit/8fac9f3e12630ed88b68c6cb7d798ebcc6a88833)) +* **ci:** enable for pull requests ([6e4d972](https://github.com/togethercomputer/together-py/commit/6e4d972a3a3094fb2d8d468d1e3e89b173ce6ffd)) +* **internal:** update conftest.py ([2b13ac4](https://github.com/togethercomputer/together-py/commit/2b13ac4298cc44c0515a3aa348cfdb4bc63d9cb2)) +* **readme:** update badges ([acfabb5](https://github.com/togethercomputer/together-py/commit/acfabb57a60aab2853283f62d72897a8bb95a778)) +* **tests:** add tests for httpx client instantiation & proxies ([30ba23e](https://github.com/togethercomputer/together-py/commit/30ba23e549ed87a82a7e49164b1809388486754b)) +* **tests:** run tests in parallel ([7efb923](https://github.com/togethercomputer/together-py/commit/7efb923a6802382cdfe676c1124e6b9dafd8e233)) + + +### Documentation + +* **client:** fix httpx.Timeout documentation reference ([bed4e88](https://github.com/togethercomputer/together-py/commit/bed4e88653ff35029c1921bd2d940abade5b00c0)) + ## 0.1.0-alpha.12 (2025-06-10) Full Changelog: [v0.1.0-alpha.11...v0.1.0-alpha.12](https://github.com/togethercomputer/together-py/compare/v0.1.0-alpha.11...v0.1.0-alpha.12) diff --git a/README.md b/README.md index 35403102..3e330c27 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Together Python API library -[![PyPI version](https://img.shields.io/pypi/v/together.svg)](https://pypi.org/project/together/) +[![PyPI version]()](https://pypi.org/project/together/) The Together Python library provides convenient access to the Together REST API from any Python 3.8+ application. The library includes type definitions for all request params and response fields, @@ -20,7 +20,7 @@ pip install git+ssh://git@github.com/togethercomputer/together-py.git ``` > [!NOTE] -> Once this package is [published to PyPI](https://app.stainless.com/docs/guides/publish), this will become: `pip install --pre together` +> Once this package is [published to PyPI](https://www.stainless.com/docs/guides/publish), this will become: `pip install --pre together` ## Usage @@ -83,6 +83,46 @@ asyncio.run(main()) Functionality between the synchronous and asynchronous clients is otherwise identical. +### With aiohttp + +By default, the async client uses `httpx` for HTTP requests. However, for improved concurrency performance you may also use `aiohttp` as the HTTP backend. + +You can enable this by installing `aiohttp`: + +```sh +# install from the production repo +pip install 'together[aiohttp] @ git+ssh://git@github.com/togethercomputer/together-py.git' +``` + +Then you can enable it by instantiating the client with `http_client=DefaultAioHttpClient()`: + +```python +import os +import asyncio +from together import DefaultAioHttpClient +from together import AsyncTogether + + +async def main() -> None: + async with AsyncTogether( + api_key=os.environ.get("TOGETHER_API_KEY"), # This is the default and can be omitted + http_client=DefaultAioHttpClient(), + ) as client: + chat_completion = await client.chat.completions.create( + messages=[ + { + "role": "user", + "content": "Say this is a test!", + } + ], + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + ) + print(chat_completion.choices) + + +asyncio.run(main()) +``` + ## Streaming responses We provide support for streaming responses using Server Side Events (SSE). @@ -258,7 +298,7 @@ client.with_options(max_retries=5).chat.completions.create( ### Timeouts By default requests time out after 1 minute. You can configure this with a `timeout` option, -which accepts a float or an [`httpx.Timeout`](https://www.python-httpx.org/advanced/#fine-tuning-the-configuration) object: +which accepts a float or an [`httpx.Timeout`](https://www.python-httpx.org/advanced/timeouts/#fine-tuning-the-configuration) object: ```python from together import Together diff --git a/api.md b/api.md index 6a918eec..729fb65f 100644 --- a/api.md +++ b/api.md @@ -220,3 +220,17 @@ from together.types import HardwareListResponse Methods: - client.hardware.list(\*\*params) -> HardwareListResponse + +# Batches + +Types: + +```python +from together.types import BatchCreateResponse, BatchRetrieveResponse, BatchListResponse +``` + +Methods: + +- client.batches.create(\*\*params) -> BatchCreateResponse +- client.batches.retrieve(id) -> BatchRetrieveResponse +- client.batches.list() -> BatchListResponse diff --git a/pyproject.toml b/pyproject.toml index 60483ff9..a4dd32f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "together" -version = "0.1.0-alpha.12" +version = "0.1.0-alpha.13" description = "The official Python library for the together API" dynamic = ["readme"] license = "Apache-2.0" @@ -45,6 +45,9 @@ classifiers = [ Homepage = "https://github.com/togethercomputer/together-py" Repository = "https://github.com/togethercomputer/together-py" +[project.optional-dependencies] +aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.6"] + [project.scripts] together = "together.lib.cli.cli:main" @@ -64,6 +67,7 @@ dev-dependencies = [ "importlib-metadata>=6.7.0", "rich>=13.7.1", "nest_asyncio==1.6.0", + "pytest-xdist>=3.6.1", "pytest-mock>=3.14.0", ] @@ -136,7 +140,7 @@ replacement = '[\1](https://github.com/togethercomputer/together-py/tree/main/\g [tool.pytest.ini_options] testpaths = ["tests"] -addopts = "--tb=short" +addopts = "--tb=short -n auto" xfail_strict = true asyncio_mode = "auto" asyncio_default_fixture_loop_scope = "session" diff --git a/requirements-dev.lock b/requirements-dev.lock index 81b0961a..6d7ab8dc 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -10,6 +10,13 @@ # universal: false -e file:. +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.12.13 + # via httpx-aiohttp + # via together +aiosignal==1.3.2 + # via aiohttp annotated-types==0.6.0 # via pydantic anyio==4.4.0 @@ -17,6 +24,10 @@ anyio==4.4.0 # via together argcomplete==3.1.2 # via nox +async-timeout==5.0.1 + # via aiohttp +attrs==25.3.0 + # via aiohttp certifi==2023.7.22 # via httpcore # via httpx @@ -32,18 +43,27 @@ distro==1.8.0 exceptiongroup==1.2.2 # via anyio # via pytest +execnet==2.1.1 + # via pytest-xdist filelock==3.12.4 # via virtualenv +frozenlist==1.7.0 + # via aiohttp + # via aiosignal h11==0.14.0 # via httpcore httpcore==1.0.2 # via httpx httpx==0.28.1 + # via httpx-aiohttp # via respx # via together +httpx-aiohttp==0.1.6 + # via together idna==3.4 # via anyio # via httpx + # via yarl importlib-metadata==7.0.0 iniconfig==2.0.0 # via pytest @@ -51,6 +71,9 @@ markdown-it-py==3.0.0 # via rich mdurl==0.1.2 # via markdown-it-py +multidict==6.5.0 + # via aiohttp + # via yarl mypy==1.14.1 mypy-extensions==1.0.0 # via mypy @@ -69,6 +92,9 @@ platformdirs==3.11.0 # via virtualenv pluggy==1.5.0 # via pytest +propcache==0.3.2 + # via aiohttp + # via yarl pyarrow==16.1.0 # via together pyarrow-stubs==10.0.1.7 @@ -83,8 +109,10 @@ pyright==1.1.399 pytest==8.3.3 # via pytest-asyncio # via pytest-mock + # via pytest-xdist pytest-asyncio==0.24.0 pytest-mock==3.14.0 +pytest-xdist==3.7.0 python-dateutil==2.8.2 # via time-machine pytz==2023.3.post1 @@ -115,6 +143,7 @@ types-tqdm==4.67.0.20250516 # via together typing-extensions==4.12.2 # via anyio + # via multidict # via mypy # via pydantic # via pydantic-core @@ -124,5 +153,7 @@ urllib3==2.4.0 # via types-requests virtualenv==20.24.5 # via nox +yarl==1.20.1 + # via aiohttp zipp==3.17.0 # via importlib-metadata diff --git a/requirements.lock b/requirements.lock index d7f9180b..ab4313ce 100644 --- a/requirements.lock +++ b/requirements.lock @@ -10,11 +10,22 @@ # universal: false -e file:. +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.12.13 + # via httpx-aiohttp + # via together +aiosignal==1.3.2 + # via aiohttp annotated-types==0.6.0 # via pydantic anyio==4.4.0 # via httpx # via together +async-timeout==5.0.1 + # via aiohttp +attrs==25.3.0 + # via aiohttp certifi==2023.7.22 # via httpcore # via httpx @@ -24,19 +35,32 @@ distro==1.8.0 # via together exceptiongroup==1.2.2 # via anyio +frozenlist==1.7.0 + # via aiohttp + # via aiosignal h11==0.14.0 # via httpcore httpcore==1.0.2 # via httpx httpx==0.28.1 + # via httpx-aiohttp + # via together +httpx-aiohttp==0.1.6 # via together idna==3.4 # via anyio # via httpx + # via yarl +multidict==6.5.0 + # via aiohttp + # via yarl numpy==2.0.0 # via pyarrow pillow==10.4.0 # via together +propcache==0.3.2 + # via aiohttp + # via yarl pyarrow==16.1.0 # via together pyarrow-stubs==10.0.1.7 @@ -60,8 +84,11 @@ types-tqdm==4.67.0.20250516 # via together typing-extensions==4.12.2 # via anyio + # via multidict # via pydantic # via pydantic-core # via together urllib3==2.4.0 # via types-requests +yarl==1.20.1 + # via aiohttp diff --git a/src/together/__init__.py b/src/together/__init__.py index 50964a8c..2fd17bda 100644 --- a/src/together/__init__.py +++ b/src/together/__init__.py @@ -42,7 +42,7 @@ UnprocessableEntityError, APIResponseValidationError, ) -from ._base_client import DefaultHttpxClient, DefaultAsyncHttpxClient +from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient from ._utils._logs import setup_logging as _setup_logging __all__ = [ @@ -84,6 +84,7 @@ "DEFAULT_CONNECTION_LIMITS", "DefaultHttpxClient", "DefaultAsyncHttpxClient", + "DefaultAioHttpClient", "create_finetune_request", "FinetuneTrainingLimits", "DownloadError", diff --git a/src/together/_base_client.py b/src/together/_base_client.py index ad80ed98..b73339b5 100644 --- a/src/together/_base_client.py +++ b/src/together/_base_client.py @@ -1071,7 +1071,14 @@ def _process_response( ) -> ResponseT: origin = get_origin(cast_to) or cast_to - if inspect.isclass(origin) and issubclass(origin, BaseAPIResponse): + if ( + inspect.isclass(origin) + and issubclass(origin, BaseAPIResponse) + # we only want to actually return the custom BaseAPIResponse class if we're + # returning the raw response, or if we're not streaming SSE, as if we're streaming + # SSE then `cast_to` doesn't actively reflect the type we need to parse into + and (not stream or bool(response.request.headers.get(RAW_RESPONSE_HEADER))) + ): if not issubclass(origin, APIResponse): raise TypeError(f"API Response types must subclass {APIResponse}; Received {origin}") @@ -1282,6 +1289,24 @@ def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) +try: + import httpx_aiohttp +except ImportError: + + class _DefaultAioHttpClient(httpx.AsyncClient): + def __init__(self, **_kwargs: Any) -> None: + raise RuntimeError("To use the aiohttp client you must have installed the package with the `aiohttp` extra") +else: + + class _DefaultAioHttpClient(httpx_aiohttp.HttpxAiohttpClient): # type: ignore + def __init__(self, **kwargs: Any) -> None: + kwargs.setdefault("timeout", DEFAULT_TIMEOUT) + kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS) + kwargs.setdefault("follow_redirects", True) + + super().__init__(**kwargs) + + if TYPE_CHECKING: DefaultAsyncHttpxClient = httpx.AsyncClient """An alias to `httpx.AsyncClient` that provides the same defaults that this SDK @@ -1290,8 +1315,12 @@ def __init__(self, **kwargs: Any) -> None: This is useful because overriding the `http_client` with your own instance of `httpx.AsyncClient` will result in httpx's defaults being used, not ours. """ + + DefaultAioHttpClient = httpx.AsyncClient + """An alias to `httpx.AsyncClient` that changes the default HTTP transport to `aiohttp`.""" else: DefaultAsyncHttpxClient = _DefaultAsyncHttpxClient + DefaultAioHttpClient = _DefaultAioHttpClient class AsyncHttpxClientWrapper(DefaultAsyncHttpxClient): @@ -1574,7 +1603,14 @@ async def _process_response( ) -> ResponseT: origin = get_origin(cast_to) or cast_to - if inspect.isclass(origin) and issubclass(origin, BaseAPIResponse): + if ( + inspect.isclass(origin) + and issubclass(origin, BaseAPIResponse) + # we only want to actually return the custom BaseAPIResponse class if we're + # returning the raw response, or if we're not streaming SSE, as if we're streaming + # SSE then `cast_to` doesn't actively reflect the type we need to parse into + and (not stream or bool(response.request.headers.get(RAW_RESPONSE_HEADER))) + ): if not issubclass(origin, AsyncAPIResponse): raise TypeError(f"API Response types must subclass {AsyncAPIResponse}; Received {origin}") diff --git a/src/together/_client.py b/src/together/_client.py index f2e64aed..0690aa27 100644 --- a/src/together/_client.py +++ b/src/together/_client.py @@ -36,7 +36,19 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) -from .resources import jobs, audio, files, images, models, hardware, endpoints, fine_tune, embeddings, completions +from .resources import ( + jobs, + audio, + files, + images, + models, + batches, + hardware, + endpoints, + fine_tune, + embeddings, + completions, +) from ._streaming import Stream as Stream, AsyncStream as AsyncStream from ._exceptions import TogetherError, APIStatusError from ._base_client import ( @@ -74,6 +86,7 @@ class Together(SyncAPIClient): jobs: jobs.JobsResource endpoints: endpoints.EndpointsResource hardware: hardware.HardwareResource + batches: batches.BatchesResource with_raw_response: TogetherWithRawResponse with_streaming_response: TogetherWithStreamedResponse @@ -145,6 +158,7 @@ def __init__( self.jobs = jobs.JobsResource(self) self.endpoints = endpoints.EndpointsResource(self) self.hardware = hardware.HardwareResource(self) + self.batches = batches.BatchesResource(self) self.with_raw_response = TogetherWithRawResponse(self) self.with_streaming_response = TogetherWithStreamedResponse(self) @@ -328,6 +342,7 @@ class AsyncTogether(AsyncAPIClient): jobs: jobs.AsyncJobsResource endpoints: endpoints.AsyncEndpointsResource hardware: hardware.AsyncHardwareResource + batches: batches.AsyncBatchesResource with_raw_response: AsyncTogetherWithRawResponse with_streaming_response: AsyncTogetherWithStreamedResponse @@ -399,6 +414,7 @@ def __init__( self.jobs = jobs.AsyncJobsResource(self) self.endpoints = endpoints.AsyncEndpointsResource(self) self.hardware = hardware.AsyncHardwareResource(self) + self.batches = batches.AsyncBatchesResource(self) self.with_raw_response = AsyncTogetherWithRawResponse(self) self.with_streaming_response = AsyncTogetherWithStreamedResponse(self) @@ -583,6 +599,7 @@ def __init__(self, client: Together) -> None: self.jobs = jobs.JobsResourceWithRawResponse(client.jobs) self.endpoints = endpoints.EndpointsResourceWithRawResponse(client.endpoints) self.hardware = hardware.HardwareResourceWithRawResponse(client.hardware) + self.batches = batches.BatchesResourceWithRawResponse(client.batches) self.rerank = to_raw_response_wrapper( client.rerank, @@ -603,6 +620,7 @@ def __init__(self, client: AsyncTogether) -> None: self.jobs = jobs.AsyncJobsResourceWithRawResponse(client.jobs) self.endpoints = endpoints.AsyncEndpointsResourceWithRawResponse(client.endpoints) self.hardware = hardware.AsyncHardwareResourceWithRawResponse(client.hardware) + self.batches = batches.AsyncBatchesResourceWithRawResponse(client.batches) self.rerank = async_to_raw_response_wrapper( client.rerank, @@ -623,6 +641,7 @@ def __init__(self, client: Together) -> None: self.jobs = jobs.JobsResourceWithStreamingResponse(client.jobs) self.endpoints = endpoints.EndpointsResourceWithStreamingResponse(client.endpoints) self.hardware = hardware.HardwareResourceWithStreamingResponse(client.hardware) + self.batches = batches.BatchesResourceWithStreamingResponse(client.batches) self.rerank = to_streamed_response_wrapper( client.rerank, @@ -645,6 +664,7 @@ def __init__(self, client: AsyncTogether) -> None: self.jobs = jobs.AsyncJobsResourceWithStreamingResponse(client.jobs) self.endpoints = endpoints.AsyncEndpointsResourceWithStreamingResponse(client.endpoints) self.hardware = hardware.AsyncHardwareResourceWithStreamingResponse(client.hardware) + self.batches = batches.AsyncBatchesResourceWithStreamingResponse(client.batches) self.rerank = async_to_streamed_response_wrapper( client.rerank, diff --git a/src/together/_version.py b/src/together/_version.py index 08f0eaca..444bb44a 100644 --- a/src/together/_version.py +++ b/src/together/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "together" -__version__ = "0.1.0-alpha.12" # x-release-please-version +__version__ = "0.1.0-alpha.13" # x-release-please-version diff --git a/src/together/resources/__init__.py b/src/together/resources/__init__.py index bd3e4c51..c94aa657 100644 --- a/src/together/resources/__init__.py +++ b/src/together/resources/__init__.py @@ -48,6 +48,14 @@ ModelsResourceWithStreamingResponse, AsyncModelsResourceWithStreamingResponse, ) +from .batches import ( + BatchesResource, + AsyncBatchesResource, + BatchesResourceWithRawResponse, + AsyncBatchesResourceWithRawResponse, + BatchesResourceWithStreamingResponse, + AsyncBatchesResourceWithStreamingResponse, +) from .hardware import ( HardwareResource, AsyncHardwareResource, @@ -170,4 +178,10 @@ "AsyncHardwareResourceWithRawResponse", "HardwareResourceWithStreamingResponse", "AsyncHardwareResourceWithStreamingResponse", + "BatchesResource", + "AsyncBatchesResource", + "BatchesResourceWithRawResponse", + "AsyncBatchesResourceWithRawResponse", + "BatchesResourceWithStreamingResponse", + "AsyncBatchesResourceWithStreamingResponse", ] diff --git a/src/together/resources/batches.py b/src/together/resources/batches.py new file mode 100644 index 00000000..48fafe00 --- /dev/null +++ b/src/together/resources/batches.py @@ -0,0 +1,339 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import httpx + +from ..types import batch_create_params +from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from .._utils import maybe_transform, async_maybe_transform +from .._compat import cached_property +from .._resource import SyncAPIResource, AsyncAPIResource +from .._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) +from .._base_client import make_request_options +from ..types.batch_list_response import BatchListResponse +from ..types.batch_create_response import BatchCreateResponse +from ..types.batch_retrieve_response import BatchRetrieveResponse + +__all__ = ["BatchesResource", "AsyncBatchesResource"] + + +class BatchesResource(SyncAPIResource): + @cached_property + def with_raw_response(self) -> BatchesResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/togethercomputer/together-py#accessing-raw-response-data-eg-headers + """ + return BatchesResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> BatchesResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/togethercomputer/together-py#with_streaming_response + """ + return BatchesResourceWithStreamingResponse(self) + + def create( + self, + *, + endpoint: str, + input_file_id: str, + completion_window: str | NotGiven = NOT_GIVEN, + model_id: str | NotGiven = NOT_GIVEN, + priority: int | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> BatchCreateResponse: + """ + Create a new batch job with the given input file and endpoint + + Args: + endpoint: The endpoint to use for batch processing + + input_file_id: ID of the uploaded input file containing batch requests + + completion_window: Time window for batch completion (optional) + + model_id: Model to use for processing batch requests + + priority: Priority for batch processing (optional) + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return self._post( + "/batches", + body=maybe_transform( + { + "endpoint": endpoint, + "input_file_id": input_file_id, + "completion_window": completion_window, + "model_id": model_id, + "priority": priority, + }, + batch_create_params.BatchCreateParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=BatchCreateResponse, + ) + + def retrieve( + self, + id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> BatchRetrieveResponse: + """ + Get details of a batch job by ID + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") + return self._get( + f"/batches/{id}", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=BatchRetrieveResponse, + ) + + def list( + self, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> BatchListResponse: + """List all batch jobs for the authenticated user""" + return self._get( + "/batches", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=BatchListResponse, + ) + + +class AsyncBatchesResource(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncBatchesResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/togethercomputer/together-py#accessing-raw-response-data-eg-headers + """ + return AsyncBatchesResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncBatchesResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/togethercomputer/together-py#with_streaming_response + """ + return AsyncBatchesResourceWithStreamingResponse(self) + + async def create( + self, + *, + endpoint: str, + input_file_id: str, + completion_window: str | NotGiven = NOT_GIVEN, + model_id: str | NotGiven = NOT_GIVEN, + priority: int | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> BatchCreateResponse: + """ + Create a new batch job with the given input file and endpoint + + Args: + endpoint: The endpoint to use for batch processing + + input_file_id: ID of the uploaded input file containing batch requests + + completion_window: Time window for batch completion (optional) + + model_id: Model to use for processing batch requests + + priority: Priority for batch processing (optional) + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return await self._post( + "/batches", + body=await async_maybe_transform( + { + "endpoint": endpoint, + "input_file_id": input_file_id, + "completion_window": completion_window, + "model_id": model_id, + "priority": priority, + }, + batch_create_params.BatchCreateParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=BatchCreateResponse, + ) + + async def retrieve( + self, + id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> BatchRetrieveResponse: + """ + Get details of a batch job by ID + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") + return await self._get( + f"/batches/{id}", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=BatchRetrieveResponse, + ) + + async def list( + self, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> BatchListResponse: + """List all batch jobs for the authenticated user""" + return await self._get( + "/batches", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=BatchListResponse, + ) + + +class BatchesResourceWithRawResponse: + def __init__(self, batches: BatchesResource) -> None: + self._batches = batches + + self.create = to_raw_response_wrapper( + batches.create, + ) + self.retrieve = to_raw_response_wrapper( + batches.retrieve, + ) + self.list = to_raw_response_wrapper( + batches.list, + ) + + +class AsyncBatchesResourceWithRawResponse: + def __init__(self, batches: AsyncBatchesResource) -> None: + self._batches = batches + + self.create = async_to_raw_response_wrapper( + batches.create, + ) + self.retrieve = async_to_raw_response_wrapper( + batches.retrieve, + ) + self.list = async_to_raw_response_wrapper( + batches.list, + ) + + +class BatchesResourceWithStreamingResponse: + def __init__(self, batches: BatchesResource) -> None: + self._batches = batches + + self.create = to_streamed_response_wrapper( + batches.create, + ) + self.retrieve = to_streamed_response_wrapper( + batches.retrieve, + ) + self.list = to_streamed_response_wrapper( + batches.list, + ) + + +class AsyncBatchesResourceWithStreamingResponse: + def __init__(self, batches: AsyncBatchesResource) -> None: + self._batches = batches + + self.create = async_to_streamed_response_wrapper( + batches.create, + ) + self.retrieve = async_to_streamed_response_wrapper( + batches.retrieve, + ) + self.list = async_to_streamed_response_wrapper( + batches.list, + ) diff --git a/src/together/types/__init__.py b/src/together/types/__init__.py index e3b99c70..401ca925 100644 --- a/src/together/types/__init__.py +++ b/src/together/types/__init__.py @@ -27,6 +27,8 @@ from .full_training_type import FullTrainingType as FullTrainingType from .lr_scheduler_param import LrSchedulerParam as LrSchedulerParam from .audio_create_params import AudioCreateParams as AudioCreateParams +from .batch_create_params import BatchCreateParams as BatchCreateParams +from .batch_list_response import BatchListResponse as BatchListResponse from .image_create_params import ImageCreateParams as ImageCreateParams from .lo_ra_training_type import LoRaTrainingType as LoRaTrainingType from .model_list_response import ModelListResponse as ModelListResponse @@ -38,6 +40,7 @@ from .file_delete_response import FileDeleteResponse as FileDeleteResponse from .file_upload_response import FileUploadResponse as FileUploadResponse from .hardware_list_params import HardwareListParams as HardwareListParams +from .batch_create_response import BatchCreateResponse as BatchCreateResponse from .job_retrieve_response import JobRetrieveResponse as JobRetrieveResponse from .model_upload_response import ModelUploadResponse as ModelUploadResponse from .endpoint_create_params import EndpointCreateParams as EndpointCreateParams @@ -45,6 +48,7 @@ from .endpoint_update_params import EndpointUpdateParams as EndpointUpdateParams from .file_retrieve_response import FileRetrieveResponse as FileRetrieveResponse from .hardware_list_response import HardwareListResponse as HardwareListResponse +from .batch_retrieve_response import BatchRetrieveResponse as BatchRetrieveResponse from .embedding_create_params import EmbeddingCreateParams as EmbeddingCreateParams from .fine_tune_create_params import FineTuneCreateParams as FineTuneCreateParams from .fine_tune_list_response import FineTuneListResponse as FineTuneListResponse diff --git a/src/together/types/batch_create_params.py b/src/together/types/batch_create_params.py new file mode 100644 index 00000000..8b696489 --- /dev/null +++ b/src/together/types/batch_create_params.py @@ -0,0 +1,24 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Required, TypedDict + +__all__ = ["BatchCreateParams"] + + +class BatchCreateParams(TypedDict, total=False): + endpoint: Required[str] + """The endpoint to use for batch processing""" + + input_file_id: Required[str] + """ID of the uploaded input file containing batch requests""" + + completion_window: str + """Time window for batch completion (optional)""" + + model_id: str + """Model to use for processing batch requests""" + + priority: int + """Priority for batch processing (optional)""" diff --git a/src/together/types/batch_create_response.py b/src/together/types/batch_create_response.py new file mode 100644 index 00000000..382f1548 --- /dev/null +++ b/src/together/types/batch_create_response.py @@ -0,0 +1,51 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Optional +from datetime import datetime +from typing_extensions import Literal + +from pydantic import Field as FieldInfo + +from .._models import BaseModel + +__all__ = ["BatchCreateResponse", "Job"] + + +class Job(BaseModel): + id: Optional[str] = None + + completed_at: Optional[datetime] = None + + created_at: Optional[datetime] = None + + endpoint: Optional[str] = None + + error: Optional[str] = None + + error_file_id: Optional[str] = None + + file_size_bytes: Optional[int] = None + """Size of input file in bytes""" + + input_file_id: Optional[str] = None + + job_deadline: Optional[datetime] = None + + x_model_id: Optional[str] = FieldInfo(alias="model_id", default=None) + """Model used for processing requests""" + + output_file_id: Optional[str] = None + + progress: Optional[float] = None + """Completion progress (0.0 to 100)""" + + status: Optional[Literal["VALIDATING", "IN_PROGRESS", "COMPLETED", "FAILED", "EXPIRED", "CANCELLED"]] = None + """Current status of the batch job""" + + user_id: Optional[str] = None + + +class BatchCreateResponse(BaseModel): + job: Optional[Job] = None + + warning: Optional[str] = None diff --git a/src/together/types/batch_list_response.py b/src/together/types/batch_list_response.py new file mode 100644 index 00000000..11b453c8 --- /dev/null +++ b/src/together/types/batch_list_response.py @@ -0,0 +1,48 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Optional +from datetime import datetime +from typing_extensions import Literal, TypeAlias + +from pydantic import Field as FieldInfo + +from .._models import BaseModel + +__all__ = ["BatchListResponse", "BatchListResponseItem"] + + +class BatchListResponseItem(BaseModel): + id: Optional[str] = None + + completed_at: Optional[datetime] = None + + created_at: Optional[datetime] = None + + endpoint: Optional[str] = None + + error: Optional[str] = None + + error_file_id: Optional[str] = None + + file_size_bytes: Optional[int] = None + """Size of input file in bytes""" + + input_file_id: Optional[str] = None + + job_deadline: Optional[datetime] = None + + x_model_id: Optional[str] = FieldInfo(alias="model_id", default=None) + """Model used for processing requests""" + + output_file_id: Optional[str] = None + + progress: Optional[float] = None + """Completion progress (0.0 to 100)""" + + status: Optional[Literal["VALIDATING", "IN_PROGRESS", "COMPLETED", "FAILED", "EXPIRED", "CANCELLED"]] = None + """Current status of the batch job""" + + user_id: Optional[str] = None + + +BatchListResponse: TypeAlias = List[BatchListResponseItem] diff --git a/src/together/types/batch_retrieve_response.py b/src/together/types/batch_retrieve_response.py new file mode 100644 index 00000000..81483615 --- /dev/null +++ b/src/together/types/batch_retrieve_response.py @@ -0,0 +1,45 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Optional +from datetime import datetime +from typing_extensions import Literal + +from pydantic import Field as FieldInfo + +from .._models import BaseModel + +__all__ = ["BatchRetrieveResponse"] + + +class BatchRetrieveResponse(BaseModel): + id: Optional[str] = None + + completed_at: Optional[datetime] = None + + created_at: Optional[datetime] = None + + endpoint: Optional[str] = None + + error: Optional[str] = None + + error_file_id: Optional[str] = None + + file_size_bytes: Optional[int] = None + """Size of input file in bytes""" + + input_file_id: Optional[str] = None + + job_deadline: Optional[datetime] = None + + x_model_id: Optional[str] = FieldInfo(alias="model_id", default=None) + """Model used for processing requests""" + + output_file_id: Optional[str] = None + + progress: Optional[float] = None + """Completion progress (0.0 to 100)""" + + status: Optional[Literal["VALIDATING", "IN_PROGRESS", "COMPLETED", "FAILED", "EXPIRED", "CANCELLED"]] = None + """Current status of the batch job""" + + user_id: Optional[str] = None diff --git a/src/together/types/training_method_dpo.py b/src/together/types/training_method_dpo.py index 2b633178..6ded8b31 100644 --- a/src/together/types/training_method_dpo.py +++ b/src/together/types/training_method_dpo.py @@ -12,3 +12,11 @@ class TrainingMethodDpo(BaseModel): method: Literal["dpo"] dpo_beta: Optional[float] = None + + dpo_normalize_logratios_by_length: Optional[bool] = None + + dpo_reference_free: Optional[bool] = None + + rpo_alpha: Optional[float] = None + + simpo_gamma: Optional[float] = None diff --git a/src/together/types/training_method_dpo_param.py b/src/together/types/training_method_dpo_param.py index 812deb77..cd776600 100644 --- a/src/together/types/training_method_dpo_param.py +++ b/src/together/types/training_method_dpo_param.py @@ -11,3 +11,11 @@ class TrainingMethodDpoParam(TypedDict, total=False): method: Required[Literal["dpo"]] dpo_beta: float + + dpo_normalize_logratios_by_length: bool + + dpo_reference_free: bool + + rpo_alpha: float + + simpo_gamma: float diff --git a/tests/api_resources/chat/test_completions.py b/tests/api_resources/chat/test_completions.py index 82987709..8b4c6fb6 100644 --- a/tests/api_resources/chat/test_completions.py +++ b/tests/api_resources/chat/test_completions.py @@ -219,7 +219,9 @@ def test_streaming_response_create_overload_2(self, client: Together) -> None: class TestAsyncCompletions: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @parametrize async def test_method_create_overload_1(self, async_client: AsyncTogether) -> None: diff --git a/tests/api_resources/code_interpreter/test_sessions.py b/tests/api_resources/code_interpreter/test_sessions.py index e53d7a4a..19313230 100644 --- a/tests/api_resources/code_interpreter/test_sessions.py +++ b/tests/api_resources/code_interpreter/test_sessions.py @@ -53,7 +53,9 @@ def test_streaming_response_list(self, client: Together) -> None: class TestAsyncSessions: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @pytest.mark.skip( reason="currently no good way to test endpoints defining callbacks, Prism mock server will fail trying to reach the provided callback url" diff --git a/tests/api_resources/test_audio.py b/tests/api_resources/test_audio.py index 5f06c217..0e01c6a9 100644 --- a/tests/api_resources/test_audio.py +++ b/tests/api_resources/test_audio.py @@ -163,7 +163,9 @@ def test_streaming_response_create_overload_2(self, client: Together, respx_mock class TestAsyncAudio: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @pytest.mark.skip(reason="AttributeError: BinaryAPIResponse object has no attribute response") @parametrize diff --git a/tests/api_resources/test_batches.py b/tests/api_resources/test_batches.py new file mode 100644 index 00000000..5b1b0a90 --- /dev/null +++ b/tests/api_resources/test_batches.py @@ -0,0 +1,240 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from together import Together, AsyncTogether +from tests.utils import assert_matches_type +from together.types import BatchListResponse, BatchCreateResponse, BatchRetrieveResponse + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestBatches: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_create(self, client: Together) -> None: + batch = client.batches.create( + endpoint="/v1/chat/completions", + input_file_id="file-abc123def456ghi789", + ) + assert_matches_type(BatchCreateResponse, batch, path=["response"]) + + @parametrize + def test_method_create_with_all_params(self, client: Together) -> None: + batch = client.batches.create( + endpoint="/v1/chat/completions", + input_file_id="file-abc123def456ghi789", + completion_window="24h", + model_id="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + priority=1, + ) + assert_matches_type(BatchCreateResponse, batch, path=["response"]) + + @parametrize + def test_raw_response_create(self, client: Together) -> None: + response = client.batches.with_raw_response.create( + endpoint="/v1/chat/completions", + input_file_id="file-abc123def456ghi789", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + batch = response.parse() + assert_matches_type(BatchCreateResponse, batch, path=["response"]) + + @parametrize + def test_streaming_response_create(self, client: Together) -> None: + with client.batches.with_streaming_response.create( + endpoint="/v1/chat/completions", + input_file_id="file-abc123def456ghi789", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + batch = response.parse() + assert_matches_type(BatchCreateResponse, batch, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_retrieve(self, client: Together) -> None: + batch = client.batches.retrieve( + "id", + ) + assert_matches_type(BatchRetrieveResponse, batch, path=["response"]) + + @parametrize + def test_raw_response_retrieve(self, client: Together) -> None: + response = client.batches.with_raw_response.retrieve( + "id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + batch = response.parse() + assert_matches_type(BatchRetrieveResponse, batch, path=["response"]) + + @parametrize + def test_streaming_response_retrieve(self, client: Together) -> None: + with client.batches.with_streaming_response.retrieve( + "id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + batch = response.parse() + assert_matches_type(BatchRetrieveResponse, batch, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_retrieve(self, client: Together) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.batches.with_raw_response.retrieve( + "", + ) + + @parametrize + def test_method_list(self, client: Together) -> None: + batch = client.batches.list() + assert_matches_type(BatchListResponse, batch, path=["response"]) + + @parametrize + def test_raw_response_list(self, client: Together) -> None: + response = client.batches.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + batch = response.parse() + assert_matches_type(BatchListResponse, batch, path=["response"]) + + @parametrize + def test_streaming_response_list(self, client: Together) -> None: + with client.batches.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + batch = response.parse() + assert_matches_type(BatchListResponse, batch, path=["response"]) + + assert cast(Any, response.is_closed) is True + + +class TestAsyncBatches: + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) + + @parametrize + async def test_method_create(self, async_client: AsyncTogether) -> None: + batch = await async_client.batches.create( + endpoint="/v1/chat/completions", + input_file_id="file-abc123def456ghi789", + ) + assert_matches_type(BatchCreateResponse, batch, path=["response"]) + + @parametrize + async def test_method_create_with_all_params(self, async_client: AsyncTogether) -> None: + batch = await async_client.batches.create( + endpoint="/v1/chat/completions", + input_file_id="file-abc123def456ghi789", + completion_window="24h", + model_id="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + priority=1, + ) + assert_matches_type(BatchCreateResponse, batch, path=["response"]) + + @parametrize + async def test_raw_response_create(self, async_client: AsyncTogether) -> None: + response = await async_client.batches.with_raw_response.create( + endpoint="/v1/chat/completions", + input_file_id="file-abc123def456ghi789", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + batch = await response.parse() + assert_matches_type(BatchCreateResponse, batch, path=["response"]) + + @parametrize + async def test_streaming_response_create(self, async_client: AsyncTogether) -> None: + async with async_client.batches.with_streaming_response.create( + endpoint="/v1/chat/completions", + input_file_id="file-abc123def456ghi789", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + batch = await response.parse() + assert_matches_type(BatchCreateResponse, batch, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_retrieve(self, async_client: AsyncTogether) -> None: + batch = await async_client.batches.retrieve( + "id", + ) + assert_matches_type(BatchRetrieveResponse, batch, path=["response"]) + + @parametrize + async def test_raw_response_retrieve(self, async_client: AsyncTogether) -> None: + response = await async_client.batches.with_raw_response.retrieve( + "id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + batch = await response.parse() + assert_matches_type(BatchRetrieveResponse, batch, path=["response"]) + + @parametrize + async def test_streaming_response_retrieve(self, async_client: AsyncTogether) -> None: + async with async_client.batches.with_streaming_response.retrieve( + "id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + batch = await response.parse() + assert_matches_type(BatchRetrieveResponse, batch, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_retrieve(self, async_client: AsyncTogether) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await async_client.batches.with_raw_response.retrieve( + "", + ) + + @parametrize + async def test_method_list(self, async_client: AsyncTogether) -> None: + batch = await async_client.batches.list() + assert_matches_type(BatchListResponse, batch, path=["response"]) + + @parametrize + async def test_raw_response_list(self, async_client: AsyncTogether) -> None: + response = await async_client.batches.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + batch = await response.parse() + assert_matches_type(BatchListResponse, batch, path=["response"]) + + @parametrize + async def test_streaming_response_list(self, async_client: AsyncTogether) -> None: + async with async_client.batches.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + batch = await response.parse() + assert_matches_type(BatchListResponse, batch, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_client.py b/tests/api_resources/test_client.py index e057e9ba..691c42c4 100644 --- a/tests/api_resources/test_client.py +++ b/tests/api_resources/test_client.py @@ -136,7 +136,9 @@ def test_streaming_response_rerank(self, client: Together) -> None: class TestAsyncClient: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @parametrize async def test_method_rerank(self, async_client: AsyncTogether) -> None: diff --git a/tests/api_resources/test_code_interpreter.py b/tests/api_resources/test_code_interpreter.py index 17c1928c..f3f405b2 100644 --- a/tests/api_resources/test_code_interpreter.py +++ b/tests/api_resources/test_code_interpreter.py @@ -81,7 +81,9 @@ def test_streaming_response_execute(self, client: Together) -> None: class TestAsyncCodeInterpreter: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @pytest.mark.skip( reason="currently no good way to test endpoints defining callbacks, Prism mock server will fail trying to reach the provided callback url" diff --git a/tests/api_resources/test_completions.py b/tests/api_resources/test_completions.py index ef05bb50..1440c691 100644 --- a/tests/api_resources/test_completions.py +++ b/tests/api_resources/test_completions.py @@ -143,7 +143,9 @@ def test_streaming_response_create_overload_2(self, client: Together) -> None: class TestAsyncCompletions: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @parametrize async def test_method_create_overload_1(self, async_client: AsyncTogether) -> None: diff --git a/tests/api_resources/test_embeddings.py b/tests/api_resources/test_embeddings.py index 084ad480..779e505a 100644 --- a/tests/api_resources/test_embeddings.py +++ b/tests/api_resources/test_embeddings.py @@ -53,7 +53,9 @@ def test_streaming_response_create(self, client: Together) -> None: class TestAsyncEmbeddings: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @parametrize async def test_method_create(self, async_client: AsyncTogether) -> None: diff --git a/tests/api_resources/test_endpoints.py b/tests/api_resources/test_endpoints.py index 59cbc6ab..5ab5c225 100644 --- a/tests/api_resources/test_endpoints.py +++ b/tests/api_resources/test_endpoints.py @@ -247,7 +247,9 @@ def test_path_params_delete(self, client: Together) -> None: class TestAsyncEndpoints: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @parametrize async def test_method_create(self, async_client: AsyncTogether) -> None: diff --git a/tests/api_resources/test_files.py b/tests/api_resources/test_files.py index 27f7bb27..ec1b4dfe 100644 --- a/tests/api_resources/test_files.py +++ b/tests/api_resources/test_files.py @@ -230,7 +230,9 @@ def test_streaming_response_upload(self, client: Together) -> None: class TestAsyncFiles: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @parametrize async def test_method_retrieve(self, async_client: AsyncTogether) -> None: diff --git a/tests/api_resources/test_fine_tune.py b/tests/api_resources/test_fine_tune.py index a8a64f3e..cf40ef0f 100644 --- a/tests/api_resources/test_fine_tune.py +++ b/tests/api_resources/test_fine_tune.py @@ -312,7 +312,9 @@ def test_path_params_retrieve_checkpoints(self, client: Together) -> None: class TestAsyncFineTune: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @parametrize async def test_method_create(self, async_client: AsyncTogether) -> None: diff --git a/tests/api_resources/test_hardware.py b/tests/api_resources/test_hardware.py index aafe18f0..737d10f5 100644 --- a/tests/api_resources/test_hardware.py +++ b/tests/api_resources/test_hardware.py @@ -51,7 +51,9 @@ def test_streaming_response_list(self, client: Together) -> None: class TestAsyncHardware: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @parametrize async def test_method_list(self, async_client: AsyncTogether) -> None: diff --git a/tests/api_resources/test_images.py b/tests/api_resources/test_images.py index d95acfcd..16851aaa 100644 --- a/tests/api_resources/test_images.py +++ b/tests/api_resources/test_images.py @@ -77,7 +77,9 @@ def test_streaming_response_create(self, client: Together) -> None: class TestAsyncImages: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @parametrize async def test_method_create(self, async_client: AsyncTogether) -> None: diff --git a/tests/api_resources/test_jobs.py b/tests/api_resources/test_jobs.py index 110600d7..70711d0a 100644 --- a/tests/api_resources/test_jobs.py +++ b/tests/api_resources/test_jobs.py @@ -82,7 +82,9 @@ def test_streaming_response_list(self, client: Together) -> None: class TestAsyncJobs: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @parametrize async def test_method_retrieve(self, async_client: AsyncTogether) -> None: diff --git a/tests/api_resources/test_models.py b/tests/api_resources/test_models.py index fbf910a0..c3689674 100644 --- a/tests/api_resources/test_models.py +++ b/tests/api_resources/test_models.py @@ -91,7 +91,9 @@ def test_streaming_response_upload(self, client: Together) -> None: class TestAsyncModels: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @parametrize async def test_method_list(self, async_client: AsyncTogether) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index b7e86792..97bce53e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,17 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + from __future__ import annotations import os import logging from typing import TYPE_CHECKING, Iterator, AsyncIterator +import httpx import pytest from pytest_asyncio import is_async_test -from together import Together, AsyncTogether +from together import Together, AsyncTogether, DefaultAioHttpClient +from together._utils import is_dict if TYPE_CHECKING: from _pytest.fixtures import FixtureRequest # pyright: ignore[reportPrivateImportUsage] @@ -25,6 +29,19 @@ def pytest_collection_modifyitems(items: list[pytest.Function]) -> None: for async_test in pytest_asyncio_tests: async_test.add_marker(session_scope_marker, append=False) + # We skip tests that use both the aiohttp client and respx_mock as respx_mock + # doesn't support custom transports. + for item in items: + if "async_client" not in item.fixturenames or "respx_mock" not in item.fixturenames: + continue + + if not hasattr(item, "callspec"): + continue + + async_client_param = item.callspec.params.get("async_client") + if is_dict(async_client_param) and async_client_param.get("http_client") == "aiohttp": + item.add_marker(pytest.mark.skip(reason="aiohttp client is not compatible with respx_mock")) + base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -43,9 +60,25 @@ def client(request: FixtureRequest) -> Iterator[Together]: @pytest.fixture(scope="session") async def async_client(request: FixtureRequest) -> AsyncIterator[AsyncTogether]: - strict = getattr(request, "param", True) - if not isinstance(strict, bool): - raise TypeError(f"Unexpected fixture parameter type {type(strict)}, expected {bool}") - - async with AsyncTogether(base_url=base_url, api_key=api_key, _strict_response_validation=strict) as client: + param = getattr(request, "param", True) + + # defaults + strict = True + http_client: None | httpx.AsyncClient = None + + if isinstance(param, bool): + strict = param + elif is_dict(param): + strict = param.get("strict", True) + assert isinstance(strict, bool) + + http_client_type = param.get("http_client", "httpx") + if http_client_type == "aiohttp": + http_client = DefaultAioHttpClient() + else: + raise TypeError(f"Unexpected fixture parameter type {type(param)}, expected bool or dict") + + async with AsyncTogether( + base_url=base_url, api_key=api_key, _strict_response_validation=strict, http_client=http_client + ) as client: yield client diff --git a/tests/test_client.py b/tests/test_client.py index afa6dbf1..4263206f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -23,18 +23,17 @@ from together import Together, AsyncTogether, APIResponseValidationError from together._types import Omit -from together._utils import maybe_transform from together._models import BaseModel, FinalRequestOptions -from together._constants import RAW_RESPONSE_HEADER from together._streaming import Stream, AsyncStream from together._exceptions import TogetherError, APIStatusError, APITimeoutError, APIResponseValidationError from together._base_client import ( DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, BaseClient, + DefaultHttpxClient, + DefaultAsyncHttpxClient, make_request_options, ) -from together.types.chat.completion_create_params import CompletionCreateParamsNonStreaming from .utils import update_env @@ -725,60 +724,37 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str @mock.patch("together._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: + def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: Together) -> None: respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error")) with pytest.raises(APITimeoutError): - self.client.post( - "/chat/completions", - body=cast( - object, - maybe_transform( - dict( - messages=[ - { - "role": "user", - "content": "Say this is a test", - } - ], - model="mistralai/Mixtral-8x7B-Instruct-v0.1", - ), - CompletionCreateParamsNonStreaming, - ), - ), - cast_to=httpx.Response, - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, - ) + client.chat.completions.with_streaming_response.create( + messages=[ + { + "content": "content", + "role": "system", + } + ], + model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + ).__enter__() assert _get_open_connections(self.client) == 0 @mock.patch("together._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: + def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: Together) -> None: respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500)) with pytest.raises(APIStatusError): - self.client.post( - "/chat/completions", - body=cast( - object, - maybe_transform( - dict( - messages=[ - { - "role": "user", - "content": "Say this is a test", - } - ], - model="mistralai/Mixtral-8x7B-Instruct-v0.1", - ), - CompletionCreateParamsNonStreaming, - ), - ), - cast_to=httpx.Response, - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, - ) - + client.chat.completions.with_streaming_response.create( + messages=[ + { + "content": "content", + "role": "system", + } + ], + model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + ).__enter__() assert _get_open_connections(self.client) == 0 @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @@ -884,6 +860,28 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: assert response.http_request.headers.get("x-stainless-retry-count") == "42" + def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Test that the proxy environment variables are set correctly + monkeypatch.setenv("HTTPS_PROXY", "https://example.org") + + client = DefaultHttpxClient() + + mounts = tuple(client._mounts.items()) + assert len(mounts) == 1 + assert mounts[0][0].pattern == "https://" + + @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning") + def test_default_client_creation(self) -> None: + # Ensure that the client can be initialized without any exceptions + DefaultHttpxClient( + verify=True, + cert=None, + trust_env=True, + http1=True, + http2=False, + limits=httpx.Limits(max_connections=100, max_keepalive_connections=20), + ) + @pytest.mark.respx(base_url=base_url) def test_follow_redirects(self, respx_mock: MockRouter) -> None: # Test that the default follow_redirects=True allows following redirects @@ -1594,60 +1592,41 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte @mock.patch("together._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: + async def test_retrying_timeout_errors_doesnt_leak( + self, respx_mock: MockRouter, async_client: AsyncTogether + ) -> None: respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error")) with pytest.raises(APITimeoutError): - await self.client.post( - "/chat/completions", - body=cast( - object, - maybe_transform( - dict( - messages=[ - { - "role": "user", - "content": "Say this is a test", - } - ], - model="mistralai/Mixtral-8x7B-Instruct-v0.1", - ), - CompletionCreateParamsNonStreaming, - ), - ), - cast_to=httpx.Response, - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, - ) + await async_client.chat.completions.with_streaming_response.create( + messages=[ + { + "content": "content", + "role": "system", + } + ], + model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + ).__aenter__() assert _get_open_connections(self.client) == 0 @mock.patch("together._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: + async def test_retrying_status_errors_doesnt_leak( + self, respx_mock: MockRouter, async_client: AsyncTogether + ) -> None: respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500)) with pytest.raises(APIStatusError): - await self.client.post( - "/chat/completions", - body=cast( - object, - maybe_transform( - dict( - messages=[ - { - "role": "user", - "content": "Say this is a test", - } - ], - model="mistralai/Mixtral-8x7B-Instruct-v0.1", - ), - CompletionCreateParamsNonStreaming, - ), - ), - cast_to=httpx.Response, - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, - ) - + await async_client.chat.completions.with_streaming_response.create( + messages=[ + { + "content": "content", + "role": "system", + } + ], + model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + ).__aenter__() assert _get_open_connections(self.client) == 0 @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @@ -1801,6 +1780,28 @@ async def test_main() -> None: time.sleep(0.1) + async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Test that the proxy environment variables are set correctly + monkeypatch.setenv("HTTPS_PROXY", "https://example.org") + + client = DefaultAsyncHttpxClient() + + mounts = tuple(client._mounts.items()) + assert len(mounts) == 1 + assert mounts[0][0].pattern == "https://" + + @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning") + async def test_default_client_creation(self) -> None: + # Ensure that the client can be initialized without any exceptions + DefaultAsyncHttpxClient( + verify=True, + cert=None, + trust_env=True, + http1=True, + http2=False, + limits=httpx.Limits(max_connections=100, max_keepalive_connections=20), + ) + @pytest.mark.respx(base_url=base_url) async def test_follow_redirects(self, respx_mock: MockRouter) -> None: # Test that the default follow_redirects=True allows following redirects