Skip to content

Commit ffa306e

Browse files
feat(api): update via SDK Studio (#16)
1 parent 31aa7ed commit ffa306e

File tree

6 files changed

+75
-43
lines changed

6 files changed

+75
-43
lines changed

api.md

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,10 @@ Methods:
8080

8181
# Predictions
8282

83-
Types:
84-
85-
```python
86-
from replicate.types import PredictionListResponse
87-
```
88-
8983
Methods:
9084

9185
- <code title="post /predictions">client.predictions.<a href="./src/replicate/resources/predictions.py">create</a>(\*\*<a href="src/replicate/types/prediction_create_params.py">params</a>) -> <a href="./src/replicate/types/prediction_response.py">PredictionResponse</a></code>
92-
- <code title="get /predictions">client.predictions.<a href="./src/replicate/resources/predictions.py">list</a>(\*\*<a href="src/replicate/types/prediction_list_params.py">params</a>) -> <a href="./src/replicate/types/prediction_list_response.py">PredictionListResponse</a></code>
86+
- <code title="get /predictions">client.predictions.<a href="./src/replicate/resources/predictions.py">list</a>(\*\*<a href="src/replicate/types/prediction_list_params.py">params</a>) -> <a href="./src/replicate/types/prediction_response.py">SyncCursorURLPage[PredictionResponse]</a></code>
9387
- <code title="post /predictions/{prediction_id}/cancel">client.predictions.<a href="./src/replicate/resources/predictions.py">cancel</a>(prediction_id) -> None</code>
9488
- <code title="get /predictions/{prediction_id}">client.predictions.<a href="./src/replicate/resources/predictions.py">list_by_id</a>(prediction_id) -> <a href="./src/replicate/types/prediction_response.py">PredictionResponse</a></code>
9589

src/replicate/pagination.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2+
3+
from typing import List, Generic, TypeVar, Optional
4+
from typing_extensions import override
5+
6+
import httpx
7+
8+
from ._base_client import BasePage, PageInfo, BaseSyncPage, BaseAsyncPage
9+
10+
__all__ = ["SyncCursorURLPage", "AsyncCursorURLPage"]
11+
12+
_T = TypeVar("_T")
13+
14+
15+
class SyncCursorURLPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
16+
results: List[_T]
17+
next: Optional[str] = None
18+
19+
@override
20+
def _get_page_items(self) -> List[_T]:
21+
results = self.results
22+
if not results:
23+
return []
24+
return results
25+
26+
@override
27+
def next_page_info(self) -> Optional[PageInfo]:
28+
url = self.next
29+
if url is None:
30+
return None
31+
32+
return PageInfo(url=httpx.URL(url))
33+
34+
35+
class AsyncCursorURLPage(BaseAsyncPage[_T], BasePage[_T], Generic[_T]):
36+
results: List[_T]
37+
next: Optional[str] = None
38+
39+
@override
40+
def _get_page_items(self) -> List[_T]:
41+
results = self.results
42+
if not results:
43+
return []
44+
return results
45+
46+
@override
47+
def next_page_info(self) -> Optional[PageInfo]:
48+
url = self.next
49+
if url is None:
50+
return None
51+
52+
return PageInfo(url=httpx.URL(url))

src/replicate/resources/predictions.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
async_to_raw_response_wrapper,
2424
async_to_streamed_response_wrapper,
2525
)
26-
from .._base_client import make_request_options
26+
from ..pagination import SyncCursorURLPage, AsyncCursorURLPage
27+
from .._base_client import AsyncPaginator, make_request_options
2728
from ..types.prediction_response import PredictionResponse
28-
from ..types.prediction_list_response import PredictionListResponse
2929

3030
__all__ = ["PredictionsResource", "AsyncPredictionsResource"]
3131

@@ -200,7 +200,7 @@ def list(
200200
extra_query: Query | None = None,
201201
extra_body: Body | None = None,
202202
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
203-
) -> PredictionListResponse:
203+
) -> SyncCursorURLPage[PredictionResponse]:
204204
"""
205205
Get a paginated list of all predictions created by the user or organization
206206
associated with the provided API token.
@@ -285,8 +285,9 @@ def list(
285285
286286
timeout: Override the client-level default timeout for this request, in seconds
287287
"""
288-
return self._get(
288+
return self._get_api_list(
289289
"/predictions",
290+
page=SyncCursorURLPage[PredictionResponse],
290291
options=make_request_options(
291292
extra_headers=extra_headers,
292293
extra_query=extra_query,
@@ -300,7 +301,7 @@ def list(
300301
prediction_list_params.PredictionListParams,
301302
),
302303
),
303-
cast_to=PredictionListResponse,
304+
model=PredictionResponse,
304305
)
305306

306307
def cancel(
@@ -599,7 +600,7 @@ async def create(
599600
cast_to=PredictionResponse,
600601
)
601602

602-
async def list(
603+
def list(
603604
self,
604605
*,
605606
created_after: Union[str, datetime] | NotGiven = NOT_GIVEN,
@@ -610,7 +611,7 @@ async def list(
610611
extra_query: Query | None = None,
611612
extra_body: Body | None = None,
612613
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
613-
) -> PredictionListResponse:
614+
) -> AsyncPaginator[PredictionResponse, AsyncCursorURLPage[PredictionResponse]]:
614615
"""
615616
Get a paginated list of all predictions created by the user or organization
616617
associated with the provided API token.
@@ -695,22 +696,23 @@ async def list(
695696
696697
timeout: Override the client-level default timeout for this request, in seconds
697698
"""
698-
return await self._get(
699+
return self._get_api_list(
699700
"/predictions",
701+
page=AsyncCursorURLPage[PredictionResponse],
700702
options=make_request_options(
701703
extra_headers=extra_headers,
702704
extra_query=extra_query,
703705
extra_body=extra_body,
704706
timeout=timeout,
705-
query=await async_maybe_transform(
707+
query=maybe_transform(
706708
{
707709
"created_after": created_after,
708710
"created_before": created_before,
709711
},
710712
prediction_list_params.PredictionListParams,
711713
),
712714
),
713-
cast_to=PredictionListResponse,
715+
model=PredictionResponse,
714716
)
715717

716718
async def cancel(

src/replicate/types/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from .deployment_list_response import DeploymentListResponse as DeploymentListResponse
1212
from .deployment_update_params import DeploymentUpdateParams as DeploymentUpdateParams
1313
from .prediction_create_params import PredictionCreateParams as PredictionCreateParams
14-
from .prediction_list_response import PredictionListResponse as PredictionListResponse
1514
from .deployment_create_response import DeploymentCreateResponse as DeploymentCreateResponse
1615
from .deployment_update_response import DeploymentUpdateResponse as DeploymentUpdateResponse
1716
from .deployment_retrieve_response import DeploymentRetrieveResponse as DeploymentRetrieveResponse

src/replicate/types/prediction_list_response.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

tests/api_resources/test_predictions.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010
from replicate import ReplicateClient, AsyncReplicateClient
1111
from tests.utils import assert_matches_type
12-
from replicate.types import PredictionResponse, PredictionListResponse
12+
from replicate.types import PredictionResponse
1313
from replicate._utils import parse_datetime
14+
from replicate.pagination import SyncCursorURLPage, AsyncCursorURLPage
1415

1516
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
1617

@@ -72,7 +73,7 @@ def test_streaming_response_create(self, client: ReplicateClient) -> None:
7273
@parametrize
7374
def test_method_list(self, client: ReplicateClient) -> None:
7475
prediction = client.predictions.list()
75-
assert_matches_type(PredictionListResponse, prediction, path=["response"])
76+
assert_matches_type(SyncCursorURLPage[PredictionResponse], prediction, path=["response"])
7677

7778
@pytest.mark.skip()
7879
@parametrize
@@ -81,7 +82,7 @@ def test_method_list_with_all_params(self, client: ReplicateClient) -> None:
8182
created_after=parse_datetime("2025-01-01T00:00:00Z"),
8283
created_before=parse_datetime("2025-02-01T00:00:00Z"),
8384
)
84-
assert_matches_type(PredictionListResponse, prediction, path=["response"])
85+
assert_matches_type(SyncCursorURLPage[PredictionResponse], prediction, path=["response"])
8586

8687
@pytest.mark.skip()
8788
@parametrize
@@ -91,7 +92,7 @@ def test_raw_response_list(self, client: ReplicateClient) -> None:
9192
assert response.is_closed is True
9293
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
9394
prediction = response.parse()
94-
assert_matches_type(PredictionListResponse, prediction, path=["response"])
95+
assert_matches_type(SyncCursorURLPage[PredictionResponse], prediction, path=["response"])
9596

9697
@pytest.mark.skip()
9798
@parametrize
@@ -101,7 +102,7 @@ def test_streaming_response_list(self, client: ReplicateClient) -> None:
101102
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
102103

103104
prediction = response.parse()
104-
assert_matches_type(PredictionListResponse, prediction, path=["response"])
105+
assert_matches_type(SyncCursorURLPage[PredictionResponse], prediction, path=["response"])
105106

106107
assert cast(Any, response.is_closed) is True
107108

@@ -247,7 +248,7 @@ async def test_streaming_response_create(self, async_client: AsyncReplicateClien
247248
@parametrize
248249
async def test_method_list(self, async_client: AsyncReplicateClient) -> None:
249250
prediction = await async_client.predictions.list()
250-
assert_matches_type(PredictionListResponse, prediction, path=["response"])
251+
assert_matches_type(AsyncCursorURLPage[PredictionResponse], prediction, path=["response"])
251252

252253
@pytest.mark.skip()
253254
@parametrize
@@ -256,7 +257,7 @@ async def test_method_list_with_all_params(self, async_client: AsyncReplicateCli
256257
created_after=parse_datetime("2025-01-01T00:00:00Z"),
257258
created_before=parse_datetime("2025-02-01T00:00:00Z"),
258259
)
259-
assert_matches_type(PredictionListResponse, prediction, path=["response"])
260+
assert_matches_type(AsyncCursorURLPage[PredictionResponse], prediction, path=["response"])
260261

261262
@pytest.mark.skip()
262263
@parametrize
@@ -266,7 +267,7 @@ async def test_raw_response_list(self, async_client: AsyncReplicateClient) -> No
266267
assert response.is_closed is True
267268
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
268269
prediction = await response.parse()
269-
assert_matches_type(PredictionListResponse, prediction, path=["response"])
270+
assert_matches_type(AsyncCursorURLPage[PredictionResponse], prediction, path=["response"])
270271

271272
@pytest.mark.skip()
272273
@parametrize
@@ -276,7 +277,7 @@ async def test_streaming_response_list(self, async_client: AsyncReplicateClient)
276277
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
277278

278279
prediction = await response.parse()
279-
assert_matches_type(PredictionListResponse, prediction, path=["response"])
280+
assert_matches_type(AsyncCursorURLPage[PredictionResponse], prediction, path=["response"])
280281

281282
assert cast(Any, response.is_closed) is True
282283

0 commit comments

Comments
 (0)