From 85810f5f4c0caf680a90fca80f1bfcd639e76894 Mon Sep 17 00:00:00 2001
From: "stainless-app[bot]"
<142633134+stainless-app[bot]@users.noreply.github.com>
Date: Wed, 23 Apr 2025 18:52:13 +0000
Subject: [PATCH 1/3] feat: enable `openapi.code_samples`
To include client snippets in the generated JSON docs.
---
.stats.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.stats.yml b/.stats.yml
index ca7f462..ddc6b98 100644
--- a/.stats.yml
+++ b/.stats.yml
@@ -1,4 +1,4 @@
configured_endpoints: 29
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-2788217b7ad7d61d1a77800bc5ff12a6810f1692d4d770b72fa8f898c6a055ab.yml
openapi_spec_hash: 4423bf747e228484547b441468a9f156
-config_hash: 976e20887b4e455f639ee6917de350b8
+config_hash: 936416d05e0c3eb2f1b5ecdaf54ca473
From 8bbddc7a788f4488311b8ed408d4b020db8e006b Mon Sep 17 00:00:00 2001
From: "stainless-app[bot]"
<142633134+stainless-app[bot]@users.noreply.github.com>
Date: Thu, 24 Apr 2025 00:14:24 +0000
Subject: [PATCH 2/3] feat: add missing resources
This adds operation IDs that were missing or in the wrong place, so they match up with the Operation IDs of the canonical Replicate OpenAPI schema.
---
.stats.yml | 2 +-
api.md | 31 +-
src/replicate/resources/collections.py | 116 ++++++
.../resources/deployments/deployments.py | 100 -----
src/replicate/resources/hardware.py | 118 +-----
src/replicate/resources/models/__init__.py | 14 +
src/replicate/resources/models/models.py | 330 ++-------------
src/replicate/resources/models/predictions.py | 382 ++++++++++++++++++
src/replicate/resources/models/versions.py | 318 ---------------
src/replicate/resources/trainings.py | 318 +++++++++++++++
.../resources/webhooks/default/__init__.py | 33 ++
.../resources/webhooks/default/default.py | 102 +++++
.../{default.py => default/secret.py} | 86 ++--
src/replicate/resources/webhooks/webhooks.py | 6 +-
src/replicate/types/__init__.py | 3 +-
src/replicate/types/models/__init__.py | 3 +-
.../prediction_create_params.py} | 6 +-
...ng_params.py => training_create_params.py} | 4 +-
...esponse.py => training_create_response.py} | 6 +-
src/replicate/types/webhooks/__init__.py | 2 -
.../types/webhooks/default/__init__.py | 5 +
.../secret_get_response.py} | 6 +-
.../api_resources/models/test_predictions.py | 164 ++++++++
tests/api_resources/models/test_versions.py | 182 ---------
tests/api_resources/test_collections.py | 84 ++++
tests/api_resources/test_deployments.py | 56 ---
tests/api_resources/test_hardware.py | 84 ----
tests/api_resources/test_models.py | 144 +------
tests/api_resources/test_trainings.py | 187 ++++++++-
.../webhooks/default/__init__.py | 1 +
.../webhooks/default/test_secret.py | 78 ++++
tests/api_resources/webhooks/test_default.py | 78 ----
32 files changed, 1598 insertions(+), 1451 deletions(-)
create mode 100644 src/replicate/resources/models/predictions.py
create mode 100644 src/replicate/resources/webhooks/default/__init__.py
create mode 100644 src/replicate/resources/webhooks/default/default.py
rename src/replicate/resources/webhooks/{default.py => default/secret.py} (64%)
rename src/replicate/types/{model_create_prediction_params.py => models/prediction_create_params.py} (96%)
rename src/replicate/types/{models/version_create_training_params.py => training_create_params.py} (96%)
rename src/replicate/types/{models/version_create_training_response.py => training_create_response.py} (92%)
create mode 100644 src/replicate/types/webhooks/default/__init__.py
rename src/replicate/types/webhooks/{default_retrieve_secret_response.py => default/secret_get_response.py} (58%)
create mode 100644 tests/api_resources/models/test_predictions.py
create mode 100644 tests/api_resources/webhooks/default/__init__.py
create mode 100644 tests/api_resources/webhooks/default/test_secret.py
delete mode 100644 tests/api_resources/webhooks/test_default.py
diff --git a/.stats.yml b/.stats.yml
index ddc6b98..61aa00b 100644
--- a/.stats.yml
+++ b/.stats.yml
@@ -1,4 +1,4 @@
configured_endpoints: 29
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-2788217b7ad7d61d1a77800bc5ff12a6810f1692d4d770b72fa8f898c6a055ab.yml
openapi_spec_hash: 4423bf747e228484547b441468a9f156
-config_hash: 936416d05e0c3eb2f1b5ecdaf54ca473
+config_hash: d820945093fc56fea6d062c90745d7a5
diff --git a/api.md b/api.md
index 65828d6..df97142 100644
--- a/api.md
+++ b/api.md
@@ -3,6 +3,7 @@
Methods:
- client.collections.list() -> None
+- client.collections.get(collection_slug) -> None
# Deployments
@@ -24,7 +25,6 @@ Methods:
- client.deployments.list() -> SyncCursorURLPage[DeploymentListResponse]
- client.deployments.delete(deployment_name, \*, deployment_owner) -> None
- client.deployments.get(deployment_name, \*, deployment_owner) -> DeploymentGetResponse
-- client.deployments.list_em_all() -> None
## Predictions
@@ -43,7 +43,6 @@ from replicate.types import HardwareListResponse
Methods:
- client.hardware.list() -> HardwareListResponse
-- client.hardware.retrieve_collections(collection_slug) -> None
# Accounts
@@ -70,7 +69,6 @@ Methods:
- client.models.create(\*\*params) -> None
- client.models.list() -> SyncCursorURLPage[ModelListResponse]
- client.models.delete(model_name, \*, model_owner) -> None
-- client.models.create_prediction(model_name, \*, model_owner, \*\*params) -> Prediction
- client.models.get(model_name, \*, model_owner) -> None
## Examples
@@ -79,6 +77,12 @@ Methods:
- client.models.examples.list(model_name, \*, model_owner) -> None
+## Predictions
+
+Methods:
+
+- client.models.predictions.create(model_name, \*, model_owner, \*\*params) -> Prediction
+
## Readme
Types:
@@ -93,17 +97,10 @@ Methods:
## Versions
-Types:
-
-```python
-from replicate.types.models import VersionCreateTrainingResponse
-```
-
Methods:
- client.models.versions.list(model_name, \*, model_owner) -> None
- client.models.versions.delete(version_id, \*, model_owner, model_name) -> None
-- client.models.versions.create_training(version_id, \*, model_owner, model_name, \*\*params) -> VersionCreateTrainingResponse
- client.models.versions.get(version_id, \*, model_owner, model_name) -> None
# Predictions
@@ -126,11 +123,17 @@ Methods:
Types:
```python
-from replicate.types import TrainingListResponse, TrainingCancelResponse, TrainingGetResponse
+from replicate.types import (
+ TrainingCreateResponse,
+ TrainingListResponse,
+ TrainingCancelResponse,
+ TrainingGetResponse,
+)
```
Methods:
+- client.trainings.create(version_id, \*, model_owner, model_name, \*\*params) -> TrainingCreateResponse
- client.trainings.list() -> SyncCursorURLPage[TrainingListResponse]
- client.trainings.cancel(training_id) -> TrainingCancelResponse
- client.trainings.get(training_id) -> TrainingGetResponse
@@ -139,12 +142,14 @@ Methods:
## Default
+### Secret
+
Types:
```python
-from replicate.types.webhooks import DefaultRetrieveSecretResponse
+from replicate.types.webhooks.default import SecretGetResponse
```
Methods:
-- client.webhooks.default.retrieve_secret() -> DefaultRetrieveSecretResponse
+- client.webhooks.default.secret.get() -> SecretGetResponse
diff --git a/src/replicate/resources/collections.py b/src/replicate/resources/collections.py
index 8768630..adf0d9a 100644
--- a/src/replicate/resources/collections.py
+++ b/src/replicate/resources/collections.py
@@ -82,6 +82,58 @@ def list(
cast_to=NoneType,
)
+ def get(
+ self,
+ collection_slug: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> None:
+ """
+ Example cURL request:
+
+ ```console
+ curl -s \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ https://api.replicate.com/v1/collections/super-resolution
+ ```
+
+ The response will be a collection object with a nested list of the models in
+ that collection:
+
+ ```json
+ {
+ "name": "Super resolution",
+ "slug": "super-resolution",
+ "description": "Upscaling models that create high-quality images from low-quality images.",
+ "models": [...]
+ }
+ ```
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not collection_slug:
+ raise ValueError(f"Expected a non-empty value for `collection_slug` but received {collection_slug!r}")
+ extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ return self._get(
+ f"/collections/{collection_slug}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=NoneType,
+ )
+
class AsyncCollectionsResource(AsyncAPIResource):
@cached_property
@@ -147,6 +199,58 @@ async def list(
cast_to=NoneType,
)
+ async def get(
+ self,
+ collection_slug: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> None:
+ """
+ Example cURL request:
+
+ ```console
+ curl -s \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ https://api.replicate.com/v1/collections/super-resolution
+ ```
+
+ The response will be a collection object with a nested list of the models in
+ that collection:
+
+ ```json
+ {
+ "name": "Super resolution",
+ "slug": "super-resolution",
+ "description": "Upscaling models that create high-quality images from low-quality images.",
+ "models": [...]
+ }
+ ```
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not collection_slug:
+ raise ValueError(f"Expected a non-empty value for `collection_slug` but received {collection_slug!r}")
+ extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ return await self._get(
+ f"/collections/{collection_slug}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=NoneType,
+ )
+
class CollectionsResourceWithRawResponse:
def __init__(self, collections: CollectionsResource) -> None:
@@ -155,6 +259,9 @@ def __init__(self, collections: CollectionsResource) -> None:
self.list = to_raw_response_wrapper(
collections.list,
)
+ self.get = to_raw_response_wrapper(
+ collections.get,
+ )
class AsyncCollectionsResourceWithRawResponse:
@@ -164,6 +271,9 @@ def __init__(self, collections: AsyncCollectionsResource) -> None:
self.list = async_to_raw_response_wrapper(
collections.list,
)
+ self.get = async_to_raw_response_wrapper(
+ collections.get,
+ )
class CollectionsResourceWithStreamingResponse:
@@ -173,6 +283,9 @@ def __init__(self, collections: CollectionsResource) -> None:
self.list = to_streamed_response_wrapper(
collections.list,
)
+ self.get = to_streamed_response_wrapper(
+ collections.get,
+ )
class AsyncCollectionsResourceWithStreamingResponse:
@@ -182,3 +295,6 @@ def __init__(self, collections: AsyncCollectionsResource) -> None:
self.list = async_to_streamed_response_wrapper(
collections.list,
)
+ self.get = async_to_streamed_response_wrapper(
+ collections.get,
+ )
diff --git a/src/replicate/resources/deployments/deployments.py b/src/replicate/resources/deployments/deployments.py
index 565a1c6..4ac53ff 100644
--- a/src/replicate/resources/deployments/deployments.py
+++ b/src/replicate/resources/deployments/deployments.py
@@ -451,50 +451,6 @@ def get(
cast_to=DeploymentGetResponse,
)
- def list_em_all(
- self,
- *,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
- """
- Example cURL request:
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/collections
- ```
-
- The response will be a paginated JSON list of collection objects:
-
- ```json
- {
- "next": "null",
- "previous": null,
- "results": [
- {
- "name": "Super resolution",
- "slug": "super-resolution",
- "description": "Upscaling models that create high-quality images from low-quality images."
- }
- ]
- }
- ```
- """
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
- return self._get(
- "/collections",
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=NoneType,
- )
-
class AsyncDeploymentsResource(AsyncAPIResource):
@cached_property
@@ -914,50 +870,6 @@ async def get(
cast_to=DeploymentGetResponse,
)
- async def list_em_all(
- self,
- *,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
- """
- Example cURL request:
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/collections
- ```
-
- The response will be a paginated JSON list of collection objects:
-
- ```json
- {
- "next": "null",
- "previous": null,
- "results": [
- {
- "name": "Super resolution",
- "slug": "super-resolution",
- "description": "Upscaling models that create high-quality images from low-quality images."
- }
- ]
- }
- ```
- """
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
- return await self._get(
- "/collections",
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=NoneType,
- )
-
class DeploymentsResourceWithRawResponse:
def __init__(self, deployments: DeploymentsResource) -> None:
@@ -978,9 +890,6 @@ def __init__(self, deployments: DeploymentsResource) -> None:
self.get = to_raw_response_wrapper(
deployments.get,
)
- self.list_em_all = to_raw_response_wrapper(
- deployments.list_em_all,
- )
@cached_property
def predictions(self) -> PredictionsResourceWithRawResponse:
@@ -1006,9 +915,6 @@ def __init__(self, deployments: AsyncDeploymentsResource) -> None:
self.get = async_to_raw_response_wrapper(
deployments.get,
)
- self.list_em_all = async_to_raw_response_wrapper(
- deployments.list_em_all,
- )
@cached_property
def predictions(self) -> AsyncPredictionsResourceWithRawResponse:
@@ -1034,9 +940,6 @@ def __init__(self, deployments: DeploymentsResource) -> None:
self.get = to_streamed_response_wrapper(
deployments.get,
)
- self.list_em_all = to_streamed_response_wrapper(
- deployments.list_em_all,
- )
@cached_property
def predictions(self) -> PredictionsResourceWithStreamingResponse:
@@ -1062,9 +965,6 @@ def __init__(self, deployments: AsyncDeploymentsResource) -> None:
self.get = async_to_streamed_response_wrapper(
deployments.get,
)
- self.list_em_all = async_to_streamed_response_wrapper(
- deployments.list_em_all,
- )
@cached_property
def predictions(self) -> AsyncPredictionsResourceWithStreamingResponse:
diff --git a/src/replicate/resources/hardware.py b/src/replicate/resources/hardware.py
index 4bbd58c..3cea6e4 100644
--- a/src/replicate/resources/hardware.py
+++ b/src/replicate/resources/hardware.py
@@ -4,7 +4,7 @@
import httpx
-from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
+from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from .._compat import cached_property
from .._resource import SyncAPIResource, AsyncAPIResource
from .._response import (
@@ -77,58 +77,6 @@ def list(
cast_to=HardwareListResponse,
)
- def retrieve_collections(
- self,
- collection_slug: str,
- *,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
- """
- Example cURL request:
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/collections/super-resolution
- ```
-
- The response will be a collection object with a nested list of the models in
- that collection:
-
- ```json
- {
- "name": "Super resolution",
- "slug": "super-resolution",
- "description": "Upscaling models that create high-quality images from low-quality images.",
- "models": [...]
- }
- ```
-
- Args:
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not collection_slug:
- raise ValueError(f"Expected a non-empty value for `collection_slug` but received {collection_slug!r}")
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
- return self._get(
- f"/collections/{collection_slug}",
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=NoneType,
- )
-
class AsyncHardwareResource(AsyncAPIResource):
@cached_property
@@ -188,58 +136,6 @@ async def list(
cast_to=HardwareListResponse,
)
- async def retrieve_collections(
- self,
- collection_slug: str,
- *,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
- """
- Example cURL request:
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/collections/super-resolution
- ```
-
- The response will be a collection object with a nested list of the models in
- that collection:
-
- ```json
- {
- "name": "Super resolution",
- "slug": "super-resolution",
- "description": "Upscaling models that create high-quality images from low-quality images.",
- "models": [...]
- }
- ```
-
- Args:
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not collection_slug:
- raise ValueError(f"Expected a non-empty value for `collection_slug` but received {collection_slug!r}")
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
- return await self._get(
- f"/collections/{collection_slug}",
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=NoneType,
- )
-
class HardwareResourceWithRawResponse:
def __init__(self, hardware: HardwareResource) -> None:
@@ -248,9 +144,6 @@ def __init__(self, hardware: HardwareResource) -> None:
self.list = to_raw_response_wrapper(
hardware.list,
)
- self.retrieve_collections = to_raw_response_wrapper(
- hardware.retrieve_collections,
- )
class AsyncHardwareResourceWithRawResponse:
@@ -260,9 +153,6 @@ def __init__(self, hardware: AsyncHardwareResource) -> None:
self.list = async_to_raw_response_wrapper(
hardware.list,
)
- self.retrieve_collections = async_to_raw_response_wrapper(
- hardware.retrieve_collections,
- )
class HardwareResourceWithStreamingResponse:
@@ -272,9 +162,6 @@ def __init__(self, hardware: HardwareResource) -> None:
self.list = to_streamed_response_wrapper(
hardware.list,
)
- self.retrieve_collections = to_streamed_response_wrapper(
- hardware.retrieve_collections,
- )
class AsyncHardwareResourceWithStreamingResponse:
@@ -284,6 +171,3 @@ def __init__(self, hardware: AsyncHardwareResource) -> None:
self.list = async_to_streamed_response_wrapper(
hardware.list,
)
- self.retrieve_collections = async_to_streamed_response_wrapper(
- hardware.retrieve_collections,
- )
diff --git a/src/replicate/resources/models/__init__.py b/src/replicate/resources/models/__init__.py
index 8afd09c..6241ca2 100644
--- a/src/replicate/resources/models/__init__.py
+++ b/src/replicate/resources/models/__init__.py
@@ -32,6 +32,14 @@
VersionsResourceWithStreamingResponse,
AsyncVersionsResourceWithStreamingResponse,
)
+from .predictions import (
+ PredictionsResource,
+ AsyncPredictionsResource,
+ PredictionsResourceWithRawResponse,
+ AsyncPredictionsResourceWithRawResponse,
+ PredictionsResourceWithStreamingResponse,
+ AsyncPredictionsResourceWithStreamingResponse,
+)
__all__ = [
"ExamplesResource",
@@ -40,6 +48,12 @@
"AsyncExamplesResourceWithRawResponse",
"ExamplesResourceWithStreamingResponse",
"AsyncExamplesResourceWithStreamingResponse",
+ "PredictionsResource",
+ "AsyncPredictionsResource",
+ "PredictionsResourceWithRawResponse",
+ "AsyncPredictionsResourceWithRawResponse",
+ "PredictionsResourceWithStreamingResponse",
+ "AsyncPredictionsResourceWithStreamingResponse",
"ReadmeResource",
"AsyncReadmeResource",
"ReadmeResourceWithRawResponse",
diff --git a/src/replicate/resources/models/models.py b/src/replicate/resources/models/models.py
index df1bf42..e309f37 100644
--- a/src/replicate/resources/models/models.py
+++ b/src/replicate/resources/models/models.py
@@ -2,7 +2,6 @@
from __future__ import annotations
-from typing import List
from typing_extensions import Literal
import httpx
@@ -15,9 +14,9 @@
ReadmeResourceWithStreamingResponse,
AsyncReadmeResourceWithStreamingResponse,
)
-from ...types import model_create_params, model_create_prediction_params
+from ...types import model_create_params
from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
-from ..._utils import maybe_transform, strip_not_given, async_maybe_transform
+from ..._utils import maybe_transform, async_maybe_transform
from .examples import (
ExamplesResource,
AsyncExamplesResource,
@@ -42,9 +41,16 @@
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
)
+from .predictions import (
+ PredictionsResource,
+ AsyncPredictionsResource,
+ PredictionsResourceWithRawResponse,
+ AsyncPredictionsResourceWithRawResponse,
+ PredictionsResourceWithStreamingResponse,
+ AsyncPredictionsResourceWithStreamingResponse,
+)
from ...pagination import SyncCursorURLPage, AsyncCursorURLPage
from ..._base_client import AsyncPaginator, make_request_options
-from ...types.prediction import Prediction
from ...types.model_list_response import ModelListResponse
__all__ = ["ModelsResource", "AsyncModelsResource"]
@@ -55,6 +61,10 @@ class ModelsResource(SyncAPIResource):
def examples(self) -> ExamplesResource:
return ExamplesResource(self._client)
+ @cached_property
+ def predictions(self) -> PredictionsResource:
+ return PredictionsResource(self._client)
+
@cached_property
def readme(self) -> ReadmeResource:
return ReadmeResource(self._client)
@@ -285,146 +295,6 @@ def delete(
cast_to=NoneType,
)
- def create_prediction(
- self,
- model_name: str,
- *,
- model_owner: str,
- input: object,
- stream: bool | NotGiven = NOT_GIVEN,
- webhook: str | NotGiven = NOT_GIVEN,
- webhook_events_filter: List[Literal["start", "output", "logs", "completed"]] | NotGiven = NOT_GIVEN,
- prefer: str | NotGiven = NOT_GIVEN,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> Prediction:
- """
- Create a prediction using an
- [official model](https://replicate.com/changelog/2025-01-29-official-models).
-
- If you're _not_ running an official model, use the
- [`predictions.create`](#predictions.create) operation instead.
-
- Example cURL request:
-
- ```console
- curl -s -X POST -H 'Prefer: wait' \\
- -d '{"input": {"prompt": "Write a short poem about the weather."}}' \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- -H 'Content-Type: application/json' \\
- https://api.replicate.com/v1/models/meta/meta-llama-3-70b-instruct/predictions
- ```
-
- The request will wait up to 60 seconds for the model to run. If this time is
- exceeded the prediction will be returned in a `"starting"` state and need to be
- retrieved using the `predictions.get` endpiont.
-
- For a complete overview of the `deployments.predictions.create` API check out
- our documentation on
- [creating a prediction](https://replicate.com/docs/topics/predictions/create-a-prediction)
- which covers a variety of use cases.
-
- Args:
- input: The model's input as a JSON object. The input schema depends on what model you
- are running. To see the available inputs, click the "API" tab on the model you
- are running or [get the model version](#models.versions.get) and look at its
- `openapi_schema` property. For example,
- [stability-ai/sdxl](https://replicate.com/stability-ai/sdxl) takes `prompt` as
- an input.
-
- Files should be passed as HTTP URLs or data URLs.
-
- Use an HTTP URL when:
-
- - you have a large file > 256kb
- - you want to be able to use the file multiple times
- - you want your prediction metadata to be associable with your input files
-
- Use a data URL when:
-
- - you have a small file <= 256kb
- - you don't want to upload and host the file somewhere
- - you don't need to use the file again (Replicate will not store it)
-
- stream: **This field is deprecated.**
-
- Request a URL to receive streaming output using
- [server-sent events (SSE)](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events).
-
- This field is no longer needed as the returned prediction will always have a
- `stream` entry in its `url` property if the model supports streaming.
-
- webhook: An HTTPS URL for receiving a webhook when the prediction has new output. The
- webhook will be a POST request where the request body is the same as the
- response body of the [get prediction](#predictions.get) operation. If there are
- network problems, we will retry the webhook a few times, so make sure it can be
- safely called more than once. Replicate will not follow redirects when sending
- webhook requests to your service, so be sure to specify a URL that will resolve
- without redirecting.
-
- webhook_events_filter: By default, we will send requests to your webhook URL whenever there are new
- outputs or the prediction has finished. You can change which events trigger
- webhook requests by specifying `webhook_events_filter` in the prediction
- request:
-
- - `start`: immediately on prediction start
- - `output`: each time a prediction generates an output (note that predictions
- can generate multiple outputs)
- - `logs`: each time log output is generated by a prediction
- - `completed`: when the prediction reaches a terminal state
- (succeeded/canceled/failed)
-
- For example, if you only wanted requests to be sent at the start and end of the
- prediction, you would provide:
-
- ```json
- {
- "input": {
- "text": "Alice"
- },
- "webhook": "https://example.com/my-webhook",
- "webhook_events_filter": ["start", "completed"]
- }
- ```
-
- Requests for event types `output` and `logs` will be sent at most once every
- 500ms. If you request `start` and `completed` webhooks, then they'll always be
- sent regardless of throttling.
-
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not model_owner:
- raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
- if not model_name:
- raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
- extra_headers = {**strip_not_given({"Prefer": prefer}), **(extra_headers or {})}
- return self._post(
- f"/models/{model_owner}/{model_name}/predictions",
- body=maybe_transform(
- {
- "input": input,
- "stream": stream,
- "webhook": webhook,
- "webhook_events_filter": webhook_events_filter,
- },
- model_create_prediction_params.ModelCreatePredictionParams,
- ),
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=Prediction,
- )
-
def get(
self,
model_name: str,
@@ -540,6 +410,10 @@ class AsyncModelsResource(AsyncAPIResource):
def examples(self) -> AsyncExamplesResource:
return AsyncExamplesResource(self._client)
+ @cached_property
+ def predictions(self) -> AsyncPredictionsResource:
+ return AsyncPredictionsResource(self._client)
+
@cached_property
def readme(self) -> AsyncReadmeResource:
return AsyncReadmeResource(self._client)
@@ -770,146 +644,6 @@ async def delete(
cast_to=NoneType,
)
- async def create_prediction(
- self,
- model_name: str,
- *,
- model_owner: str,
- input: object,
- stream: bool | NotGiven = NOT_GIVEN,
- webhook: str | NotGiven = NOT_GIVEN,
- webhook_events_filter: List[Literal["start", "output", "logs", "completed"]] | NotGiven = NOT_GIVEN,
- prefer: str | NotGiven = NOT_GIVEN,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> Prediction:
- """
- Create a prediction using an
- [official model](https://replicate.com/changelog/2025-01-29-official-models).
-
- If you're _not_ running an official model, use the
- [`predictions.create`](#predictions.create) operation instead.
-
- Example cURL request:
-
- ```console
- curl -s -X POST -H 'Prefer: wait' \\
- -d '{"input": {"prompt": "Write a short poem about the weather."}}' \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- -H 'Content-Type: application/json' \\
- https://api.replicate.com/v1/models/meta/meta-llama-3-70b-instruct/predictions
- ```
-
- The request will wait up to 60 seconds for the model to run. If this time is
- exceeded the prediction will be returned in a `"starting"` state and need to be
- retrieved using the `predictions.get` endpiont.
-
- For a complete overview of the `deployments.predictions.create` API check out
- our documentation on
- [creating a prediction](https://replicate.com/docs/topics/predictions/create-a-prediction)
- which covers a variety of use cases.
-
- Args:
- input: The model's input as a JSON object. The input schema depends on what model you
- are running. To see the available inputs, click the "API" tab on the model you
- are running or [get the model version](#models.versions.get) and look at its
- `openapi_schema` property. For example,
- [stability-ai/sdxl](https://replicate.com/stability-ai/sdxl) takes `prompt` as
- an input.
-
- Files should be passed as HTTP URLs or data URLs.
-
- Use an HTTP URL when:
-
- - you have a large file > 256kb
- - you want to be able to use the file multiple times
- - you want your prediction metadata to be associable with your input files
-
- Use a data URL when:
-
- - you have a small file <= 256kb
- - you don't want to upload and host the file somewhere
- - you don't need to use the file again (Replicate will not store it)
-
- stream: **This field is deprecated.**
-
- Request a URL to receive streaming output using
- [server-sent events (SSE)](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events).
-
- This field is no longer needed as the returned prediction will always have a
- `stream` entry in its `url` property if the model supports streaming.
-
- webhook: An HTTPS URL for receiving a webhook when the prediction has new output. The
- webhook will be a POST request where the request body is the same as the
- response body of the [get prediction](#predictions.get) operation. If there are
- network problems, we will retry the webhook a few times, so make sure it can be
- safely called more than once. Replicate will not follow redirects when sending
- webhook requests to your service, so be sure to specify a URL that will resolve
- without redirecting.
-
- webhook_events_filter: By default, we will send requests to your webhook URL whenever there are new
- outputs or the prediction has finished. You can change which events trigger
- webhook requests by specifying `webhook_events_filter` in the prediction
- request:
-
- - `start`: immediately on prediction start
- - `output`: each time a prediction generates an output (note that predictions
- can generate multiple outputs)
- - `logs`: each time log output is generated by a prediction
- - `completed`: when the prediction reaches a terminal state
- (succeeded/canceled/failed)
-
- For example, if you only wanted requests to be sent at the start and end of the
- prediction, you would provide:
-
- ```json
- {
- "input": {
- "text": "Alice"
- },
- "webhook": "https://example.com/my-webhook",
- "webhook_events_filter": ["start", "completed"]
- }
- ```
-
- Requests for event types `output` and `logs` will be sent at most once every
- 500ms. If you request `start` and `completed` webhooks, then they'll always be
- sent regardless of throttling.
-
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not model_owner:
- raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
- if not model_name:
- raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
- extra_headers = {**strip_not_given({"Prefer": prefer}), **(extra_headers or {})}
- return await self._post(
- f"/models/{model_owner}/{model_name}/predictions",
- body=await async_maybe_transform(
- {
- "input": input,
- "stream": stream,
- "webhook": webhook,
- "webhook_events_filter": webhook_events_filter,
- },
- model_create_prediction_params.ModelCreatePredictionParams,
- ),
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=Prediction,
- )
-
async def get(
self,
model_name: str,
@@ -1033,9 +767,6 @@ def __init__(self, models: ModelsResource) -> None:
self.delete = to_raw_response_wrapper(
models.delete,
)
- self.create_prediction = to_raw_response_wrapper(
- models.create_prediction,
- )
self.get = to_raw_response_wrapper(
models.get,
)
@@ -1044,6 +775,10 @@ def __init__(self, models: ModelsResource) -> None:
def examples(self) -> ExamplesResourceWithRawResponse:
return ExamplesResourceWithRawResponse(self._models.examples)
+ @cached_property
+ def predictions(self) -> PredictionsResourceWithRawResponse:
+ return PredictionsResourceWithRawResponse(self._models.predictions)
+
@cached_property
def readme(self) -> ReadmeResourceWithRawResponse:
return ReadmeResourceWithRawResponse(self._models.readme)
@@ -1066,9 +801,6 @@ def __init__(self, models: AsyncModelsResource) -> None:
self.delete = async_to_raw_response_wrapper(
models.delete,
)
- self.create_prediction = async_to_raw_response_wrapper(
- models.create_prediction,
- )
self.get = async_to_raw_response_wrapper(
models.get,
)
@@ -1077,6 +809,10 @@ def __init__(self, models: AsyncModelsResource) -> None:
def examples(self) -> AsyncExamplesResourceWithRawResponse:
return AsyncExamplesResourceWithRawResponse(self._models.examples)
+ @cached_property
+ def predictions(self) -> AsyncPredictionsResourceWithRawResponse:
+ return AsyncPredictionsResourceWithRawResponse(self._models.predictions)
+
@cached_property
def readme(self) -> AsyncReadmeResourceWithRawResponse:
return AsyncReadmeResourceWithRawResponse(self._models.readme)
@@ -1099,9 +835,6 @@ def __init__(self, models: ModelsResource) -> None:
self.delete = to_streamed_response_wrapper(
models.delete,
)
- self.create_prediction = to_streamed_response_wrapper(
- models.create_prediction,
- )
self.get = to_streamed_response_wrapper(
models.get,
)
@@ -1110,6 +843,10 @@ def __init__(self, models: ModelsResource) -> None:
def examples(self) -> ExamplesResourceWithStreamingResponse:
return ExamplesResourceWithStreamingResponse(self._models.examples)
+ @cached_property
+ def predictions(self) -> PredictionsResourceWithStreamingResponse:
+ return PredictionsResourceWithStreamingResponse(self._models.predictions)
+
@cached_property
def readme(self) -> ReadmeResourceWithStreamingResponse:
return ReadmeResourceWithStreamingResponse(self._models.readme)
@@ -1132,9 +869,6 @@ def __init__(self, models: AsyncModelsResource) -> None:
self.delete = async_to_streamed_response_wrapper(
models.delete,
)
- self.create_prediction = async_to_streamed_response_wrapper(
- models.create_prediction,
- )
self.get = async_to_streamed_response_wrapper(
models.get,
)
@@ -1143,6 +877,10 @@ def __init__(self, models: AsyncModelsResource) -> None:
def examples(self) -> AsyncExamplesResourceWithStreamingResponse:
return AsyncExamplesResourceWithStreamingResponse(self._models.examples)
+ @cached_property
+ def predictions(self) -> AsyncPredictionsResourceWithStreamingResponse:
+ return AsyncPredictionsResourceWithStreamingResponse(self._models.predictions)
+
@cached_property
def readme(self) -> AsyncReadmeResourceWithStreamingResponse:
return AsyncReadmeResourceWithStreamingResponse(self._models.readme)
diff --git a/src/replicate/resources/models/predictions.py b/src/replicate/resources/models/predictions.py
new file mode 100644
index 0000000..03d8e5b
--- /dev/null
+++ b/src/replicate/resources/models/predictions.py
@@ -0,0 +1,382 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+from typing import List
+from typing_extensions import Literal
+
+import httpx
+
+from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
+from ..._utils import maybe_transform, strip_not_given, async_maybe_transform
+from ..._compat import cached_property
+from ..._resource import SyncAPIResource, AsyncAPIResource
+from ..._response import (
+ to_raw_response_wrapper,
+ to_streamed_response_wrapper,
+ async_to_raw_response_wrapper,
+ async_to_streamed_response_wrapper,
+)
+from ..._base_client import make_request_options
+from ...types.models import prediction_create_params
+from ...types.prediction import Prediction
+
+__all__ = ["PredictionsResource", "AsyncPredictionsResource"]
+
+
+class PredictionsResource(SyncAPIResource):
+ @cached_property
+ def with_raw_response(self) -> PredictionsResourceWithRawResponse:
+ """
+ This property can be used as a prefix for any HTTP method call to return
+ the raw response object instead of the parsed content.
+
+ For more information, see https://www.github.com/replicate/replicate-python-stainless#accessing-raw-response-data-eg-headers
+ """
+ return PredictionsResourceWithRawResponse(self)
+
+ @cached_property
+ def with_streaming_response(self) -> PredictionsResourceWithStreamingResponse:
+ """
+ An alternative to `.with_raw_response` that doesn't eagerly read the response body.
+
+ For more information, see https://www.github.com/replicate/replicate-python-stainless#with_streaming_response
+ """
+ return PredictionsResourceWithStreamingResponse(self)
+
+ def create(
+ self,
+ model_name: str,
+ *,
+ model_owner: str,
+ input: object,
+ stream: bool | NotGiven = NOT_GIVEN,
+ webhook: str | NotGiven = NOT_GIVEN,
+ webhook_events_filter: List[Literal["start", "output", "logs", "completed"]] | NotGiven = NOT_GIVEN,
+ prefer: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> Prediction:
+ """
+ Create a prediction using an
+ [official model](https://replicate.com/changelog/2025-01-29-official-models).
+
+ If you're _not_ running an official model, use the
+ [`predictions.create`](#predictions.create) operation instead.
+
+ Example cURL request:
+
+ ```console
+ curl -s -X POST -H 'Prefer: wait' \\
+ -d '{"input": {"prompt": "Write a short poem about the weather."}}' \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ -H 'Content-Type: application/json' \\
+ https://api.replicate.com/v1/models/meta/meta-llama-3-70b-instruct/predictions
+ ```
+
+ The request will wait up to 60 seconds for the model to run. If this time is
+ exceeded the prediction will be returned in a `"starting"` state and need to be
+ retrieved using the `predictions.get` endpiont.
+
+ For a complete overview of the `deployments.predictions.create` API check out
+ our documentation on
+ [creating a prediction](https://replicate.com/docs/topics/predictions/create-a-prediction)
+ which covers a variety of use cases.
+
+ Args:
+ input: The model's input as a JSON object. The input schema depends on what model you
+ are running. To see the available inputs, click the "API" tab on the model you
+ are running or [get the model version](#models.versions.get) and look at its
+ `openapi_schema` property. For example,
+ [stability-ai/sdxl](https://replicate.com/stability-ai/sdxl) takes `prompt` as
+ an input.
+
+ Files should be passed as HTTP URLs or data URLs.
+
+ Use an HTTP URL when:
+
+ - you have a large file > 256kb
+ - you want to be able to use the file multiple times
+ - you want your prediction metadata to be associable with your input files
+
+ Use a data URL when:
+
+ - you have a small file <= 256kb
+ - you don't want to upload and host the file somewhere
+ - you don't need to use the file again (Replicate will not store it)
+
+ stream: **This field is deprecated.**
+
+ Request a URL to receive streaming output using
+ [server-sent events (SSE)](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events).
+
+ This field is no longer needed as the returned prediction will always have a
+ `stream` entry in its `url` property if the model supports streaming.
+
+ webhook: An HTTPS URL for receiving a webhook when the prediction has new output. The
+ webhook will be a POST request where the request body is the same as the
+ response body of the [get prediction](#predictions.get) operation. If there are
+ network problems, we will retry the webhook a few times, so make sure it can be
+ safely called more than once. Replicate will not follow redirects when sending
+ webhook requests to your service, so be sure to specify a URL that will resolve
+ without redirecting.
+
+ webhook_events_filter: By default, we will send requests to your webhook URL whenever there are new
+ outputs or the prediction has finished. You can change which events trigger
+ webhook requests by specifying `webhook_events_filter` in the prediction
+ request:
+
+ - `start`: immediately on prediction start
+ - `output`: each time a prediction generates an output (note that predictions
+ can generate multiple outputs)
+ - `logs`: each time log output is generated by a prediction
+ - `completed`: when the prediction reaches a terminal state
+ (succeeded/canceled/failed)
+
+ For example, if you only wanted requests to be sent at the start and end of the
+ prediction, you would provide:
+
+ ```json
+ {
+ "input": {
+ "text": "Alice"
+ },
+ "webhook": "https://example.com/my-webhook",
+ "webhook_events_filter": ["start", "completed"]
+ }
+ ```
+
+ Requests for event types `output` and `logs` will be sent at most once every
+ 500ms. If you request `start` and `completed` webhooks, then they'll always be
+ sent regardless of throttling.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not model_owner:
+ raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
+ if not model_name:
+ raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
+ extra_headers = {**strip_not_given({"Prefer": prefer}), **(extra_headers or {})}
+ return self._post(
+ f"/models/{model_owner}/{model_name}/predictions",
+ body=maybe_transform(
+ {
+ "input": input,
+ "stream": stream,
+ "webhook": webhook,
+ "webhook_events_filter": webhook_events_filter,
+ },
+ prediction_create_params.PredictionCreateParams,
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=Prediction,
+ )
+
+
+class AsyncPredictionsResource(AsyncAPIResource):
+ @cached_property
+ def with_raw_response(self) -> AsyncPredictionsResourceWithRawResponse:
+ """
+ This property can be used as a prefix for any HTTP method call to return
+ the raw response object instead of the parsed content.
+
+ For more information, see https://www.github.com/replicate/replicate-python-stainless#accessing-raw-response-data-eg-headers
+ """
+ return AsyncPredictionsResourceWithRawResponse(self)
+
+ @cached_property
+ def with_streaming_response(self) -> AsyncPredictionsResourceWithStreamingResponse:
+ """
+ An alternative to `.with_raw_response` that doesn't eagerly read the response body.
+
+ For more information, see https://www.github.com/replicate/replicate-python-stainless#with_streaming_response
+ """
+ return AsyncPredictionsResourceWithStreamingResponse(self)
+
+ async def create(
+ self,
+ model_name: str,
+ *,
+ model_owner: str,
+ input: object,
+ stream: bool | NotGiven = NOT_GIVEN,
+ webhook: str | NotGiven = NOT_GIVEN,
+ webhook_events_filter: List[Literal["start", "output", "logs", "completed"]] | NotGiven = NOT_GIVEN,
+ prefer: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> Prediction:
+ """
+ Create a prediction using an
+ [official model](https://replicate.com/changelog/2025-01-29-official-models).
+
+ If you're _not_ running an official model, use the
+ [`predictions.create`](#predictions.create) operation instead.
+
+ Example cURL request:
+
+ ```console
+ curl -s -X POST -H 'Prefer: wait' \\
+ -d '{"input": {"prompt": "Write a short poem about the weather."}}' \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ -H 'Content-Type: application/json' \\
+ https://api.replicate.com/v1/models/meta/meta-llama-3-70b-instruct/predictions
+ ```
+
+ The request will wait up to 60 seconds for the model to run. If this time is
+ exceeded the prediction will be returned in a `"starting"` state and need to be
+ retrieved using the `predictions.get` endpiont.
+
+ For a complete overview of the `deployments.predictions.create` API check out
+ our documentation on
+ [creating a prediction](https://replicate.com/docs/topics/predictions/create-a-prediction)
+ which covers a variety of use cases.
+
+ Args:
+ input: The model's input as a JSON object. The input schema depends on what model you
+ are running. To see the available inputs, click the "API" tab on the model you
+ are running or [get the model version](#models.versions.get) and look at its
+ `openapi_schema` property. For example,
+ [stability-ai/sdxl](https://replicate.com/stability-ai/sdxl) takes `prompt` as
+ an input.
+
+ Files should be passed as HTTP URLs or data URLs.
+
+ Use an HTTP URL when:
+
+ - you have a large file > 256kb
+ - you want to be able to use the file multiple times
+ - you want your prediction metadata to be associable with your input files
+
+ Use a data URL when:
+
+ - you have a small file <= 256kb
+ - you don't want to upload and host the file somewhere
+ - you don't need to use the file again (Replicate will not store it)
+
+ stream: **This field is deprecated.**
+
+ Request a URL to receive streaming output using
+ [server-sent events (SSE)](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events).
+
+ This field is no longer needed as the returned prediction will always have a
+ `stream` entry in its `url` property if the model supports streaming.
+
+ webhook: An HTTPS URL for receiving a webhook when the prediction has new output. The
+ webhook will be a POST request where the request body is the same as the
+ response body of the [get prediction](#predictions.get) operation. If there are
+ network problems, we will retry the webhook a few times, so make sure it can be
+ safely called more than once. Replicate will not follow redirects when sending
+ webhook requests to your service, so be sure to specify a URL that will resolve
+ without redirecting.
+
+ webhook_events_filter: By default, we will send requests to your webhook URL whenever there are new
+ outputs or the prediction has finished. You can change which events trigger
+ webhook requests by specifying `webhook_events_filter` in the prediction
+ request:
+
+ - `start`: immediately on prediction start
+ - `output`: each time a prediction generates an output (note that predictions
+ can generate multiple outputs)
+ - `logs`: each time log output is generated by a prediction
+ - `completed`: when the prediction reaches a terminal state
+ (succeeded/canceled/failed)
+
+ For example, if you only wanted requests to be sent at the start and end of the
+ prediction, you would provide:
+
+ ```json
+ {
+ "input": {
+ "text": "Alice"
+ },
+ "webhook": "https://example.com/my-webhook",
+ "webhook_events_filter": ["start", "completed"]
+ }
+ ```
+
+ Requests for event types `output` and `logs` will be sent at most once every
+ 500ms. If you request `start` and `completed` webhooks, then they'll always be
+ sent regardless of throttling.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not model_owner:
+ raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
+ if not model_name:
+ raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
+ extra_headers = {**strip_not_given({"Prefer": prefer}), **(extra_headers or {})}
+ return await self._post(
+ f"/models/{model_owner}/{model_name}/predictions",
+ body=await async_maybe_transform(
+ {
+ "input": input,
+ "stream": stream,
+ "webhook": webhook,
+ "webhook_events_filter": webhook_events_filter,
+ },
+ prediction_create_params.PredictionCreateParams,
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=Prediction,
+ )
+
+
+class PredictionsResourceWithRawResponse:
+ def __init__(self, predictions: PredictionsResource) -> None:
+ self._predictions = predictions
+
+ self.create = to_raw_response_wrapper(
+ predictions.create,
+ )
+
+
+class AsyncPredictionsResourceWithRawResponse:
+ def __init__(self, predictions: AsyncPredictionsResource) -> None:
+ self._predictions = predictions
+
+ self.create = async_to_raw_response_wrapper(
+ predictions.create,
+ )
+
+
+class PredictionsResourceWithStreamingResponse:
+ def __init__(self, predictions: PredictionsResource) -> None:
+ self._predictions = predictions
+
+ self.create = to_streamed_response_wrapper(
+ predictions.create,
+ )
+
+
+class AsyncPredictionsResourceWithStreamingResponse:
+ def __init__(self, predictions: AsyncPredictionsResource) -> None:
+ self._predictions = predictions
+
+ self.create = async_to_streamed_response_wrapper(
+ predictions.create,
+ )
diff --git a/src/replicate/resources/models/versions.py b/src/replicate/resources/models/versions.py
index 308d04b..294664e 100644
--- a/src/replicate/resources/models/versions.py
+++ b/src/replicate/resources/models/versions.py
@@ -2,13 +2,9 @@
from __future__ import annotations
-from typing import List
-from typing_extensions import Literal
-
import httpx
from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
-from ..._utils import maybe_transform, async_maybe_transform
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import (
@@ -18,8 +14,6 @@
async_to_streamed_response_wrapper,
)
from ..._base_client import make_request_options
-from ...types.models import version_create_training_params
-from ...types.models.version_create_training_response import VersionCreateTrainingResponse
__all__ = ["VersionsResource", "AsyncVersionsResource"]
@@ -168,156 +162,6 @@ def delete(
cast_to=NoneType,
)
- def create_training(
- self,
- version_id: str,
- *,
- model_owner: str,
- model_name: str,
- destination: str,
- input: object,
- webhook: str | NotGiven = NOT_GIVEN,
- webhook_events_filter: List[Literal["start", "output", "logs", "completed"]] | NotGiven = NOT_GIVEN,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> VersionCreateTrainingResponse:
- """
- Start a new training of the model version you specify.
-
- Example request body:
-
- ```json
- {
- "destination": "{new_owner}/{new_name}",
- "input": {
- "train_data": "https://example.com/my-input-images.zip"
- },
- "webhook": "https://example.com/my-webhook"
- }
- ```
-
- Example cURL request:
-
- ```console
- curl -s -X POST \\
- -d '{"destination": "{new_owner}/{new_name}", "input": {"input_images": "https://example.com/my-input-images.zip"}}' \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- -H 'Content-Type: application/json' \\
- https://api.replicate.com/v1/models/stability-ai/sdxl/versions/da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf/trainings
- ```
-
- The response will be the training object:
-
- ```json
- {
- "id": "zz4ibbonubfz7carwiefibzgga",
- "model": "stability-ai/sdxl",
- "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf",
- "input": {
- "input_images": "https://example.com/my-input-images.zip"
- },
- "logs": "",
- "error": null,
- "status": "starting",
- "created_at": "2023-09-08T16:32:56.990893084Z",
- "urls": {
- "cancel": "https://api.replicate.com/v1/predictions/zz4ibbonubfz7carwiefibzgga/cancel",
- "get": "https://api.replicate.com/v1/predictions/zz4ibbonubfz7carwiefibzgga"
- }
- }
- ```
-
- As models can take several minutes or more to train, the result will not be
- available immediately. To get the final result of the training you should either
- provide a `webhook` HTTPS URL for us to call when the results are ready, or poll
- the [get a training](#trainings.get) endpoint until it has finished.
-
- When a training completes, it creates a new
- [version](https://replicate.com/docs/how-does-replicate-work#terminology) of the
- model at the specified destination.
-
- To find some models to train on, check out the
- [trainable language models collection](https://replicate.com/collections/trainable-language-models).
-
- Args:
- destination: A string representing the desired model to push to in the format
- `{destination_model_owner}/{destination_model_name}`. This should be an existing
- model owned by the user or organization making the API request. If the
- destination is invalid, the server will return an appropriate 4XX response.
-
- input: An object containing inputs to the Cog model's `train()` function.
-
- webhook: An HTTPS URL for receiving a webhook when the training completes. The webhook
- will be a POST request where the request body is the same as the response body
- of the [get training](#trainings.get) operation. If there are network problems,
- we will retry the webhook a few times, so make sure it can be safely called more
- than once. Replicate will not follow redirects when sending webhook requests to
- your service, so be sure to specify a URL that will resolve without redirecting.
-
- webhook_events_filter: By default, we will send requests to your webhook URL whenever there are new
- outputs or the training has finished. You can change which events trigger
- webhook requests by specifying `webhook_events_filter` in the training request:
-
- - `start`: immediately on training start
- - `output`: each time a training generates an output (note that trainings can
- generate multiple outputs)
- - `logs`: each time log output is generated by a training
- - `completed`: when the training reaches a terminal state
- (succeeded/canceled/failed)
-
- For example, if you only wanted requests to be sent at the start and end of the
- training, you would provide:
-
- ```json
- {
- "destination": "my-organization/my-model",
- "input": {
- "text": "Alice"
- },
- "webhook": "https://example.com/my-webhook",
- "webhook_events_filter": ["start", "completed"]
- }
- ```
-
- Requests for event types `output` and `logs` will be sent at most once every
- 500ms. If you request `start` and `completed` webhooks, then they'll always be
- sent regardless of throttling.
-
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not model_owner:
- raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
- if not model_name:
- raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
- if not version_id:
- raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}")
- return self._post(
- f"/models/{model_owner}/{model_name}/versions/{version_id}/trainings",
- body=maybe_transform(
- {
- "destination": destination,
- "input": input,
- "webhook": webhook,
- "webhook_events_filter": webhook_events_filter,
- },
- version_create_training_params.VersionCreateTrainingParams,
- ),
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=VersionCreateTrainingResponse,
- )
-
def get(
self,
version_id: str,
@@ -558,156 +402,6 @@ async def delete(
cast_to=NoneType,
)
- async def create_training(
- self,
- version_id: str,
- *,
- model_owner: str,
- model_name: str,
- destination: str,
- input: object,
- webhook: str | NotGiven = NOT_GIVEN,
- webhook_events_filter: List[Literal["start", "output", "logs", "completed"]] | NotGiven = NOT_GIVEN,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> VersionCreateTrainingResponse:
- """
- Start a new training of the model version you specify.
-
- Example request body:
-
- ```json
- {
- "destination": "{new_owner}/{new_name}",
- "input": {
- "train_data": "https://example.com/my-input-images.zip"
- },
- "webhook": "https://example.com/my-webhook"
- }
- ```
-
- Example cURL request:
-
- ```console
- curl -s -X POST \\
- -d '{"destination": "{new_owner}/{new_name}", "input": {"input_images": "https://example.com/my-input-images.zip"}}' \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- -H 'Content-Type: application/json' \\
- https://api.replicate.com/v1/models/stability-ai/sdxl/versions/da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf/trainings
- ```
-
- The response will be the training object:
-
- ```json
- {
- "id": "zz4ibbonubfz7carwiefibzgga",
- "model": "stability-ai/sdxl",
- "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf",
- "input": {
- "input_images": "https://example.com/my-input-images.zip"
- },
- "logs": "",
- "error": null,
- "status": "starting",
- "created_at": "2023-09-08T16:32:56.990893084Z",
- "urls": {
- "cancel": "https://api.replicate.com/v1/predictions/zz4ibbonubfz7carwiefibzgga/cancel",
- "get": "https://api.replicate.com/v1/predictions/zz4ibbonubfz7carwiefibzgga"
- }
- }
- ```
-
- As models can take several minutes or more to train, the result will not be
- available immediately. To get the final result of the training you should either
- provide a `webhook` HTTPS URL for us to call when the results are ready, or poll
- the [get a training](#trainings.get) endpoint until it has finished.
-
- When a training completes, it creates a new
- [version](https://replicate.com/docs/how-does-replicate-work#terminology) of the
- model at the specified destination.
-
- To find some models to train on, check out the
- [trainable language models collection](https://replicate.com/collections/trainable-language-models).
-
- Args:
- destination: A string representing the desired model to push to in the format
- `{destination_model_owner}/{destination_model_name}`. This should be an existing
- model owned by the user or organization making the API request. If the
- destination is invalid, the server will return an appropriate 4XX response.
-
- input: An object containing inputs to the Cog model's `train()` function.
-
- webhook: An HTTPS URL for receiving a webhook when the training completes. The webhook
- will be a POST request where the request body is the same as the response body
- of the [get training](#trainings.get) operation. If there are network problems,
- we will retry the webhook a few times, so make sure it can be safely called more
- than once. Replicate will not follow redirects when sending webhook requests to
- your service, so be sure to specify a URL that will resolve without redirecting.
-
- webhook_events_filter: By default, we will send requests to your webhook URL whenever there are new
- outputs or the training has finished. You can change which events trigger
- webhook requests by specifying `webhook_events_filter` in the training request:
-
- - `start`: immediately on training start
- - `output`: each time a training generates an output (note that trainings can
- generate multiple outputs)
- - `logs`: each time log output is generated by a training
- - `completed`: when the training reaches a terminal state
- (succeeded/canceled/failed)
-
- For example, if you only wanted requests to be sent at the start and end of the
- training, you would provide:
-
- ```json
- {
- "destination": "my-organization/my-model",
- "input": {
- "text": "Alice"
- },
- "webhook": "https://example.com/my-webhook",
- "webhook_events_filter": ["start", "completed"]
- }
- ```
-
- Requests for event types `output` and `logs` will be sent at most once every
- 500ms. If you request `start` and `completed` webhooks, then they'll always be
- sent regardless of throttling.
-
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not model_owner:
- raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
- if not model_name:
- raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
- if not version_id:
- raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}")
- return await self._post(
- f"/models/{model_owner}/{model_name}/versions/{version_id}/trainings",
- body=await async_maybe_transform(
- {
- "destination": destination,
- "input": input,
- "webhook": webhook,
- "webhook_events_filter": webhook_events_filter,
- },
- version_create_training_params.VersionCreateTrainingParams,
- ),
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=VersionCreateTrainingResponse,
- )
-
async def get(
self,
version_id: str,
@@ -814,9 +508,6 @@ def __init__(self, versions: VersionsResource) -> None:
self.delete = to_raw_response_wrapper(
versions.delete,
)
- self.create_training = to_raw_response_wrapper(
- versions.create_training,
- )
self.get = to_raw_response_wrapper(
versions.get,
)
@@ -832,9 +523,6 @@ def __init__(self, versions: AsyncVersionsResource) -> None:
self.delete = async_to_raw_response_wrapper(
versions.delete,
)
- self.create_training = async_to_raw_response_wrapper(
- versions.create_training,
- )
self.get = async_to_raw_response_wrapper(
versions.get,
)
@@ -850,9 +538,6 @@ def __init__(self, versions: VersionsResource) -> None:
self.delete = to_streamed_response_wrapper(
versions.delete,
)
- self.create_training = to_streamed_response_wrapper(
- versions.create_training,
- )
self.get = to_streamed_response_wrapper(
versions.get,
)
@@ -868,9 +553,6 @@ def __init__(self, versions: AsyncVersionsResource) -> None:
self.delete = async_to_streamed_response_wrapper(
versions.delete,
)
- self.create_training = async_to_streamed_response_wrapper(
- versions.create_training,
- )
self.get = async_to_streamed_response_wrapper(
versions.get,
)
diff --git a/src/replicate/resources/trainings.py b/src/replicate/resources/trainings.py
index 5a357af..4b9d127 100644
--- a/src/replicate/resources/trainings.py
+++ b/src/replicate/resources/trainings.py
@@ -2,9 +2,14 @@
from __future__ import annotations
+from typing import List
+from typing_extensions import Literal
+
import httpx
+from ..types import training_create_params
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
+from .._utils import maybe_transform, async_maybe_transform
from .._compat import cached_property
from .._resource import SyncAPIResource, AsyncAPIResource
from .._response import (
@@ -18,6 +23,7 @@
from ..types.training_get_response import TrainingGetResponse
from ..types.training_list_response import TrainingListResponse
from ..types.training_cancel_response import TrainingCancelResponse
+from ..types.training_create_response import TrainingCreateResponse
__all__ = ["TrainingsResource", "AsyncTrainingsResource"]
@@ -42,6 +48,156 @@ def with_streaming_response(self) -> TrainingsResourceWithStreamingResponse:
"""
return TrainingsResourceWithStreamingResponse(self)
+ def create(
+ self,
+ version_id: str,
+ *,
+ model_owner: str,
+ model_name: str,
+ destination: str,
+ input: object,
+ webhook: str | NotGiven = NOT_GIVEN,
+ webhook_events_filter: List[Literal["start", "output", "logs", "completed"]] | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> TrainingCreateResponse:
+ """
+ Start a new training of the model version you specify.
+
+ Example request body:
+
+ ```json
+ {
+ "destination": "{new_owner}/{new_name}",
+ "input": {
+ "train_data": "https://example.com/my-input-images.zip"
+ },
+ "webhook": "https://example.com/my-webhook"
+ }
+ ```
+
+ Example cURL request:
+
+ ```console
+ curl -s -X POST \\
+ -d '{"destination": "{new_owner}/{new_name}", "input": {"input_images": "https://example.com/my-input-images.zip"}}' \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ -H 'Content-Type: application/json' \\
+ https://api.replicate.com/v1/models/stability-ai/sdxl/versions/da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf/trainings
+ ```
+
+ The response will be the training object:
+
+ ```json
+ {
+ "id": "zz4ibbonubfz7carwiefibzgga",
+ "model": "stability-ai/sdxl",
+ "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf",
+ "input": {
+ "input_images": "https://example.com/my-input-images.zip"
+ },
+ "logs": "",
+ "error": null,
+ "status": "starting",
+ "created_at": "2023-09-08T16:32:56.990893084Z",
+ "urls": {
+ "cancel": "https://api.replicate.com/v1/predictions/zz4ibbonubfz7carwiefibzgga/cancel",
+ "get": "https://api.replicate.com/v1/predictions/zz4ibbonubfz7carwiefibzgga"
+ }
+ }
+ ```
+
+ As models can take several minutes or more to train, the result will not be
+ available immediately. To get the final result of the training you should either
+ provide a `webhook` HTTPS URL for us to call when the results are ready, or poll
+ the [get a training](#trainings.get) endpoint until it has finished.
+
+ When a training completes, it creates a new
+ [version](https://replicate.com/docs/how-does-replicate-work#terminology) of the
+ model at the specified destination.
+
+ To find some models to train on, check out the
+ [trainable language models collection](https://replicate.com/collections/trainable-language-models).
+
+ Args:
+ destination: A string representing the desired model to push to in the format
+ `{destination_model_owner}/{destination_model_name}`. This should be an existing
+ model owned by the user or organization making the API request. If the
+ destination is invalid, the server will return an appropriate 4XX response.
+
+ input: An object containing inputs to the Cog model's `train()` function.
+
+ webhook: An HTTPS URL for receiving a webhook when the training completes. The webhook
+ will be a POST request where the request body is the same as the response body
+ of the [get training](#trainings.get) operation. If there are network problems,
+ we will retry the webhook a few times, so make sure it can be safely called more
+ than once. Replicate will not follow redirects when sending webhook requests to
+ your service, so be sure to specify a URL that will resolve without redirecting.
+
+ webhook_events_filter: By default, we will send requests to your webhook URL whenever there are new
+ outputs or the training has finished. You can change which events trigger
+ webhook requests by specifying `webhook_events_filter` in the training request:
+
+ - `start`: immediately on training start
+ - `output`: each time a training generates an output (note that trainings can
+ generate multiple outputs)
+ - `logs`: each time log output is generated by a training
+ - `completed`: when the training reaches a terminal state
+ (succeeded/canceled/failed)
+
+ For example, if you only wanted requests to be sent at the start and end of the
+ training, you would provide:
+
+ ```json
+ {
+ "destination": "my-organization/my-model",
+ "input": {
+ "text": "Alice"
+ },
+ "webhook": "https://example.com/my-webhook",
+ "webhook_events_filter": ["start", "completed"]
+ }
+ ```
+
+ Requests for event types `output` and `logs` will be sent at most once every
+ 500ms. If you request `start` and `completed` webhooks, then they'll always be
+ sent regardless of throttling.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not model_owner:
+ raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
+ if not model_name:
+ raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
+ if not version_id:
+ raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}")
+ return self._post(
+ f"/models/{model_owner}/{model_name}/versions/{version_id}/trainings",
+ body=maybe_transform(
+ {
+ "destination": destination,
+ "input": input,
+ "webhook": webhook,
+ "webhook_events_filter": webhook_events_filter,
+ },
+ training_create_params.TrainingCreateParams,
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=TrainingCreateResponse,
+ )
+
def list(
self,
*,
@@ -273,6 +429,156 @@ def with_streaming_response(self) -> AsyncTrainingsResourceWithStreamingResponse
"""
return AsyncTrainingsResourceWithStreamingResponse(self)
+ async def create(
+ self,
+ version_id: str,
+ *,
+ model_owner: str,
+ model_name: str,
+ destination: str,
+ input: object,
+ webhook: str | NotGiven = NOT_GIVEN,
+ webhook_events_filter: List[Literal["start", "output", "logs", "completed"]] | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> TrainingCreateResponse:
+ """
+ Start a new training of the model version you specify.
+
+ Example request body:
+
+ ```json
+ {
+ "destination": "{new_owner}/{new_name}",
+ "input": {
+ "train_data": "https://example.com/my-input-images.zip"
+ },
+ "webhook": "https://example.com/my-webhook"
+ }
+ ```
+
+ Example cURL request:
+
+ ```console
+ curl -s -X POST \\
+ -d '{"destination": "{new_owner}/{new_name}", "input": {"input_images": "https://example.com/my-input-images.zip"}}' \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ -H 'Content-Type: application/json' \\
+ https://api.replicate.com/v1/models/stability-ai/sdxl/versions/da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf/trainings
+ ```
+
+ The response will be the training object:
+
+ ```json
+ {
+ "id": "zz4ibbonubfz7carwiefibzgga",
+ "model": "stability-ai/sdxl",
+ "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf",
+ "input": {
+ "input_images": "https://example.com/my-input-images.zip"
+ },
+ "logs": "",
+ "error": null,
+ "status": "starting",
+ "created_at": "2023-09-08T16:32:56.990893084Z",
+ "urls": {
+ "cancel": "https://api.replicate.com/v1/predictions/zz4ibbonubfz7carwiefibzgga/cancel",
+ "get": "https://api.replicate.com/v1/predictions/zz4ibbonubfz7carwiefibzgga"
+ }
+ }
+ ```
+
+ As models can take several minutes or more to train, the result will not be
+ available immediately. To get the final result of the training you should either
+ provide a `webhook` HTTPS URL for us to call when the results are ready, or poll
+ the [get a training](#trainings.get) endpoint until it has finished.
+
+ When a training completes, it creates a new
+ [version](https://replicate.com/docs/how-does-replicate-work#terminology) of the
+ model at the specified destination.
+
+ To find some models to train on, check out the
+ [trainable language models collection](https://replicate.com/collections/trainable-language-models).
+
+ Args:
+ destination: A string representing the desired model to push to in the format
+ `{destination_model_owner}/{destination_model_name}`. This should be an existing
+ model owned by the user or organization making the API request. If the
+ destination is invalid, the server will return an appropriate 4XX response.
+
+ input: An object containing inputs to the Cog model's `train()` function.
+
+ webhook: An HTTPS URL for receiving a webhook when the training completes. The webhook
+ will be a POST request where the request body is the same as the response body
+ of the [get training](#trainings.get) operation. If there are network problems,
+ we will retry the webhook a few times, so make sure it can be safely called more
+ than once. Replicate will not follow redirects when sending webhook requests to
+ your service, so be sure to specify a URL that will resolve without redirecting.
+
+ webhook_events_filter: By default, we will send requests to your webhook URL whenever there are new
+ outputs or the training has finished. You can change which events trigger
+ webhook requests by specifying `webhook_events_filter` in the training request:
+
+ - `start`: immediately on training start
+ - `output`: each time a training generates an output (note that trainings can
+ generate multiple outputs)
+ - `logs`: each time log output is generated by a training
+ - `completed`: when the training reaches a terminal state
+ (succeeded/canceled/failed)
+
+ For example, if you only wanted requests to be sent at the start and end of the
+ training, you would provide:
+
+ ```json
+ {
+ "destination": "my-organization/my-model",
+ "input": {
+ "text": "Alice"
+ },
+ "webhook": "https://example.com/my-webhook",
+ "webhook_events_filter": ["start", "completed"]
+ }
+ ```
+
+ Requests for event types `output` and `logs` will be sent at most once every
+ 500ms. If you request `start` and `completed` webhooks, then they'll always be
+ sent regardless of throttling.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not model_owner:
+ raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
+ if not model_name:
+ raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
+ if not version_id:
+ raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}")
+ return await self._post(
+ f"/models/{model_owner}/{model_name}/versions/{version_id}/trainings",
+ body=await async_maybe_transform(
+ {
+ "destination": destination,
+ "input": input,
+ "webhook": webhook,
+ "webhook_events_filter": webhook_events_filter,
+ },
+ training_create_params.TrainingCreateParams,
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=TrainingCreateResponse,
+ )
+
def list(
self,
*,
@@ -488,6 +794,9 @@ class TrainingsResourceWithRawResponse:
def __init__(self, trainings: TrainingsResource) -> None:
self._trainings = trainings
+ self.create = to_raw_response_wrapper(
+ trainings.create,
+ )
self.list = to_raw_response_wrapper(
trainings.list,
)
@@ -503,6 +812,9 @@ class AsyncTrainingsResourceWithRawResponse:
def __init__(self, trainings: AsyncTrainingsResource) -> None:
self._trainings = trainings
+ self.create = async_to_raw_response_wrapper(
+ trainings.create,
+ )
self.list = async_to_raw_response_wrapper(
trainings.list,
)
@@ -518,6 +830,9 @@ class TrainingsResourceWithStreamingResponse:
def __init__(self, trainings: TrainingsResource) -> None:
self._trainings = trainings
+ self.create = to_streamed_response_wrapper(
+ trainings.create,
+ )
self.list = to_streamed_response_wrapper(
trainings.list,
)
@@ -533,6 +848,9 @@ class AsyncTrainingsResourceWithStreamingResponse:
def __init__(self, trainings: AsyncTrainingsResource) -> None:
self._trainings = trainings
+ self.create = async_to_streamed_response_wrapper(
+ trainings.create,
+ )
self.list = async_to_streamed_response_wrapper(
trainings.list,
)
diff --git a/src/replicate/resources/webhooks/default/__init__.py b/src/replicate/resources/webhooks/default/__init__.py
new file mode 100644
index 0000000..f3aa085
--- /dev/null
+++ b/src/replicate/resources/webhooks/default/__init__.py
@@ -0,0 +1,33 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from .secret import (
+ SecretResource,
+ AsyncSecretResource,
+ SecretResourceWithRawResponse,
+ AsyncSecretResourceWithRawResponse,
+ SecretResourceWithStreamingResponse,
+ AsyncSecretResourceWithStreamingResponse,
+)
+from .default import (
+ DefaultResource,
+ AsyncDefaultResource,
+ DefaultResourceWithRawResponse,
+ AsyncDefaultResourceWithRawResponse,
+ DefaultResourceWithStreamingResponse,
+ AsyncDefaultResourceWithStreamingResponse,
+)
+
+__all__ = [
+ "SecretResource",
+ "AsyncSecretResource",
+ "SecretResourceWithRawResponse",
+ "AsyncSecretResourceWithRawResponse",
+ "SecretResourceWithStreamingResponse",
+ "AsyncSecretResourceWithStreamingResponse",
+ "DefaultResource",
+ "AsyncDefaultResource",
+ "DefaultResourceWithRawResponse",
+ "AsyncDefaultResourceWithRawResponse",
+ "DefaultResourceWithStreamingResponse",
+ "AsyncDefaultResourceWithStreamingResponse",
+]
diff --git a/src/replicate/resources/webhooks/default/default.py b/src/replicate/resources/webhooks/default/default.py
new file mode 100644
index 0000000..3c28ace
--- /dev/null
+++ b/src/replicate/resources/webhooks/default/default.py
@@ -0,0 +1,102 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+from .secret import (
+ SecretResource,
+ AsyncSecretResource,
+ SecretResourceWithRawResponse,
+ AsyncSecretResourceWithRawResponse,
+ SecretResourceWithStreamingResponse,
+ AsyncSecretResourceWithStreamingResponse,
+)
+from ...._compat import cached_property
+from ...._resource import SyncAPIResource, AsyncAPIResource
+
+__all__ = ["DefaultResource", "AsyncDefaultResource"]
+
+
+class DefaultResource(SyncAPIResource):
+ @cached_property
+ def secret(self) -> SecretResource:
+ return SecretResource(self._client)
+
+ @cached_property
+ def with_raw_response(self) -> DefaultResourceWithRawResponse:
+ """
+ This property can be used as a prefix for any HTTP method call to return
+ the raw response object instead of the parsed content.
+
+ For more information, see https://www.github.com/replicate/replicate-python-stainless#accessing-raw-response-data-eg-headers
+ """
+ return DefaultResourceWithRawResponse(self)
+
+ @cached_property
+ def with_streaming_response(self) -> DefaultResourceWithStreamingResponse:
+ """
+ An alternative to `.with_raw_response` that doesn't eagerly read the response body.
+
+ For more information, see https://www.github.com/replicate/replicate-python-stainless#with_streaming_response
+ """
+ return DefaultResourceWithStreamingResponse(self)
+
+
+class AsyncDefaultResource(AsyncAPIResource):
+ @cached_property
+ def secret(self) -> AsyncSecretResource:
+ return AsyncSecretResource(self._client)
+
+ @cached_property
+ def with_raw_response(self) -> AsyncDefaultResourceWithRawResponse:
+ """
+ This property can be used as a prefix for any HTTP method call to return
+ the raw response object instead of the parsed content.
+
+ For more information, see https://www.github.com/replicate/replicate-python-stainless#accessing-raw-response-data-eg-headers
+ """
+ return AsyncDefaultResourceWithRawResponse(self)
+
+ @cached_property
+ def with_streaming_response(self) -> AsyncDefaultResourceWithStreamingResponse:
+ """
+ An alternative to `.with_raw_response` that doesn't eagerly read the response body.
+
+ For more information, see https://www.github.com/replicate/replicate-python-stainless#with_streaming_response
+ """
+ return AsyncDefaultResourceWithStreamingResponse(self)
+
+
+class DefaultResourceWithRawResponse:
+ def __init__(self, default: DefaultResource) -> None:
+ self._default = default
+
+ @cached_property
+ def secret(self) -> SecretResourceWithRawResponse:
+ return SecretResourceWithRawResponse(self._default.secret)
+
+
+class AsyncDefaultResourceWithRawResponse:
+ def __init__(self, default: AsyncDefaultResource) -> None:
+ self._default = default
+
+ @cached_property
+ def secret(self) -> AsyncSecretResourceWithRawResponse:
+ return AsyncSecretResourceWithRawResponse(self._default.secret)
+
+
+class DefaultResourceWithStreamingResponse:
+ def __init__(self, default: DefaultResource) -> None:
+ self._default = default
+
+ @cached_property
+ def secret(self) -> SecretResourceWithStreamingResponse:
+ return SecretResourceWithStreamingResponse(self._default.secret)
+
+
+class AsyncDefaultResourceWithStreamingResponse:
+ def __init__(self, default: AsyncDefaultResource) -> None:
+ self._default = default
+
+ @cached_property
+ def secret(self) -> AsyncSecretResourceWithStreamingResponse:
+ return AsyncSecretResourceWithStreamingResponse(self._default.secret)
diff --git a/src/replicate/resources/webhooks/default.py b/src/replicate/resources/webhooks/default/secret.py
similarity index 64%
rename from src/replicate/resources/webhooks/default.py
rename to src/replicate/resources/webhooks/default/secret.py
index bde9966..9430140 100644
--- a/src/replicate/resources/webhooks/default.py
+++ b/src/replicate/resources/webhooks/default/secret.py
@@ -4,42 +4,42 @@
import httpx
-from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
-from ..._compat import cached_property
-from ..._resource import SyncAPIResource, AsyncAPIResource
-from ..._response import (
+from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven
+from ...._compat import cached_property
+from ...._resource import SyncAPIResource, AsyncAPIResource
+from ...._response import (
to_raw_response_wrapper,
to_streamed_response_wrapper,
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
)
-from ..._base_client import make_request_options
-from ...types.webhooks.default_retrieve_secret_response import DefaultRetrieveSecretResponse
+from ...._base_client import make_request_options
+from ....types.webhooks.default.secret_get_response import SecretGetResponse
-__all__ = ["DefaultResource", "AsyncDefaultResource"]
+__all__ = ["SecretResource", "AsyncSecretResource"]
-class DefaultResource(SyncAPIResource):
+class SecretResource(SyncAPIResource):
@cached_property
- def with_raw_response(self) -> DefaultResourceWithRawResponse:
+ def with_raw_response(self) -> SecretResourceWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/replicate/replicate-python-stainless#accessing-raw-response-data-eg-headers
"""
- return DefaultResourceWithRawResponse(self)
+ return SecretResourceWithRawResponse(self)
@cached_property
- def with_streaming_response(self) -> DefaultResourceWithStreamingResponse:
+ def with_streaming_response(self) -> SecretResourceWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/replicate/replicate-python-stainless#with_streaming_response
"""
- return DefaultResourceWithStreamingResponse(self)
+ return SecretResourceWithStreamingResponse(self)
- def retrieve_secret(
+ def get(
self,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
@@ -48,7 +48,7 @@ def retrieve_secret(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> DefaultRetrieveSecretResponse:
+ ) -> SecretGetResponse:
"""Get the signing secret for the default webhook endpoint.
This is used to verify
@@ -75,31 +75,31 @@ def retrieve_secret(
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=DefaultRetrieveSecretResponse,
+ cast_to=SecretGetResponse,
)
-class AsyncDefaultResource(AsyncAPIResource):
+class AsyncSecretResource(AsyncAPIResource):
@cached_property
- def with_raw_response(self) -> AsyncDefaultResourceWithRawResponse:
+ def with_raw_response(self) -> AsyncSecretResourceWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/replicate/replicate-python-stainless#accessing-raw-response-data-eg-headers
"""
- return AsyncDefaultResourceWithRawResponse(self)
+ return AsyncSecretResourceWithRawResponse(self)
@cached_property
- def with_streaming_response(self) -> AsyncDefaultResourceWithStreamingResponse:
+ def with_streaming_response(self) -> AsyncSecretResourceWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/replicate/replicate-python-stainless#with_streaming_response
"""
- return AsyncDefaultResourceWithStreamingResponse(self)
+ return AsyncSecretResourceWithStreamingResponse(self)
- async def retrieve_secret(
+ async def get(
self,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
@@ -108,7 +108,7 @@ async def retrieve_secret(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> DefaultRetrieveSecretResponse:
+ ) -> SecretGetResponse:
"""Get the signing secret for the default webhook endpoint.
This is used to verify
@@ -135,41 +135,41 @@ async def retrieve_secret(
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=DefaultRetrieveSecretResponse,
+ cast_to=SecretGetResponse,
)
-class DefaultResourceWithRawResponse:
- def __init__(self, default: DefaultResource) -> None:
- self._default = default
+class SecretResourceWithRawResponse:
+ def __init__(self, secret: SecretResource) -> None:
+ self._secret = secret
- self.retrieve_secret = to_raw_response_wrapper(
- default.retrieve_secret,
+ self.get = to_raw_response_wrapper(
+ secret.get,
)
-class AsyncDefaultResourceWithRawResponse:
- def __init__(self, default: AsyncDefaultResource) -> None:
- self._default = default
+class AsyncSecretResourceWithRawResponse:
+ def __init__(self, secret: AsyncSecretResource) -> None:
+ self._secret = secret
- self.retrieve_secret = async_to_raw_response_wrapper(
- default.retrieve_secret,
+ self.get = async_to_raw_response_wrapper(
+ secret.get,
)
-class DefaultResourceWithStreamingResponse:
- def __init__(self, default: DefaultResource) -> None:
- self._default = default
+class SecretResourceWithStreamingResponse:
+ def __init__(self, secret: SecretResource) -> None:
+ self._secret = secret
- self.retrieve_secret = to_streamed_response_wrapper(
- default.retrieve_secret,
+ self.get = to_streamed_response_wrapper(
+ secret.get,
)
-class AsyncDefaultResourceWithStreamingResponse:
- def __init__(self, default: AsyncDefaultResource) -> None:
- self._default = default
+class AsyncSecretResourceWithStreamingResponse:
+ def __init__(self, secret: AsyncSecretResource) -> None:
+ self._secret = secret
- self.retrieve_secret = async_to_streamed_response_wrapper(
- default.retrieve_secret,
+ self.get = async_to_streamed_response_wrapper(
+ secret.get,
)
diff --git a/src/replicate/resources/webhooks/webhooks.py b/src/replicate/resources/webhooks/webhooks.py
index f00d5b3..67ce252 100644
--- a/src/replicate/resources/webhooks/webhooks.py
+++ b/src/replicate/resources/webhooks/webhooks.py
@@ -2,7 +2,9 @@
from __future__ import annotations
-from .default import (
+from ..._compat import cached_property
+from ..._resource import SyncAPIResource, AsyncAPIResource
+from .default.default import (
DefaultResource,
AsyncDefaultResource,
DefaultResourceWithRawResponse,
@@ -10,8 +12,6 @@
DefaultResourceWithStreamingResponse,
AsyncDefaultResourceWithStreamingResponse,
)
-from ..._compat import cached_property
-from ..._resource import SyncAPIResource, AsyncAPIResource
__all__ = ["WebhooksResource", "AsyncWebhooksResource"]
diff --git a/src/replicate/types/__init__.py b/src/replicate/types/__init__.py
index fa8ee5d..3fe2c03 100644
--- a/src/replicate/types/__init__.py
+++ b/src/replicate/types/__init__.py
@@ -10,6 +10,7 @@
from .training_get_response import TrainingGetResponse as TrainingGetResponse
from .hardware_list_response import HardwareListResponse as HardwareListResponse
from .prediction_list_params import PredictionListParams as PredictionListParams
+from .training_create_params import TrainingCreateParams as TrainingCreateParams
from .training_list_response import TrainingListResponse as TrainingListResponse
from .deployment_get_response import DeploymentGetResponse as DeploymentGetResponse
from .deployment_create_params import DeploymentCreateParams as DeploymentCreateParams
@@ -17,6 +18,6 @@
from .deployment_update_params import DeploymentUpdateParams as DeploymentUpdateParams
from .prediction_create_params import PredictionCreateParams as PredictionCreateParams
from .training_cancel_response import TrainingCancelResponse as TrainingCancelResponse
+from .training_create_response import TrainingCreateResponse as TrainingCreateResponse
from .deployment_create_response import DeploymentCreateResponse as DeploymentCreateResponse
from .deployment_update_response import DeploymentUpdateResponse as DeploymentUpdateResponse
-from .model_create_prediction_params import ModelCreatePredictionParams as ModelCreatePredictionParams
diff --git a/src/replicate/types/models/__init__.py b/src/replicate/types/models/__init__.py
index 05fbb6f..b53e1a7 100644
--- a/src/replicate/types/models/__init__.py
+++ b/src/replicate/types/models/__init__.py
@@ -3,5 +3,4 @@
from __future__ import annotations
from .readme_get_response import ReadmeGetResponse as ReadmeGetResponse
-from .version_create_training_params import VersionCreateTrainingParams as VersionCreateTrainingParams
-from .version_create_training_response import VersionCreateTrainingResponse as VersionCreateTrainingResponse
+from .prediction_create_params import PredictionCreateParams as PredictionCreateParams
diff --git a/src/replicate/types/model_create_prediction_params.py b/src/replicate/types/models/prediction_create_params.py
similarity index 96%
rename from src/replicate/types/model_create_prediction_params.py
rename to src/replicate/types/models/prediction_create_params.py
index 32474b7..6b76121 100644
--- a/src/replicate/types/model_create_prediction_params.py
+++ b/src/replicate/types/models/prediction_create_params.py
@@ -5,12 +5,12 @@
from typing import List
from typing_extensions import Literal, Required, Annotated, TypedDict
-from .._utils import PropertyInfo
+from ..._utils import PropertyInfo
-__all__ = ["ModelCreatePredictionParams"]
+__all__ = ["PredictionCreateParams"]
-class ModelCreatePredictionParams(TypedDict, total=False):
+class PredictionCreateParams(TypedDict, total=False):
model_owner: Required[str]
input: Required[object]
diff --git a/src/replicate/types/models/version_create_training_params.py b/src/replicate/types/training_create_params.py
similarity index 96%
rename from src/replicate/types/models/version_create_training_params.py
rename to src/replicate/types/training_create_params.py
index 59e6146..38542fb 100644
--- a/src/replicate/types/models/version_create_training_params.py
+++ b/src/replicate/types/training_create_params.py
@@ -5,10 +5,10 @@
from typing import List
from typing_extensions import Literal, Required, TypedDict
-__all__ = ["VersionCreateTrainingParams"]
+__all__ = ["TrainingCreateParams"]
-class VersionCreateTrainingParams(TypedDict, total=False):
+class TrainingCreateParams(TypedDict, total=False):
model_owner: Required[str]
model_name: Required[str]
diff --git a/src/replicate/types/models/version_create_training_response.py b/src/replicate/types/training_create_response.py
similarity index 92%
rename from src/replicate/types/models/version_create_training_response.py
rename to src/replicate/types/training_create_response.py
index 5b005fe..d29e3fa 100644
--- a/src/replicate/types/models/version_create_training_response.py
+++ b/src/replicate/types/training_create_response.py
@@ -4,9 +4,9 @@
from datetime import datetime
from typing_extensions import Literal
-from ..._models import BaseModel
+from .._models import BaseModel
-__all__ = ["VersionCreateTrainingResponse", "Metrics", "Output", "URLs"]
+__all__ = ["TrainingCreateResponse", "Metrics", "Output", "URLs"]
class Metrics(BaseModel):
@@ -30,7 +30,7 @@ class URLs(BaseModel):
"""URL to get the training details"""
-class VersionCreateTrainingResponse(BaseModel):
+class TrainingCreateResponse(BaseModel):
id: Optional[str] = None
"""The unique ID of the training"""
diff --git a/src/replicate/types/webhooks/__init__.py b/src/replicate/types/webhooks/__init__.py
index f546bf5..f8ee8b1 100644
--- a/src/replicate/types/webhooks/__init__.py
+++ b/src/replicate/types/webhooks/__init__.py
@@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
-
-from .default_retrieve_secret_response import DefaultRetrieveSecretResponse as DefaultRetrieveSecretResponse
diff --git a/src/replicate/types/webhooks/default/__init__.py b/src/replicate/types/webhooks/default/__init__.py
new file mode 100644
index 0000000..70715e0
--- /dev/null
+++ b/src/replicate/types/webhooks/default/__init__.py
@@ -0,0 +1,5 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+from .secret_get_response import SecretGetResponse as SecretGetResponse
diff --git a/src/replicate/types/webhooks/default_retrieve_secret_response.py b/src/replicate/types/webhooks/default/secret_get_response.py
similarity index 58%
rename from src/replicate/types/webhooks/default_retrieve_secret_response.py
rename to src/replicate/types/webhooks/default/secret_get_response.py
index 7bd4744..0e4303c 100644
--- a/src/replicate/types/webhooks/default_retrieve_secret_response.py
+++ b/src/replicate/types/webhooks/default/secret_get_response.py
@@ -2,11 +2,11 @@
from typing import Optional
-from ..._models import BaseModel
+from ...._models import BaseModel
-__all__ = ["DefaultRetrieveSecretResponse"]
+__all__ = ["SecretGetResponse"]
-class DefaultRetrieveSecretResponse(BaseModel):
+class SecretGetResponse(BaseModel):
key: Optional[str] = None
"""The signing secret."""
diff --git a/tests/api_resources/models/test_predictions.py b/tests/api_resources/models/test_predictions.py
new file mode 100644
index 0000000..d4f1974
--- /dev/null
+++ b/tests/api_resources/models/test_predictions.py
@@ -0,0 +1,164 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+import os
+from typing import Any, cast
+
+import pytest
+
+from replicate import ReplicateClient, AsyncReplicateClient
+from tests.utils import assert_matches_type
+from replicate.types import Prediction
+
+base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
+
+
+class TestPredictions:
+ parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_method_create(self, client: ReplicateClient) -> None:
+ prediction = client.models.predictions.create(
+ model_name="model_name",
+ model_owner="model_owner",
+ input={},
+ )
+ assert_matches_type(Prediction, prediction, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_method_create_with_all_params(self, client: ReplicateClient) -> None:
+ prediction = client.models.predictions.create(
+ model_name="model_name",
+ model_owner="model_owner",
+ input={},
+ stream=True,
+ webhook="webhook",
+ webhook_events_filter=["start"],
+ prefer="wait=5",
+ )
+ assert_matches_type(Prediction, prediction, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_raw_response_create(self, client: ReplicateClient) -> None:
+ response = client.models.predictions.with_raw_response.create(
+ model_name="model_name",
+ model_owner="model_owner",
+ input={},
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ prediction = response.parse()
+ assert_matches_type(Prediction, prediction, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_streaming_response_create(self, client: ReplicateClient) -> None:
+ with client.models.predictions.with_streaming_response.create(
+ model_name="model_name",
+ model_owner="model_owner",
+ input={},
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ prediction = response.parse()
+ assert_matches_type(Prediction, prediction, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_path_params_create(self, client: ReplicateClient) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"):
+ client.models.predictions.with_raw_response.create(
+ model_name="model_name",
+ model_owner="",
+ input={},
+ )
+
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"):
+ client.models.predictions.with_raw_response.create(
+ model_name="",
+ model_owner="model_owner",
+ input={},
+ )
+
+
+class TestAsyncPredictions:
+ parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_method_create(self, async_client: AsyncReplicateClient) -> None:
+ prediction = await async_client.models.predictions.create(
+ model_name="model_name",
+ model_owner="model_owner",
+ input={},
+ )
+ assert_matches_type(Prediction, prediction, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_method_create_with_all_params(self, async_client: AsyncReplicateClient) -> None:
+ prediction = await async_client.models.predictions.create(
+ model_name="model_name",
+ model_owner="model_owner",
+ input={},
+ stream=True,
+ webhook="webhook",
+ webhook_events_filter=["start"],
+ prefer="wait=5",
+ )
+ assert_matches_type(Prediction, prediction, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_raw_response_create(self, async_client: AsyncReplicateClient) -> None:
+ response = await async_client.models.predictions.with_raw_response.create(
+ model_name="model_name",
+ model_owner="model_owner",
+ input={},
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ prediction = await response.parse()
+ assert_matches_type(Prediction, prediction, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_streaming_response_create(self, async_client: AsyncReplicateClient) -> None:
+ async with async_client.models.predictions.with_streaming_response.create(
+ model_name="model_name",
+ model_owner="model_owner",
+ input={},
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ prediction = await response.parse()
+ assert_matches_type(Prediction, prediction, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_path_params_create(self, async_client: AsyncReplicateClient) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"):
+ await async_client.models.predictions.with_raw_response.create(
+ model_name="model_name",
+ model_owner="",
+ input={},
+ )
+
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"):
+ await async_client.models.predictions.with_raw_response.create(
+ model_name="",
+ model_owner="model_owner",
+ input={},
+ )
diff --git a/tests/api_resources/models/test_versions.py b/tests/api_resources/models/test_versions.py
index d1fb7a8..af132a2 100644
--- a/tests/api_resources/models/test_versions.py
+++ b/tests/api_resources/models/test_versions.py
@@ -8,8 +8,6 @@
import pytest
from replicate import ReplicateClient, AsyncReplicateClient
-from tests.utils import assert_matches_type
-from replicate.types.models import VersionCreateTrainingResponse
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
@@ -133,96 +131,6 @@ def test_path_params_delete(self, client: ReplicateClient) -> None:
model_name="model_name",
)
- @pytest.mark.skip()
- @parametrize
- def test_method_create_training(self, client: ReplicateClient) -> None:
- version = client.models.versions.create_training(
- version_id="version_id",
- model_owner="model_owner",
- model_name="model_name",
- destination="destination",
- input={},
- )
- assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- def test_method_create_training_with_all_params(self, client: ReplicateClient) -> None:
- version = client.models.versions.create_training(
- version_id="version_id",
- model_owner="model_owner",
- model_name="model_name",
- destination="destination",
- input={},
- webhook="webhook",
- webhook_events_filter=["start"],
- )
- assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- def test_raw_response_create_training(self, client: ReplicateClient) -> None:
- response = client.models.versions.with_raw_response.create_training(
- version_id="version_id",
- model_owner="model_owner",
- model_name="model_name",
- destination="destination",
- input={},
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- version = response.parse()
- assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- def test_streaming_response_create_training(self, client: ReplicateClient) -> None:
- with client.models.versions.with_streaming_response.create_training(
- version_id="version_id",
- model_owner="model_owner",
- model_name="model_name",
- destination="destination",
- input={},
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- version = response.parse()
- assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip()
- @parametrize
- def test_path_params_create_training(self, client: ReplicateClient) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"):
- client.models.versions.with_raw_response.create_training(
- version_id="version_id",
- model_owner="",
- model_name="model_name",
- destination="destination",
- input={},
- )
-
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"):
- client.models.versions.with_raw_response.create_training(
- version_id="version_id",
- model_owner="model_owner",
- model_name="",
- destination="destination",
- input={},
- )
-
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `version_id` but received ''"):
- client.models.versions.with_raw_response.create_training(
- version_id="",
- model_owner="model_owner",
- model_name="model_name",
- destination="destination",
- input={},
- )
-
@pytest.mark.skip()
@parametrize
def test_method_get(self, client: ReplicateClient) -> None:
@@ -407,96 +315,6 @@ async def test_path_params_delete(self, async_client: AsyncReplicateClient) -> N
model_name="model_name",
)
- @pytest.mark.skip()
- @parametrize
- async def test_method_create_training(self, async_client: AsyncReplicateClient) -> None:
- version = await async_client.models.versions.create_training(
- version_id="version_id",
- model_owner="model_owner",
- model_name="model_name",
- destination="destination",
- input={},
- )
- assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- async def test_method_create_training_with_all_params(self, async_client: AsyncReplicateClient) -> None:
- version = await async_client.models.versions.create_training(
- version_id="version_id",
- model_owner="model_owner",
- model_name="model_name",
- destination="destination",
- input={},
- webhook="webhook",
- webhook_events_filter=["start"],
- )
- assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- async def test_raw_response_create_training(self, async_client: AsyncReplicateClient) -> None:
- response = await async_client.models.versions.with_raw_response.create_training(
- version_id="version_id",
- model_owner="model_owner",
- model_name="model_name",
- destination="destination",
- input={},
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- version = await response.parse()
- assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- async def test_streaming_response_create_training(self, async_client: AsyncReplicateClient) -> None:
- async with async_client.models.versions.with_streaming_response.create_training(
- version_id="version_id",
- model_owner="model_owner",
- model_name="model_name",
- destination="destination",
- input={},
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- version = await response.parse()
- assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip()
- @parametrize
- async def test_path_params_create_training(self, async_client: AsyncReplicateClient) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"):
- await async_client.models.versions.with_raw_response.create_training(
- version_id="version_id",
- model_owner="",
- model_name="model_name",
- destination="destination",
- input={},
- )
-
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"):
- await async_client.models.versions.with_raw_response.create_training(
- version_id="version_id",
- model_owner="model_owner",
- model_name="",
- destination="destination",
- input={},
- )
-
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `version_id` but received ''"):
- await async_client.models.versions.with_raw_response.create_training(
- version_id="",
- model_owner="model_owner",
- model_name="model_name",
- destination="destination",
- input={},
- )
-
@pytest.mark.skip()
@parametrize
async def test_method_get(self, async_client: AsyncReplicateClient) -> None:
diff --git a/tests/api_resources/test_collections.py b/tests/api_resources/test_collections.py
index 820e42e..d231a8d 100644
--- a/tests/api_resources/test_collections.py
+++ b/tests/api_resources/test_collections.py
@@ -43,6 +43,48 @@ def test_streaming_response_list(self, client: ReplicateClient) -> None:
assert cast(Any, response.is_closed) is True
+ @pytest.mark.skip()
+ @parametrize
+ def test_method_get(self, client: ReplicateClient) -> None:
+ collection = client.collections.get(
+ "collection_slug",
+ )
+ assert collection is None
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_raw_response_get(self, client: ReplicateClient) -> None:
+ response = client.collections.with_raw_response.get(
+ "collection_slug",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ collection = response.parse()
+ assert collection is None
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_streaming_response_get(self, client: ReplicateClient) -> None:
+ with client.collections.with_streaming_response.get(
+ "collection_slug",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ collection = response.parse()
+ assert collection is None
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_path_params_get(self, client: ReplicateClient) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `collection_slug` but received ''"):
+ client.collections.with_raw_response.get(
+ "",
+ )
+
class TestAsyncCollections:
parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
@@ -74,3 +116,45 @@ async def test_streaming_response_list(self, async_client: AsyncReplicateClient)
assert collection is None
assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_method_get(self, async_client: AsyncReplicateClient) -> None:
+ collection = await async_client.collections.get(
+ "collection_slug",
+ )
+ assert collection is None
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None:
+ response = await async_client.collections.with_raw_response.get(
+ "collection_slug",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ collection = await response.parse()
+ assert collection is None
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None:
+ async with async_client.collections.with_streaming_response.get(
+ "collection_slug",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ collection = await response.parse()
+ assert collection is None
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `collection_slug` but received ''"):
+ await async_client.collections.with_raw_response.get(
+ "",
+ )
diff --git a/tests/api_resources/test_deployments.py b/tests/api_resources/test_deployments.py
index 3a01dcb..6d6360a 100644
--- a/tests/api_resources/test_deployments.py
+++ b/tests/api_resources/test_deployments.py
@@ -269,34 +269,6 @@ def test_path_params_get(self, client: ReplicateClient) -> None:
deployment_owner="deployment_owner",
)
- @pytest.mark.skip()
- @parametrize
- def test_method_list_em_all(self, client: ReplicateClient) -> None:
- deployment = client.deployments.list_em_all()
- assert deployment is None
-
- @pytest.mark.skip()
- @parametrize
- def test_raw_response_list_em_all(self, client: ReplicateClient) -> None:
- response = client.deployments.with_raw_response.list_em_all()
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- deployment = response.parse()
- assert deployment is None
-
- @pytest.mark.skip()
- @parametrize
- def test_streaming_response_list_em_all(self, client: ReplicateClient) -> None:
- with client.deployments.with_streaming_response.list_em_all() as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- deployment = response.parse()
- assert deployment is None
-
- assert cast(Any, response.is_closed) is True
-
class TestAsyncDeployments:
parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
@@ -546,31 +518,3 @@ async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None
deployment_name="",
deployment_owner="deployment_owner",
)
-
- @pytest.mark.skip()
- @parametrize
- async def test_method_list_em_all(self, async_client: AsyncReplicateClient) -> None:
- deployment = await async_client.deployments.list_em_all()
- assert deployment is None
-
- @pytest.mark.skip()
- @parametrize
- async def test_raw_response_list_em_all(self, async_client: AsyncReplicateClient) -> None:
- response = await async_client.deployments.with_raw_response.list_em_all()
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- deployment = await response.parse()
- assert deployment is None
-
- @pytest.mark.skip()
- @parametrize
- async def test_streaming_response_list_em_all(self, async_client: AsyncReplicateClient) -> None:
- async with async_client.deployments.with_streaming_response.list_em_all() as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- deployment = await response.parse()
- assert deployment is None
-
- assert cast(Any, response.is_closed) is True
diff --git a/tests/api_resources/test_hardware.py b/tests/api_resources/test_hardware.py
index 1dcb3a7..9aa535b 100644
--- a/tests/api_resources/test_hardware.py
+++ b/tests/api_resources/test_hardware.py
@@ -45,48 +45,6 @@ def test_streaming_response_list(self, client: ReplicateClient) -> None:
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
- @parametrize
- def test_method_retrieve_collections(self, client: ReplicateClient) -> None:
- hardware = client.hardware.retrieve_collections(
- "collection_slug",
- )
- assert hardware is None
-
- @pytest.mark.skip()
- @parametrize
- def test_raw_response_retrieve_collections(self, client: ReplicateClient) -> None:
- response = client.hardware.with_raw_response.retrieve_collections(
- "collection_slug",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- hardware = response.parse()
- assert hardware is None
-
- @pytest.mark.skip()
- @parametrize
- def test_streaming_response_retrieve_collections(self, client: ReplicateClient) -> None:
- with client.hardware.with_streaming_response.retrieve_collections(
- "collection_slug",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- hardware = response.parse()
- assert hardware is None
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip()
- @parametrize
- def test_path_params_retrieve_collections(self, client: ReplicateClient) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `collection_slug` but received ''"):
- client.hardware.with_raw_response.retrieve_collections(
- "",
- )
-
class TestAsyncHardware:
parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
@@ -118,45 +76,3 @@ async def test_streaming_response_list(self, async_client: AsyncReplicateClient)
assert_matches_type(HardwareListResponse, hardware, path=["response"])
assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip()
- @parametrize
- async def test_method_retrieve_collections(self, async_client: AsyncReplicateClient) -> None:
- hardware = await async_client.hardware.retrieve_collections(
- "collection_slug",
- )
- assert hardware is None
-
- @pytest.mark.skip()
- @parametrize
- async def test_raw_response_retrieve_collections(self, async_client: AsyncReplicateClient) -> None:
- response = await async_client.hardware.with_raw_response.retrieve_collections(
- "collection_slug",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- hardware = await response.parse()
- assert hardware is None
-
- @pytest.mark.skip()
- @parametrize
- async def test_streaming_response_retrieve_collections(self, async_client: AsyncReplicateClient) -> None:
- async with async_client.hardware.with_streaming_response.retrieve_collections(
- "collection_slug",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- hardware = await response.parse()
- assert hardware is None
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip()
- @parametrize
- async def test_path_params_retrieve_collections(self, async_client: AsyncReplicateClient) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `collection_slug` but received ''"):
- await async_client.hardware.with_raw_response.retrieve_collections(
- "",
- )
diff --git a/tests/api_resources/test_models.py b/tests/api_resources/test_models.py
index a2a9f52..1188e65 100644
--- a/tests/api_resources/test_models.py
+++ b/tests/api_resources/test_models.py
@@ -9,7 +9,7 @@
from replicate import ReplicateClient, AsyncReplicateClient
from tests.utils import assert_matches_type
-from replicate.types import Prediction, ModelListResponse
+from replicate.types import ModelListResponse
from replicate.pagination import SyncCursorURLPage, AsyncCursorURLPage
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
@@ -157,77 +157,6 @@ def test_path_params_delete(self, client: ReplicateClient) -> None:
model_owner="model_owner",
)
- @pytest.mark.skip()
- @parametrize
- def test_method_create_prediction(self, client: ReplicateClient) -> None:
- model = client.models.create_prediction(
- model_name="model_name",
- model_owner="model_owner",
- input={},
- )
- assert_matches_type(Prediction, model, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- def test_method_create_prediction_with_all_params(self, client: ReplicateClient) -> None:
- model = client.models.create_prediction(
- model_name="model_name",
- model_owner="model_owner",
- input={},
- stream=True,
- webhook="webhook",
- webhook_events_filter=["start"],
- prefer="wait=5",
- )
- assert_matches_type(Prediction, model, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- def test_raw_response_create_prediction(self, client: ReplicateClient) -> None:
- response = client.models.with_raw_response.create_prediction(
- model_name="model_name",
- model_owner="model_owner",
- input={},
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- model = response.parse()
- assert_matches_type(Prediction, model, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- def test_streaming_response_create_prediction(self, client: ReplicateClient) -> None:
- with client.models.with_streaming_response.create_prediction(
- model_name="model_name",
- model_owner="model_owner",
- input={},
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- model = response.parse()
- assert_matches_type(Prediction, model, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip()
- @parametrize
- def test_path_params_create_prediction(self, client: ReplicateClient) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"):
- client.models.with_raw_response.create_prediction(
- model_name="model_name",
- model_owner="",
- input={},
- )
-
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"):
- client.models.with_raw_response.create_prediction(
- model_name="",
- model_owner="model_owner",
- input={},
- )
-
@pytest.mark.skip()
@parametrize
def test_method_get(self, client: ReplicateClient) -> None:
@@ -423,77 +352,6 @@ async def test_path_params_delete(self, async_client: AsyncReplicateClient) -> N
model_owner="model_owner",
)
- @pytest.mark.skip()
- @parametrize
- async def test_method_create_prediction(self, async_client: AsyncReplicateClient) -> None:
- model = await async_client.models.create_prediction(
- model_name="model_name",
- model_owner="model_owner",
- input={},
- )
- assert_matches_type(Prediction, model, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- async def test_method_create_prediction_with_all_params(self, async_client: AsyncReplicateClient) -> None:
- model = await async_client.models.create_prediction(
- model_name="model_name",
- model_owner="model_owner",
- input={},
- stream=True,
- webhook="webhook",
- webhook_events_filter=["start"],
- prefer="wait=5",
- )
- assert_matches_type(Prediction, model, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- async def test_raw_response_create_prediction(self, async_client: AsyncReplicateClient) -> None:
- response = await async_client.models.with_raw_response.create_prediction(
- model_name="model_name",
- model_owner="model_owner",
- input={},
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- model = await response.parse()
- assert_matches_type(Prediction, model, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- async def test_streaming_response_create_prediction(self, async_client: AsyncReplicateClient) -> None:
- async with async_client.models.with_streaming_response.create_prediction(
- model_name="model_name",
- model_owner="model_owner",
- input={},
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- model = await response.parse()
- assert_matches_type(Prediction, model, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip()
- @parametrize
- async def test_path_params_create_prediction(self, async_client: AsyncReplicateClient) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"):
- await async_client.models.with_raw_response.create_prediction(
- model_name="model_name",
- model_owner="",
- input={},
- )
-
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"):
- await async_client.models.with_raw_response.create_prediction(
- model_name="",
- model_owner="model_owner",
- input={},
- )
-
@pytest.mark.skip()
@parametrize
async def test_method_get(self, async_client: AsyncReplicateClient) -> None:
diff --git a/tests/api_resources/test_trainings.py b/tests/api_resources/test_trainings.py
index f2dadb1..e1b0572 100644
--- a/tests/api_resources/test_trainings.py
+++ b/tests/api_resources/test_trainings.py
@@ -9,7 +9,12 @@
from replicate import ReplicateClient, AsyncReplicateClient
from tests.utils import assert_matches_type
-from replicate.types import TrainingGetResponse, TrainingListResponse, TrainingCancelResponse
+from replicate.types import (
+ TrainingGetResponse,
+ TrainingListResponse,
+ TrainingCancelResponse,
+ TrainingCreateResponse,
+)
from replicate.pagination import SyncCursorURLPage, AsyncCursorURLPage
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
@@ -18,6 +23,96 @@
class TestTrainings:
parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])
+ @pytest.mark.skip()
+ @parametrize
+ def test_method_create(self, client: ReplicateClient) -> None:
+ training = client.trainings.create(
+ version_id="version_id",
+ model_owner="model_owner",
+ model_name="model_name",
+ destination="destination",
+ input={},
+ )
+ assert_matches_type(TrainingCreateResponse, training, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_method_create_with_all_params(self, client: ReplicateClient) -> None:
+ training = client.trainings.create(
+ version_id="version_id",
+ model_owner="model_owner",
+ model_name="model_name",
+ destination="destination",
+ input={},
+ webhook="webhook",
+ webhook_events_filter=["start"],
+ )
+ assert_matches_type(TrainingCreateResponse, training, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_raw_response_create(self, client: ReplicateClient) -> None:
+ response = client.trainings.with_raw_response.create(
+ version_id="version_id",
+ model_owner="model_owner",
+ model_name="model_name",
+ destination="destination",
+ input={},
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ training = response.parse()
+ assert_matches_type(TrainingCreateResponse, training, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_streaming_response_create(self, client: ReplicateClient) -> None:
+ with client.trainings.with_streaming_response.create(
+ version_id="version_id",
+ model_owner="model_owner",
+ model_name="model_name",
+ destination="destination",
+ input={},
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ training = response.parse()
+ assert_matches_type(TrainingCreateResponse, training, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_path_params_create(self, client: ReplicateClient) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"):
+ client.trainings.with_raw_response.create(
+ version_id="version_id",
+ model_owner="",
+ model_name="model_name",
+ destination="destination",
+ input={},
+ )
+
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"):
+ client.trainings.with_raw_response.create(
+ version_id="version_id",
+ model_owner="model_owner",
+ model_name="",
+ destination="destination",
+ input={},
+ )
+
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `version_id` but received ''"):
+ client.trainings.with_raw_response.create(
+ version_id="",
+ model_owner="model_owner",
+ model_name="model_name",
+ destination="destination",
+ input={},
+ )
+
@pytest.mark.skip()
@parametrize
def test_method_list(self, client: ReplicateClient) -> None:
@@ -134,6 +229,96 @@ def test_path_params_get(self, client: ReplicateClient) -> None:
class TestAsyncTrainings:
parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
+ @pytest.mark.skip()
+ @parametrize
+ async def test_method_create(self, async_client: AsyncReplicateClient) -> None:
+ training = await async_client.trainings.create(
+ version_id="version_id",
+ model_owner="model_owner",
+ model_name="model_name",
+ destination="destination",
+ input={},
+ )
+ assert_matches_type(TrainingCreateResponse, training, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_method_create_with_all_params(self, async_client: AsyncReplicateClient) -> None:
+ training = await async_client.trainings.create(
+ version_id="version_id",
+ model_owner="model_owner",
+ model_name="model_name",
+ destination="destination",
+ input={},
+ webhook="webhook",
+ webhook_events_filter=["start"],
+ )
+ assert_matches_type(TrainingCreateResponse, training, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_raw_response_create(self, async_client: AsyncReplicateClient) -> None:
+ response = await async_client.trainings.with_raw_response.create(
+ version_id="version_id",
+ model_owner="model_owner",
+ model_name="model_name",
+ destination="destination",
+ input={},
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ training = await response.parse()
+ assert_matches_type(TrainingCreateResponse, training, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_streaming_response_create(self, async_client: AsyncReplicateClient) -> None:
+ async with async_client.trainings.with_streaming_response.create(
+ version_id="version_id",
+ model_owner="model_owner",
+ model_name="model_name",
+ destination="destination",
+ input={},
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ training = await response.parse()
+ assert_matches_type(TrainingCreateResponse, training, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_path_params_create(self, async_client: AsyncReplicateClient) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"):
+ await async_client.trainings.with_raw_response.create(
+ version_id="version_id",
+ model_owner="",
+ model_name="model_name",
+ destination="destination",
+ input={},
+ )
+
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"):
+ await async_client.trainings.with_raw_response.create(
+ version_id="version_id",
+ model_owner="model_owner",
+ model_name="",
+ destination="destination",
+ input={},
+ )
+
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `version_id` but received ''"):
+ await async_client.trainings.with_raw_response.create(
+ version_id="",
+ model_owner="model_owner",
+ model_name="model_name",
+ destination="destination",
+ input={},
+ )
+
@pytest.mark.skip()
@parametrize
async def test_method_list(self, async_client: AsyncReplicateClient) -> None:
diff --git a/tests/api_resources/webhooks/default/__init__.py b/tests/api_resources/webhooks/default/__init__.py
new file mode 100644
index 0000000..fd8019a
--- /dev/null
+++ b/tests/api_resources/webhooks/default/__init__.py
@@ -0,0 +1 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
diff --git a/tests/api_resources/webhooks/default/test_secret.py b/tests/api_resources/webhooks/default/test_secret.py
new file mode 100644
index 0000000..c95ab8a
--- /dev/null
+++ b/tests/api_resources/webhooks/default/test_secret.py
@@ -0,0 +1,78 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+import os
+from typing import Any, cast
+
+import pytest
+
+from replicate import ReplicateClient, AsyncReplicateClient
+from tests.utils import assert_matches_type
+from replicate.types.webhooks.default import SecretGetResponse
+
+base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
+
+
+class TestSecret:
+ parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_method_get(self, client: ReplicateClient) -> None:
+ secret = client.webhooks.default.secret.get()
+ assert_matches_type(SecretGetResponse, secret, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_raw_response_get(self, client: ReplicateClient) -> None:
+ response = client.webhooks.default.secret.with_raw_response.get()
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ secret = response.parse()
+ assert_matches_type(SecretGetResponse, secret, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_streaming_response_get(self, client: ReplicateClient) -> None:
+ with client.webhooks.default.secret.with_streaming_response.get() as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ secret = response.parse()
+ assert_matches_type(SecretGetResponse, secret, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+
+class TestAsyncSecret:
+ parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_method_get(self, async_client: AsyncReplicateClient) -> None:
+ secret = await async_client.webhooks.default.secret.get()
+ assert_matches_type(SecretGetResponse, secret, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None:
+ response = await async_client.webhooks.default.secret.with_raw_response.get()
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ secret = await response.parse()
+ assert_matches_type(SecretGetResponse, secret, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None:
+ async with async_client.webhooks.default.secret.with_streaming_response.get() as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ secret = await response.parse()
+ assert_matches_type(SecretGetResponse, secret, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
diff --git a/tests/api_resources/webhooks/test_default.py b/tests/api_resources/webhooks/test_default.py
deleted file mode 100644
index 7b91c21..0000000
--- a/tests/api_resources/webhooks/test_default.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
-
-from __future__ import annotations
-
-import os
-from typing import Any, cast
-
-import pytest
-
-from replicate import ReplicateClient, AsyncReplicateClient
-from tests.utils import assert_matches_type
-from replicate.types.webhooks import DefaultRetrieveSecretResponse
-
-base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
-
-
-class TestDefault:
- parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])
-
- @pytest.mark.skip()
- @parametrize
- def test_method_retrieve_secret(self, client: ReplicateClient) -> None:
- default = client.webhooks.default.retrieve_secret()
- assert_matches_type(DefaultRetrieveSecretResponse, default, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- def test_raw_response_retrieve_secret(self, client: ReplicateClient) -> None:
- response = client.webhooks.default.with_raw_response.retrieve_secret()
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- default = response.parse()
- assert_matches_type(DefaultRetrieveSecretResponse, default, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- def test_streaming_response_retrieve_secret(self, client: ReplicateClient) -> None:
- with client.webhooks.default.with_streaming_response.retrieve_secret() as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- default = response.parse()
- assert_matches_type(DefaultRetrieveSecretResponse, default, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
-
-class TestAsyncDefault:
- parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
-
- @pytest.mark.skip()
- @parametrize
- async def test_method_retrieve_secret(self, async_client: AsyncReplicateClient) -> None:
- default = await async_client.webhooks.default.retrieve_secret()
- assert_matches_type(DefaultRetrieveSecretResponse, default, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- async def test_raw_response_retrieve_secret(self, async_client: AsyncReplicateClient) -> None:
- response = await async_client.webhooks.default.with_raw_response.retrieve_secret()
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- default = await response.parse()
- assert_matches_type(DefaultRetrieveSecretResponse, default, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- async def test_streaming_response_retrieve_secret(self, async_client: AsyncReplicateClient) -> None:
- async with async_client.webhooks.default.with_streaming_response.retrieve_secret() as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- default = await response.parse()
- assert_matches_type(DefaultRetrieveSecretResponse, default, path=["response"])
-
- assert cast(Any, response.is_closed) is True
From 1c25bf4568495fa4b980b06978d2ea52319f7ffc Mon Sep 17 00:00:00 2001
From: "stainless-app[bot]"
<142633134+stainless-app[bot]@users.noreply.github.com>
Date: Thu, 24 Apr 2025 00:14:40 +0000
Subject: [PATCH 3/3] release: 0.1.0-alpha.5
---
.release-please-manifest.json | 2 +-
CHANGELOG.md | 9 +++++++++
pyproject.toml | 2 +-
src/replicate/_version.py | 2 +-
4 files changed, 12 insertions(+), 3 deletions(-)
diff --git a/.release-please-manifest.json b/.release-please-manifest.json
index b56c3d0..e8285b7 100644
--- a/.release-please-manifest.json
+++ b/.release-please-manifest.json
@@ -1,3 +1,3 @@
{
- ".": "0.1.0-alpha.4"
+ ".": "0.1.0-alpha.5"
}
\ No newline at end of file
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 5be0765..3b53281 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,14 @@
# Changelog
+## 0.1.0-alpha.5 (2025-04-24)
+
+Full Changelog: [v0.1.0-alpha.4...v0.1.0-alpha.5](https://github.com/replicate/replicate-python-stainless/compare/v0.1.0-alpha.4...v0.1.0-alpha.5)
+
+### Features
+
+* add missing resources ([8bbddc7](https://github.com/replicate/replicate-python-stainless/commit/8bbddc7a788f4488311b8ed408d4b020db8e006b))
+* enable `openapi.code_samples` ([85810f5](https://github.com/replicate/replicate-python-stainless/commit/85810f5f4c0caf680a90fca80f1bfcd639e76894))
+
## 0.1.0-alpha.4 (2025-04-23)
Full Changelog: [v0.1.0-alpha.3...v0.1.0-alpha.4](https://github.com/replicate/replicate-python-stainless/compare/v0.1.0-alpha.3...v0.1.0-alpha.4)
diff --git a/pyproject.toml b/pyproject.toml
index 8854bbc..af88d90 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "replicate-stainless"
-version = "0.1.0-alpha.4"
+version = "0.1.0-alpha.5"
description = "The official Python library for the replicate-client API"
dynamic = ["readme"]
license = "Apache-2.0"
diff --git a/src/replicate/_version.py b/src/replicate/_version.py
index e966f49..8343825 100644
--- a/src/replicate/_version.py
+++ b/src/replicate/_version.py
@@ -1,4 +1,4 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
__title__ = "replicate"
-__version__ = "0.1.0-alpha.4" # x-release-please-version
+__version__ = "0.1.0-alpha.5" # x-release-please-version