From 6df1fd6b994373a49b602258a8064998c88a0eca Mon Sep 17 00:00:00 2001
From: "stainless-app[bot]"
<142633134+stainless-app[bot]@users.noreply.github.com>
Date: Fri, 25 Apr 2025 22:46:47 +0000
Subject: [PATCH 1/3] feat(client): add support for model queries
---
.stats.yml | 4 +-
api.md | 1 +
src/replicate/_base_client.py | 28 +++++
src/replicate/_resource.py | 2 +
src/replicate/resources/models/models.py | 116 ++++++++++++++++++++-
src/replicate/types/__init__.py | 1 +
src/replicate/types/model_search_params.py | 12 +++
tests/api_resources/test_models.py | 68 ++++++++++++
8 files changed, 229 insertions(+), 3 deletions(-)
create mode 100644 src/replicate/types/model_search_params.py
diff --git a/.stats.yml b/.stats.yml
index eb9b16b..9523b41 100644
--- a/.stats.yml
+++ b/.stats.yml
@@ -1,4 +1,4 @@
-configured_endpoints: 29
+configured_endpoints: 30
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-2788217b7ad7d61d1a77800bc5ff12a6810f1692d4d770b72fa8f898c6a055ab.yml
openapi_spec_hash: 4423bf747e228484547b441468a9f156
-config_hash: f4b37a468a2e67394c9d35f080d37c51
+config_hash: 2e6a171ce57a4a6a8e8dcd3dd893d8cc
diff --git a/api.md b/api.md
index b223f02..2f6581a 100644
--- a/api.md
+++ b/api.md
@@ -70,6 +70,7 @@ Methods:
- client.models.list() -> SyncCursorURLPage[ModelListResponse]
- client.models.delete(model_name, \*, model_owner) -> None
- client.models.get(model_name, \*, model_owner) -> None
+- client.models.search(\*\*params) -> None
## Examples
diff --git a/src/replicate/_base_client.py b/src/replicate/_base_client.py
index 84db2c9..fec5e9d 100644
--- a/src/replicate/_base_client.py
+++ b/src/replicate/_base_client.py
@@ -1221,6 +1221,20 @@ def post(
)
return cast(ResponseT, self.request(cast_to, opts, stream=stream, stream_cls=stream_cls))
+ def query(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ body: Body | None = None,
+ options: RequestOptions = {},
+ files: RequestFiles | None = None,
+ ) -> ResponseT:
+ opts = FinalRequestOptions.construct(
+ method="query", url=path, json_data=body, files=to_httpx_files(files), **options
+ )
+ return self.request(cast_to, opts)
+
def patch(
self,
path: str,
@@ -1709,6 +1723,20 @@ async def post(
)
return await self.request(cast_to, opts, stream=stream, stream_cls=stream_cls)
+ async def query(
+ self,
+ path: str,
+ *,
+ cast_to: Type[ResponseT],
+ body: Body | None = None,
+ options: RequestOptions = {},
+ files: RequestFiles | None = None,
+ ) -> ResponseT:
+ opts = FinalRequestOptions.construct(
+ method="query", url=path, json_data=body, files=await async_to_httpx_files(files), **options
+ )
+ return await self.request(cast_to, opts)
+
async def patch(
self,
path: str,
diff --git a/src/replicate/_resource.py b/src/replicate/_resource.py
index a99e0ae..81eab3d 100644
--- a/src/replicate/_resource.py
+++ b/src/replicate/_resource.py
@@ -22,6 +22,7 @@ def __init__(self, client: ReplicateClient) -> None:
self._put = client.put
self._delete = client.delete
self._get_api_list = client.get_api_list
+ self._query = client.query
def _sleep(self, seconds: float) -> None:
time.sleep(seconds)
@@ -38,6 +39,7 @@ def __init__(self, client: AsyncReplicateClient) -> None:
self._put = client.put
self._delete = client.delete
self._get_api_list = client.get_api_list
+ self._query = client.query
async def _sleep(self, seconds: float) -> None:
await anyio.sleep(seconds)
diff --git a/src/replicate/resources/models/models.py b/src/replicate/resources/models/models.py
index e309f37..fbc48e3 100644
--- a/src/replicate/resources/models/models.py
+++ b/src/replicate/resources/models/models.py
@@ -14,7 +14,7 @@
ReadmeResourceWithStreamingResponse,
AsyncReadmeResourceWithStreamingResponse,
)
-from ...types import model_create_params
+from ...types import model_create_params, model_search_params
from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
from ..._utils import maybe_transform, async_maybe_transform
from .examples import (
@@ -404,6 +404,57 @@ def get(
cast_to=NoneType,
)
+ def search(
+ self,
+ *,
+ body: str,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> None:
+ """
+ Get a list of public models matching a search query.
+
+ Example cURL request:
+
+ ```console
+ curl -s -X QUERY \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ -H "Content-Type: text/plain" \\
+ -d "hello" \\
+ https://api.replicate.com/v1/models
+ ```
+
+ The response will be a paginated JSON object containing an array of model
+ objects.
+
+ See the [`models.get`](#models.get) docs for more details about the model
+ object.
+
+ Args:
+ body: The search query
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ return self._query(
+ "/models",
+ body=maybe_transform(body, model_search_params.ModelSearchParams),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=NoneType,
+ )
+
class AsyncModelsResource(AsyncAPIResource):
@cached_property
@@ -753,6 +804,57 @@ async def get(
cast_to=NoneType,
)
+ async def search(
+ self,
+ *,
+ body: str,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> None:
+ """
+ Get a list of public models matching a search query.
+
+ Example cURL request:
+
+ ```console
+ curl -s -X QUERY \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ -H "Content-Type: text/plain" \\
+ -d "hello" \\
+ https://api.replicate.com/v1/models
+ ```
+
+ The response will be a paginated JSON object containing an array of model
+ objects.
+
+ See the [`models.get`](#models.get) docs for more details about the model
+ object.
+
+ Args:
+ body: The search query
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ return await self._query(
+ "/models",
+ body=await async_maybe_transform(body, model_search_params.ModelSearchParams),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=NoneType,
+ )
+
class ModelsResourceWithRawResponse:
def __init__(self, models: ModelsResource) -> None:
@@ -770,6 +872,9 @@ def __init__(self, models: ModelsResource) -> None:
self.get = to_raw_response_wrapper(
models.get,
)
+ self.search = to_raw_response_wrapper(
+ models.search,
+ )
@cached_property
def examples(self) -> ExamplesResourceWithRawResponse:
@@ -804,6 +909,9 @@ def __init__(self, models: AsyncModelsResource) -> None:
self.get = async_to_raw_response_wrapper(
models.get,
)
+ self.search = async_to_raw_response_wrapper(
+ models.search,
+ )
@cached_property
def examples(self) -> AsyncExamplesResourceWithRawResponse:
@@ -838,6 +946,9 @@ def __init__(self, models: ModelsResource) -> None:
self.get = to_streamed_response_wrapper(
models.get,
)
+ self.search = to_streamed_response_wrapper(
+ models.search,
+ )
@cached_property
def examples(self) -> ExamplesResourceWithStreamingResponse:
@@ -872,6 +983,9 @@ def __init__(self, models: AsyncModelsResource) -> None:
self.get = async_to_streamed_response_wrapper(
models.get,
)
+ self.search = async_to_streamed_response_wrapper(
+ models.search,
+ )
@cached_property
def examples(self) -> AsyncExamplesResourceWithStreamingResponse:
diff --git a/src/replicate/types/__init__.py b/src/replicate/types/__init__.py
index a7c016f..67a3567 100644
--- a/src/replicate/types/__init__.py
+++ b/src/replicate/types/__init__.py
@@ -6,6 +6,7 @@
from .prediction_output import PredictionOutput as PredictionOutput
from .model_create_params import ModelCreateParams as ModelCreateParams
from .model_list_response import ModelListResponse as ModelListResponse
+from .model_search_params import ModelSearchParams as ModelSearchParams
from .account_get_response import AccountGetResponse as AccountGetResponse
from .training_get_response import TrainingGetResponse as TrainingGetResponse
from .hardware_list_response import HardwareListResponse as HardwareListResponse
diff --git a/src/replicate/types/model_search_params.py b/src/replicate/types/model_search_params.py
new file mode 100644
index 0000000..233d04c
--- /dev/null
+++ b/src/replicate/types/model_search_params.py
@@ -0,0 +1,12 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+from typing_extensions import Required, TypedDict
+
+__all__ = ["ModelSearchParams"]
+
+
+class ModelSearchParams(TypedDict, total=False):
+ body: Required[str]
+ """The search query"""
diff --git a/tests/api_resources/test_models.py b/tests/api_resources/test_models.py
index 1188e65..8c6ad82 100644
--- a/tests/api_resources/test_models.py
+++ b/tests/api_resources/test_models.py
@@ -209,6 +209,40 @@ def test_path_params_get(self, client: ReplicateClient) -> None:
model_owner="model_owner",
)
+ @pytest.mark.skip(reason="Prism doesn't support query methods yet")
+ @parametrize
+ def test_method_search(self, client: ReplicateClient) -> None:
+ model = client.models.search(
+ body="body",
+ )
+ assert model is None
+
+ @pytest.mark.skip(reason="Prism doesn't support query methods yet")
+ @parametrize
+ def test_raw_response_search(self, client: ReplicateClient) -> None:
+ response = client.models.with_raw_response.search(
+ body="body",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ model = response.parse()
+ assert model is None
+
+ @pytest.mark.skip(reason="Prism doesn't support query methods yet")
+ @parametrize
+ def test_streaming_response_search(self, client: ReplicateClient) -> None:
+ with client.models.with_streaming_response.search(
+ body="body",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ model = response.parse()
+ assert model is None
+
+ assert cast(Any, response.is_closed) is True
+
class TestAsyncModels:
parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
@@ -403,3 +437,37 @@ async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None
model_name="",
model_owner="model_owner",
)
+
+ @pytest.mark.skip(reason="Prism doesn't support query methods yet")
+ @parametrize
+ async def test_method_search(self, async_client: AsyncReplicateClient) -> None:
+ model = await async_client.models.search(
+ body="body",
+ )
+ assert model is None
+
+ @pytest.mark.skip(reason="Prism doesn't support query methods yet")
+ @parametrize
+ async def test_raw_response_search(self, async_client: AsyncReplicateClient) -> None:
+ response = await async_client.models.with_raw_response.search(
+ body="body",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ model = await response.parse()
+ assert model is None
+
+ @pytest.mark.skip(reason="Prism doesn't support query methods yet")
+ @parametrize
+ async def test_streaming_response_search(self, async_client: AsyncReplicateClient) -> None:
+ async with async_client.models.with_streaming_response.search(
+ body="body",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ model = await response.parse()
+ assert model is None
+
+ assert cast(Any, response.is_closed) is True
From a5aa64a71517fdf74e61e4debe68fba458f2e380 Mon Sep 17 00:00:00 2001
From: "stainless-app[bot]"
<142633134+stainless-app[bot]@users.noreply.github.com>
Date: Wed, 30 Apr 2025 14:15:43 +0000
Subject: [PATCH 2/3] feat(api): api update
---
.stats.yml | 4 ++--
src/replicate/resources/predictions.py | 22 +++++++++++++++----
.../types/prediction_create_params.py | 11 +++++++++-
3 files changed, 30 insertions(+), 7 deletions(-)
diff --git a/.stats.yml b/.stats.yml
index 9523b41..7e839de 100644
--- a/.stats.yml
+++ b/.stats.yml
@@ -1,4 +1,4 @@
configured_endpoints: 30
-openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-2788217b7ad7d61d1a77800bc5ff12a6810f1692d4d770b72fa8f898c6a055ab.yml
-openapi_spec_hash: 4423bf747e228484547b441468a9f156
+openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-0d7d82bff8a18b03e0cd1cbf8609c3026bb07db851bc6f9166032045a9925eea.yml
+openapi_spec_hash: 8ce211dfa6fece24b1413e91ba55210a
config_hash: 2e6a171ce57a4a6a8e8dcd3dd893d8cc
diff --git a/src/replicate/resources/predictions.py b/src/replicate/resources/predictions.py
index abe0ca1..6388e08 100644
--- a/src/replicate/resources/predictions.py
+++ b/src/replicate/resources/predictions.py
@@ -73,7 +73,7 @@ def create(
```console
curl -s -X POST -H 'Prefer: wait' \\
- -d '{"version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", "input": {"text": "Alice"}}' \\
+ -d '{"version": "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", "input": {"text": "Alice"}}' \\
-H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
-H 'Content-Type: application/json' \\
https://api.replicate.com/v1/predictions
@@ -110,7 +110,14 @@ def create(
- you don't want to upload and host the file somewhere
- you don't need to use the file again (Replicate will not store it)
- version: The ID of the model version that you want to run.
+ version: The ID of the model version that you want to run. This can be specified in two
+ formats:
+
+ 1. Just the 64-character version ID:
+ `9dcd6d78e7c6560c340d916fe32e9f24aabfa331e5cce95fe31f77fb03121426`
+ 2. Full model identifier with version ID in the format `{owner}/{model}:{id}`.
+ For example,
+ `replicate/hello-world:9dcd6d78e7c6560c340d916fe32e9f24aabfa331e5cce95fe31f77fb03121426`
stream: **This field is deprecated.**
@@ -484,7 +491,7 @@ async def create(
```console
curl -s -X POST -H 'Prefer: wait' \\
- -d '{"version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", "input": {"text": "Alice"}}' \\
+ -d '{"version": "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", "input": {"text": "Alice"}}' \\
-H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
-H 'Content-Type: application/json' \\
https://api.replicate.com/v1/predictions
@@ -521,7 +528,14 @@ async def create(
- you don't want to upload and host the file somewhere
- you don't need to use the file again (Replicate will not store it)
- version: The ID of the model version that you want to run.
+ version: The ID of the model version that you want to run. This can be specified in two
+ formats:
+
+ 1. Just the 64-character version ID:
+ `9dcd6d78e7c6560c340d916fe32e9f24aabfa331e5cce95fe31f77fb03121426`
+ 2. Full model identifier with version ID in the format `{owner}/{model}:{id}`.
+ For example,
+ `replicate/hello-world:9dcd6d78e7c6560c340d916fe32e9f24aabfa331e5cce95fe31f77fb03121426`
stream: **This field is deprecated.**
diff --git a/src/replicate/types/prediction_create_params.py b/src/replicate/types/prediction_create_params.py
index e5ab7d0..4e3026f 100644
--- a/src/replicate/types/prediction_create_params.py
+++ b/src/replicate/types/prediction_create_params.py
@@ -37,7 +37,16 @@ class PredictionCreateParams(TypedDict, total=False):
"""
version: Required[str]
- """The ID of the model version that you want to run."""
+ """The ID of the model version that you want to run.
+
+ This can be specified in two formats:
+
+ 1. Just the 64-character version ID:
+ `9dcd6d78e7c6560c340d916fe32e9f24aabfa331e5cce95fe31f77fb03121426`
+ 2. Full model identifier with version ID in the format `{owner}/{model}:{id}`.
+ For example,
+ `replicate/hello-world:9dcd6d78e7c6560c340d916fe32e9f24aabfa331e5cce95fe31f77fb03121426`
+ """
stream: bool
"""**This field is deprecated.**
From b93d8b2b03f408dc77e3bb77f3fb9fa1ebff07cd Mon Sep 17 00:00:00 2001
From: "stainless-app[bot]"
<142633134+stainless-app[bot]@users.noreply.github.com>
Date: Wed, 30 Apr 2025 14:16:17 +0000
Subject: [PATCH 3/3] release: 0.1.0-alpha.8
---
.release-please-manifest.json | 2 +-
CHANGELOG.md | 9 +++++++++
pyproject.toml | 2 +-
src/replicate/_version.py | 2 +-
4 files changed, 12 insertions(+), 3 deletions(-)
diff --git a/.release-please-manifest.json b/.release-please-manifest.json
index b5db7ce..c373724 100644
--- a/.release-please-manifest.json
+++ b/.release-please-manifest.json
@@ -1,3 +1,3 @@
{
- ".": "0.1.0-alpha.7"
+ ".": "0.1.0-alpha.8"
}
\ No newline at end of file
diff --git a/CHANGELOG.md b/CHANGELOG.md
index b44d0e0..2fced87 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,14 @@
# Changelog
+## 0.1.0-alpha.8 (2025-04-30)
+
+Full Changelog: [v0.1.0-alpha.7...v0.1.0-alpha.8](https://github.com/replicate/replicate-python-stainless/compare/v0.1.0-alpha.7...v0.1.0-alpha.8)
+
+### Features
+
+* **api:** api update ([a5aa64a](https://github.com/replicate/replicate-python-stainless/commit/a5aa64a71517fdf74e61e4debe68fba458f2e380))
+* **client:** add support for model queries ([6df1fd6](https://github.com/replicate/replicate-python-stainless/commit/6df1fd6b994373a49b602258a8064998c88a0eca))
+
## 0.1.0-alpha.7 (2025-04-24)
Full Changelog: [v0.1.0-alpha.6...v0.1.0-alpha.7](https://github.com/replicate/replicate-python-stainless/compare/v0.1.0-alpha.6...v0.1.0-alpha.7)
diff --git a/pyproject.toml b/pyproject.toml
index dea9534..be725ab 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "replicate-stainless"
-version = "0.1.0-alpha.7"
+version = "0.1.0-alpha.8"
description = "The official Python library for the replicate-client API"
dynamic = ["readme"]
license = "Apache-2.0"
diff --git a/src/replicate/_version.py b/src/replicate/_version.py
index eecdc4a..514d196 100644
--- a/src/replicate/_version.py
+++ b/src/replicate/_version.py
@@ -1,4 +1,4 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
__title__ = "replicate"
-__version__ = "0.1.0-alpha.7" # x-release-please-version
+__version__ = "0.1.0-alpha.8" # x-release-please-version