diff --git a/.release-please-manifest.json b/.release-please-manifest.json index d50ed41..7563b5d 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "2.0.0-alpha.4" + ".": "2.0.0-alpha.5" } \ No newline at end of file diff --git a/.stats.yml b/.stats.yml index 29aa088..c59bcd1 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ configured_endpoints: 35 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-12e7ef40109b6b34f1471a638d09b79f005c8dbf7e1a8aeca9db7e37a334e8eb.yml -openapi_spec_hash: 10b0fc9094dac5d51f46bbdd5fe3de32 +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-f5aff38eef8d7b245a7af062bf02920ae75e5b9b3dc822416aeb48966c2c6874.yml +openapi_spec_hash: c0a966beaf5ae95c6bdddd4a933bd4aa config_hash: 12536d2bf978a995771d076a4647c17d diff --git a/CHANGELOG.md b/CHANGELOG.md index 8422405..18f6354 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,26 @@ # Changelog +## 2.0.0-alpha.5 (2025-06-25) + +Full Changelog: [v2.0.0-alpha.4...v2.0.0-alpha.5](https://github.com/replicate/replicate-python-stainless/compare/v2.0.0-alpha.4...v2.0.0-alpha.5) + +### Features + +* **api:** api update ([6e667da](https://github.com/replicate/replicate-python-stainless/commit/6e667da6c2e80add847e612bbd08db1c865793d7)) +* **api:** api update ([0a187a9](https://github.com/replicate/replicate-python-stainless/commit/0a187a9ba906c0bc5c4e658883266276fc357665)) +* **api:** api update ([edb14b6](https://github.com/replicate/replicate-python-stainless/commit/edb14b65c61203c2e42a1accd384e7b456e33448)) +* **client:** add support for aiohttp ([c802a30](https://github.com/replicate/replicate-python-stainless/commit/c802a30a0569cb25eb700ff5501c5a87291ef4b0)) + + +### Chores + +* **tests:** skip some failing tests on the latest python versions ([d331b72](https://github.com/replicate/replicate-python-stainless/commit/d331b72364eaed6f935f9b23fdc776303ebf57a6)) + + +### Documentation + +* **client:** fix httpx.Timeout documentation reference ([d17c345](https://github.com/replicate/replicate-python-stainless/commit/d17c3454afaa0ae0b022f468515e8478e5ba6568)) + ## 2.0.0-alpha.4 (2025-06-18) Full Changelog: [v2.0.0-alpha.3...v2.0.0-alpha.4](https://github.com/replicate/replicate-python-stainless/compare/v2.0.0-alpha.3...v2.0.0-alpha.4) diff --git a/README.md b/README.md index b79c6ba..f21ffc3 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,42 @@ asyncio.run(main()) Functionality between the synchronous and asynchronous clients is otherwise identical. +### With aiohttp + +By default, the async client uses `httpx` for HTTP requests. However, for improved concurrency performance you may also use `aiohttp` as the HTTP backend. + +You can enable this by installing `aiohttp`: + +```sh +# install from PyPI +pip install --pre replicate[aiohttp] +``` + +Then you can enable it by instantiating the client with `http_client=DefaultAioHttpClient()`: + +```python +import os +import asyncio +from replicate import DefaultAioHttpClient +from replicate import AsyncReplicate + + +async def main() -> None: + async with AsyncReplicate( + bearer_token=os.environ.get( + "REPLICATE_API_TOKEN" + ), # This is the default and can be omitted + http_client=DefaultAioHttpClient(), + ) as replicate: + prediction = await replicate.predictions.get( + prediction_id="gm3qorzdhgbfurvjtvhg6dckhu", + ) + print(prediction.id) + + +asyncio.run(main()) +``` + ## Using types Nested request parameters are [TypedDicts](https://docs.python.org/3/library/typing.html#typing.TypedDict). Responses are [Pydantic models](https://docs.pydantic.dev) which also provide helper methods for things like: @@ -228,7 +264,7 @@ replicate.with_options(max_retries=5).predictions.create( ### Timeouts By default requests time out after 1 minute. You can configure this with a `timeout` option, -which accepts a float or an [`httpx.Timeout`](https://www.python-httpx.org/advanced/#fine-tuning-the-configuration) object: +which accepts a float or an [`httpx.Timeout`](https://www.python-httpx.org/advanced/timeouts/#fine-tuning-the-configuration) object: ```python from replicate import Replicate diff --git a/api.md b/api.md index 38f61b7..a28a485 100644 --- a/api.md +++ b/api.md @@ -61,7 +61,7 @@ Methods: Types: ```python -from replicate.types import ModelListResponse +from replicate.types import ModelListResponse, ModelGetResponse, ModelSearchResponse ``` Methods: @@ -69,8 +69,8 @@ Methods: - replicate.models.create(\*\*params) -> None - replicate.models.list() -> SyncCursorURLPage[ModelListResponse] - replicate.models.delete(\*, model_owner, model_name) -> None -- replicate.models.get(\*, model_owner, model_name) -> None -- replicate.models.search(\*\*params) -> None +- replicate.models.get(\*, model_owner, model_name) -> ModelGetResponse +- replicate.models.search(\*\*params) -> SyncCursorURLPage[ModelSearchResponse] ## Examples diff --git a/pyproject.toml b/pyproject.toml index 8f9560a..318dbaf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "replicate" -version = "2.0.0-alpha.4" +version = "2.0.0-alpha.5" description = "The official Python library for the replicate API" dynamic = ["readme"] license = "Apache-2.0" @@ -37,6 +37,8 @@ classifiers = [ Homepage = "https://github.com/replicate/replicate-python-stainless" Repository = "https://github.com/replicate/replicate-python-stainless" +[project.optional-dependencies] +aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.6"] [tool.rye] managed = true diff --git a/requirements-dev.lock b/requirements-dev.lock index c456756..001b806 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -10,6 +10,13 @@ # universal: false -e file:. +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.12.8 + # via httpx-aiohttp + # via replicate +aiosignal==1.3.2 + # via aiohttp annotated-types==0.6.0 # via pydantic anyio==4.4.0 @@ -17,6 +24,10 @@ anyio==4.4.0 # via replicate argcomplete==3.1.2 # via nox +async-timeout==5.0.1 + # via aiohttp +attrs==25.3.0 + # via aiohttp certifi==2023.7.22 # via httpcore # via httpx @@ -34,16 +45,23 @@ execnet==2.1.1 # via pytest-xdist filelock==3.12.4 # via virtualenv +frozenlist==1.6.2 + # via aiohttp + # via aiosignal h11==0.14.0 # via httpcore httpcore==1.0.2 # via httpx httpx==0.28.1 + # via httpx-aiohttp # via replicate # via respx +httpx-aiohttp==0.1.6 + # via replicate idna==3.4 # via anyio # via httpx + # via yarl importlib-metadata==7.0.0 iniconfig==2.0.0 # via pytest @@ -51,6 +69,9 @@ markdown-it-py==3.0.0 # via rich mdurl==0.1.2 # via markdown-it-py +multidict==6.4.4 + # via aiohttp + # via yarl mypy==1.14.1 mypy-extensions==1.0.0 # via mypy @@ -65,6 +86,9 @@ platformdirs==3.11.0 # via virtualenv pluggy==1.5.0 # via pytest +propcache==0.3.1 + # via aiohttp + # via yarl pydantic==2.10.3 # via replicate pydantic-core==2.27.1 @@ -97,6 +121,7 @@ tomli==2.0.2 # via pytest typing-extensions==4.12.2 # via anyio + # via multidict # via mypy # via pydantic # via pydantic-core @@ -104,5 +129,7 @@ typing-extensions==4.12.2 # via replicate virtualenv==20.24.5 # via nox +yarl==1.20.0 + # via aiohttp zipp==3.17.0 # via importlib-metadata diff --git a/requirements.lock b/requirements.lock index f022008..d884405 100644 --- a/requirements.lock +++ b/requirements.lock @@ -10,11 +10,22 @@ # universal: false -e file:. +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.12.8 + # via httpx-aiohttp + # via replicate +aiosignal==1.3.2 + # via aiohttp annotated-types==0.6.0 # via pydantic anyio==4.4.0 # via httpx # via replicate +async-timeout==5.0.1 + # via aiohttp +attrs==25.3.0 + # via aiohttp certifi==2023.7.22 # via httpcore # via httpx @@ -22,15 +33,28 @@ distro==1.8.0 # via replicate exceptiongroup==1.2.2 # via anyio +frozenlist==1.6.2 + # via aiohttp + # via aiosignal h11==0.14.0 # via httpcore httpcore==1.0.2 # via httpx httpx==0.28.1 + # via httpx-aiohttp + # via replicate +httpx-aiohttp==0.1.6 # via replicate idna==3.4 # via anyio # via httpx + # via yarl +multidict==6.4.4 + # via aiohttp + # via yarl +propcache==0.3.1 + # via aiohttp + # via yarl pydantic==2.10.3 # via replicate pydantic-core==2.27.1 @@ -40,6 +64,9 @@ sniffio==1.3.0 # via replicate typing-extensions==4.12.2 # via anyio + # via multidict # via pydantic # via pydantic-core # via replicate +yarl==1.20.0 + # via aiohttp diff --git a/src/replicate/__init__.py b/src/replicate/__init__.py index 137a79d..3dc3c80 100644 --- a/src/replicate/__init__.py +++ b/src/replicate/__init__.py @@ -41,7 +41,7 @@ APIResponseValidationError, ) from .lib._models import Model as Model, Version as Version, ModelVersionIdentifier as ModelVersionIdentifier -from ._base_client import DefaultHttpxClient, DefaultAsyncHttpxClient +from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient from ._utils._logs import setup_logging as _setup_logging __all__ = [ @@ -83,6 +83,7 @@ "DEFAULT_CONNECTION_LIMITS", "DefaultHttpxClient", "DefaultAsyncHttpxClient", + "DefaultAioHttpClient", "FileOutput", "AsyncFileOutput", "Model", diff --git a/src/replicate/_base_client.py b/src/replicate/_base_client.py index af0d340..43f54b6 100644 --- a/src/replicate/_base_client.py +++ b/src/replicate/_base_client.py @@ -1303,6 +1303,24 @@ def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) +try: + import httpx_aiohttp +except ImportError: + + class _DefaultAioHttpClient(httpx.AsyncClient): + def __init__(self, **_kwargs: Any) -> None: + raise RuntimeError("To use the aiohttp client you must have installed the package with the `aiohttp` extra") +else: + + class _DefaultAioHttpClient(httpx_aiohttp.HttpxAiohttpClient): # type: ignore + def __init__(self, **kwargs: Any) -> None: + kwargs.setdefault("timeout", DEFAULT_TIMEOUT) + kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS) + kwargs.setdefault("follow_redirects", True) + + super().__init__(**kwargs) + + if TYPE_CHECKING: DefaultAsyncHttpxClient = httpx.AsyncClient """An alias to `httpx.AsyncClient` that provides the same defaults that this SDK @@ -1311,8 +1329,12 @@ def __init__(self, **kwargs: Any) -> None: This is useful because overriding the `http_client` with your own instance of `httpx.AsyncClient` will result in httpx's defaults being used, not ours. """ + + DefaultAioHttpClient = httpx.AsyncClient + """An alias to `httpx.AsyncClient` that changes the default HTTP transport to `aiohttp`.""" else: DefaultAsyncHttpxClient = _DefaultAsyncHttpxClient + DefaultAioHttpClient = _DefaultAioHttpClient class AsyncHttpxClientWrapper(DefaultAsyncHttpxClient): diff --git a/src/replicate/_version.py b/src/replicate/_version.py index ffe5082..13d48c0 100644 --- a/src/replicate/_version.py +++ b/src/replicate/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "replicate" -__version__ = "2.0.0-alpha.4" # x-release-please-version +__version__ = "2.0.0-alpha.5" # x-release-please-version diff --git a/src/replicate/resources/models/models.py b/src/replicate/resources/models/models.py index 3ca0a72..004bd07 100644 --- a/src/replicate/resources/models/models.py +++ b/src/replicate/resources/models/models.py @@ -51,7 +51,9 @@ ) from ...pagination import SyncCursorURLPage, AsyncCursorURLPage from ..._base_client import AsyncPaginator, make_request_options +from ...types.model_get_response import ModelGetResponse from ...types.model_list_response import ModelListResponse +from ...types.model_search_response import ModelSearchResponse __all__ = ["ModelsResource", "AsyncModelsResource"] @@ -306,7 +308,7 @@ def get( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: + ) -> ModelGetResponse: """ Example cURL request: @@ -395,13 +397,12 @@ def get( raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}") if not model_name: raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}") - extra_headers = {"Accept": "*/*", **(extra_headers or {})} return self._get( f"/models/{model_owner}/{model_name}", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=NoneType, + cast_to=ModelGetResponse, ) def search( @@ -414,7 +415,7 @@ def search( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: + ) -> SyncCursorURLPage[ModelSearchResponse]: """ Get a list of public models matching a search query. @@ -445,14 +446,15 @@ def search( timeout: Override the client-level default timeout for this request, in seconds """ - extra_headers = {"Accept": "*/*", **(extra_headers or {})} - return self._query( + return self._get_api_list( "/models", + page=SyncCursorURLPage[ModelSearchResponse], body=maybe_transform(body, model_search_params.ModelSearchParams), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=NoneType, + model=ModelSearchResponse, + method="query", ) @@ -706,7 +708,7 @@ async def get( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: + ) -> ModelGetResponse: """ Example cURL request: @@ -795,16 +797,15 @@ async def get( raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}") if not model_name: raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}") - extra_headers = {"Accept": "*/*", **(extra_headers or {})} return await self._get( f"/models/{model_owner}/{model_name}", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=NoneType, + cast_to=ModelGetResponse, ) - async def search( + def search( self, *, body: str, @@ -814,7 +815,7 @@ async def search( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: + ) -> AsyncPaginator[ModelSearchResponse, AsyncCursorURLPage[ModelSearchResponse]]: """ Get a list of public models matching a search query. @@ -845,14 +846,15 @@ async def search( timeout: Override the client-level default timeout for this request, in seconds """ - extra_headers = {"Accept": "*/*", **(extra_headers or {})} - return await self._query( + return self._get_api_list( "/models", - body=await async_maybe_transform(body, model_search_params.ModelSearchParams), + page=AsyncCursorURLPage[ModelSearchResponse], + body=maybe_transform(body, model_search_params.ModelSearchParams), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=NoneType, + model=ModelSearchResponse, + method="query", ) diff --git a/src/replicate/types/__init__.py b/src/replicate/types/__init__.py index 342b2a7..f398100 100644 --- a/src/replicate/types/__init__.py +++ b/src/replicate/types/__init__.py @@ -6,12 +6,14 @@ from .file_get_response import FileGetResponse as FileGetResponse from .file_create_params import FileCreateParams as FileCreateParams from .file_list_response import FileListResponse as FileListResponse +from .model_get_response import ModelGetResponse as ModelGetResponse from .model_create_params import ModelCreateParams as ModelCreateParams from .model_list_response import ModelListResponse as ModelListResponse from .model_search_params import ModelSearchParams as ModelSearchParams from .account_get_response import AccountGetResponse as AccountGetResponse from .file_create_response import FileCreateResponse as FileCreateResponse from .file_download_params import FileDownloadParams as FileDownloadParams +from .model_search_response import ModelSearchResponse as ModelSearchResponse from .training_get_response import TrainingGetResponse as TrainingGetResponse from .hardware_list_response import HardwareListResponse as HardwareListResponse from .prediction_list_params import PredictionListParams as PredictionListParams diff --git a/src/replicate/types/model_get_response.py b/src/replicate/types/model_get_response.py new file mode 100644 index 0000000..9ba9146 --- /dev/null +++ b/src/replicate/types/model_get_response.py @@ -0,0 +1,46 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Optional +from typing_extensions import Literal + +from .._models import BaseModel + +__all__ = ["ModelGetResponse"] + + +class ModelGetResponse(BaseModel): + cover_image_url: Optional[str] = None + """A URL for the model's cover image""" + + default_example: Optional[object] = None + """The model's default example prediction""" + + description: Optional[str] = None + """A description of the model""" + + github_url: Optional[str] = None + """A URL for the model's source code on GitHub""" + + latest_version: Optional[object] = None + """The model's latest version""" + + license_url: Optional[str] = None + """A URL for the model's license""" + + name: Optional[str] = None + """The name of the model""" + + owner: Optional[str] = None + """The name of the user or organization that owns the model""" + + paper_url: Optional[str] = None + """A URL for the model's paper""" + + run_count: Optional[int] = None + """The number of times the model has been run""" + + url: Optional[str] = None + """The URL of the model on Replicate""" + + visibility: Optional[Literal["public", "private"]] = None + """Whether the model is public or private""" diff --git a/src/replicate/types/model_search_response.py b/src/replicate/types/model_search_response.py new file mode 100644 index 0000000..6b2bde6 --- /dev/null +++ b/src/replicate/types/model_search_response.py @@ -0,0 +1,46 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Optional +from typing_extensions import Literal + +from .._models import BaseModel + +__all__ = ["ModelSearchResponse"] + + +class ModelSearchResponse(BaseModel): + cover_image_url: Optional[str] = None + """A URL for the model's cover image""" + + default_example: Optional[object] = None + """The model's default example prediction""" + + description: Optional[str] = None + """A description of the model""" + + github_url: Optional[str] = None + """A URL for the model's source code on GitHub""" + + latest_version: Optional[object] = None + """The model's latest version""" + + license_url: Optional[str] = None + """A URL for the model's license""" + + name: Optional[str] = None + """The name of the model""" + + owner: Optional[str] = None + """The name of the user or organization that owns the model""" + + paper_url: Optional[str] = None + """A URL for the model's paper""" + + run_count: Optional[int] = None + """The number of times the model has been run""" + + url: Optional[str] = None + """The URL of the model on Replicate""" + + visibility: Optional[Literal["public", "private"]] = None + """Whether the model is public or private""" diff --git a/tests/api_resources/deployments/test_predictions.py b/tests/api_resources/deployments/test_predictions.py index 72ae47d..f103d7d 100644 --- a/tests/api_resources/deployments/test_predictions.py +++ b/tests/api_resources/deployments/test_predictions.py @@ -108,7 +108,9 @@ def test_path_params_create(self, client: Replicate) -> None: class TestAsyncPredictions: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @pytest.mark.skip() @parametrize diff --git a/tests/api_resources/models/test_examples.py b/tests/api_resources/models/test_examples.py index 3023f9c..8d499cc 100644 --- a/tests/api_resources/models/test_examples.py +++ b/tests/api_resources/models/test_examples.py @@ -69,7 +69,9 @@ def test_path_params_list(self, client: Replicate) -> None: class TestAsyncExamples: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @pytest.mark.skip() @parametrize diff --git a/tests/api_resources/models/test_predictions.py b/tests/api_resources/models/test_predictions.py index 93e1377..ced4d87 100644 --- a/tests/api_resources/models/test_predictions.py +++ b/tests/api_resources/models/test_predictions.py @@ -108,7 +108,9 @@ def test_path_params_create(self, client: Replicate) -> None: class TestAsyncPredictions: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @pytest.mark.skip() @parametrize diff --git a/tests/api_resources/models/test_readme.py b/tests/api_resources/models/test_readme.py index 3e34823..007d919 100644 --- a/tests/api_resources/models/test_readme.py +++ b/tests/api_resources/models/test_readme.py @@ -70,7 +70,9 @@ def test_path_params_get(self, client: Replicate) -> None: class TestAsyncReadme: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @pytest.mark.skip() @parametrize diff --git a/tests/api_resources/models/test_versions.py b/tests/api_resources/models/test_versions.py index 4236863..b827c92 100644 --- a/tests/api_resources/models/test_versions.py +++ b/tests/api_resources/models/test_versions.py @@ -197,7 +197,9 @@ def test_path_params_get(self, client: Replicate) -> None: class TestAsyncVersions: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @pytest.mark.skip() @parametrize diff --git a/tests/api_resources/test_account.py b/tests/api_resources/test_account.py index 74e7679..72f892e 100644 --- a/tests/api_resources/test_account.py +++ b/tests/api_resources/test_account.py @@ -47,7 +47,9 @@ def test_streaming_response_get(self, client: Replicate) -> None: class TestAsyncAccount: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @pytest.mark.skip() @parametrize diff --git a/tests/api_resources/test_collections.py b/tests/api_resources/test_collections.py index 0fe096b..b3834f8 100644 --- a/tests/api_resources/test_collections.py +++ b/tests/api_resources/test_collections.py @@ -87,7 +87,9 @@ def test_path_params_get(self, client: Replicate) -> None: class TestAsyncCollections: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @pytest.mark.skip() @parametrize diff --git a/tests/api_resources/test_deployments.py b/tests/api_resources/test_deployments.py index 39e1a92..25078b7 100644 --- a/tests/api_resources/test_deployments.py +++ b/tests/api_resources/test_deployments.py @@ -271,7 +271,9 @@ def test_path_params_get(self, client: Replicate) -> None: class TestAsyncDeployments: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @pytest.mark.skip() @parametrize diff --git a/tests/api_resources/test_files.py b/tests/api_resources/test_files.py index 02b6d59..32a9305 100644 --- a/tests/api_resources/test_files.py +++ b/tests/api_resources/test_files.py @@ -255,7 +255,9 @@ def test_path_params_get(self, client: Replicate) -> None: class TestAsyncFiles: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @pytest.mark.skip() @parametrize diff --git a/tests/api_resources/test_hardware.py b/tests/api_resources/test_hardware.py index 4ec3027..a04889d 100644 --- a/tests/api_resources/test_hardware.py +++ b/tests/api_resources/test_hardware.py @@ -47,7 +47,9 @@ def test_streaming_response_list(self, client: Replicate) -> None: class TestAsyncHardware: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @pytest.mark.skip() @parametrize diff --git a/tests/api_resources/test_models.py b/tests/api_resources/test_models.py index 2d1ac8f..9b07d74 100644 --- a/tests/api_resources/test_models.py +++ b/tests/api_resources/test_models.py @@ -9,7 +9,11 @@ from replicate import Replicate, AsyncReplicate from tests.utils import assert_matches_type -from replicate.types import ModelListResponse +from replicate.types import ( + ModelGetResponse, + ModelListResponse, + ModelSearchResponse, +) from replicate.pagination import SyncCursorURLPage, AsyncCursorURLPage base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -164,7 +168,7 @@ def test_method_get(self, client: Replicate) -> None: model_owner="model_owner", model_name="model_name", ) - assert model is None + assert_matches_type(ModelGetResponse, model, path=["response"]) @pytest.mark.skip() @parametrize @@ -177,7 +181,7 @@ def test_raw_response_get(self, client: Replicate) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" model = response.parse() - assert model is None + assert_matches_type(ModelGetResponse, model, path=["response"]) @pytest.mark.skip() @parametrize @@ -190,7 +194,7 @@ def test_streaming_response_get(self, client: Replicate) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" model = response.parse() - assert model is None + assert_matches_type(ModelGetResponse, model, path=["response"]) assert cast(Any, response.is_closed) is True @@ -215,7 +219,7 @@ def test_method_search(self, client: Replicate) -> None: model = client.models.search( body="body", ) - assert model is None + assert_matches_type(SyncCursorURLPage[ModelSearchResponse], model, path=["response"]) @pytest.mark.skip(reason="Prism doesn't support query methods yet") @parametrize @@ -227,7 +231,7 @@ def test_raw_response_search(self, client: Replicate) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" model = response.parse() - assert model is None + assert_matches_type(SyncCursorURLPage[ModelSearchResponse], model, path=["response"]) @pytest.mark.skip(reason="Prism doesn't support query methods yet") @parametrize @@ -239,13 +243,15 @@ def test_streaming_response_search(self, client: Replicate) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" model = response.parse() - assert model is None + assert_matches_type(SyncCursorURLPage[ModelSearchResponse], model, path=["response"]) assert cast(Any, response.is_closed) is True class TestAsyncModels: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @pytest.mark.skip() @parametrize @@ -393,7 +399,7 @@ async def test_method_get(self, async_client: AsyncReplicate) -> None: model_owner="model_owner", model_name="model_name", ) - assert model is None + assert_matches_type(ModelGetResponse, model, path=["response"]) @pytest.mark.skip() @parametrize @@ -406,7 +412,7 @@ async def test_raw_response_get(self, async_client: AsyncReplicate) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" model = await response.parse() - assert model is None + assert_matches_type(ModelGetResponse, model, path=["response"]) @pytest.mark.skip() @parametrize @@ -419,7 +425,7 @@ async def test_streaming_response_get(self, async_client: AsyncReplicate) -> Non assert response.http_request.headers.get("X-Stainless-Lang") == "python" model = await response.parse() - assert model is None + assert_matches_type(ModelGetResponse, model, path=["response"]) assert cast(Any, response.is_closed) is True @@ -444,7 +450,7 @@ async def test_method_search(self, async_client: AsyncReplicate) -> None: model = await async_client.models.search( body="body", ) - assert model is None + assert_matches_type(AsyncCursorURLPage[ModelSearchResponse], model, path=["response"]) @pytest.mark.skip(reason="Prism doesn't support query methods yet") @parametrize @@ -456,7 +462,7 @@ async def test_raw_response_search(self, async_client: AsyncReplicate) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" model = await response.parse() - assert model is None + assert_matches_type(AsyncCursorURLPage[ModelSearchResponse], model, path=["response"]) @pytest.mark.skip(reason="Prism doesn't support query methods yet") @parametrize @@ -468,6 +474,6 @@ async def test_streaming_response_search(self, async_client: AsyncReplicate) -> assert response.http_request.headers.get("X-Stainless-Lang") == "python" model = await response.parse() - assert model is None + assert_matches_type(AsyncCursorURLPage[ModelSearchResponse], model, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_predictions.py b/tests/api_resources/test_predictions.py index e1839c2..145f378 100644 --- a/tests/api_resources/test_predictions.py +++ b/tests/api_resources/test_predictions.py @@ -192,7 +192,9 @@ def test_path_params_get(self, client: Replicate) -> None: class TestAsyncPredictions: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @pytest.mark.skip() @parametrize diff --git a/tests/api_resources/test_trainings.py b/tests/api_resources/test_trainings.py index 3acb09b..3607d52 100644 --- a/tests/api_resources/test_trainings.py +++ b/tests/api_resources/test_trainings.py @@ -227,7 +227,9 @@ def test_path_params_get(self, client: Replicate) -> None: class TestAsyncTrainings: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @pytest.mark.skip() @parametrize diff --git a/tests/api_resources/webhooks/default/test_secret.py b/tests/api_resources/webhooks/default/test_secret.py index 4ccc3d4..047d7d0 100644 --- a/tests/api_resources/webhooks/default/test_secret.py +++ b/tests/api_resources/webhooks/default/test_secret.py @@ -47,7 +47,9 @@ def test_streaming_response_get(self, client: Replicate) -> None: class TestAsyncSecret: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @pytest.mark.skip() @parametrize diff --git a/tests/conftest.py b/tests/conftest.py index 7e794cd..c9b1543 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,10 +6,12 @@ import logging from typing import TYPE_CHECKING, Iterator, AsyncIterator +import httpx import pytest from pytest_asyncio import is_async_test -from replicate import Replicate, AsyncReplicate +from replicate import Replicate, AsyncReplicate, DefaultAioHttpClient +from replicate._utils import is_dict if TYPE_CHECKING: from _pytest.fixtures import FixtureRequest # pyright: ignore[reportPrivateImportUsage] @@ -27,6 +29,19 @@ def pytest_collection_modifyitems(items: list[pytest.Function]) -> None: for async_test in pytest_asyncio_tests: async_test.add_marker(session_scope_marker, append=False) + # We skip tests that use both the aiohttp client and respx_mock as respx_mock + # doesn't support custom transports. + for item in items: + if "async_client" not in item.fixturenames or "respx_mock" not in item.fixturenames: + continue + + if not hasattr(item, "callspec"): + continue + + async_client_param = item.callspec.params.get("async_client") + if is_dict(async_client_param) and async_client_param.get("http_client") == "aiohttp": + item.add_marker(pytest.mark.skip(reason="aiohttp client is not compatible with respx_mock")) + base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -45,11 +60,25 @@ def client(request: FixtureRequest) -> Iterator[Replicate]: @pytest.fixture(scope="session") async def async_client(request: FixtureRequest) -> AsyncIterator[AsyncReplicate]: - strict = getattr(request, "param", True) - if not isinstance(strict, bool): - raise TypeError(f"Unexpected fixture parameter type {type(strict)}, expected {bool}") + param = getattr(request, "param", True) + + # defaults + strict = True + http_client: None | httpx.AsyncClient = None + + if isinstance(param, bool): + strict = param + elif is_dict(param): + strict = param.get("strict", True) + assert isinstance(strict, bool) + + http_client_type = param.get("http_client", "httpx") + if http_client_type == "aiohttp": + http_client = DefaultAioHttpClient() + else: + raise TypeError(f"Unexpected fixture parameter type {type(param)}, expected bool or dict") async with AsyncReplicate( - base_url=base_url, bearer_token=bearer_token, _strict_response_validation=strict + base_url=base_url, bearer_token=bearer_token, _strict_response_validation=strict, http_client=http_client ) as replicate: yield replicate diff --git a/tests/test_client.py b/tests/test_client.py index 2f73b4e..3af3ff8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -194,6 +194,7 @@ def test_copy_signature(self) -> None: copy_param = copy_signature.parameters.get(name) assert copy_param is not None, f"copy() signature is missing the {name} param" + @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12") def test_copy_build_request(self) -> None: options = FinalRequestOptions(method="get", url="/foo") @@ -1039,6 +1040,7 @@ def test_copy_signature(self) -> None: copy_param = copy_signature.parameters.get(name) assert copy_param is not None, f"copy() signature is missing the {name} param" + @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12") def test_copy_build_request(self) -> None: options = FinalRequestOptions(method="get", url="/foo")