Skip to content

Commit 6df1fd6

Browse files
feat(client): add support for model queries
1 parent 35fe4c5 commit 6df1fd6

File tree

8 files changed

+229
-3
lines changed

8 files changed

+229
-3
lines changed

.stats.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
configured_endpoints: 29
1+
configured_endpoints: 30
22
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-2788217b7ad7d61d1a77800bc5ff12a6810f1692d4d770b72fa8f898c6a055ab.yml
33
openapi_spec_hash: 4423bf747e228484547b441468a9f156
4-
config_hash: f4b37a468a2e67394c9d35f080d37c51
4+
config_hash: 2e6a171ce57a4a6a8e8dcd3dd893d8cc

api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ Methods:
7070
- <code title="get /models">client.models.<a href="./src/replicate/resources/models/models.py">list</a>() -> <a href="./src/replicate/types/model_list_response.py">SyncCursorURLPage[ModelListResponse]</a></code>
7171
- <code title="delete /models/{model_owner}/{model_name}">client.models.<a href="./src/replicate/resources/models/models.py">delete</a>(model_name, \*, model_owner) -> None</code>
7272
- <code title="get /models/{model_owner}/{model_name}">client.models.<a href="./src/replicate/resources/models/models.py">get</a>(model_name, \*, model_owner) -> None</code>
73+
- <code title="query /models">client.models.<a href="./src/replicate/resources/models/models.py">search</a>(\*\*<a href="src/replicate/types/model_search_params.py">params</a>) -> None</code>
7374

7475
## Examples
7576

src/replicate/_base_client.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,6 +1221,20 @@ def post(
12211221
)
12221222
return cast(ResponseT, self.request(cast_to, opts, stream=stream, stream_cls=stream_cls))
12231223

1224+
def query(
1225+
self,
1226+
path: str,
1227+
*,
1228+
cast_to: Type[ResponseT],
1229+
body: Body | None = None,
1230+
options: RequestOptions = {},
1231+
files: RequestFiles | None = None,
1232+
) -> ResponseT:
1233+
opts = FinalRequestOptions.construct(
1234+
method="query", url=path, json_data=body, files=to_httpx_files(files), **options
1235+
)
1236+
return self.request(cast_to, opts)
1237+
12241238
def patch(
12251239
self,
12261240
path: str,
@@ -1709,6 +1723,20 @@ async def post(
17091723
)
17101724
return await self.request(cast_to, opts, stream=stream, stream_cls=stream_cls)
17111725

1726+
async def query(
1727+
self,
1728+
path: str,
1729+
*,
1730+
cast_to: Type[ResponseT],
1731+
body: Body | None = None,
1732+
options: RequestOptions = {},
1733+
files: RequestFiles | None = None,
1734+
) -> ResponseT:
1735+
opts = FinalRequestOptions.construct(
1736+
method="query", url=path, json_data=body, files=await async_to_httpx_files(files), **options
1737+
)
1738+
return await self.request(cast_to, opts)
1739+
17121740
async def patch(
17131741
self,
17141742
path: str,

src/replicate/_resource.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(self, client: ReplicateClient) -> None:
2222
self._put = client.put
2323
self._delete = client.delete
2424
self._get_api_list = client.get_api_list
25+
self._query = client.query
2526

2627
def _sleep(self, seconds: float) -> None:
2728
time.sleep(seconds)
@@ -38,6 +39,7 @@ def __init__(self, client: AsyncReplicateClient) -> None:
3839
self._put = client.put
3940
self._delete = client.delete
4041
self._get_api_list = client.get_api_list
42+
self._query = client.query
4143

4244
async def _sleep(self, seconds: float) -> None:
4345
await anyio.sleep(seconds)

src/replicate/resources/models/models.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
ReadmeResourceWithStreamingResponse,
1515
AsyncReadmeResourceWithStreamingResponse,
1616
)
17-
from ...types import model_create_params
17+
from ...types import model_create_params, model_search_params
1818
from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
1919
from ..._utils import maybe_transform, async_maybe_transform
2020
from .examples import (
@@ -404,6 +404,57 @@ def get(
404404
cast_to=NoneType,
405405
)
406406

407+
def search(
408+
self,
409+
*,
410+
body: str,
411+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
412+
# The extra values given here take precedence over values defined on the client or passed to this method.
413+
extra_headers: Headers | None = None,
414+
extra_query: Query | None = None,
415+
extra_body: Body | None = None,
416+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
417+
) -> None:
418+
"""
419+
Get a list of public models matching a search query.
420+
421+
Example cURL request:
422+
423+
```console
424+
curl -s -X QUERY \\
425+
-H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
426+
-H "Content-Type: text/plain" \\
427+
-d "hello" \\
428+
https://api.replicate.com/v1/models
429+
```
430+
431+
The response will be a paginated JSON object containing an array of model
432+
objects.
433+
434+
See the [`models.get`](#models.get) docs for more details about the model
435+
object.
436+
437+
Args:
438+
body: The search query
439+
440+
extra_headers: Send extra headers
441+
442+
extra_query: Add additional query parameters to the request
443+
444+
extra_body: Add additional JSON properties to the request
445+
446+
timeout: Override the client-level default timeout for this request, in seconds
447+
"""
448+
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
449+
return self._query(
450+
"/models",
451+
body=maybe_transform(body, model_search_params.ModelSearchParams),
452+
options=make_request_options(
453+
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
454+
),
455+
cast_to=NoneType,
456+
)
457+
407458

408459
class AsyncModelsResource(AsyncAPIResource):
409460
@cached_property
@@ -753,6 +804,57 @@ async def get(
753804
cast_to=NoneType,
754805
)
755806

807+
async def search(
808+
self,
809+
*,
810+
body: str,
811+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
812+
# The extra values given here take precedence over values defined on the client or passed to this method.
813+
extra_headers: Headers | None = None,
814+
extra_query: Query | None = None,
815+
extra_body: Body | None = None,
816+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
817+
) -> None:
818+
"""
819+
Get a list of public models matching a search query.
820+
821+
Example cURL request:
822+
823+
```console
824+
curl -s -X QUERY \\
825+
-H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
826+
-H "Content-Type: text/plain" \\
827+
-d "hello" \\
828+
https://api.replicate.com/v1/models
829+
```
830+
831+
The response will be a paginated JSON object containing an array of model
832+
objects.
833+
834+
See the [`models.get`](#models.get) docs for more details about the model
835+
object.
836+
837+
Args:
838+
body: The search query
839+
840+
extra_headers: Send extra headers
841+
842+
extra_query: Add additional query parameters to the request
843+
844+
extra_body: Add additional JSON properties to the request
845+
846+
timeout: Override the client-level default timeout for this request, in seconds
847+
"""
848+
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
849+
return await self._query(
850+
"/models",
851+
body=await async_maybe_transform(body, model_search_params.ModelSearchParams),
852+
options=make_request_options(
853+
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
854+
),
855+
cast_to=NoneType,
856+
)
857+
756858

757859
class ModelsResourceWithRawResponse:
758860
def __init__(self, models: ModelsResource) -> None:
@@ -770,6 +872,9 @@ def __init__(self, models: ModelsResource) -> None:
770872
self.get = to_raw_response_wrapper(
771873
models.get,
772874
)
875+
self.search = to_raw_response_wrapper(
876+
models.search,
877+
)
773878

774879
@cached_property
775880
def examples(self) -> ExamplesResourceWithRawResponse:
@@ -804,6 +909,9 @@ def __init__(self, models: AsyncModelsResource) -> None:
804909
self.get = async_to_raw_response_wrapper(
805910
models.get,
806911
)
912+
self.search = async_to_raw_response_wrapper(
913+
models.search,
914+
)
807915

808916
@cached_property
809917
def examples(self) -> AsyncExamplesResourceWithRawResponse:
@@ -838,6 +946,9 @@ def __init__(self, models: ModelsResource) -> None:
838946
self.get = to_streamed_response_wrapper(
839947
models.get,
840948
)
949+
self.search = to_streamed_response_wrapper(
950+
models.search,
951+
)
841952

842953
@cached_property
843954
def examples(self) -> ExamplesResourceWithStreamingResponse:
@@ -872,6 +983,9 @@ def __init__(self, models: AsyncModelsResource) -> None:
872983
self.get = async_to_streamed_response_wrapper(
873984
models.get,
874985
)
986+
self.search = async_to_streamed_response_wrapper(
987+
models.search,
988+
)
875989

876990
@cached_property
877991
def examples(self) -> AsyncExamplesResourceWithStreamingResponse:

src/replicate/types/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .prediction_output import PredictionOutput as PredictionOutput
77
from .model_create_params import ModelCreateParams as ModelCreateParams
88
from .model_list_response import ModelListResponse as ModelListResponse
9+
from .model_search_params import ModelSearchParams as ModelSearchParams
910
from .account_get_response import AccountGetResponse as AccountGetResponse
1011
from .training_get_response import TrainingGetResponse as TrainingGetResponse
1112
from .hardware_list_response import HardwareListResponse as HardwareListResponse
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2+
3+
from __future__ import annotations
4+
5+
from typing_extensions import Required, TypedDict
6+
7+
__all__ = ["ModelSearchParams"]
8+
9+
10+
class ModelSearchParams(TypedDict, total=False):
11+
body: Required[str]
12+
"""The search query"""

tests/api_resources/test_models.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,40 @@ def test_path_params_get(self, client: ReplicateClient) -> None:
209209
model_owner="model_owner",
210210
)
211211

212+
@pytest.mark.skip(reason="Prism doesn't support query methods yet")
213+
@parametrize
214+
def test_method_search(self, client: ReplicateClient) -> None:
215+
model = client.models.search(
216+
body="body",
217+
)
218+
assert model is None
219+
220+
@pytest.mark.skip(reason="Prism doesn't support query methods yet")
221+
@parametrize
222+
def test_raw_response_search(self, client: ReplicateClient) -> None:
223+
response = client.models.with_raw_response.search(
224+
body="body",
225+
)
226+
227+
assert response.is_closed is True
228+
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
229+
model = response.parse()
230+
assert model is None
231+
232+
@pytest.mark.skip(reason="Prism doesn't support query methods yet")
233+
@parametrize
234+
def test_streaming_response_search(self, client: ReplicateClient) -> None:
235+
with client.models.with_streaming_response.search(
236+
body="body",
237+
) as response:
238+
assert not response.is_closed
239+
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
240+
241+
model = response.parse()
242+
assert model is None
243+
244+
assert cast(Any, response.is_closed) is True
245+
212246

213247
class TestAsyncModels:
214248
parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
@@ -403,3 +437,37 @@ async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None
403437
model_name="",
404438
model_owner="model_owner",
405439
)
440+
441+
@pytest.mark.skip(reason="Prism doesn't support query methods yet")
442+
@parametrize
443+
async def test_method_search(self, async_client: AsyncReplicateClient) -> None:
444+
model = await async_client.models.search(
445+
body="body",
446+
)
447+
assert model is None
448+
449+
@pytest.mark.skip(reason="Prism doesn't support query methods yet")
450+
@parametrize
451+
async def test_raw_response_search(self, async_client: AsyncReplicateClient) -> None:
452+
response = await async_client.models.with_raw_response.search(
453+
body="body",
454+
)
455+
456+
assert response.is_closed is True
457+
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
458+
model = await response.parse()
459+
assert model is None
460+
461+
@pytest.mark.skip(reason="Prism doesn't support query methods yet")
462+
@parametrize
463+
async def test_streaming_response_search(self, async_client: AsyncReplicateClient) -> None:
464+
async with async_client.models.with_streaming_response.search(
465+
body="body",
466+
) as response:
467+
assert not response.is_closed
468+
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
469+
470+
model = await response.parse()
471+
assert model is None
472+
473+
assert cast(Any, response.is_closed) is True

0 commit comments

Comments
 (0)