From 6df1fd6b994373a49b602258a8064998c88a0eca Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Fri, 25 Apr 2025 22:46:47 +0000 Subject: [PATCH 1/3] feat(client): add support for model queries --- .stats.yml | 4 +- api.md | 1 + src/replicate/_base_client.py | 28 +++++ src/replicate/_resource.py | 2 + src/replicate/resources/models/models.py | 116 ++++++++++++++++++++- src/replicate/types/__init__.py | 1 + src/replicate/types/model_search_params.py | 12 +++ tests/api_resources/test_models.py | 68 ++++++++++++ 8 files changed, 229 insertions(+), 3 deletions(-) create mode 100644 src/replicate/types/model_search_params.py diff --git a/.stats.yml b/.stats.yml index eb9b16b..9523b41 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ -configured_endpoints: 29 +configured_endpoints: 30 openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-2788217b7ad7d61d1a77800bc5ff12a6810f1692d4d770b72fa8f898c6a055ab.yml openapi_spec_hash: 4423bf747e228484547b441468a9f156 -config_hash: f4b37a468a2e67394c9d35f080d37c51 +config_hash: 2e6a171ce57a4a6a8e8dcd3dd893d8cc diff --git a/api.md b/api.md index b223f02..2f6581a 100644 --- a/api.md +++ b/api.md @@ -70,6 +70,7 @@ Methods: - client.models.list() -> SyncCursorURLPage[ModelListResponse] - client.models.delete(model_name, \*, model_owner) -> None - client.models.get(model_name, \*, model_owner) -> None +- client.models.search(\*\*params) -> None ## Examples diff --git a/src/replicate/_base_client.py b/src/replicate/_base_client.py index 84db2c9..fec5e9d 100644 --- a/src/replicate/_base_client.py +++ b/src/replicate/_base_client.py @@ -1221,6 +1221,20 @@ def post( ) return cast(ResponseT, self.request(cast_to, opts, stream=stream, stream_cls=stream_cls)) + def query( + self, + path: str, + *, + cast_to: Type[ResponseT], + body: Body | None = None, + options: RequestOptions = {}, + files: RequestFiles | None = None, + ) -> ResponseT: + opts = FinalRequestOptions.construct( + method="query", url=path, json_data=body, files=to_httpx_files(files), **options + ) + return self.request(cast_to, opts) + def patch( self, path: str, @@ -1709,6 +1723,20 @@ async def post( ) return await self.request(cast_to, opts, stream=stream, stream_cls=stream_cls) + async def query( + self, + path: str, + *, + cast_to: Type[ResponseT], + body: Body | None = None, + options: RequestOptions = {}, + files: RequestFiles | None = None, + ) -> ResponseT: + opts = FinalRequestOptions.construct( + method="query", url=path, json_data=body, files=await async_to_httpx_files(files), **options + ) + return await self.request(cast_to, opts) + async def patch( self, path: str, diff --git a/src/replicate/_resource.py b/src/replicate/_resource.py index a99e0ae..81eab3d 100644 --- a/src/replicate/_resource.py +++ b/src/replicate/_resource.py @@ -22,6 +22,7 @@ def __init__(self, client: ReplicateClient) -> None: self._put = client.put self._delete = client.delete self._get_api_list = client.get_api_list + self._query = client.query def _sleep(self, seconds: float) -> None: time.sleep(seconds) @@ -38,6 +39,7 @@ def __init__(self, client: AsyncReplicateClient) -> None: self._put = client.put self._delete = client.delete self._get_api_list = client.get_api_list + self._query = client.query async def _sleep(self, seconds: float) -> None: await anyio.sleep(seconds) diff --git a/src/replicate/resources/models/models.py b/src/replicate/resources/models/models.py index e309f37..fbc48e3 100644 --- a/src/replicate/resources/models/models.py +++ b/src/replicate/resources/models/models.py @@ -14,7 +14,7 @@ ReadmeResourceWithStreamingResponse, AsyncReadmeResourceWithStreamingResponse, ) -from ...types import model_create_params +from ...types import model_create_params, model_search_params from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven from ..._utils import maybe_transform, async_maybe_transform from .examples import ( @@ -404,6 +404,57 @@ def get( cast_to=NoneType, ) + def search( + self, + *, + body: str, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> None: + """ + Get a list of public models matching a search query. + + Example cURL request: + + ```console + curl -s -X QUERY \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + -H "Content-Type: text/plain" \\ + -d "hello" \\ + https://api.replicate.com/v1/models + ``` + + The response will be a paginated JSON object containing an array of model + objects. + + See the [`models.get`](#models.get) docs for more details about the model + object. + + Args: + body: The search query + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = {"Accept": "*/*", **(extra_headers or {})} + return self._query( + "/models", + body=maybe_transform(body, model_search_params.ModelSearchParams), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=NoneType, + ) + class AsyncModelsResource(AsyncAPIResource): @cached_property @@ -753,6 +804,57 @@ async def get( cast_to=NoneType, ) + async def search( + self, + *, + body: str, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> None: + """ + Get a list of public models matching a search query. + + Example cURL request: + + ```console + curl -s -X QUERY \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + -H "Content-Type: text/plain" \\ + -d "hello" \\ + https://api.replicate.com/v1/models + ``` + + The response will be a paginated JSON object containing an array of model + objects. + + See the [`models.get`](#models.get) docs for more details about the model + object. + + Args: + body: The search query + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = {"Accept": "*/*", **(extra_headers or {})} + return await self._query( + "/models", + body=await async_maybe_transform(body, model_search_params.ModelSearchParams), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=NoneType, + ) + class ModelsResourceWithRawResponse: def __init__(self, models: ModelsResource) -> None: @@ -770,6 +872,9 @@ def __init__(self, models: ModelsResource) -> None: self.get = to_raw_response_wrapper( models.get, ) + self.search = to_raw_response_wrapper( + models.search, + ) @cached_property def examples(self) -> ExamplesResourceWithRawResponse: @@ -804,6 +909,9 @@ def __init__(self, models: AsyncModelsResource) -> None: self.get = async_to_raw_response_wrapper( models.get, ) + self.search = async_to_raw_response_wrapper( + models.search, + ) @cached_property def examples(self) -> AsyncExamplesResourceWithRawResponse: @@ -838,6 +946,9 @@ def __init__(self, models: ModelsResource) -> None: self.get = to_streamed_response_wrapper( models.get, ) + self.search = to_streamed_response_wrapper( + models.search, + ) @cached_property def examples(self) -> ExamplesResourceWithStreamingResponse: @@ -872,6 +983,9 @@ def __init__(self, models: AsyncModelsResource) -> None: self.get = async_to_streamed_response_wrapper( models.get, ) + self.search = async_to_streamed_response_wrapper( + models.search, + ) @cached_property def examples(self) -> AsyncExamplesResourceWithStreamingResponse: diff --git a/src/replicate/types/__init__.py b/src/replicate/types/__init__.py index a7c016f..67a3567 100644 --- a/src/replicate/types/__init__.py +++ b/src/replicate/types/__init__.py @@ -6,6 +6,7 @@ from .prediction_output import PredictionOutput as PredictionOutput from .model_create_params import ModelCreateParams as ModelCreateParams from .model_list_response import ModelListResponse as ModelListResponse +from .model_search_params import ModelSearchParams as ModelSearchParams from .account_get_response import AccountGetResponse as AccountGetResponse from .training_get_response import TrainingGetResponse as TrainingGetResponse from .hardware_list_response import HardwareListResponse as HardwareListResponse diff --git a/src/replicate/types/model_search_params.py b/src/replicate/types/model_search_params.py new file mode 100644 index 0000000..233d04c --- /dev/null +++ b/src/replicate/types/model_search_params.py @@ -0,0 +1,12 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Required, TypedDict + +__all__ = ["ModelSearchParams"] + + +class ModelSearchParams(TypedDict, total=False): + body: Required[str] + """The search query""" diff --git a/tests/api_resources/test_models.py b/tests/api_resources/test_models.py index 1188e65..8c6ad82 100644 --- a/tests/api_resources/test_models.py +++ b/tests/api_resources/test_models.py @@ -209,6 +209,40 @@ def test_path_params_get(self, client: ReplicateClient) -> None: model_owner="model_owner", ) + @pytest.mark.skip(reason="Prism doesn't support query methods yet") + @parametrize + def test_method_search(self, client: ReplicateClient) -> None: + model = client.models.search( + body="body", + ) + assert model is None + + @pytest.mark.skip(reason="Prism doesn't support query methods yet") + @parametrize + def test_raw_response_search(self, client: ReplicateClient) -> None: + response = client.models.with_raw_response.search( + body="body", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + model = response.parse() + assert model is None + + @pytest.mark.skip(reason="Prism doesn't support query methods yet") + @parametrize + def test_streaming_response_search(self, client: ReplicateClient) -> None: + with client.models.with_streaming_response.search( + body="body", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + model = response.parse() + assert model is None + + assert cast(Any, response.is_closed) is True + class TestAsyncModels: parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @@ -403,3 +437,37 @@ async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None model_name="", model_owner="model_owner", ) + + @pytest.mark.skip(reason="Prism doesn't support query methods yet") + @parametrize + async def test_method_search(self, async_client: AsyncReplicateClient) -> None: + model = await async_client.models.search( + body="body", + ) + assert model is None + + @pytest.mark.skip(reason="Prism doesn't support query methods yet") + @parametrize + async def test_raw_response_search(self, async_client: AsyncReplicateClient) -> None: + response = await async_client.models.with_raw_response.search( + body="body", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + model = await response.parse() + assert model is None + + @pytest.mark.skip(reason="Prism doesn't support query methods yet") + @parametrize + async def test_streaming_response_search(self, async_client: AsyncReplicateClient) -> None: + async with async_client.models.with_streaming_response.search( + body="body", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + model = await response.parse() + assert model is None + + assert cast(Any, response.is_closed) is True From a5aa64a71517fdf74e61e4debe68fba458f2e380 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Wed, 30 Apr 2025 14:15:43 +0000 Subject: [PATCH 2/3] feat(api): api update --- .stats.yml | 4 ++-- src/replicate/resources/predictions.py | 22 +++++++++++++++---- .../types/prediction_create_params.py | 11 +++++++++- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/.stats.yml b/.stats.yml index 9523b41..7e839de 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ configured_endpoints: 30 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-2788217b7ad7d61d1a77800bc5ff12a6810f1692d4d770b72fa8f898c6a055ab.yml -openapi_spec_hash: 4423bf747e228484547b441468a9f156 +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-0d7d82bff8a18b03e0cd1cbf8609c3026bb07db851bc6f9166032045a9925eea.yml +openapi_spec_hash: 8ce211dfa6fece24b1413e91ba55210a config_hash: 2e6a171ce57a4a6a8e8dcd3dd893d8cc diff --git a/src/replicate/resources/predictions.py b/src/replicate/resources/predictions.py index abe0ca1..6388e08 100644 --- a/src/replicate/resources/predictions.py +++ b/src/replicate/resources/predictions.py @@ -73,7 +73,7 @@ def create( ```console curl -s -X POST -H 'Prefer: wait' \\ - -d '{"version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", "input": {"text": "Alice"}}' \\ + -d '{"version": "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", "input": {"text": "Alice"}}' \\ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ -H 'Content-Type: application/json' \\ https://api.replicate.com/v1/predictions @@ -110,7 +110,14 @@ def create( - you don't want to upload and host the file somewhere - you don't need to use the file again (Replicate will not store it) - version: The ID of the model version that you want to run. + version: The ID of the model version that you want to run. This can be specified in two + formats: + + 1. Just the 64-character version ID: + `9dcd6d78e7c6560c340d916fe32e9f24aabfa331e5cce95fe31f77fb03121426` + 2. Full model identifier with version ID in the format `{owner}/{model}:{id}`. + For example, + `replicate/hello-world:9dcd6d78e7c6560c340d916fe32e9f24aabfa331e5cce95fe31f77fb03121426` stream: **This field is deprecated.** @@ -484,7 +491,7 @@ async def create( ```console curl -s -X POST -H 'Prefer: wait' \\ - -d '{"version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", "input": {"text": "Alice"}}' \\ + -d '{"version": "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", "input": {"text": "Alice"}}' \\ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ -H 'Content-Type: application/json' \\ https://api.replicate.com/v1/predictions @@ -521,7 +528,14 @@ async def create( - you don't want to upload and host the file somewhere - you don't need to use the file again (Replicate will not store it) - version: The ID of the model version that you want to run. + version: The ID of the model version that you want to run. This can be specified in two + formats: + + 1. Just the 64-character version ID: + `9dcd6d78e7c6560c340d916fe32e9f24aabfa331e5cce95fe31f77fb03121426` + 2. Full model identifier with version ID in the format `{owner}/{model}:{id}`. + For example, + `replicate/hello-world:9dcd6d78e7c6560c340d916fe32e9f24aabfa331e5cce95fe31f77fb03121426` stream: **This field is deprecated.** diff --git a/src/replicate/types/prediction_create_params.py b/src/replicate/types/prediction_create_params.py index e5ab7d0..4e3026f 100644 --- a/src/replicate/types/prediction_create_params.py +++ b/src/replicate/types/prediction_create_params.py @@ -37,7 +37,16 @@ class PredictionCreateParams(TypedDict, total=False): """ version: Required[str] - """The ID of the model version that you want to run.""" + """The ID of the model version that you want to run. + + This can be specified in two formats: + + 1. Just the 64-character version ID: + `9dcd6d78e7c6560c340d916fe32e9f24aabfa331e5cce95fe31f77fb03121426` + 2. Full model identifier with version ID in the format `{owner}/{model}:{id}`. + For example, + `replicate/hello-world:9dcd6d78e7c6560c340d916fe32e9f24aabfa331e5cce95fe31f77fb03121426` + """ stream: bool """**This field is deprecated.** From b93d8b2b03f408dc77e3bb77f3fb9fa1ebff07cd Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Wed, 30 Apr 2025 14:16:17 +0000 Subject: [PATCH 3/3] release: 0.1.0-alpha.8 --- .release-please-manifest.json | 2 +- CHANGELOG.md | 9 +++++++++ pyproject.toml | 2 +- src/replicate/_version.py | 2 +- 4 files changed, 12 insertions(+), 3 deletions(-) diff --git a/.release-please-manifest.json b/.release-please-manifest.json index b5db7ce..c373724 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "0.1.0-alpha.7" + ".": "0.1.0-alpha.8" } \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index b44d0e0..2fced87 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ # Changelog +## 0.1.0-alpha.8 (2025-04-30) + +Full Changelog: [v0.1.0-alpha.7...v0.1.0-alpha.8](https://github.com/replicate/replicate-python-stainless/compare/v0.1.0-alpha.7...v0.1.0-alpha.8) + +### Features + +* **api:** api update ([a5aa64a](https://github.com/replicate/replicate-python-stainless/commit/a5aa64a71517fdf74e61e4debe68fba458f2e380)) +* **client:** add support for model queries ([6df1fd6](https://github.com/replicate/replicate-python-stainless/commit/6df1fd6b994373a49b602258a8064998c88a0eca)) + ## 0.1.0-alpha.7 (2025-04-24) Full Changelog: [v0.1.0-alpha.6...v0.1.0-alpha.7](https://github.com/replicate/replicate-python-stainless/compare/v0.1.0-alpha.6...v0.1.0-alpha.7) diff --git a/pyproject.toml b/pyproject.toml index dea9534..be725ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "replicate-stainless" -version = "0.1.0-alpha.7" +version = "0.1.0-alpha.8" description = "The official Python library for the replicate-client API" dynamic = ["readme"] license = "Apache-2.0" diff --git a/src/replicate/_version.py b/src/replicate/_version.py index eecdc4a..514d196 100644 --- a/src/replicate/_version.py +++ b/src/replicate/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "replicate" -__version__ = "0.1.0-alpha.7" # x-release-please-version +__version__ = "0.1.0-alpha.8" # x-release-please-version