diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index 4accf1c..3bd86d3 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -28,4 +28,4 @@ jobs: run: | bash ./bin/publish-pypi env: - PYPI_TOKEN: ${{ secrets.REPLICATE_CLIENT_PYPI_TOKEN || secrets.PYPI_TOKEN }} + PYPI_TOKEN: ${{ secrets.REPLICATE_PYPI_TOKEN || secrets.PYPI_TOKEN }} diff --git a/.github/workflows/release-doctor.yml b/.github/workflows/release-doctor.yml index b7ee733..8fad131 100644 --- a/.github/workflows/release-doctor.yml +++ b/.github/workflows/release-doctor.yml @@ -18,4 +18,4 @@ jobs: run: | bash ./bin/check-release-environment env: - PYPI_TOKEN: ${{ secrets.REPLICATE_CLIENT_PYPI_TOKEN || secrets.PYPI_TOKEN }} + PYPI_TOKEN: ${{ secrets.REPLICATE_PYPI_TOKEN || secrets.PYPI_TOKEN }} diff --git a/.release-please-manifest.json b/.release-please-manifest.json index c373724..46b9b6b 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "0.1.0-alpha.8" + ".": "0.1.0-alpha.9" } \ No newline at end of file diff --git a/.stats.yml b/.stats.yml index 7e839de..66ca18c 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ configured_endpoints: 30 openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-0d7d82bff8a18b03e0cd1cbf8609c3026bb07db851bc6f9166032045a9925eea.yml openapi_spec_hash: 8ce211dfa6fece24b1413e91ba55210a -config_hash: 2e6a171ce57a4a6a8e8dcd3dd893d8cc +config_hash: c784c102324b1d027c6ce40e17fe9590 diff --git a/CHANGELOG.md b/CHANGELOG.md index 2fced87..2bd87fe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,19 @@ # Changelog +## 0.1.0-alpha.9 (2025-05-06) + +Full Changelog: [v0.1.0-alpha.8...v0.1.0-alpha.9](https://github.com/replicate/replicate-python-stainless/compare/v0.1.0-alpha.8...v0.1.0-alpha.9) + +### Bug Fixes + +* change organization.name to replicate ([04b0797](https://github.com/replicate/replicate-python-stainless/commit/04b079729cd431cad9e992c5c0a0d82ad838f5ef)) + + +### Chores + +* use lazy imports for module level client ([14f6cfc](https://github.com/replicate/replicate-python-stainless/commit/14f6cfcad3045d1bde023d1896b369057d3f6b77)) +* use lazy imports for resources ([b2a0246](https://github.com/replicate/replicate-python-stainless/commit/b2a024612fc8b5a1bc7a15038cd33ab29e728b58)) + ## 0.1.0-alpha.8 (2025-04-30) Full Changelog: [v0.1.0-alpha.7...v0.1.0-alpha.8](https://github.com/replicate/replicate-python-stainless/compare/v0.1.0-alpha.7...v0.1.0-alpha.8) diff --git a/LICENSE b/LICENSE index 901b4db..633868c 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright 2025 Replicate Client + Copyright 2025 Replicate Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/README.md b/README.md index ba3cfc7..a90d6c4 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ -# Replicate Client Python API library +# Replicate Python API library [![PyPI version](https://img.shields.io/pypi/v/replicate-stainless.svg)](https://pypi.org/project/replicate-stainless/) -The Replicate Client Python library provides convenient access to the Replicate Client REST API from any Python 3.8+ +The Replicate Python library provides convenient access to the Replicate REST API from any Python 3.8+ application. The library includes type definitions for all request params and response fields, and offers both synchronous and asynchronous clients powered by [httpx](https://github.com/encode/httpx). @@ -25,9 +25,9 @@ The full API of this library can be found in [api.md](api.md). ```python import os -from replicate import ReplicateClient +from replicate import Replicate -client = ReplicateClient( +client = Replicate( bearer_token=os.environ.get("REPLICATE_API_TOKEN"), # This is the default and can be omitted ) @@ -42,14 +42,14 @@ so that your Bearer Token is not stored in source control. ## Async usage -Simply import `AsyncReplicateClient` instead of `ReplicateClient` and use `await` with each API call: +Simply import `AsyncReplicate` instead of `Replicate` and use `await` with each API call: ```python import os import asyncio -from replicate import AsyncReplicateClient +from replicate import AsyncReplicate -client = AsyncReplicateClient( +client = AsyncReplicate( bearer_token=os.environ.get("REPLICATE_API_TOKEN"), # This is the default and can be omitted ) @@ -75,14 +75,14 @@ Typed requests and responses provide autocomplete and documentation within your ## Pagination -List methods in the Replicate Client API are paginated. +List methods in the Replicate API are paginated. This library provides auto-paginating iterators with each list response, so you do not have to request successive pages manually: ```python -from replicate import ReplicateClient +from replicate import Replicate -client = ReplicateClient() +client = Replicate() all_predictions = [] # Automatically fetches more pages as needed. @@ -96,9 +96,9 @@ Or, asynchronously: ```python import asyncio -from replicate import AsyncReplicateClient +from replicate import AsyncReplicate -client = AsyncReplicateClient() +client = AsyncReplicate() async def main() -> None: @@ -147,9 +147,9 @@ All errors inherit from `replicate.APIError`. ```python import replicate -from replicate import ReplicateClient +from replicate import Replicate -client = ReplicateClient() +client = Replicate() try: client.account.get() @@ -186,10 +186,10 @@ Connection errors (for example, due to a network connectivity problem), 408 Requ You can use the `max_retries` option to configure or disable retry settings: ```python -from replicate import ReplicateClient +from replicate import Replicate # Configure the default for all requests: -client = ReplicateClient( +client = Replicate( # default is 2 max_retries=0, ) @@ -204,16 +204,16 @@ By default requests time out after 1 minute. You can configure this with a `time which accepts a float or an [`httpx.Timeout`](https://www.python-httpx.org/advanced/#fine-tuning-the-configuration) object: ```python -from replicate import ReplicateClient +from replicate import Replicate # Configure the default for all requests: -client = ReplicateClient( +client = Replicate( # 20 seconds (default is 1 minute) timeout=20.0, ) # More granular control: -client = ReplicateClient( +client = Replicate( timeout=httpx.Timeout(60.0, read=5.0, write=10.0, connect=2.0), ) @@ -231,10 +231,10 @@ Note that requests that time out are [retried twice by default](#retries). We use the standard library [`logging`](https://docs.python.org/3/library/logging.html) module. -You can enable logging by setting the environment variable `REPLICATE_CLIENT_LOG` to `info`. +You can enable logging by setting the environment variable `REPLICATE_LOG` to `info`. ```shell -$ export REPLICATE_CLIENT_LOG=info +$ export REPLICATE_LOG=info ``` Or to `debug` for more verbose logging. @@ -256,9 +256,9 @@ if response.my_field is None: The "raw" Response object can be accessed by prefixing `.with_raw_response.` to any HTTP method call, e.g., ```py -from replicate import ReplicateClient +from replicate import Replicate -client = ReplicateClient() +client = Replicate() response = client.account.with_raw_response.get() print(response.headers.get('X-My-Header')) @@ -330,10 +330,10 @@ You can directly override the [httpx client](https://www.python-httpx.org/api/#c ```python import httpx -from replicate import ReplicateClient, DefaultHttpxClient +from replicate import Replicate, DefaultHttpxClient -client = ReplicateClient( - # Or use the `REPLICATE_CLIENT_BASE_URL` env var +client = Replicate( + # Or use the `REPLICATE_BASE_URL` env var base_url="http://my.test.server.example.com:8083", http_client=DefaultHttpxClient( proxy="http://my.test.proxy.example.com", @@ -353,9 +353,9 @@ client.with_options(http_client=DefaultHttpxClient(...)) By default the library closes underlying HTTP connections whenever the client is [garbage collected](https://docs.python.org/3/reference/datamodel.html#object.__del__). You can manually close the client using the `.close()` method if desired, or with a context manager that closes when exiting. ```py -from replicate import ReplicateClient +from replicate import Replicate -with ReplicateClient() as client: +with Replicate() as client: # make requests here ... diff --git a/SECURITY.md b/SECURITY.md index 9b9fbb6..a7c5e80 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -16,9 +16,9 @@ before making any information public. ## Reporting Non-SDK Related Security Issues If you encounter security issues that are not directly related to SDKs but pertain to the services -or products provided by Replicate Client please follow the respective company's security reporting guidelines. +or products provided by Replicate please follow the respective company's security reporting guidelines. -### Replicate Client Terms and Policies +### Replicate Terms and Policies Please contact team@replicate.com for any questions or concerns regarding security of our services. diff --git a/bin/check-release-environment b/bin/check-release-environment index 8e755de..e6d331c 100644 --- a/bin/check-release-environment +++ b/bin/check-release-environment @@ -3,7 +3,7 @@ errors=() if [ -z "${PYPI_TOKEN}" ]; then - errors+=("The REPLICATE_CLIENT_PYPI_TOKEN secret has not been set. Please set it in either this repository's secrets or your organization secrets.") + errors+=("The REPLICATE_PYPI_TOKEN secret has not been set. Please set it in either this repository's secrets or your organization secrets.") fi lenErrors=${#errors[@]} diff --git a/pyproject.toml b/pyproject.toml index be725ab..b282b63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,11 +1,11 @@ [project] name = "replicate-stainless" -version = "0.1.0-alpha.8" -description = "The official Python library for the replicate-client API" +version = "0.1.0-alpha.9" +description = "The official Python library for the replicate API" dynamic = ["readme"] license = "Apache-2.0" authors = [ -{ name = "Replicate Client", email = "team@replicate.com" }, +{ name = "Replicate", email = "team@replicate.com" }, ] dependencies = [ "httpx>=0.23.0, <1", diff --git a/src/replicate/__init__.py b/src/replicate/__init__.py index 0815fe9..ad8bd3f 100644 --- a/src/replicate/__init__.py +++ b/src/replicate/__init__.py @@ -11,12 +11,12 @@ Client, Stream, Timeout, + Replicate, Transport, AsyncClient, AsyncStream, + AsyncReplicate, RequestOptions, - ReplicateClient, - AsyncReplicateClient, ) from ._models import BaseModel from ._version import __title__, __version__ @@ -28,12 +28,12 @@ NotFoundError, APIStatusError, RateLimitError, + ReplicateError, APITimeoutError, BadRequestError, APIConnectionError, AuthenticationError, InternalServerError, - ReplicateClientError, PermissionDeniedError, UnprocessableEntityError, APIResponseValidationError, @@ -51,7 +51,7 @@ "NotGiven", "NOT_GIVEN", "Omit", - "ReplicateClientError", + "ReplicateError", "APIError", "APIStatusError", "APITimeoutError", @@ -71,8 +71,8 @@ "AsyncClient", "Stream", "AsyncStream", - "ReplicateClient", - "AsyncReplicateClient", + "Replicate", + "AsyncReplicate", "file_from_path", "BaseModel", "DEFAULT_TIMEOUT", @@ -119,7 +119,7 @@ http_client: _httpx.Client | None = None -class _ModuleClient(ReplicateClient): +class _ModuleClient(Replicate): # Note: we have to use type: ignores here as overriding class members # with properties is technically unsafe but it is fine for our use case @@ -202,10 +202,10 @@ def _client(self, value: _httpx.Client) -> None: # type: ignore http_client = value -_client: ReplicateClient | None = None +_client: Replicate | None = None -def _load_client() -> ReplicateClient: # type: ignore[reportUnusedFunction] +def _load_client() -> Replicate: # type: ignore[reportUnusedFunction] global _client if _client is None: diff --git a/src/replicate/_client.py b/src/replicate/_client.py index 5d2a69c..88fc0d9 100644 --- a/src/replicate/_client.py +++ b/src/replicate/_client.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from typing import Any, Union, Mapping +from typing import TYPE_CHECKING, Any, Union, Mapping from typing_extensions import Self, override import httpx @@ -20,43 +20,40 @@ RequestOptions, ) from ._utils import is_given, get_async_library +from ._compat import cached_property from ._version import __version__ -from .resources import account, hardware, trainings, collections, predictions from ._streaming import Stream as Stream, AsyncStream as AsyncStream -from ._exceptions import APIStatusError, ReplicateClientError +from ._exceptions import APIStatusError, ReplicateError from ._base_client import ( DEFAULT_MAX_RETRIES, SyncAPIClient, AsyncAPIClient, ) -from .resources.models import models -from .resources.webhooks import webhooks -from .resources.deployments import deployments + +if TYPE_CHECKING: + from .resources import models, account, hardware, webhooks, trainings, collections, deployments, predictions + from .resources.account import AccountResource, AsyncAccountResource + from .resources.hardware import HardwareResource, AsyncHardwareResource + from .resources.trainings import TrainingsResource, AsyncTrainingsResource + from .resources.collections import CollectionsResource, AsyncCollectionsResource + from .resources.predictions import PredictionsResource, AsyncPredictionsResource + from .resources.models.models import ModelsResource, AsyncModelsResource + from .resources.webhooks.webhooks import WebhooksResource, AsyncWebhooksResource + from .resources.deployments.deployments import DeploymentsResource, AsyncDeploymentsResource __all__ = [ "Timeout", "Transport", "ProxiesTypes", "RequestOptions", - "ReplicateClient", - "AsyncReplicateClient", + "Replicate", + "AsyncReplicate", "Client", "AsyncClient", ] -class ReplicateClient(SyncAPIClient): - collections: collections.CollectionsResource - deployments: deployments.DeploymentsResource - hardware: hardware.HardwareResource - account: account.AccountResource - models: models.ModelsResource - predictions: predictions.PredictionsResource - trainings: trainings.TrainingsResource - webhooks: webhooks.WebhooksResource - with_raw_response: ReplicateClientWithRawResponse - with_streaming_response: ReplicateClientWithStreamedResponse - +class Replicate(SyncAPIClient): # client options bearer_token: str @@ -83,20 +80,20 @@ def __init__( # part of our public interface in the future. _strict_response_validation: bool = False, ) -> None: - """Construct a new synchronous ReplicateClient client instance. + """Construct a new synchronous Replicate client instance. This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided. """ if bearer_token is None: bearer_token = os.environ.get("REPLICATE_API_TOKEN") if bearer_token is None: - raise ReplicateClientError( + raise ReplicateError( "The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_API_TOKEN environment variable" ) self.bearer_token = bearer_token if base_url is None: - base_url = os.environ.get("REPLICATE_CLIENT_BASE_URL") + base_url = os.environ.get("REPLICATE_BASE_URL") if base_url is None: base_url = f"https://api.replicate.com/v1" @@ -111,16 +108,61 @@ def __init__( _strict_response_validation=_strict_response_validation, ) - self.collections = collections.CollectionsResource(self) - self.deployments = deployments.DeploymentsResource(self) - self.hardware = hardware.HardwareResource(self) - self.account = account.AccountResource(self) - self.models = models.ModelsResource(self) - self.predictions = predictions.PredictionsResource(self) - self.trainings = trainings.TrainingsResource(self) - self.webhooks = webhooks.WebhooksResource(self) - self.with_raw_response = ReplicateClientWithRawResponse(self) - self.with_streaming_response = ReplicateClientWithStreamedResponse(self) + @cached_property + def collections(self) -> CollectionsResource: + from .resources.collections import CollectionsResource + + return CollectionsResource(self) + + @cached_property + def deployments(self) -> DeploymentsResource: + from .resources.deployments import DeploymentsResource + + return DeploymentsResource(self) + + @cached_property + def hardware(self) -> HardwareResource: + from .resources.hardware import HardwareResource + + return HardwareResource(self) + + @cached_property + def account(self) -> AccountResource: + from .resources.account import AccountResource + + return AccountResource(self) + + @cached_property + def models(self) -> ModelsResource: + from .resources.models import ModelsResource + + return ModelsResource(self) + + @cached_property + def predictions(self) -> PredictionsResource: + from .resources.predictions import PredictionsResource + + return PredictionsResource(self) + + @cached_property + def trainings(self) -> TrainingsResource: + from .resources.trainings import TrainingsResource + + return TrainingsResource(self) + + @cached_property + def webhooks(self) -> WebhooksResource: + from .resources.webhooks import WebhooksResource + + return WebhooksResource(self) + + @cached_property + def with_raw_response(self) -> ReplicateWithRawResponse: + return ReplicateWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> ReplicateWithStreamedResponse: + return ReplicateWithStreamedResponse(self) @property @override @@ -227,18 +269,7 @@ def _make_status_error( return APIStatusError(err_msg, response=response, body=body) -class AsyncReplicateClient(AsyncAPIClient): - collections: collections.AsyncCollectionsResource - deployments: deployments.AsyncDeploymentsResource - hardware: hardware.AsyncHardwareResource - account: account.AsyncAccountResource - models: models.AsyncModelsResource - predictions: predictions.AsyncPredictionsResource - trainings: trainings.AsyncTrainingsResource - webhooks: webhooks.AsyncWebhooksResource - with_raw_response: AsyncReplicateClientWithRawResponse - with_streaming_response: AsyncReplicateClientWithStreamedResponse - +class AsyncReplicate(AsyncAPIClient): # client options bearer_token: str @@ -265,20 +296,20 @@ def __init__( # part of our public interface in the future. _strict_response_validation: bool = False, ) -> None: - """Construct a new async AsyncReplicateClient client instance. + """Construct a new async AsyncReplicate client instance. This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided. """ if bearer_token is None: bearer_token = os.environ.get("REPLICATE_API_TOKEN") if bearer_token is None: - raise ReplicateClientError( + raise ReplicateError( "The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_API_TOKEN environment variable" ) self.bearer_token = bearer_token if base_url is None: - base_url = os.environ.get("REPLICATE_CLIENT_BASE_URL") + base_url = os.environ.get("REPLICATE_BASE_URL") if base_url is None: base_url = f"https://api.replicate.com/v1" @@ -293,16 +324,61 @@ def __init__( _strict_response_validation=_strict_response_validation, ) - self.collections = collections.AsyncCollectionsResource(self) - self.deployments = deployments.AsyncDeploymentsResource(self) - self.hardware = hardware.AsyncHardwareResource(self) - self.account = account.AsyncAccountResource(self) - self.models = models.AsyncModelsResource(self) - self.predictions = predictions.AsyncPredictionsResource(self) - self.trainings = trainings.AsyncTrainingsResource(self) - self.webhooks = webhooks.AsyncWebhooksResource(self) - self.with_raw_response = AsyncReplicateClientWithRawResponse(self) - self.with_streaming_response = AsyncReplicateClientWithStreamedResponse(self) + @cached_property + def collections(self) -> AsyncCollectionsResource: + from .resources.collections import AsyncCollectionsResource + + return AsyncCollectionsResource(self) + + @cached_property + def deployments(self) -> AsyncDeploymentsResource: + from .resources.deployments import AsyncDeploymentsResource + + return AsyncDeploymentsResource(self) + + @cached_property + def hardware(self) -> AsyncHardwareResource: + from .resources.hardware import AsyncHardwareResource + + return AsyncHardwareResource(self) + + @cached_property + def account(self) -> AsyncAccountResource: + from .resources.account import AsyncAccountResource + + return AsyncAccountResource(self) + + @cached_property + def models(self) -> AsyncModelsResource: + from .resources.models import AsyncModelsResource + + return AsyncModelsResource(self) + + @cached_property + def predictions(self) -> AsyncPredictionsResource: + from .resources.predictions import AsyncPredictionsResource + + return AsyncPredictionsResource(self) + + @cached_property + def trainings(self) -> AsyncTrainingsResource: + from .resources.trainings import AsyncTrainingsResource + + return AsyncTrainingsResource(self) + + @cached_property + def webhooks(self) -> AsyncWebhooksResource: + from .resources.webhooks import AsyncWebhooksResource + + return AsyncWebhooksResource(self) + + @cached_property + def with_raw_response(self) -> AsyncReplicateWithRawResponse: + return AsyncReplicateWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncReplicateWithStreamedResponse: + return AsyncReplicateWithStreamedResponse(self) @property @override @@ -409,54 +485,226 @@ def _make_status_error( return APIStatusError(err_msg, response=response, body=body) -class ReplicateClientWithRawResponse: - def __init__(self, client: ReplicateClient) -> None: - self.collections = collections.CollectionsResourceWithRawResponse(client.collections) - self.deployments = deployments.DeploymentsResourceWithRawResponse(client.deployments) - self.hardware = hardware.HardwareResourceWithRawResponse(client.hardware) - self.account = account.AccountResourceWithRawResponse(client.account) - self.models = models.ModelsResourceWithRawResponse(client.models) - self.predictions = predictions.PredictionsResourceWithRawResponse(client.predictions) - self.trainings = trainings.TrainingsResourceWithRawResponse(client.trainings) - self.webhooks = webhooks.WebhooksResourceWithRawResponse(client.webhooks) - - -class AsyncReplicateClientWithRawResponse: - def __init__(self, client: AsyncReplicateClient) -> None: - self.collections = collections.AsyncCollectionsResourceWithRawResponse(client.collections) - self.deployments = deployments.AsyncDeploymentsResourceWithRawResponse(client.deployments) - self.hardware = hardware.AsyncHardwareResourceWithRawResponse(client.hardware) - self.account = account.AsyncAccountResourceWithRawResponse(client.account) - self.models = models.AsyncModelsResourceWithRawResponse(client.models) - self.predictions = predictions.AsyncPredictionsResourceWithRawResponse(client.predictions) - self.trainings = trainings.AsyncTrainingsResourceWithRawResponse(client.trainings) - self.webhooks = webhooks.AsyncWebhooksResourceWithRawResponse(client.webhooks) - - -class ReplicateClientWithStreamedResponse: - def __init__(self, client: ReplicateClient) -> None: - self.collections = collections.CollectionsResourceWithStreamingResponse(client.collections) - self.deployments = deployments.DeploymentsResourceWithStreamingResponse(client.deployments) - self.hardware = hardware.HardwareResourceWithStreamingResponse(client.hardware) - self.account = account.AccountResourceWithStreamingResponse(client.account) - self.models = models.ModelsResourceWithStreamingResponse(client.models) - self.predictions = predictions.PredictionsResourceWithStreamingResponse(client.predictions) - self.trainings = trainings.TrainingsResourceWithStreamingResponse(client.trainings) - self.webhooks = webhooks.WebhooksResourceWithStreamingResponse(client.webhooks) - - -class AsyncReplicateClientWithStreamedResponse: - def __init__(self, client: AsyncReplicateClient) -> None: - self.collections = collections.AsyncCollectionsResourceWithStreamingResponse(client.collections) - self.deployments = deployments.AsyncDeploymentsResourceWithStreamingResponse(client.deployments) - self.hardware = hardware.AsyncHardwareResourceWithStreamingResponse(client.hardware) - self.account = account.AsyncAccountResourceWithStreamingResponse(client.account) - self.models = models.AsyncModelsResourceWithStreamingResponse(client.models) - self.predictions = predictions.AsyncPredictionsResourceWithStreamingResponse(client.predictions) - self.trainings = trainings.AsyncTrainingsResourceWithStreamingResponse(client.trainings) - self.webhooks = webhooks.AsyncWebhooksResourceWithStreamingResponse(client.webhooks) - - -Client = ReplicateClient - -AsyncClient = AsyncReplicateClient +class ReplicateWithRawResponse: + _client: Replicate + + def __init__(self, client: Replicate) -> None: + self._client = client + + @cached_property + def collections(self) -> collections.CollectionsResourceWithRawResponse: + from .resources.collections import CollectionsResourceWithRawResponse + + return CollectionsResourceWithRawResponse(self._client.collections) + + @cached_property + def deployments(self) -> deployments.DeploymentsResourceWithRawResponse: + from .resources.deployments import DeploymentsResourceWithRawResponse + + return DeploymentsResourceWithRawResponse(self._client.deployments) + + @cached_property + def hardware(self) -> hardware.HardwareResourceWithRawResponse: + from .resources.hardware import HardwareResourceWithRawResponse + + return HardwareResourceWithRawResponse(self._client.hardware) + + @cached_property + def account(self) -> account.AccountResourceWithRawResponse: + from .resources.account import AccountResourceWithRawResponse + + return AccountResourceWithRawResponse(self._client.account) + + @cached_property + def models(self) -> models.ModelsResourceWithRawResponse: + from .resources.models import ModelsResourceWithRawResponse + + return ModelsResourceWithRawResponse(self._client.models) + + @cached_property + def predictions(self) -> predictions.PredictionsResourceWithRawResponse: + from .resources.predictions import PredictionsResourceWithRawResponse + + return PredictionsResourceWithRawResponse(self._client.predictions) + + @cached_property + def trainings(self) -> trainings.TrainingsResourceWithRawResponse: + from .resources.trainings import TrainingsResourceWithRawResponse + + return TrainingsResourceWithRawResponse(self._client.trainings) + + @cached_property + def webhooks(self) -> webhooks.WebhooksResourceWithRawResponse: + from .resources.webhooks import WebhooksResourceWithRawResponse + + return WebhooksResourceWithRawResponse(self._client.webhooks) + + +class AsyncReplicateWithRawResponse: + _client: AsyncReplicate + + def __init__(self, client: AsyncReplicate) -> None: + self._client = client + + @cached_property + def collections(self) -> collections.AsyncCollectionsResourceWithRawResponse: + from .resources.collections import AsyncCollectionsResourceWithRawResponse + + return AsyncCollectionsResourceWithRawResponse(self._client.collections) + + @cached_property + def deployments(self) -> deployments.AsyncDeploymentsResourceWithRawResponse: + from .resources.deployments import AsyncDeploymentsResourceWithRawResponse + + return AsyncDeploymentsResourceWithRawResponse(self._client.deployments) + + @cached_property + def hardware(self) -> hardware.AsyncHardwareResourceWithRawResponse: + from .resources.hardware import AsyncHardwareResourceWithRawResponse + + return AsyncHardwareResourceWithRawResponse(self._client.hardware) + + @cached_property + def account(self) -> account.AsyncAccountResourceWithRawResponse: + from .resources.account import AsyncAccountResourceWithRawResponse + + return AsyncAccountResourceWithRawResponse(self._client.account) + + @cached_property + def models(self) -> models.AsyncModelsResourceWithRawResponse: + from .resources.models import AsyncModelsResourceWithRawResponse + + return AsyncModelsResourceWithRawResponse(self._client.models) + + @cached_property + def predictions(self) -> predictions.AsyncPredictionsResourceWithRawResponse: + from .resources.predictions import AsyncPredictionsResourceWithRawResponse + + return AsyncPredictionsResourceWithRawResponse(self._client.predictions) + + @cached_property + def trainings(self) -> trainings.AsyncTrainingsResourceWithRawResponse: + from .resources.trainings import AsyncTrainingsResourceWithRawResponse + + return AsyncTrainingsResourceWithRawResponse(self._client.trainings) + + @cached_property + def webhooks(self) -> webhooks.AsyncWebhooksResourceWithRawResponse: + from .resources.webhooks import AsyncWebhooksResourceWithRawResponse + + return AsyncWebhooksResourceWithRawResponse(self._client.webhooks) + + +class ReplicateWithStreamedResponse: + _client: Replicate + + def __init__(self, client: Replicate) -> None: + self._client = client + + @cached_property + def collections(self) -> collections.CollectionsResourceWithStreamingResponse: + from .resources.collections import CollectionsResourceWithStreamingResponse + + return CollectionsResourceWithStreamingResponse(self._client.collections) + + @cached_property + def deployments(self) -> deployments.DeploymentsResourceWithStreamingResponse: + from .resources.deployments import DeploymentsResourceWithStreamingResponse + + return DeploymentsResourceWithStreamingResponse(self._client.deployments) + + @cached_property + def hardware(self) -> hardware.HardwareResourceWithStreamingResponse: + from .resources.hardware import HardwareResourceWithStreamingResponse + + return HardwareResourceWithStreamingResponse(self._client.hardware) + + @cached_property + def account(self) -> account.AccountResourceWithStreamingResponse: + from .resources.account import AccountResourceWithStreamingResponse + + return AccountResourceWithStreamingResponse(self._client.account) + + @cached_property + def models(self) -> models.ModelsResourceWithStreamingResponse: + from .resources.models import ModelsResourceWithStreamingResponse + + return ModelsResourceWithStreamingResponse(self._client.models) + + @cached_property + def predictions(self) -> predictions.PredictionsResourceWithStreamingResponse: + from .resources.predictions import PredictionsResourceWithStreamingResponse + + return PredictionsResourceWithStreamingResponse(self._client.predictions) + + @cached_property + def trainings(self) -> trainings.TrainingsResourceWithStreamingResponse: + from .resources.trainings import TrainingsResourceWithStreamingResponse + + return TrainingsResourceWithStreamingResponse(self._client.trainings) + + @cached_property + def webhooks(self) -> webhooks.WebhooksResourceWithStreamingResponse: + from .resources.webhooks import WebhooksResourceWithStreamingResponse + + return WebhooksResourceWithStreamingResponse(self._client.webhooks) + + +class AsyncReplicateWithStreamedResponse: + _client: AsyncReplicate + + def __init__(self, client: AsyncReplicate) -> None: + self._client = client + + @cached_property + def collections(self) -> collections.AsyncCollectionsResourceWithStreamingResponse: + from .resources.collections import AsyncCollectionsResourceWithStreamingResponse + + return AsyncCollectionsResourceWithStreamingResponse(self._client.collections) + + @cached_property + def deployments(self) -> deployments.AsyncDeploymentsResourceWithStreamingResponse: + from .resources.deployments import AsyncDeploymentsResourceWithStreamingResponse + + return AsyncDeploymentsResourceWithStreamingResponse(self._client.deployments) + + @cached_property + def hardware(self) -> hardware.AsyncHardwareResourceWithStreamingResponse: + from .resources.hardware import AsyncHardwareResourceWithStreamingResponse + + return AsyncHardwareResourceWithStreamingResponse(self._client.hardware) + + @cached_property + def account(self) -> account.AsyncAccountResourceWithStreamingResponse: + from .resources.account import AsyncAccountResourceWithStreamingResponse + + return AsyncAccountResourceWithStreamingResponse(self._client.account) + + @cached_property + def models(self) -> models.AsyncModelsResourceWithStreamingResponse: + from .resources.models import AsyncModelsResourceWithStreamingResponse + + return AsyncModelsResourceWithStreamingResponse(self._client.models) + + @cached_property + def predictions(self) -> predictions.AsyncPredictionsResourceWithStreamingResponse: + from .resources.predictions import AsyncPredictionsResourceWithStreamingResponse + + return AsyncPredictionsResourceWithStreamingResponse(self._client.predictions) + + @cached_property + def trainings(self) -> trainings.AsyncTrainingsResourceWithStreamingResponse: + from .resources.trainings import AsyncTrainingsResourceWithStreamingResponse + + return AsyncTrainingsResourceWithStreamingResponse(self._client.trainings) + + @cached_property + def webhooks(self) -> webhooks.AsyncWebhooksResourceWithStreamingResponse: + from .resources.webhooks import AsyncWebhooksResourceWithStreamingResponse + + return AsyncWebhooksResourceWithStreamingResponse(self._client.webhooks) + + +Client = Replicate + +AsyncClient = AsyncReplicate diff --git a/src/replicate/_exceptions.py b/src/replicate/_exceptions.py index e3bfb90..9fbb505 100644 --- a/src/replicate/_exceptions.py +++ b/src/replicate/_exceptions.py @@ -18,11 +18,11 @@ ] -class ReplicateClientError(Exception): +class ReplicateError(Exception): pass -class APIError(ReplicateClientError): +class APIError(ReplicateError): message: str request: httpx.Request diff --git a/src/replicate/_module_client.py b/src/replicate/_module_client.py index 7733870..c80ea25 100644 --- a/src/replicate/_module_client.py +++ b/src/replicate/_module_client.py @@ -1,64 +1,77 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. +from __future__ import annotations + +from typing import TYPE_CHECKING from typing_extensions import override -from . import resources, _load_client +if TYPE_CHECKING: + from .resources.account import AccountResource + from .resources.hardware import HardwareResource + from .resources.trainings import TrainingsResource + from .resources.collections import CollectionsResource + from .resources.predictions import PredictionsResource + from .resources.models.models import ModelsResource + from .resources.webhooks.webhooks import WebhooksResource + from .resources.deployments.deployments import DeploymentsResource + +from . import _load_client from ._utils import LazyProxy -class ModelsResourceProxy(LazyProxy[resources.ModelsResource]): +class ModelsResourceProxy(LazyProxy["ModelsResource"]): @override - def __load__(self) -> resources.ModelsResource: + def __load__(self) -> ModelsResource: return _load_client().models -class AccountResourceProxy(LazyProxy[resources.AccountResource]): +class AccountResourceProxy(LazyProxy["AccountResource"]): @override - def __load__(self) -> resources.AccountResource: + def __load__(self) -> AccountResource: return _load_client().account -class HardwareResourceProxy(LazyProxy[resources.HardwareResource]): +class HardwareResourceProxy(LazyProxy["HardwareResource"]): @override - def __load__(self) -> resources.HardwareResource: + def __load__(self) -> HardwareResource: return _load_client().hardware -class WebhooksResourceProxy(LazyProxy[resources.WebhooksResource]): +class WebhooksResourceProxy(LazyProxy["WebhooksResource"]): @override - def __load__(self) -> resources.WebhooksResource: + def __load__(self) -> WebhooksResource: return _load_client().webhooks -class TrainingsResourceProxy(LazyProxy[resources.TrainingsResource]): +class TrainingsResourceProxy(LazyProxy["TrainingsResource"]): @override - def __load__(self) -> resources.TrainingsResource: + def __load__(self) -> TrainingsResource: return _load_client().trainings -class CollectionsResourceProxy(LazyProxy[resources.CollectionsResource]): +class CollectionsResourceProxy(LazyProxy["CollectionsResource"]): @override - def __load__(self) -> resources.CollectionsResource: + def __load__(self) -> CollectionsResource: return _load_client().collections -class DeploymentsResourceProxy(LazyProxy[resources.DeploymentsResource]): +class DeploymentsResourceProxy(LazyProxy["DeploymentsResource"]): @override - def __load__(self) -> resources.DeploymentsResource: + def __load__(self) -> DeploymentsResource: return _load_client().deployments -class PredictionsResourceProxy(LazyProxy[resources.PredictionsResource]): +class PredictionsResourceProxy(LazyProxy["PredictionsResource"]): @override - def __load__(self) -> resources.PredictionsResource: + def __load__(self) -> PredictionsResource: return _load_client().predictions -models: resources.ModelsResource = ModelsResourceProxy().__as_proxied__() -account: resources.AccountResource = AccountResourceProxy().__as_proxied__() -hardware: resources.HardwareResource = HardwareResourceProxy().__as_proxied__() -webhooks: resources.WebhooksResource = WebhooksResourceProxy().__as_proxied__() -trainings: resources.TrainingsResource = TrainingsResourceProxy().__as_proxied__() -collections: resources.CollectionsResource = CollectionsResourceProxy().__as_proxied__() -deployments: resources.DeploymentsResource = DeploymentsResourceProxy().__as_proxied__() -predictions: resources.PredictionsResource = PredictionsResourceProxy().__as_proxied__() +models: ModelsResource = ModelsResourceProxy().__as_proxied__() +account: AccountResource = AccountResourceProxy().__as_proxied__() +hardware: HardwareResource = HardwareResourceProxy().__as_proxied__() +webhooks: WebhooksResource = WebhooksResourceProxy().__as_proxied__() +trainings: TrainingsResource = TrainingsResourceProxy().__as_proxied__() +collections: CollectionsResource = CollectionsResourceProxy().__as_proxied__() +deployments: DeploymentsResource = DeploymentsResourceProxy().__as_proxied__() +predictions: PredictionsResource = PredictionsResourceProxy().__as_proxied__() diff --git a/src/replicate/_resource.py b/src/replicate/_resource.py index 81eab3d..b05edea 100644 --- a/src/replicate/_resource.py +++ b/src/replicate/_resource.py @@ -8,13 +8,13 @@ import anyio if TYPE_CHECKING: - from ._client import ReplicateClient, AsyncReplicateClient + from ._client import Replicate, AsyncReplicate class SyncAPIResource: - _client: ReplicateClient + _client: Replicate - def __init__(self, client: ReplicateClient) -> None: + def __init__(self, client: Replicate) -> None: self._client = client self._get = client.get self._post = client.post @@ -29,9 +29,9 @@ def _sleep(self, seconds: float) -> None: class AsyncAPIResource: - _client: AsyncReplicateClient + _client: AsyncReplicate - def __init__(self, client: AsyncReplicateClient) -> None: + def __init__(self, client: AsyncReplicate) -> None: self._client = client self._get = client.get self._post = client.post diff --git a/src/replicate/_response.py b/src/replicate/_response.py index eb2ea3a..2783c74 100644 --- a/src/replicate/_response.py +++ b/src/replicate/_response.py @@ -29,7 +29,7 @@ from ._models import BaseModel, is_basemodel from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type -from ._exceptions import ReplicateClientError, APIResponseValidationError +from ._exceptions import ReplicateError, APIResponseValidationError if TYPE_CHECKING: from ._models import FinalRequestOptions @@ -560,7 +560,7 @@ def __init__(self) -> None: ) -class StreamAlreadyConsumed(ReplicateClientError): +class StreamAlreadyConsumed(ReplicateError): """ Attempted to read or stream content, but the content has already been streamed. diff --git a/src/replicate/_streaming.py b/src/replicate/_streaming.py index 44556e3..e8206d8 100644 --- a/src/replicate/_streaming.py +++ b/src/replicate/_streaming.py @@ -12,7 +12,7 @@ from ._utils import extract_type_var_from_base if TYPE_CHECKING: - from ._client import ReplicateClient, AsyncReplicateClient + from ._client import Replicate, AsyncReplicate _T = TypeVar("_T") @@ -30,7 +30,7 @@ def __init__( *, cast_to: type[_T], response: httpx.Response, - client: ReplicateClient, + client: Replicate, ) -> None: self.response = response self._cast_to = cast_to @@ -93,7 +93,7 @@ def __init__( *, cast_to: type[_T], response: httpx.Response, - client: AsyncReplicateClient, + client: AsyncReplicate, ) -> None: self.response = response self._cast_to = cast_to diff --git a/src/replicate/_utils/_logs.py b/src/replicate/_utils/_logs.py index f1bd4d2..83d237c 100644 --- a/src/replicate/_utils/_logs.py +++ b/src/replicate/_utils/_logs.py @@ -14,7 +14,7 @@ def _basic_config() -> None: def setup_logging() -> None: - env = os.environ.get("REPLICATE_CLIENT_LOG") + env = os.environ.get("REPLICATE_LOG") if env == "debug": _basic_config() logger.setLevel(logging.DEBUG) diff --git a/src/replicate/_version.py b/src/replicate/_version.py index 514d196..5cd526a 100644 --- a/src/replicate/_version.py +++ b/src/replicate/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "replicate" -__version__ = "0.1.0-alpha.8" # x-release-please-version +__version__ = "0.1.0-alpha.9" # x-release-please-version diff --git a/tests/api_resources/deployments/test_predictions.py b/tests/api_resources/deployments/test_predictions.py index 032765b..33983d4 100644 --- a/tests/api_resources/deployments/test_predictions.py +++ b/tests/api_resources/deployments/test_predictions.py @@ -7,7 +7,7 @@ import pytest -from replicate import ReplicateClient, AsyncReplicateClient +from replicate import Replicate, AsyncReplicate from tests.utils import assert_matches_type from replicate.types import Prediction @@ -19,7 +19,7 @@ class TestPredictions: @pytest.mark.skip() @parametrize - def test_method_create(self, client: ReplicateClient) -> None: + def test_method_create(self, client: Replicate) -> None: prediction = client.deployments.predictions.create( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -29,7 +29,7 @@ def test_method_create(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_method_create_with_all_params(self, client: ReplicateClient) -> None: + def test_method_create_with_all_params(self, client: Replicate) -> None: prediction = client.deployments.predictions.create( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -43,7 +43,7 @@ def test_method_create_with_all_params(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_raw_response_create(self, client: ReplicateClient) -> None: + def test_raw_response_create(self, client: Replicate) -> None: response = client.deployments.predictions.with_raw_response.create( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -57,7 +57,7 @@ def test_raw_response_create(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_create(self, client: ReplicateClient) -> None: + def test_streaming_response_create(self, client: Replicate) -> None: with client.deployments.predictions.with_streaming_response.create( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -73,7 +73,7 @@ def test_streaming_response_create(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_path_params_create(self, client: ReplicateClient) -> None: + def test_path_params_create(self, client: Replicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"): client.deployments.predictions.with_raw_response.create( deployment_name="deployment_name", @@ -94,7 +94,7 @@ class TestAsyncPredictions: @pytest.mark.skip() @parametrize - async def test_method_create(self, async_client: AsyncReplicateClient) -> None: + async def test_method_create(self, async_client: AsyncReplicate) -> None: prediction = await async_client.deployments.predictions.create( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -104,7 +104,7 @@ async def test_method_create(self, async_client: AsyncReplicateClient) -> None: @pytest.mark.skip() @parametrize - async def test_method_create_with_all_params(self, async_client: AsyncReplicateClient) -> None: + async def test_method_create_with_all_params(self, async_client: AsyncReplicate) -> None: prediction = await async_client.deployments.predictions.create( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -118,7 +118,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncReplicateC @pytest.mark.skip() @parametrize - async def test_raw_response_create(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_create(self, async_client: AsyncReplicate) -> None: response = await async_client.deployments.predictions.with_raw_response.create( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -132,7 +132,7 @@ async def test_raw_response_create(self, async_client: AsyncReplicateClient) -> @pytest.mark.skip() @parametrize - async def test_streaming_response_create(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_create(self, async_client: AsyncReplicate) -> None: async with async_client.deployments.predictions.with_streaming_response.create( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -148,7 +148,7 @@ async def test_streaming_response_create(self, async_client: AsyncReplicateClien @pytest.mark.skip() @parametrize - async def test_path_params_create(self, async_client: AsyncReplicateClient) -> None: + async def test_path_params_create(self, async_client: AsyncReplicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"): await async_client.deployments.predictions.with_raw_response.create( deployment_name="deployment_name", diff --git a/tests/api_resources/models/test_examples.py b/tests/api_resources/models/test_examples.py index c3d655e..45a1f6a 100644 --- a/tests/api_resources/models/test_examples.py +++ b/tests/api_resources/models/test_examples.py @@ -7,7 +7,7 @@ import pytest -from replicate import ReplicateClient, AsyncReplicateClient +from replicate import Replicate, AsyncReplicate base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -17,7 +17,7 @@ class TestExamples: @pytest.mark.skip() @parametrize - def test_method_list(self, client: ReplicateClient) -> None: + def test_method_list(self, client: Replicate) -> None: example = client.models.examples.list( model_name="model_name", model_owner="model_owner", @@ -26,7 +26,7 @@ def test_method_list(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_raw_response_list(self, client: ReplicateClient) -> None: + def test_raw_response_list(self, client: Replicate) -> None: response = client.models.examples.with_raw_response.list( model_name="model_name", model_owner="model_owner", @@ -39,7 +39,7 @@ def test_raw_response_list(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_list(self, client: ReplicateClient) -> None: + def test_streaming_response_list(self, client: Replicate) -> None: with client.models.examples.with_streaming_response.list( model_name="model_name", model_owner="model_owner", @@ -54,7 +54,7 @@ def test_streaming_response_list(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_path_params_list(self, client: ReplicateClient) -> None: + def test_path_params_list(self, client: Replicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): client.models.examples.with_raw_response.list( model_name="model_name", @@ -73,7 +73,7 @@ class TestAsyncExamples: @pytest.mark.skip() @parametrize - async def test_method_list(self, async_client: AsyncReplicateClient) -> None: + async def test_method_list(self, async_client: AsyncReplicate) -> None: example = await async_client.models.examples.list( model_name="model_name", model_owner="model_owner", @@ -82,7 +82,7 @@ async def test_method_list(self, async_client: AsyncReplicateClient) -> None: @pytest.mark.skip() @parametrize - async def test_raw_response_list(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_list(self, async_client: AsyncReplicate) -> None: response = await async_client.models.examples.with_raw_response.list( model_name="model_name", model_owner="model_owner", @@ -95,7 +95,7 @@ async def test_raw_response_list(self, async_client: AsyncReplicateClient) -> No @pytest.mark.skip() @parametrize - async def test_streaming_response_list(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_list(self, async_client: AsyncReplicate) -> None: async with async_client.models.examples.with_streaming_response.list( model_name="model_name", model_owner="model_owner", @@ -110,7 +110,7 @@ async def test_streaming_response_list(self, async_client: AsyncReplicateClient) @pytest.mark.skip() @parametrize - async def test_path_params_list(self, async_client: AsyncReplicateClient) -> None: + async def test_path_params_list(self, async_client: AsyncReplicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): await async_client.models.examples.with_raw_response.list( model_name="model_name", diff --git a/tests/api_resources/models/test_predictions.py b/tests/api_resources/models/test_predictions.py index d4f1974..0bb48a6 100644 --- a/tests/api_resources/models/test_predictions.py +++ b/tests/api_resources/models/test_predictions.py @@ -7,7 +7,7 @@ import pytest -from replicate import ReplicateClient, AsyncReplicateClient +from replicate import Replicate, AsyncReplicate from tests.utils import assert_matches_type from replicate.types import Prediction @@ -19,7 +19,7 @@ class TestPredictions: @pytest.mark.skip() @parametrize - def test_method_create(self, client: ReplicateClient) -> None: + def test_method_create(self, client: Replicate) -> None: prediction = client.models.predictions.create( model_name="model_name", model_owner="model_owner", @@ -29,7 +29,7 @@ def test_method_create(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_method_create_with_all_params(self, client: ReplicateClient) -> None: + def test_method_create_with_all_params(self, client: Replicate) -> None: prediction = client.models.predictions.create( model_name="model_name", model_owner="model_owner", @@ -43,7 +43,7 @@ def test_method_create_with_all_params(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_raw_response_create(self, client: ReplicateClient) -> None: + def test_raw_response_create(self, client: Replicate) -> None: response = client.models.predictions.with_raw_response.create( model_name="model_name", model_owner="model_owner", @@ -57,7 +57,7 @@ def test_raw_response_create(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_create(self, client: ReplicateClient) -> None: + def test_streaming_response_create(self, client: Replicate) -> None: with client.models.predictions.with_streaming_response.create( model_name="model_name", model_owner="model_owner", @@ -73,7 +73,7 @@ def test_streaming_response_create(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_path_params_create(self, client: ReplicateClient) -> None: + def test_path_params_create(self, client: Replicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): client.models.predictions.with_raw_response.create( model_name="model_name", @@ -94,7 +94,7 @@ class TestAsyncPredictions: @pytest.mark.skip() @parametrize - async def test_method_create(self, async_client: AsyncReplicateClient) -> None: + async def test_method_create(self, async_client: AsyncReplicate) -> None: prediction = await async_client.models.predictions.create( model_name="model_name", model_owner="model_owner", @@ -104,7 +104,7 @@ async def test_method_create(self, async_client: AsyncReplicateClient) -> None: @pytest.mark.skip() @parametrize - async def test_method_create_with_all_params(self, async_client: AsyncReplicateClient) -> None: + async def test_method_create_with_all_params(self, async_client: AsyncReplicate) -> None: prediction = await async_client.models.predictions.create( model_name="model_name", model_owner="model_owner", @@ -118,7 +118,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncReplicateC @pytest.mark.skip() @parametrize - async def test_raw_response_create(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_create(self, async_client: AsyncReplicate) -> None: response = await async_client.models.predictions.with_raw_response.create( model_name="model_name", model_owner="model_owner", @@ -132,7 +132,7 @@ async def test_raw_response_create(self, async_client: AsyncReplicateClient) -> @pytest.mark.skip() @parametrize - async def test_streaming_response_create(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_create(self, async_client: AsyncReplicate) -> None: async with async_client.models.predictions.with_streaming_response.create( model_name="model_name", model_owner="model_owner", @@ -148,7 +148,7 @@ async def test_streaming_response_create(self, async_client: AsyncReplicateClien @pytest.mark.skip() @parametrize - async def test_path_params_create(self, async_client: AsyncReplicateClient) -> None: + async def test_path_params_create(self, async_client: AsyncReplicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): await async_client.models.predictions.with_raw_response.create( model_name="model_name", diff --git a/tests/api_resources/models/test_readme.py b/tests/api_resources/models/test_readme.py index 048334b..698affd 100644 --- a/tests/api_resources/models/test_readme.py +++ b/tests/api_resources/models/test_readme.py @@ -7,7 +7,7 @@ import pytest -from replicate import ReplicateClient, AsyncReplicateClient +from replicate import Replicate, AsyncReplicate from tests.utils import assert_matches_type base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -18,7 +18,7 @@ class TestReadme: @pytest.mark.skip() @parametrize - def test_method_get(self, client: ReplicateClient) -> None: + def test_method_get(self, client: Replicate) -> None: readme = client.models.readme.get( model_name="model_name", model_owner="model_owner", @@ -27,7 +27,7 @@ def test_method_get(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_raw_response_get(self, client: ReplicateClient) -> None: + def test_raw_response_get(self, client: Replicate) -> None: response = client.models.readme.with_raw_response.get( model_name="model_name", model_owner="model_owner", @@ -40,7 +40,7 @@ def test_raw_response_get(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_get(self, client: ReplicateClient) -> None: + def test_streaming_response_get(self, client: Replicate) -> None: with client.models.readme.with_streaming_response.get( model_name="model_name", model_owner="model_owner", @@ -55,7 +55,7 @@ def test_streaming_response_get(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_path_params_get(self, client: ReplicateClient) -> None: + def test_path_params_get(self, client: Replicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): client.models.readme.with_raw_response.get( model_name="model_name", @@ -74,7 +74,7 @@ class TestAsyncReadme: @pytest.mark.skip() @parametrize - async def test_method_get(self, async_client: AsyncReplicateClient) -> None: + async def test_method_get(self, async_client: AsyncReplicate) -> None: readme = await async_client.models.readme.get( model_name="model_name", model_owner="model_owner", @@ -83,7 +83,7 @@ async def test_method_get(self, async_client: AsyncReplicateClient) -> None: @pytest.mark.skip() @parametrize - async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_get(self, async_client: AsyncReplicate) -> None: response = await async_client.models.readme.with_raw_response.get( model_name="model_name", model_owner="model_owner", @@ -96,7 +96,7 @@ async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> Non @pytest.mark.skip() @parametrize - async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_get(self, async_client: AsyncReplicate) -> None: async with async_client.models.readme.with_streaming_response.get( model_name="model_name", model_owner="model_owner", @@ -111,7 +111,7 @@ async def test_streaming_response_get(self, async_client: AsyncReplicateClient) @pytest.mark.skip() @parametrize - async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None: + async def test_path_params_get(self, async_client: AsyncReplicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): await async_client.models.readme.with_raw_response.get( model_name="model_name", diff --git a/tests/api_resources/models/test_versions.py b/tests/api_resources/models/test_versions.py index af132a2..6aac669 100644 --- a/tests/api_resources/models/test_versions.py +++ b/tests/api_resources/models/test_versions.py @@ -7,7 +7,7 @@ import pytest -from replicate import ReplicateClient, AsyncReplicateClient +from replicate import Replicate, AsyncReplicate base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -17,7 +17,7 @@ class TestVersions: @pytest.mark.skip() @parametrize - def test_method_list(self, client: ReplicateClient) -> None: + def test_method_list(self, client: Replicate) -> None: version = client.models.versions.list( model_name="model_name", model_owner="model_owner", @@ -26,7 +26,7 @@ def test_method_list(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_raw_response_list(self, client: ReplicateClient) -> None: + def test_raw_response_list(self, client: Replicate) -> None: response = client.models.versions.with_raw_response.list( model_name="model_name", model_owner="model_owner", @@ -39,7 +39,7 @@ def test_raw_response_list(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_list(self, client: ReplicateClient) -> None: + def test_streaming_response_list(self, client: Replicate) -> None: with client.models.versions.with_streaming_response.list( model_name="model_name", model_owner="model_owner", @@ -54,7 +54,7 @@ def test_streaming_response_list(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_path_params_list(self, client: ReplicateClient) -> None: + def test_path_params_list(self, client: Replicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): client.models.versions.with_raw_response.list( model_name="model_name", @@ -69,7 +69,7 @@ def test_path_params_list(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_method_delete(self, client: ReplicateClient) -> None: + def test_method_delete(self, client: Replicate) -> None: version = client.models.versions.delete( version_id="version_id", model_owner="model_owner", @@ -79,7 +79,7 @@ def test_method_delete(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_raw_response_delete(self, client: ReplicateClient) -> None: + def test_raw_response_delete(self, client: Replicate) -> None: response = client.models.versions.with_raw_response.delete( version_id="version_id", model_owner="model_owner", @@ -93,7 +93,7 @@ def test_raw_response_delete(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_delete(self, client: ReplicateClient) -> None: + def test_streaming_response_delete(self, client: Replicate) -> None: with client.models.versions.with_streaming_response.delete( version_id="version_id", model_owner="model_owner", @@ -109,7 +109,7 @@ def test_streaming_response_delete(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_path_params_delete(self, client: ReplicateClient) -> None: + def test_path_params_delete(self, client: Replicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): client.models.versions.with_raw_response.delete( version_id="version_id", @@ -133,7 +133,7 @@ def test_path_params_delete(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_method_get(self, client: ReplicateClient) -> None: + def test_method_get(self, client: Replicate) -> None: version = client.models.versions.get( version_id="version_id", model_owner="model_owner", @@ -143,7 +143,7 @@ def test_method_get(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_raw_response_get(self, client: ReplicateClient) -> None: + def test_raw_response_get(self, client: Replicate) -> None: response = client.models.versions.with_raw_response.get( version_id="version_id", model_owner="model_owner", @@ -157,7 +157,7 @@ def test_raw_response_get(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_get(self, client: ReplicateClient) -> None: + def test_streaming_response_get(self, client: Replicate) -> None: with client.models.versions.with_streaming_response.get( version_id="version_id", model_owner="model_owner", @@ -173,7 +173,7 @@ def test_streaming_response_get(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_path_params_get(self, client: ReplicateClient) -> None: + def test_path_params_get(self, client: Replicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): client.models.versions.with_raw_response.get( version_id="version_id", @@ -201,7 +201,7 @@ class TestAsyncVersions: @pytest.mark.skip() @parametrize - async def test_method_list(self, async_client: AsyncReplicateClient) -> None: + async def test_method_list(self, async_client: AsyncReplicate) -> None: version = await async_client.models.versions.list( model_name="model_name", model_owner="model_owner", @@ -210,7 +210,7 @@ async def test_method_list(self, async_client: AsyncReplicateClient) -> None: @pytest.mark.skip() @parametrize - async def test_raw_response_list(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_list(self, async_client: AsyncReplicate) -> None: response = await async_client.models.versions.with_raw_response.list( model_name="model_name", model_owner="model_owner", @@ -223,7 +223,7 @@ async def test_raw_response_list(self, async_client: AsyncReplicateClient) -> No @pytest.mark.skip() @parametrize - async def test_streaming_response_list(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_list(self, async_client: AsyncReplicate) -> None: async with async_client.models.versions.with_streaming_response.list( model_name="model_name", model_owner="model_owner", @@ -238,7 +238,7 @@ async def test_streaming_response_list(self, async_client: AsyncReplicateClient) @pytest.mark.skip() @parametrize - async def test_path_params_list(self, async_client: AsyncReplicateClient) -> None: + async def test_path_params_list(self, async_client: AsyncReplicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): await async_client.models.versions.with_raw_response.list( model_name="model_name", @@ -253,7 +253,7 @@ async def test_path_params_list(self, async_client: AsyncReplicateClient) -> Non @pytest.mark.skip() @parametrize - async def test_method_delete(self, async_client: AsyncReplicateClient) -> None: + async def test_method_delete(self, async_client: AsyncReplicate) -> None: version = await async_client.models.versions.delete( version_id="version_id", model_owner="model_owner", @@ -263,7 +263,7 @@ async def test_method_delete(self, async_client: AsyncReplicateClient) -> None: @pytest.mark.skip() @parametrize - async def test_raw_response_delete(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_delete(self, async_client: AsyncReplicate) -> None: response = await async_client.models.versions.with_raw_response.delete( version_id="version_id", model_owner="model_owner", @@ -277,7 +277,7 @@ async def test_raw_response_delete(self, async_client: AsyncReplicateClient) -> @pytest.mark.skip() @parametrize - async def test_streaming_response_delete(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_delete(self, async_client: AsyncReplicate) -> None: async with async_client.models.versions.with_streaming_response.delete( version_id="version_id", model_owner="model_owner", @@ -293,7 +293,7 @@ async def test_streaming_response_delete(self, async_client: AsyncReplicateClien @pytest.mark.skip() @parametrize - async def test_path_params_delete(self, async_client: AsyncReplicateClient) -> None: + async def test_path_params_delete(self, async_client: AsyncReplicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): await async_client.models.versions.with_raw_response.delete( version_id="version_id", @@ -317,7 +317,7 @@ async def test_path_params_delete(self, async_client: AsyncReplicateClient) -> N @pytest.mark.skip() @parametrize - async def test_method_get(self, async_client: AsyncReplicateClient) -> None: + async def test_method_get(self, async_client: AsyncReplicate) -> None: version = await async_client.models.versions.get( version_id="version_id", model_owner="model_owner", @@ -327,7 +327,7 @@ async def test_method_get(self, async_client: AsyncReplicateClient) -> None: @pytest.mark.skip() @parametrize - async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_get(self, async_client: AsyncReplicate) -> None: response = await async_client.models.versions.with_raw_response.get( version_id="version_id", model_owner="model_owner", @@ -341,7 +341,7 @@ async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> Non @pytest.mark.skip() @parametrize - async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_get(self, async_client: AsyncReplicate) -> None: async with async_client.models.versions.with_streaming_response.get( version_id="version_id", model_owner="model_owner", @@ -357,7 +357,7 @@ async def test_streaming_response_get(self, async_client: AsyncReplicateClient) @pytest.mark.skip() @parametrize - async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None: + async def test_path_params_get(self, async_client: AsyncReplicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): await async_client.models.versions.with_raw_response.get( version_id="version_id", diff --git a/tests/api_resources/test_account.py b/tests/api_resources/test_account.py index 1039fbb..74e7679 100644 --- a/tests/api_resources/test_account.py +++ b/tests/api_resources/test_account.py @@ -7,7 +7,7 @@ import pytest -from replicate import ReplicateClient, AsyncReplicateClient +from replicate import Replicate, AsyncReplicate from tests.utils import assert_matches_type from replicate.types import AccountGetResponse @@ -19,13 +19,13 @@ class TestAccount: @pytest.mark.skip() @parametrize - def test_method_get(self, client: ReplicateClient) -> None: + def test_method_get(self, client: Replicate) -> None: account = client.account.get() assert_matches_type(AccountGetResponse, account, path=["response"]) @pytest.mark.skip() @parametrize - def test_raw_response_get(self, client: ReplicateClient) -> None: + def test_raw_response_get(self, client: Replicate) -> None: response = client.account.with_raw_response.get() assert response.is_closed is True @@ -35,7 +35,7 @@ def test_raw_response_get(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_get(self, client: ReplicateClient) -> None: + def test_streaming_response_get(self, client: Replicate) -> None: with client.account.with_streaming_response.get() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -51,13 +51,13 @@ class TestAsyncAccount: @pytest.mark.skip() @parametrize - async def test_method_get(self, async_client: AsyncReplicateClient) -> None: + async def test_method_get(self, async_client: AsyncReplicate) -> None: account = await async_client.account.get() assert_matches_type(AccountGetResponse, account, path=["response"]) @pytest.mark.skip() @parametrize - async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_get(self, async_client: AsyncReplicate) -> None: response = await async_client.account.with_raw_response.get() assert response.is_closed is True @@ -67,7 +67,7 @@ async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> Non @pytest.mark.skip() @parametrize - async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_get(self, async_client: AsyncReplicate) -> None: async with async_client.account.with_streaming_response.get() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" diff --git a/tests/api_resources/test_collections.py b/tests/api_resources/test_collections.py index d231a8d..50ceede 100644 --- a/tests/api_resources/test_collections.py +++ b/tests/api_resources/test_collections.py @@ -7,7 +7,7 @@ import pytest -from replicate import ReplicateClient, AsyncReplicateClient +from replicate import Replicate, AsyncReplicate base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -17,13 +17,13 @@ class TestCollections: @pytest.mark.skip() @parametrize - def test_method_list(self, client: ReplicateClient) -> None: + def test_method_list(self, client: Replicate) -> None: collection = client.collections.list() assert collection is None @pytest.mark.skip() @parametrize - def test_raw_response_list(self, client: ReplicateClient) -> None: + def test_raw_response_list(self, client: Replicate) -> None: response = client.collections.with_raw_response.list() assert response.is_closed is True @@ -33,7 +33,7 @@ def test_raw_response_list(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_list(self, client: ReplicateClient) -> None: + def test_streaming_response_list(self, client: Replicate) -> None: with client.collections.with_streaming_response.list() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -45,7 +45,7 @@ def test_streaming_response_list(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_method_get(self, client: ReplicateClient) -> None: + def test_method_get(self, client: Replicate) -> None: collection = client.collections.get( "collection_slug", ) @@ -53,7 +53,7 @@ def test_method_get(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_raw_response_get(self, client: ReplicateClient) -> None: + def test_raw_response_get(self, client: Replicate) -> None: response = client.collections.with_raw_response.get( "collection_slug", ) @@ -65,7 +65,7 @@ def test_raw_response_get(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_get(self, client: ReplicateClient) -> None: + def test_streaming_response_get(self, client: Replicate) -> None: with client.collections.with_streaming_response.get( "collection_slug", ) as response: @@ -79,7 +79,7 @@ def test_streaming_response_get(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_path_params_get(self, client: ReplicateClient) -> None: + def test_path_params_get(self, client: Replicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `collection_slug` but received ''"): client.collections.with_raw_response.get( "", @@ -91,13 +91,13 @@ class TestAsyncCollections: @pytest.mark.skip() @parametrize - async def test_method_list(self, async_client: AsyncReplicateClient) -> None: + async def test_method_list(self, async_client: AsyncReplicate) -> None: collection = await async_client.collections.list() assert collection is None @pytest.mark.skip() @parametrize - async def test_raw_response_list(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_list(self, async_client: AsyncReplicate) -> None: response = await async_client.collections.with_raw_response.list() assert response.is_closed is True @@ -107,7 +107,7 @@ async def test_raw_response_list(self, async_client: AsyncReplicateClient) -> No @pytest.mark.skip() @parametrize - async def test_streaming_response_list(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_list(self, async_client: AsyncReplicate) -> None: async with async_client.collections.with_streaming_response.list() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -119,7 +119,7 @@ async def test_streaming_response_list(self, async_client: AsyncReplicateClient) @pytest.mark.skip() @parametrize - async def test_method_get(self, async_client: AsyncReplicateClient) -> None: + async def test_method_get(self, async_client: AsyncReplicate) -> None: collection = await async_client.collections.get( "collection_slug", ) @@ -127,7 +127,7 @@ async def test_method_get(self, async_client: AsyncReplicateClient) -> None: @pytest.mark.skip() @parametrize - async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_get(self, async_client: AsyncReplicate) -> None: response = await async_client.collections.with_raw_response.get( "collection_slug", ) @@ -139,7 +139,7 @@ async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> Non @pytest.mark.skip() @parametrize - async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_get(self, async_client: AsyncReplicate) -> None: async with async_client.collections.with_streaming_response.get( "collection_slug", ) as response: @@ -153,7 +153,7 @@ async def test_streaming_response_get(self, async_client: AsyncReplicateClient) @pytest.mark.skip() @parametrize - async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None: + async def test_path_params_get(self, async_client: AsyncReplicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `collection_slug` but received ''"): await async_client.collections.with_raw_response.get( "", diff --git a/tests/api_resources/test_deployments.py b/tests/api_resources/test_deployments.py index 6d6360a..14c32af 100644 --- a/tests/api_resources/test_deployments.py +++ b/tests/api_resources/test_deployments.py @@ -7,7 +7,7 @@ import pytest -from replicate import ReplicateClient, AsyncReplicateClient +from replicate import Replicate, AsyncReplicate from tests.utils import assert_matches_type from replicate.types import ( DeploymentGetResponse, @@ -25,7 +25,7 @@ class TestDeployments: @pytest.mark.skip() @parametrize - def test_method_create(self, client: ReplicateClient) -> None: + def test_method_create(self, client: Replicate) -> None: deployment = client.deployments.create( hardware="hardware", max_instances=0, @@ -38,7 +38,7 @@ def test_method_create(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_raw_response_create(self, client: ReplicateClient) -> None: + def test_raw_response_create(self, client: Replicate) -> None: response = client.deployments.with_raw_response.create( hardware="hardware", max_instances=0, @@ -55,7 +55,7 @@ def test_raw_response_create(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_create(self, client: ReplicateClient) -> None: + def test_streaming_response_create(self, client: Replicate) -> None: with client.deployments.with_streaming_response.create( hardware="hardware", max_instances=0, @@ -74,7 +74,7 @@ def test_streaming_response_create(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_method_update(self, client: ReplicateClient) -> None: + def test_method_update(self, client: Replicate) -> None: deployment = client.deployments.update( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -83,7 +83,7 @@ def test_method_update(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_method_update_with_all_params(self, client: ReplicateClient) -> None: + def test_method_update_with_all_params(self, client: Replicate) -> None: deployment = client.deployments.update( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -96,7 +96,7 @@ def test_method_update_with_all_params(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_raw_response_update(self, client: ReplicateClient) -> None: + def test_raw_response_update(self, client: Replicate) -> None: response = client.deployments.with_raw_response.update( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -109,7 +109,7 @@ def test_raw_response_update(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_update(self, client: ReplicateClient) -> None: + def test_streaming_response_update(self, client: Replicate) -> None: with client.deployments.with_streaming_response.update( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -124,7 +124,7 @@ def test_streaming_response_update(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_path_params_update(self, client: ReplicateClient) -> None: + def test_path_params_update(self, client: Replicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"): client.deployments.with_raw_response.update( deployment_name="deployment_name", @@ -139,13 +139,13 @@ def test_path_params_update(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_method_list(self, client: ReplicateClient) -> None: + def test_method_list(self, client: Replicate) -> None: deployment = client.deployments.list() assert_matches_type(SyncCursorURLPage[DeploymentListResponse], deployment, path=["response"]) @pytest.mark.skip() @parametrize - def test_raw_response_list(self, client: ReplicateClient) -> None: + def test_raw_response_list(self, client: Replicate) -> None: response = client.deployments.with_raw_response.list() assert response.is_closed is True @@ -155,7 +155,7 @@ def test_raw_response_list(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_list(self, client: ReplicateClient) -> None: + def test_streaming_response_list(self, client: Replicate) -> None: with client.deployments.with_streaming_response.list() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -167,7 +167,7 @@ def test_streaming_response_list(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_method_delete(self, client: ReplicateClient) -> None: + def test_method_delete(self, client: Replicate) -> None: deployment = client.deployments.delete( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -176,7 +176,7 @@ def test_method_delete(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_raw_response_delete(self, client: ReplicateClient) -> None: + def test_raw_response_delete(self, client: Replicate) -> None: response = client.deployments.with_raw_response.delete( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -189,7 +189,7 @@ def test_raw_response_delete(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_delete(self, client: ReplicateClient) -> None: + def test_streaming_response_delete(self, client: Replicate) -> None: with client.deployments.with_streaming_response.delete( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -204,7 +204,7 @@ def test_streaming_response_delete(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_path_params_delete(self, client: ReplicateClient) -> None: + def test_path_params_delete(self, client: Replicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"): client.deployments.with_raw_response.delete( deployment_name="deployment_name", @@ -219,7 +219,7 @@ def test_path_params_delete(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_method_get(self, client: ReplicateClient) -> None: + def test_method_get(self, client: Replicate) -> None: deployment = client.deployments.get( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -228,7 +228,7 @@ def test_method_get(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_raw_response_get(self, client: ReplicateClient) -> None: + def test_raw_response_get(self, client: Replicate) -> None: response = client.deployments.with_raw_response.get( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -241,7 +241,7 @@ def test_raw_response_get(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_get(self, client: ReplicateClient) -> None: + def test_streaming_response_get(self, client: Replicate) -> None: with client.deployments.with_streaming_response.get( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -256,7 +256,7 @@ def test_streaming_response_get(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_path_params_get(self, client: ReplicateClient) -> None: + def test_path_params_get(self, client: Replicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"): client.deployments.with_raw_response.get( deployment_name="deployment_name", @@ -275,7 +275,7 @@ class TestAsyncDeployments: @pytest.mark.skip() @parametrize - async def test_method_create(self, async_client: AsyncReplicateClient) -> None: + async def test_method_create(self, async_client: AsyncReplicate) -> None: deployment = await async_client.deployments.create( hardware="hardware", max_instances=0, @@ -288,7 +288,7 @@ async def test_method_create(self, async_client: AsyncReplicateClient) -> None: @pytest.mark.skip() @parametrize - async def test_raw_response_create(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_create(self, async_client: AsyncReplicate) -> None: response = await async_client.deployments.with_raw_response.create( hardware="hardware", max_instances=0, @@ -305,7 +305,7 @@ async def test_raw_response_create(self, async_client: AsyncReplicateClient) -> @pytest.mark.skip() @parametrize - async def test_streaming_response_create(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_create(self, async_client: AsyncReplicate) -> None: async with async_client.deployments.with_streaming_response.create( hardware="hardware", max_instances=0, @@ -324,7 +324,7 @@ async def test_streaming_response_create(self, async_client: AsyncReplicateClien @pytest.mark.skip() @parametrize - async def test_method_update(self, async_client: AsyncReplicateClient) -> None: + async def test_method_update(self, async_client: AsyncReplicate) -> None: deployment = await async_client.deployments.update( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -333,7 +333,7 @@ async def test_method_update(self, async_client: AsyncReplicateClient) -> None: @pytest.mark.skip() @parametrize - async def test_method_update_with_all_params(self, async_client: AsyncReplicateClient) -> None: + async def test_method_update_with_all_params(self, async_client: AsyncReplicate) -> None: deployment = await async_client.deployments.update( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -346,7 +346,7 @@ async def test_method_update_with_all_params(self, async_client: AsyncReplicateC @pytest.mark.skip() @parametrize - async def test_raw_response_update(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_update(self, async_client: AsyncReplicate) -> None: response = await async_client.deployments.with_raw_response.update( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -359,7 +359,7 @@ async def test_raw_response_update(self, async_client: AsyncReplicateClient) -> @pytest.mark.skip() @parametrize - async def test_streaming_response_update(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_update(self, async_client: AsyncReplicate) -> None: async with async_client.deployments.with_streaming_response.update( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -374,7 +374,7 @@ async def test_streaming_response_update(self, async_client: AsyncReplicateClien @pytest.mark.skip() @parametrize - async def test_path_params_update(self, async_client: AsyncReplicateClient) -> None: + async def test_path_params_update(self, async_client: AsyncReplicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"): await async_client.deployments.with_raw_response.update( deployment_name="deployment_name", @@ -389,13 +389,13 @@ async def test_path_params_update(self, async_client: AsyncReplicateClient) -> N @pytest.mark.skip() @parametrize - async def test_method_list(self, async_client: AsyncReplicateClient) -> None: + async def test_method_list(self, async_client: AsyncReplicate) -> None: deployment = await async_client.deployments.list() assert_matches_type(AsyncCursorURLPage[DeploymentListResponse], deployment, path=["response"]) @pytest.mark.skip() @parametrize - async def test_raw_response_list(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_list(self, async_client: AsyncReplicate) -> None: response = await async_client.deployments.with_raw_response.list() assert response.is_closed is True @@ -405,7 +405,7 @@ async def test_raw_response_list(self, async_client: AsyncReplicateClient) -> No @pytest.mark.skip() @parametrize - async def test_streaming_response_list(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_list(self, async_client: AsyncReplicate) -> None: async with async_client.deployments.with_streaming_response.list() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -417,7 +417,7 @@ async def test_streaming_response_list(self, async_client: AsyncReplicateClient) @pytest.mark.skip() @parametrize - async def test_method_delete(self, async_client: AsyncReplicateClient) -> None: + async def test_method_delete(self, async_client: AsyncReplicate) -> None: deployment = await async_client.deployments.delete( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -426,7 +426,7 @@ async def test_method_delete(self, async_client: AsyncReplicateClient) -> None: @pytest.mark.skip() @parametrize - async def test_raw_response_delete(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_delete(self, async_client: AsyncReplicate) -> None: response = await async_client.deployments.with_raw_response.delete( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -439,7 +439,7 @@ async def test_raw_response_delete(self, async_client: AsyncReplicateClient) -> @pytest.mark.skip() @parametrize - async def test_streaming_response_delete(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_delete(self, async_client: AsyncReplicate) -> None: async with async_client.deployments.with_streaming_response.delete( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -454,7 +454,7 @@ async def test_streaming_response_delete(self, async_client: AsyncReplicateClien @pytest.mark.skip() @parametrize - async def test_path_params_delete(self, async_client: AsyncReplicateClient) -> None: + async def test_path_params_delete(self, async_client: AsyncReplicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"): await async_client.deployments.with_raw_response.delete( deployment_name="deployment_name", @@ -469,7 +469,7 @@ async def test_path_params_delete(self, async_client: AsyncReplicateClient) -> N @pytest.mark.skip() @parametrize - async def test_method_get(self, async_client: AsyncReplicateClient) -> None: + async def test_method_get(self, async_client: AsyncReplicate) -> None: deployment = await async_client.deployments.get( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -478,7 +478,7 @@ async def test_method_get(self, async_client: AsyncReplicateClient) -> None: @pytest.mark.skip() @parametrize - async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_get(self, async_client: AsyncReplicate) -> None: response = await async_client.deployments.with_raw_response.get( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -491,7 +491,7 @@ async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> Non @pytest.mark.skip() @parametrize - async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_get(self, async_client: AsyncReplicate) -> None: async with async_client.deployments.with_streaming_response.get( deployment_name="deployment_name", deployment_owner="deployment_owner", @@ -506,7 +506,7 @@ async def test_streaming_response_get(self, async_client: AsyncReplicateClient) @pytest.mark.skip() @parametrize - async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None: + async def test_path_params_get(self, async_client: AsyncReplicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"): await async_client.deployments.with_raw_response.get( deployment_name="deployment_name", diff --git a/tests/api_resources/test_hardware.py b/tests/api_resources/test_hardware.py index 9aa535b..4ec3027 100644 --- a/tests/api_resources/test_hardware.py +++ b/tests/api_resources/test_hardware.py @@ -7,7 +7,7 @@ import pytest -from replicate import ReplicateClient, AsyncReplicateClient +from replicate import Replicate, AsyncReplicate from tests.utils import assert_matches_type from replicate.types import HardwareListResponse @@ -19,13 +19,13 @@ class TestHardware: @pytest.mark.skip() @parametrize - def test_method_list(self, client: ReplicateClient) -> None: + def test_method_list(self, client: Replicate) -> None: hardware = client.hardware.list() assert_matches_type(HardwareListResponse, hardware, path=["response"]) @pytest.mark.skip() @parametrize - def test_raw_response_list(self, client: ReplicateClient) -> None: + def test_raw_response_list(self, client: Replicate) -> None: response = client.hardware.with_raw_response.list() assert response.is_closed is True @@ -35,7 +35,7 @@ def test_raw_response_list(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_list(self, client: ReplicateClient) -> None: + def test_streaming_response_list(self, client: Replicate) -> None: with client.hardware.with_streaming_response.list() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -51,13 +51,13 @@ class TestAsyncHardware: @pytest.mark.skip() @parametrize - async def test_method_list(self, async_client: AsyncReplicateClient) -> None: + async def test_method_list(self, async_client: AsyncReplicate) -> None: hardware = await async_client.hardware.list() assert_matches_type(HardwareListResponse, hardware, path=["response"]) @pytest.mark.skip() @parametrize - async def test_raw_response_list(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_list(self, async_client: AsyncReplicate) -> None: response = await async_client.hardware.with_raw_response.list() assert response.is_closed is True @@ -67,7 +67,7 @@ async def test_raw_response_list(self, async_client: AsyncReplicateClient) -> No @pytest.mark.skip() @parametrize - async def test_streaming_response_list(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_list(self, async_client: AsyncReplicate) -> None: async with async_client.hardware.with_streaming_response.list() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" diff --git a/tests/api_resources/test_models.py b/tests/api_resources/test_models.py index 8c6ad82..ae56127 100644 --- a/tests/api_resources/test_models.py +++ b/tests/api_resources/test_models.py @@ -7,7 +7,7 @@ import pytest -from replicate import ReplicateClient, AsyncReplicateClient +from replicate import Replicate, AsyncReplicate from tests.utils import assert_matches_type from replicate.types import ModelListResponse from replicate.pagination import SyncCursorURLPage, AsyncCursorURLPage @@ -20,7 +20,7 @@ class TestModels: @pytest.mark.skip() @parametrize - def test_method_create(self, client: ReplicateClient) -> None: + def test_method_create(self, client: Replicate) -> None: model = client.models.create( hardware="hardware", name="name", @@ -31,7 +31,7 @@ def test_method_create(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_method_create_with_all_params(self, client: ReplicateClient) -> None: + def test_method_create_with_all_params(self, client: Replicate) -> None: model = client.models.create( hardware="hardware", name="name", @@ -47,7 +47,7 @@ def test_method_create_with_all_params(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_raw_response_create(self, client: ReplicateClient) -> None: + def test_raw_response_create(self, client: Replicate) -> None: response = client.models.with_raw_response.create( hardware="hardware", name="name", @@ -62,7 +62,7 @@ def test_raw_response_create(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_create(self, client: ReplicateClient) -> None: + def test_streaming_response_create(self, client: Replicate) -> None: with client.models.with_streaming_response.create( hardware="hardware", name="name", @@ -79,13 +79,13 @@ def test_streaming_response_create(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_method_list(self, client: ReplicateClient) -> None: + def test_method_list(self, client: Replicate) -> None: model = client.models.list() assert_matches_type(SyncCursorURLPage[ModelListResponse], model, path=["response"]) @pytest.mark.skip() @parametrize - def test_raw_response_list(self, client: ReplicateClient) -> None: + def test_raw_response_list(self, client: Replicate) -> None: response = client.models.with_raw_response.list() assert response.is_closed is True @@ -95,7 +95,7 @@ def test_raw_response_list(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_list(self, client: ReplicateClient) -> None: + def test_streaming_response_list(self, client: Replicate) -> None: with client.models.with_streaming_response.list() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -107,7 +107,7 @@ def test_streaming_response_list(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_method_delete(self, client: ReplicateClient) -> None: + def test_method_delete(self, client: Replicate) -> None: model = client.models.delete( model_name="model_name", model_owner="model_owner", @@ -116,7 +116,7 @@ def test_method_delete(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_raw_response_delete(self, client: ReplicateClient) -> None: + def test_raw_response_delete(self, client: Replicate) -> None: response = client.models.with_raw_response.delete( model_name="model_name", model_owner="model_owner", @@ -129,7 +129,7 @@ def test_raw_response_delete(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_delete(self, client: ReplicateClient) -> None: + def test_streaming_response_delete(self, client: Replicate) -> None: with client.models.with_streaming_response.delete( model_name="model_name", model_owner="model_owner", @@ -144,7 +144,7 @@ def test_streaming_response_delete(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_path_params_delete(self, client: ReplicateClient) -> None: + def test_path_params_delete(self, client: Replicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): client.models.with_raw_response.delete( model_name="model_name", @@ -159,7 +159,7 @@ def test_path_params_delete(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_method_get(self, client: ReplicateClient) -> None: + def test_method_get(self, client: Replicate) -> None: model = client.models.get( model_name="model_name", model_owner="model_owner", @@ -168,7 +168,7 @@ def test_method_get(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_raw_response_get(self, client: ReplicateClient) -> None: + def test_raw_response_get(self, client: Replicate) -> None: response = client.models.with_raw_response.get( model_name="model_name", model_owner="model_owner", @@ -181,7 +181,7 @@ def test_raw_response_get(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_get(self, client: ReplicateClient) -> None: + def test_streaming_response_get(self, client: Replicate) -> None: with client.models.with_streaming_response.get( model_name="model_name", model_owner="model_owner", @@ -196,7 +196,7 @@ def test_streaming_response_get(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_path_params_get(self, client: ReplicateClient) -> None: + def test_path_params_get(self, client: Replicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): client.models.with_raw_response.get( model_name="model_name", @@ -211,7 +211,7 @@ def test_path_params_get(self, client: ReplicateClient) -> None: @pytest.mark.skip(reason="Prism doesn't support query methods yet") @parametrize - def test_method_search(self, client: ReplicateClient) -> None: + def test_method_search(self, client: Replicate) -> None: model = client.models.search( body="body", ) @@ -219,7 +219,7 @@ def test_method_search(self, client: ReplicateClient) -> None: @pytest.mark.skip(reason="Prism doesn't support query methods yet") @parametrize - def test_raw_response_search(self, client: ReplicateClient) -> None: + def test_raw_response_search(self, client: Replicate) -> None: response = client.models.with_raw_response.search( body="body", ) @@ -231,7 +231,7 @@ def test_raw_response_search(self, client: ReplicateClient) -> None: @pytest.mark.skip(reason="Prism doesn't support query methods yet") @parametrize - def test_streaming_response_search(self, client: ReplicateClient) -> None: + def test_streaming_response_search(self, client: Replicate) -> None: with client.models.with_streaming_response.search( body="body", ) as response: @@ -249,7 +249,7 @@ class TestAsyncModels: @pytest.mark.skip() @parametrize - async def test_method_create(self, async_client: AsyncReplicateClient) -> None: + async def test_method_create(self, async_client: AsyncReplicate) -> None: model = await async_client.models.create( hardware="hardware", name="name", @@ -260,7 +260,7 @@ async def test_method_create(self, async_client: AsyncReplicateClient) -> None: @pytest.mark.skip() @parametrize - async def test_method_create_with_all_params(self, async_client: AsyncReplicateClient) -> None: + async def test_method_create_with_all_params(self, async_client: AsyncReplicate) -> None: model = await async_client.models.create( hardware="hardware", name="name", @@ -276,7 +276,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncReplicateC @pytest.mark.skip() @parametrize - async def test_raw_response_create(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_create(self, async_client: AsyncReplicate) -> None: response = await async_client.models.with_raw_response.create( hardware="hardware", name="name", @@ -291,7 +291,7 @@ async def test_raw_response_create(self, async_client: AsyncReplicateClient) -> @pytest.mark.skip() @parametrize - async def test_streaming_response_create(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_create(self, async_client: AsyncReplicate) -> None: async with async_client.models.with_streaming_response.create( hardware="hardware", name="name", @@ -308,13 +308,13 @@ async def test_streaming_response_create(self, async_client: AsyncReplicateClien @pytest.mark.skip() @parametrize - async def test_method_list(self, async_client: AsyncReplicateClient) -> None: + async def test_method_list(self, async_client: AsyncReplicate) -> None: model = await async_client.models.list() assert_matches_type(AsyncCursorURLPage[ModelListResponse], model, path=["response"]) @pytest.mark.skip() @parametrize - async def test_raw_response_list(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_list(self, async_client: AsyncReplicate) -> None: response = await async_client.models.with_raw_response.list() assert response.is_closed is True @@ -324,7 +324,7 @@ async def test_raw_response_list(self, async_client: AsyncReplicateClient) -> No @pytest.mark.skip() @parametrize - async def test_streaming_response_list(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_list(self, async_client: AsyncReplicate) -> None: async with async_client.models.with_streaming_response.list() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -336,7 +336,7 @@ async def test_streaming_response_list(self, async_client: AsyncReplicateClient) @pytest.mark.skip() @parametrize - async def test_method_delete(self, async_client: AsyncReplicateClient) -> None: + async def test_method_delete(self, async_client: AsyncReplicate) -> None: model = await async_client.models.delete( model_name="model_name", model_owner="model_owner", @@ -345,7 +345,7 @@ async def test_method_delete(self, async_client: AsyncReplicateClient) -> None: @pytest.mark.skip() @parametrize - async def test_raw_response_delete(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_delete(self, async_client: AsyncReplicate) -> None: response = await async_client.models.with_raw_response.delete( model_name="model_name", model_owner="model_owner", @@ -358,7 +358,7 @@ async def test_raw_response_delete(self, async_client: AsyncReplicateClient) -> @pytest.mark.skip() @parametrize - async def test_streaming_response_delete(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_delete(self, async_client: AsyncReplicate) -> None: async with async_client.models.with_streaming_response.delete( model_name="model_name", model_owner="model_owner", @@ -373,7 +373,7 @@ async def test_streaming_response_delete(self, async_client: AsyncReplicateClien @pytest.mark.skip() @parametrize - async def test_path_params_delete(self, async_client: AsyncReplicateClient) -> None: + async def test_path_params_delete(self, async_client: AsyncReplicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): await async_client.models.with_raw_response.delete( model_name="model_name", @@ -388,7 +388,7 @@ async def test_path_params_delete(self, async_client: AsyncReplicateClient) -> N @pytest.mark.skip() @parametrize - async def test_method_get(self, async_client: AsyncReplicateClient) -> None: + async def test_method_get(self, async_client: AsyncReplicate) -> None: model = await async_client.models.get( model_name="model_name", model_owner="model_owner", @@ -397,7 +397,7 @@ async def test_method_get(self, async_client: AsyncReplicateClient) -> None: @pytest.mark.skip() @parametrize - async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_get(self, async_client: AsyncReplicate) -> None: response = await async_client.models.with_raw_response.get( model_name="model_name", model_owner="model_owner", @@ -410,7 +410,7 @@ async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> Non @pytest.mark.skip() @parametrize - async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_get(self, async_client: AsyncReplicate) -> None: async with async_client.models.with_streaming_response.get( model_name="model_name", model_owner="model_owner", @@ -425,7 +425,7 @@ async def test_streaming_response_get(self, async_client: AsyncReplicateClient) @pytest.mark.skip() @parametrize - async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None: + async def test_path_params_get(self, async_client: AsyncReplicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): await async_client.models.with_raw_response.get( model_name="model_name", @@ -440,7 +440,7 @@ async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None @pytest.mark.skip(reason="Prism doesn't support query methods yet") @parametrize - async def test_method_search(self, async_client: AsyncReplicateClient) -> None: + async def test_method_search(self, async_client: AsyncReplicate) -> None: model = await async_client.models.search( body="body", ) @@ -448,7 +448,7 @@ async def test_method_search(self, async_client: AsyncReplicateClient) -> None: @pytest.mark.skip(reason="Prism doesn't support query methods yet") @parametrize - async def test_raw_response_search(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_search(self, async_client: AsyncReplicate) -> None: response = await async_client.models.with_raw_response.search( body="body", ) @@ -460,7 +460,7 @@ async def test_raw_response_search(self, async_client: AsyncReplicateClient) -> @pytest.mark.skip(reason="Prism doesn't support query methods yet") @parametrize - async def test_streaming_response_search(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_search(self, async_client: AsyncReplicate) -> None: async with async_client.models.with_streaming_response.search( body="body", ) as response: diff --git a/tests/api_resources/test_predictions.py b/tests/api_resources/test_predictions.py index 51bd80d..a43f67e 100644 --- a/tests/api_resources/test_predictions.py +++ b/tests/api_resources/test_predictions.py @@ -7,7 +7,7 @@ import pytest -from replicate import ReplicateClient, AsyncReplicateClient +from replicate import Replicate, AsyncReplicate from tests.utils import assert_matches_type from replicate.types import Prediction from replicate._utils import parse_datetime @@ -21,7 +21,7 @@ class TestPredictions: @pytest.mark.skip() @parametrize - def test_method_create(self, client: ReplicateClient) -> None: + def test_method_create(self, client: Replicate) -> None: prediction = client.predictions.create( input={}, version="version", @@ -30,7 +30,7 @@ def test_method_create(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_method_create_with_all_params(self, client: ReplicateClient) -> None: + def test_method_create_with_all_params(self, client: Replicate) -> None: prediction = client.predictions.create( input={}, version="version", @@ -43,7 +43,7 @@ def test_method_create_with_all_params(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_raw_response_create(self, client: ReplicateClient) -> None: + def test_raw_response_create(self, client: Replicate) -> None: response = client.predictions.with_raw_response.create( input={}, version="version", @@ -56,7 +56,7 @@ def test_raw_response_create(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_create(self, client: ReplicateClient) -> None: + def test_streaming_response_create(self, client: Replicate) -> None: with client.predictions.with_streaming_response.create( input={}, version="version", @@ -71,13 +71,13 @@ def test_streaming_response_create(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_method_list(self, client: ReplicateClient) -> None: + def test_method_list(self, client: Replicate) -> None: prediction = client.predictions.list() assert_matches_type(SyncCursorURLPageWithCreatedFilters[Prediction], prediction, path=["response"]) @pytest.mark.skip() @parametrize - def test_method_list_with_all_params(self, client: ReplicateClient) -> None: + def test_method_list_with_all_params(self, client: Replicate) -> None: prediction = client.predictions.list( created_after=parse_datetime("2025-01-01T00:00:00Z"), created_before=parse_datetime("2025-02-01T00:00:00Z"), @@ -86,7 +86,7 @@ def test_method_list_with_all_params(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_raw_response_list(self, client: ReplicateClient) -> None: + def test_raw_response_list(self, client: Replicate) -> None: response = client.predictions.with_raw_response.list() assert response.is_closed is True @@ -96,7 +96,7 @@ def test_raw_response_list(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_list(self, client: ReplicateClient) -> None: + def test_streaming_response_list(self, client: Replicate) -> None: with client.predictions.with_streaming_response.list() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -108,7 +108,7 @@ def test_streaming_response_list(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_method_cancel(self, client: ReplicateClient) -> None: + def test_method_cancel(self, client: Replicate) -> None: prediction = client.predictions.cancel( "prediction_id", ) @@ -116,7 +116,7 @@ def test_method_cancel(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_raw_response_cancel(self, client: ReplicateClient) -> None: + def test_raw_response_cancel(self, client: Replicate) -> None: response = client.predictions.with_raw_response.cancel( "prediction_id", ) @@ -128,7 +128,7 @@ def test_raw_response_cancel(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_cancel(self, client: ReplicateClient) -> None: + def test_streaming_response_cancel(self, client: Replicate) -> None: with client.predictions.with_streaming_response.cancel( "prediction_id", ) as response: @@ -142,7 +142,7 @@ def test_streaming_response_cancel(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_path_params_cancel(self, client: ReplicateClient) -> None: + def test_path_params_cancel(self, client: Replicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `prediction_id` but received ''"): client.predictions.with_raw_response.cancel( "", @@ -150,7 +150,7 @@ def test_path_params_cancel(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_method_get(self, client: ReplicateClient) -> None: + def test_method_get(self, client: Replicate) -> None: prediction = client.predictions.get( "prediction_id", ) @@ -158,7 +158,7 @@ def test_method_get(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_raw_response_get(self, client: ReplicateClient) -> None: + def test_raw_response_get(self, client: Replicate) -> None: response = client.predictions.with_raw_response.get( "prediction_id", ) @@ -170,7 +170,7 @@ def test_raw_response_get(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_get(self, client: ReplicateClient) -> None: + def test_streaming_response_get(self, client: Replicate) -> None: with client.predictions.with_streaming_response.get( "prediction_id", ) as response: @@ -184,7 +184,7 @@ def test_streaming_response_get(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_path_params_get(self, client: ReplicateClient) -> None: + def test_path_params_get(self, client: Replicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `prediction_id` but received ''"): client.predictions.with_raw_response.get( "", @@ -196,7 +196,7 @@ class TestAsyncPredictions: @pytest.mark.skip() @parametrize - async def test_method_create(self, async_client: AsyncReplicateClient) -> None: + async def test_method_create(self, async_client: AsyncReplicate) -> None: prediction = await async_client.predictions.create( input={}, version="version", @@ -205,7 +205,7 @@ async def test_method_create(self, async_client: AsyncReplicateClient) -> None: @pytest.mark.skip() @parametrize - async def test_method_create_with_all_params(self, async_client: AsyncReplicateClient) -> None: + async def test_method_create_with_all_params(self, async_client: AsyncReplicate) -> None: prediction = await async_client.predictions.create( input={}, version="version", @@ -218,7 +218,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncReplicateC @pytest.mark.skip() @parametrize - async def test_raw_response_create(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_create(self, async_client: AsyncReplicate) -> None: response = await async_client.predictions.with_raw_response.create( input={}, version="version", @@ -231,7 +231,7 @@ async def test_raw_response_create(self, async_client: AsyncReplicateClient) -> @pytest.mark.skip() @parametrize - async def test_streaming_response_create(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_create(self, async_client: AsyncReplicate) -> None: async with async_client.predictions.with_streaming_response.create( input={}, version="version", @@ -246,13 +246,13 @@ async def test_streaming_response_create(self, async_client: AsyncReplicateClien @pytest.mark.skip() @parametrize - async def test_method_list(self, async_client: AsyncReplicateClient) -> None: + async def test_method_list(self, async_client: AsyncReplicate) -> None: prediction = await async_client.predictions.list() assert_matches_type(AsyncCursorURLPageWithCreatedFilters[Prediction], prediction, path=["response"]) @pytest.mark.skip() @parametrize - async def test_method_list_with_all_params(self, async_client: AsyncReplicateClient) -> None: + async def test_method_list_with_all_params(self, async_client: AsyncReplicate) -> None: prediction = await async_client.predictions.list( created_after=parse_datetime("2025-01-01T00:00:00Z"), created_before=parse_datetime("2025-02-01T00:00:00Z"), @@ -261,7 +261,7 @@ async def test_method_list_with_all_params(self, async_client: AsyncReplicateCli @pytest.mark.skip() @parametrize - async def test_raw_response_list(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_list(self, async_client: AsyncReplicate) -> None: response = await async_client.predictions.with_raw_response.list() assert response.is_closed is True @@ -271,7 +271,7 @@ async def test_raw_response_list(self, async_client: AsyncReplicateClient) -> No @pytest.mark.skip() @parametrize - async def test_streaming_response_list(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_list(self, async_client: AsyncReplicate) -> None: async with async_client.predictions.with_streaming_response.list() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -283,7 +283,7 @@ async def test_streaming_response_list(self, async_client: AsyncReplicateClient) @pytest.mark.skip() @parametrize - async def test_method_cancel(self, async_client: AsyncReplicateClient) -> None: + async def test_method_cancel(self, async_client: AsyncReplicate) -> None: prediction = await async_client.predictions.cancel( "prediction_id", ) @@ -291,7 +291,7 @@ async def test_method_cancel(self, async_client: AsyncReplicateClient) -> None: @pytest.mark.skip() @parametrize - async def test_raw_response_cancel(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_cancel(self, async_client: AsyncReplicate) -> None: response = await async_client.predictions.with_raw_response.cancel( "prediction_id", ) @@ -303,7 +303,7 @@ async def test_raw_response_cancel(self, async_client: AsyncReplicateClient) -> @pytest.mark.skip() @parametrize - async def test_streaming_response_cancel(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_cancel(self, async_client: AsyncReplicate) -> None: async with async_client.predictions.with_streaming_response.cancel( "prediction_id", ) as response: @@ -317,7 +317,7 @@ async def test_streaming_response_cancel(self, async_client: AsyncReplicateClien @pytest.mark.skip() @parametrize - async def test_path_params_cancel(self, async_client: AsyncReplicateClient) -> None: + async def test_path_params_cancel(self, async_client: AsyncReplicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `prediction_id` but received ''"): await async_client.predictions.with_raw_response.cancel( "", @@ -325,7 +325,7 @@ async def test_path_params_cancel(self, async_client: AsyncReplicateClient) -> N @pytest.mark.skip() @parametrize - async def test_method_get(self, async_client: AsyncReplicateClient) -> None: + async def test_method_get(self, async_client: AsyncReplicate) -> None: prediction = await async_client.predictions.get( "prediction_id", ) @@ -333,7 +333,7 @@ async def test_method_get(self, async_client: AsyncReplicateClient) -> None: @pytest.mark.skip() @parametrize - async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_get(self, async_client: AsyncReplicate) -> None: response = await async_client.predictions.with_raw_response.get( "prediction_id", ) @@ -345,7 +345,7 @@ async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> Non @pytest.mark.skip() @parametrize - async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_get(self, async_client: AsyncReplicate) -> None: async with async_client.predictions.with_streaming_response.get( "prediction_id", ) as response: @@ -359,7 +359,7 @@ async def test_streaming_response_get(self, async_client: AsyncReplicateClient) @pytest.mark.skip() @parametrize - async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None: + async def test_path_params_get(self, async_client: AsyncReplicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `prediction_id` but received ''"): await async_client.predictions.with_raw_response.get( "", diff --git a/tests/api_resources/test_trainings.py b/tests/api_resources/test_trainings.py index e1b0572..9e006d3 100644 --- a/tests/api_resources/test_trainings.py +++ b/tests/api_resources/test_trainings.py @@ -7,7 +7,7 @@ import pytest -from replicate import ReplicateClient, AsyncReplicateClient +from replicate import Replicate, AsyncReplicate from tests.utils import assert_matches_type from replicate.types import ( TrainingGetResponse, @@ -25,7 +25,7 @@ class TestTrainings: @pytest.mark.skip() @parametrize - def test_method_create(self, client: ReplicateClient) -> None: + def test_method_create(self, client: Replicate) -> None: training = client.trainings.create( version_id="version_id", model_owner="model_owner", @@ -37,7 +37,7 @@ def test_method_create(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_method_create_with_all_params(self, client: ReplicateClient) -> None: + def test_method_create_with_all_params(self, client: Replicate) -> None: training = client.trainings.create( version_id="version_id", model_owner="model_owner", @@ -51,7 +51,7 @@ def test_method_create_with_all_params(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_raw_response_create(self, client: ReplicateClient) -> None: + def test_raw_response_create(self, client: Replicate) -> None: response = client.trainings.with_raw_response.create( version_id="version_id", model_owner="model_owner", @@ -67,7 +67,7 @@ def test_raw_response_create(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_create(self, client: ReplicateClient) -> None: + def test_streaming_response_create(self, client: Replicate) -> None: with client.trainings.with_streaming_response.create( version_id="version_id", model_owner="model_owner", @@ -85,7 +85,7 @@ def test_streaming_response_create(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_path_params_create(self, client: ReplicateClient) -> None: + def test_path_params_create(self, client: Replicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): client.trainings.with_raw_response.create( version_id="version_id", @@ -115,13 +115,13 @@ def test_path_params_create(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_method_list(self, client: ReplicateClient) -> None: + def test_method_list(self, client: Replicate) -> None: training = client.trainings.list() assert_matches_type(SyncCursorURLPage[TrainingListResponse], training, path=["response"]) @pytest.mark.skip() @parametrize - def test_raw_response_list(self, client: ReplicateClient) -> None: + def test_raw_response_list(self, client: Replicate) -> None: response = client.trainings.with_raw_response.list() assert response.is_closed is True @@ -131,7 +131,7 @@ def test_raw_response_list(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_list(self, client: ReplicateClient) -> None: + def test_streaming_response_list(self, client: Replicate) -> None: with client.trainings.with_streaming_response.list() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -143,7 +143,7 @@ def test_streaming_response_list(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_method_cancel(self, client: ReplicateClient) -> None: + def test_method_cancel(self, client: Replicate) -> None: training = client.trainings.cancel( "training_id", ) @@ -151,7 +151,7 @@ def test_method_cancel(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_raw_response_cancel(self, client: ReplicateClient) -> None: + def test_raw_response_cancel(self, client: Replicate) -> None: response = client.trainings.with_raw_response.cancel( "training_id", ) @@ -163,7 +163,7 @@ def test_raw_response_cancel(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_cancel(self, client: ReplicateClient) -> None: + def test_streaming_response_cancel(self, client: Replicate) -> None: with client.trainings.with_streaming_response.cancel( "training_id", ) as response: @@ -177,7 +177,7 @@ def test_streaming_response_cancel(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_path_params_cancel(self, client: ReplicateClient) -> None: + def test_path_params_cancel(self, client: Replicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `training_id` but received ''"): client.trainings.with_raw_response.cancel( "", @@ -185,7 +185,7 @@ def test_path_params_cancel(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_method_get(self, client: ReplicateClient) -> None: + def test_method_get(self, client: Replicate) -> None: training = client.trainings.get( "training_id", ) @@ -193,7 +193,7 @@ def test_method_get(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_raw_response_get(self, client: ReplicateClient) -> None: + def test_raw_response_get(self, client: Replicate) -> None: response = client.trainings.with_raw_response.get( "training_id", ) @@ -205,7 +205,7 @@ def test_raw_response_get(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_get(self, client: ReplicateClient) -> None: + def test_streaming_response_get(self, client: Replicate) -> None: with client.trainings.with_streaming_response.get( "training_id", ) as response: @@ -219,7 +219,7 @@ def test_streaming_response_get(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_path_params_get(self, client: ReplicateClient) -> None: + def test_path_params_get(self, client: Replicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `training_id` but received ''"): client.trainings.with_raw_response.get( "", @@ -231,7 +231,7 @@ class TestAsyncTrainings: @pytest.mark.skip() @parametrize - async def test_method_create(self, async_client: AsyncReplicateClient) -> None: + async def test_method_create(self, async_client: AsyncReplicate) -> None: training = await async_client.trainings.create( version_id="version_id", model_owner="model_owner", @@ -243,7 +243,7 @@ async def test_method_create(self, async_client: AsyncReplicateClient) -> None: @pytest.mark.skip() @parametrize - async def test_method_create_with_all_params(self, async_client: AsyncReplicateClient) -> None: + async def test_method_create_with_all_params(self, async_client: AsyncReplicate) -> None: training = await async_client.trainings.create( version_id="version_id", model_owner="model_owner", @@ -257,7 +257,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncReplicateC @pytest.mark.skip() @parametrize - async def test_raw_response_create(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_create(self, async_client: AsyncReplicate) -> None: response = await async_client.trainings.with_raw_response.create( version_id="version_id", model_owner="model_owner", @@ -273,7 +273,7 @@ async def test_raw_response_create(self, async_client: AsyncReplicateClient) -> @pytest.mark.skip() @parametrize - async def test_streaming_response_create(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_create(self, async_client: AsyncReplicate) -> None: async with async_client.trainings.with_streaming_response.create( version_id="version_id", model_owner="model_owner", @@ -291,7 +291,7 @@ async def test_streaming_response_create(self, async_client: AsyncReplicateClien @pytest.mark.skip() @parametrize - async def test_path_params_create(self, async_client: AsyncReplicateClient) -> None: + async def test_path_params_create(self, async_client: AsyncReplicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): await async_client.trainings.with_raw_response.create( version_id="version_id", @@ -321,13 +321,13 @@ async def test_path_params_create(self, async_client: AsyncReplicateClient) -> N @pytest.mark.skip() @parametrize - async def test_method_list(self, async_client: AsyncReplicateClient) -> None: + async def test_method_list(self, async_client: AsyncReplicate) -> None: training = await async_client.trainings.list() assert_matches_type(AsyncCursorURLPage[TrainingListResponse], training, path=["response"]) @pytest.mark.skip() @parametrize - async def test_raw_response_list(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_list(self, async_client: AsyncReplicate) -> None: response = await async_client.trainings.with_raw_response.list() assert response.is_closed is True @@ -337,7 +337,7 @@ async def test_raw_response_list(self, async_client: AsyncReplicateClient) -> No @pytest.mark.skip() @parametrize - async def test_streaming_response_list(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_list(self, async_client: AsyncReplicate) -> None: async with async_client.trainings.with_streaming_response.list() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -349,7 +349,7 @@ async def test_streaming_response_list(self, async_client: AsyncReplicateClient) @pytest.mark.skip() @parametrize - async def test_method_cancel(self, async_client: AsyncReplicateClient) -> None: + async def test_method_cancel(self, async_client: AsyncReplicate) -> None: training = await async_client.trainings.cancel( "training_id", ) @@ -357,7 +357,7 @@ async def test_method_cancel(self, async_client: AsyncReplicateClient) -> None: @pytest.mark.skip() @parametrize - async def test_raw_response_cancel(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_cancel(self, async_client: AsyncReplicate) -> None: response = await async_client.trainings.with_raw_response.cancel( "training_id", ) @@ -369,7 +369,7 @@ async def test_raw_response_cancel(self, async_client: AsyncReplicateClient) -> @pytest.mark.skip() @parametrize - async def test_streaming_response_cancel(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_cancel(self, async_client: AsyncReplicate) -> None: async with async_client.trainings.with_streaming_response.cancel( "training_id", ) as response: @@ -383,7 +383,7 @@ async def test_streaming_response_cancel(self, async_client: AsyncReplicateClien @pytest.mark.skip() @parametrize - async def test_path_params_cancel(self, async_client: AsyncReplicateClient) -> None: + async def test_path_params_cancel(self, async_client: AsyncReplicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `training_id` but received ''"): await async_client.trainings.with_raw_response.cancel( "", @@ -391,7 +391,7 @@ async def test_path_params_cancel(self, async_client: AsyncReplicateClient) -> N @pytest.mark.skip() @parametrize - async def test_method_get(self, async_client: AsyncReplicateClient) -> None: + async def test_method_get(self, async_client: AsyncReplicate) -> None: training = await async_client.trainings.get( "training_id", ) @@ -399,7 +399,7 @@ async def test_method_get(self, async_client: AsyncReplicateClient) -> None: @pytest.mark.skip() @parametrize - async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_get(self, async_client: AsyncReplicate) -> None: response = await async_client.trainings.with_raw_response.get( "training_id", ) @@ -411,7 +411,7 @@ async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> Non @pytest.mark.skip() @parametrize - async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_get(self, async_client: AsyncReplicate) -> None: async with async_client.trainings.with_streaming_response.get( "training_id", ) as response: @@ -425,7 +425,7 @@ async def test_streaming_response_get(self, async_client: AsyncReplicateClient) @pytest.mark.skip() @parametrize - async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None: + async def test_path_params_get(self, async_client: AsyncReplicate) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `training_id` but received ''"): await async_client.trainings.with_raw_response.get( "", diff --git a/tests/api_resources/webhooks/default/test_secret.py b/tests/api_resources/webhooks/default/test_secret.py index c95ab8a..4ccc3d4 100644 --- a/tests/api_resources/webhooks/default/test_secret.py +++ b/tests/api_resources/webhooks/default/test_secret.py @@ -7,7 +7,7 @@ import pytest -from replicate import ReplicateClient, AsyncReplicateClient +from replicate import Replicate, AsyncReplicate from tests.utils import assert_matches_type from replicate.types.webhooks.default import SecretGetResponse @@ -19,13 +19,13 @@ class TestSecret: @pytest.mark.skip() @parametrize - def test_method_get(self, client: ReplicateClient) -> None: + def test_method_get(self, client: Replicate) -> None: secret = client.webhooks.default.secret.get() assert_matches_type(SecretGetResponse, secret, path=["response"]) @pytest.mark.skip() @parametrize - def test_raw_response_get(self, client: ReplicateClient) -> None: + def test_raw_response_get(self, client: Replicate) -> None: response = client.webhooks.default.secret.with_raw_response.get() assert response.is_closed is True @@ -35,7 +35,7 @@ def test_raw_response_get(self, client: ReplicateClient) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_get(self, client: ReplicateClient) -> None: + def test_streaming_response_get(self, client: Replicate) -> None: with client.webhooks.default.secret.with_streaming_response.get() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -51,13 +51,13 @@ class TestAsyncSecret: @pytest.mark.skip() @parametrize - async def test_method_get(self, async_client: AsyncReplicateClient) -> None: + async def test_method_get(self, async_client: AsyncReplicate) -> None: secret = await async_client.webhooks.default.secret.get() assert_matches_type(SecretGetResponse, secret, path=["response"]) @pytest.mark.skip() @parametrize - async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None: + async def test_raw_response_get(self, async_client: AsyncReplicate) -> None: response = await async_client.webhooks.default.secret.with_raw_response.get() assert response.is_closed is True @@ -67,7 +67,7 @@ async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> Non @pytest.mark.skip() @parametrize - async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None: + async def test_streaming_response_get(self, async_client: AsyncReplicate) -> None: async with async_client.webhooks.default.secret.with_streaming_response.get() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" diff --git a/tests/conftest.py b/tests/conftest.py index 7d1538b..79342b0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,7 @@ import pytest from pytest_asyncio import is_async_test -from replicate import ReplicateClient, AsyncReplicateClient +from replicate import Replicate, AsyncReplicate if TYPE_CHECKING: from _pytest.fixtures import FixtureRequest # pyright: ignore[reportPrivateImportUsage] @@ -32,22 +32,22 @@ def pytest_collection_modifyitems(items: list[pytest.Function]) -> None: @pytest.fixture(scope="session") -def client(request: FixtureRequest) -> Iterator[ReplicateClient]: +def client(request: FixtureRequest) -> Iterator[Replicate]: strict = getattr(request, "param", True) if not isinstance(strict, bool): raise TypeError(f"Unexpected fixture parameter type {type(strict)}, expected {bool}") - with ReplicateClient(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=strict) as client: + with Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=strict) as client: yield client @pytest.fixture(scope="session") -async def async_client(request: FixtureRequest) -> AsyncIterator[AsyncReplicateClient]: +async def async_client(request: FixtureRequest) -> AsyncIterator[AsyncReplicate]: strict = getattr(request, "param", True) if not isinstance(strict, bool): raise TypeError(f"Unexpected fixture parameter type {type(strict)}, expected {bool}") - async with AsyncReplicateClient( + async with AsyncReplicate( base_url=base_url, bearer_token=bearer_token, _strict_response_validation=strict ) as client: yield client diff --git a/tests/test_client.py b/tests/test_client.py index 129afa4..3ffb5ab 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -21,11 +21,11 @@ from respx import MockRouter from pydantic import ValidationError -from replicate import ReplicateClient, AsyncReplicateClient, APIResponseValidationError +from replicate import Replicate, AsyncReplicate, APIResponseValidationError from replicate._types import Omit from replicate._models import BaseModel, FinalRequestOptions from replicate._constants import RAW_RESPONSE_HEADER -from replicate._exceptions import APIStatusError, APITimeoutError, ReplicateClientError, APIResponseValidationError +from replicate._exceptions import APIStatusError, ReplicateError, APITimeoutError, APIResponseValidationError from replicate._base_client import ( DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, @@ -49,7 +49,7 @@ def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float: return 0.1 -def _get_open_connections(client: ReplicateClient | AsyncReplicateClient) -> int: +def _get_open_connections(client: Replicate | AsyncReplicate) -> int: transport = client._client._transport assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport) @@ -57,8 +57,8 @@ def _get_open_connections(client: ReplicateClient | AsyncReplicateClient) -> int return len(pool._requests) -class TestReplicateClient: - client = ReplicateClient(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) +class TestReplicate: + client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) @pytest.mark.respx(base_url=base_url) def test_raw_response(self, respx_mock: MockRouter) -> None: @@ -105,7 +105,7 @@ def test_copy_default_options(self) -> None: assert isinstance(self.client.timeout, httpx.Timeout) def test_copy_default_headers(self) -> None: - client = ReplicateClient( + client = Replicate( base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True, @@ -142,7 +142,7 @@ def test_copy_default_headers(self) -> None: client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"}) def test_copy_default_query(self) -> None: - client = ReplicateClient( + client = Replicate( base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True, default_query={"foo": "bar"} ) assert _get_params(client)["foo"] == "bar" @@ -267,7 +267,7 @@ def test_request_timeout(self) -> None: assert timeout == httpx.Timeout(100.0) def test_client_timeout_option(self) -> None: - client = ReplicateClient( + client = Replicate( base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True, timeout=httpx.Timeout(0) ) @@ -278,7 +278,7 @@ def test_client_timeout_option(self) -> None: def test_http_client_timeout_option(self) -> None: # custom timeout given to the httpx client should be used with httpx.Client(timeout=None) as http_client: - client = ReplicateClient( + client = Replicate( base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True, http_client=http_client ) @@ -288,7 +288,7 @@ def test_http_client_timeout_option(self) -> None: # no timeout given to the httpx client should not use the httpx default with httpx.Client() as http_client: - client = ReplicateClient( + client = Replicate( base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True, http_client=http_client ) @@ -298,7 +298,7 @@ def test_http_client_timeout_option(self) -> None: # explicitly passing the default timeout currently results in it being ignored with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client: - client = ReplicateClient( + client = Replicate( base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True, http_client=http_client ) @@ -309,7 +309,7 @@ def test_http_client_timeout_option(self) -> None: async def test_invalid_http_client(self) -> None: with pytest.raises(TypeError, match="Invalid `http_client` arg"): async with httpx.AsyncClient() as http_client: - ReplicateClient( + Replicate( base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True, @@ -317,7 +317,7 @@ async def test_invalid_http_client(self) -> None: ) def test_default_headers_option(self) -> None: - client = ReplicateClient( + client = Replicate( base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True, @@ -327,7 +327,7 @@ def test_default_headers_option(self) -> None: assert request.headers.get("x-foo") == "bar" assert request.headers.get("x-stainless-lang") == "python" - client2 = ReplicateClient( + client2 = Replicate( base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True, @@ -341,17 +341,17 @@ def test_default_headers_option(self) -> None: assert request.headers.get("x-stainless-lang") == "my-overriding-header" def test_validate_headers(self) -> None: - client = ReplicateClient(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) + client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) request = client._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("Authorization") == f"Bearer {bearer_token}" - with pytest.raises(ReplicateClientError): + with pytest.raises(ReplicateError): with update_env(**{"REPLICATE_API_TOKEN": Omit()}): - client2 = ReplicateClient(base_url=base_url, bearer_token=None, _strict_response_validation=True) + client2 = Replicate(base_url=base_url, bearer_token=None, _strict_response_validation=True) _ = client2 def test_default_query_option(self) -> None: - client = ReplicateClient( + client = Replicate( base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True, @@ -468,7 +468,7 @@ def test_request_extra_query(self) -> None: params = dict(request.url.params) assert params == {"foo": "2"} - def test_multipart_repeating_array(self, client: ReplicateClient) -> None: + def test_multipart_repeating_array(self, client: Replicate) -> None: request = client._build_request( FinalRequestOptions.construct( method="get", @@ -555,7 +555,7 @@ class Model(BaseModel): assert response.foo == 2 def test_base_url_setter(self) -> None: - client = ReplicateClient( + client = Replicate( base_url="https://example.com/from_init", bearer_token=bearer_token, _strict_response_validation=True ) assert client.base_url == "https://example.com/from_init/" @@ -565,19 +565,19 @@ def test_base_url_setter(self) -> None: assert client.base_url == "https://example.com/from_setter/" def test_base_url_env(self) -> None: - with update_env(REPLICATE_CLIENT_BASE_URL="http://localhost:5000/from/env"): - client = ReplicateClient(bearer_token=bearer_token, _strict_response_validation=True) + with update_env(REPLICATE_BASE_URL="http://localhost:5000/from/env"): + client = Replicate(bearer_token=bearer_token, _strict_response_validation=True) assert client.base_url == "http://localhost:5000/from/env/" @pytest.mark.parametrize( "client", [ - ReplicateClient( + Replicate( base_url="http://localhost:5000/custom/path/", bearer_token=bearer_token, _strict_response_validation=True, ), - ReplicateClient( + Replicate( base_url="http://localhost:5000/custom/path/", bearer_token=bearer_token, _strict_response_validation=True, @@ -586,7 +586,7 @@ def test_base_url_env(self) -> None: ], ids=["standard", "custom http client"], ) - def test_base_url_trailing_slash(self, client: ReplicateClient) -> None: + def test_base_url_trailing_slash(self, client: Replicate) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -599,12 +599,12 @@ def test_base_url_trailing_slash(self, client: ReplicateClient) -> None: @pytest.mark.parametrize( "client", [ - ReplicateClient( + Replicate( base_url="http://localhost:5000/custom/path/", bearer_token=bearer_token, _strict_response_validation=True, ), - ReplicateClient( + Replicate( base_url="http://localhost:5000/custom/path/", bearer_token=bearer_token, _strict_response_validation=True, @@ -613,7 +613,7 @@ def test_base_url_trailing_slash(self, client: ReplicateClient) -> None: ], ids=["standard", "custom http client"], ) - def test_base_url_no_trailing_slash(self, client: ReplicateClient) -> None: + def test_base_url_no_trailing_slash(self, client: Replicate) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -626,12 +626,12 @@ def test_base_url_no_trailing_slash(self, client: ReplicateClient) -> None: @pytest.mark.parametrize( "client", [ - ReplicateClient( + Replicate( base_url="http://localhost:5000/custom/path/", bearer_token=bearer_token, _strict_response_validation=True, ), - ReplicateClient( + Replicate( base_url="http://localhost:5000/custom/path/", bearer_token=bearer_token, _strict_response_validation=True, @@ -640,7 +640,7 @@ def test_base_url_no_trailing_slash(self, client: ReplicateClient) -> None: ], ids=["standard", "custom http client"], ) - def test_absolute_request_url(self, client: ReplicateClient) -> None: + def test_absolute_request_url(self, client: Replicate) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -651,7 +651,7 @@ def test_absolute_request_url(self, client: ReplicateClient) -> None: assert request.url == "https://myapi.com/foo" def test_copied_client_does_not_close_http(self) -> None: - client = ReplicateClient(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) + client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) assert not client.is_closed() copied = client.copy() @@ -662,7 +662,7 @@ def test_copied_client_does_not_close_http(self) -> None: assert not client.is_closed() def test_client_context_manager(self) -> None: - client = ReplicateClient(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) + client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) with client as c2: assert c2 is client assert not c2.is_closed() @@ -683,7 +683,7 @@ class Model(BaseModel): def test_client_max_retries_validation(self) -> None: with pytest.raises(TypeError, match=r"max_retries cannot be None"): - ReplicateClient( + Replicate( base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True, @@ -697,12 +697,12 @@ class Model(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format")) - strict_client = ReplicateClient(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) + strict_client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) with pytest.raises(APIResponseValidationError): strict_client.get("/foo", cast_to=Model) - client = ReplicateClient(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=False) + client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=False) response = client.get("/foo", cast_to=Model) assert isinstance(response, str) # type: ignore[unreachable] @@ -730,7 +730,7 @@ class Model(BaseModel): ) @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None: - client = ReplicateClient(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) + client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) headers = httpx.Headers({"retry-after": retry_after}) options = FinalRequestOptions(method="get", url="/foo", max_retries=3) @@ -763,7 +763,7 @@ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> Non @pytest.mark.parametrize("failure_mode", ["status", "exception"]) def test_retries_taken( self, - client: ReplicateClient, + client: Replicate, failures_before_success: int, failure_mode: Literal["status", "exception"], respx_mock: MockRouter, @@ -792,7 +792,7 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: @mock.patch("replicate._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) def test_omit_retry_count_header( - self, client: ReplicateClient, failures_before_success: int, respx_mock: MockRouter + self, client: Replicate, failures_before_success: int, respx_mock: MockRouter ) -> None: client = client.with_options(max_retries=4) @@ -815,7 +815,7 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: @mock.patch("replicate._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) def test_overwrite_retry_count_header( - self, client: ReplicateClient, failures_before_success: int, respx_mock: MockRouter + self, client: Replicate, failures_before_success: int, respx_mock: MockRouter ) -> None: client = client.with_options(max_retries=4) @@ -835,8 +835,8 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: assert response.http_request.headers.get("x-stainless-retry-count") == "42" -class TestAsyncReplicateClient: - client = AsyncReplicateClient(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) +class TestAsyncReplicate: + client = AsyncReplicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) @pytest.mark.respx(base_url=base_url) @pytest.mark.asyncio @@ -885,7 +885,7 @@ def test_copy_default_options(self) -> None: assert isinstance(self.client.timeout, httpx.Timeout) def test_copy_default_headers(self) -> None: - client = AsyncReplicateClient( + client = AsyncReplicate( base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True, @@ -922,7 +922,7 @@ def test_copy_default_headers(self) -> None: client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"}) def test_copy_default_query(self) -> None: - client = AsyncReplicateClient( + client = AsyncReplicate( base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True, default_query={"foo": "bar"} ) assert _get_params(client)["foo"] == "bar" @@ -1047,7 +1047,7 @@ async def test_request_timeout(self) -> None: assert timeout == httpx.Timeout(100.0) async def test_client_timeout_option(self) -> None: - client = AsyncReplicateClient( + client = AsyncReplicate( base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True, timeout=httpx.Timeout(0) ) @@ -1058,7 +1058,7 @@ async def test_client_timeout_option(self) -> None: async def test_http_client_timeout_option(self) -> None: # custom timeout given to the httpx client should be used async with httpx.AsyncClient(timeout=None) as http_client: - client = AsyncReplicateClient( + client = AsyncReplicate( base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True, http_client=http_client ) @@ -1068,7 +1068,7 @@ async def test_http_client_timeout_option(self) -> None: # no timeout given to the httpx client should not use the httpx default async with httpx.AsyncClient() as http_client: - client = AsyncReplicateClient( + client = AsyncReplicate( base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True, http_client=http_client ) @@ -1078,7 +1078,7 @@ async def test_http_client_timeout_option(self) -> None: # explicitly passing the default timeout currently results in it being ignored async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client: - client = AsyncReplicateClient( + client = AsyncReplicate( base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True, http_client=http_client ) @@ -1089,7 +1089,7 @@ async def test_http_client_timeout_option(self) -> None: def test_invalid_http_client(self) -> None: with pytest.raises(TypeError, match="Invalid `http_client` arg"): with httpx.Client() as http_client: - AsyncReplicateClient( + AsyncReplicate( base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True, @@ -1097,7 +1097,7 @@ def test_invalid_http_client(self) -> None: ) def test_default_headers_option(self) -> None: - client = AsyncReplicateClient( + client = AsyncReplicate( base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True, @@ -1107,7 +1107,7 @@ def test_default_headers_option(self) -> None: assert request.headers.get("x-foo") == "bar" assert request.headers.get("x-stainless-lang") == "python" - client2 = AsyncReplicateClient( + client2 = AsyncReplicate( base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True, @@ -1121,17 +1121,17 @@ def test_default_headers_option(self) -> None: assert request.headers.get("x-stainless-lang") == "my-overriding-header" def test_validate_headers(self) -> None: - client = AsyncReplicateClient(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) + client = AsyncReplicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) request = client._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("Authorization") == f"Bearer {bearer_token}" - with pytest.raises(ReplicateClientError): + with pytest.raises(ReplicateError): with update_env(**{"REPLICATE_API_TOKEN": Omit()}): - client2 = AsyncReplicateClient(base_url=base_url, bearer_token=None, _strict_response_validation=True) + client2 = AsyncReplicate(base_url=base_url, bearer_token=None, _strict_response_validation=True) _ = client2 def test_default_query_option(self) -> None: - client = AsyncReplicateClient( + client = AsyncReplicate( base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True, @@ -1248,7 +1248,7 @@ def test_request_extra_query(self) -> None: params = dict(request.url.params) assert params == {"foo": "2"} - def test_multipart_repeating_array(self, async_client: AsyncReplicateClient) -> None: + def test_multipart_repeating_array(self, async_client: AsyncReplicate) -> None: request = async_client._build_request( FinalRequestOptions.construct( method="get", @@ -1335,7 +1335,7 @@ class Model(BaseModel): assert response.foo == 2 def test_base_url_setter(self) -> None: - client = AsyncReplicateClient( + client = AsyncReplicate( base_url="https://example.com/from_init", bearer_token=bearer_token, _strict_response_validation=True ) assert client.base_url == "https://example.com/from_init/" @@ -1345,19 +1345,19 @@ def test_base_url_setter(self) -> None: assert client.base_url == "https://example.com/from_setter/" def test_base_url_env(self) -> None: - with update_env(REPLICATE_CLIENT_BASE_URL="http://localhost:5000/from/env"): - client = AsyncReplicateClient(bearer_token=bearer_token, _strict_response_validation=True) + with update_env(REPLICATE_BASE_URL="http://localhost:5000/from/env"): + client = AsyncReplicate(bearer_token=bearer_token, _strict_response_validation=True) assert client.base_url == "http://localhost:5000/from/env/" @pytest.mark.parametrize( "client", [ - AsyncReplicateClient( + AsyncReplicate( base_url="http://localhost:5000/custom/path/", bearer_token=bearer_token, _strict_response_validation=True, ), - AsyncReplicateClient( + AsyncReplicate( base_url="http://localhost:5000/custom/path/", bearer_token=bearer_token, _strict_response_validation=True, @@ -1366,7 +1366,7 @@ def test_base_url_env(self) -> None: ], ids=["standard", "custom http client"], ) - def test_base_url_trailing_slash(self, client: AsyncReplicateClient) -> None: + def test_base_url_trailing_slash(self, client: AsyncReplicate) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -1379,12 +1379,12 @@ def test_base_url_trailing_slash(self, client: AsyncReplicateClient) -> None: @pytest.mark.parametrize( "client", [ - AsyncReplicateClient( + AsyncReplicate( base_url="http://localhost:5000/custom/path/", bearer_token=bearer_token, _strict_response_validation=True, ), - AsyncReplicateClient( + AsyncReplicate( base_url="http://localhost:5000/custom/path/", bearer_token=bearer_token, _strict_response_validation=True, @@ -1393,7 +1393,7 @@ def test_base_url_trailing_slash(self, client: AsyncReplicateClient) -> None: ], ids=["standard", "custom http client"], ) - def test_base_url_no_trailing_slash(self, client: AsyncReplicateClient) -> None: + def test_base_url_no_trailing_slash(self, client: AsyncReplicate) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -1406,12 +1406,12 @@ def test_base_url_no_trailing_slash(self, client: AsyncReplicateClient) -> None: @pytest.mark.parametrize( "client", [ - AsyncReplicateClient( + AsyncReplicate( base_url="http://localhost:5000/custom/path/", bearer_token=bearer_token, _strict_response_validation=True, ), - AsyncReplicateClient( + AsyncReplicate( base_url="http://localhost:5000/custom/path/", bearer_token=bearer_token, _strict_response_validation=True, @@ -1420,7 +1420,7 @@ def test_base_url_no_trailing_slash(self, client: AsyncReplicateClient) -> None: ], ids=["standard", "custom http client"], ) - def test_absolute_request_url(self, client: AsyncReplicateClient) -> None: + def test_absolute_request_url(self, client: AsyncReplicate) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -1431,7 +1431,7 @@ def test_absolute_request_url(self, client: AsyncReplicateClient) -> None: assert request.url == "https://myapi.com/foo" async def test_copied_client_does_not_close_http(self) -> None: - client = AsyncReplicateClient(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) + client = AsyncReplicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) assert not client.is_closed() copied = client.copy() @@ -1443,7 +1443,7 @@ async def test_copied_client_does_not_close_http(self) -> None: assert not client.is_closed() async def test_client_context_manager(self) -> None: - client = AsyncReplicateClient(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) + client = AsyncReplicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) async with client as c2: assert c2 is client assert not c2.is_closed() @@ -1465,7 +1465,7 @@ class Model(BaseModel): async def test_client_max_retries_validation(self) -> None: with pytest.raises(TypeError, match=r"max_retries cannot be None"): - AsyncReplicateClient( + AsyncReplicate( base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True, @@ -1480,14 +1480,12 @@ class Model(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format")) - strict_client = AsyncReplicateClient( - base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True - ) + strict_client = AsyncReplicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) with pytest.raises(APIResponseValidationError): await strict_client.get("/foo", cast_to=Model) - client = AsyncReplicateClient(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=False) + client = AsyncReplicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=False) response = await client.get("/foo", cast_to=Model) assert isinstance(response, str) # type: ignore[unreachable] @@ -1516,7 +1514,7 @@ class Model(BaseModel): @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) @pytest.mark.asyncio async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None: - client = AsyncReplicateClient(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) + client = AsyncReplicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) headers = httpx.Headers({"retry-after": retry_after}) options = FinalRequestOptions(method="get", url="/foo", max_retries=3) @@ -1554,7 +1552,7 @@ async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) @pytest.mark.parametrize("failure_mode", ["status", "exception"]) async def test_retries_taken( self, - async_client: AsyncReplicateClient, + async_client: AsyncReplicate, failures_before_success: int, failure_mode: Literal["status", "exception"], respx_mock: MockRouter, @@ -1584,7 +1582,7 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: @pytest.mark.respx(base_url=base_url) @pytest.mark.asyncio async def test_omit_retry_count_header( - self, async_client: AsyncReplicateClient, failures_before_success: int, respx_mock: MockRouter + self, async_client: AsyncReplicate, failures_before_success: int, respx_mock: MockRouter ) -> None: client = async_client.with_options(max_retries=4) @@ -1608,7 +1606,7 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: @pytest.mark.respx(base_url=base_url) @pytest.mark.asyncio async def test_overwrite_retry_count_header( - self, async_client: AsyncReplicateClient, failures_before_success: int, respx_mock: MockRouter + self, async_client: AsyncReplicate, failures_before_success: int, respx_mock: MockRouter ) -> None: client = async_client.with_options(max_retries=4) diff --git a/tests/test_response.py b/tests/test_response.py index c935876..30e5c2f 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -6,7 +6,7 @@ import pytest import pydantic -from replicate import BaseModel, ReplicateClient, AsyncReplicateClient +from replicate import BaseModel, Replicate, AsyncReplicate from replicate._response import ( APIResponse, BaseAPIResponse, @@ -56,7 +56,7 @@ def test_extract_response_type_binary_response() -> None: class PydanticModel(pydantic.BaseModel): ... -def test_response_parse_mismatched_basemodel(client: ReplicateClient) -> None: +def test_response_parse_mismatched_basemodel(client: Replicate) -> None: response = APIResponse( raw=httpx.Response(200, content=b"foo"), client=client, @@ -74,7 +74,7 @@ def test_response_parse_mismatched_basemodel(client: ReplicateClient) -> None: @pytest.mark.asyncio -async def test_async_response_parse_mismatched_basemodel(async_client: AsyncReplicateClient) -> None: +async def test_async_response_parse_mismatched_basemodel(async_client: AsyncReplicate) -> None: response = AsyncAPIResponse( raw=httpx.Response(200, content=b"foo"), client=async_client, @@ -91,7 +91,7 @@ async def test_async_response_parse_mismatched_basemodel(async_client: AsyncRepl await response.parse(to=PydanticModel) -def test_response_parse_custom_stream(client: ReplicateClient) -> None: +def test_response_parse_custom_stream(client: Replicate) -> None: response = APIResponse( raw=httpx.Response(200, content=b"foo"), client=client, @@ -106,7 +106,7 @@ def test_response_parse_custom_stream(client: ReplicateClient) -> None: @pytest.mark.asyncio -async def test_async_response_parse_custom_stream(async_client: AsyncReplicateClient) -> None: +async def test_async_response_parse_custom_stream(async_client: AsyncReplicate) -> None: response = AsyncAPIResponse( raw=httpx.Response(200, content=b"foo"), client=async_client, @@ -125,7 +125,7 @@ class CustomModel(BaseModel): bar: int -def test_response_parse_custom_model(client: ReplicateClient) -> None: +def test_response_parse_custom_model(client: Replicate) -> None: response = APIResponse( raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), client=client, @@ -141,7 +141,7 @@ def test_response_parse_custom_model(client: ReplicateClient) -> None: @pytest.mark.asyncio -async def test_async_response_parse_custom_model(async_client: AsyncReplicateClient) -> None: +async def test_async_response_parse_custom_model(async_client: AsyncReplicate) -> None: response = AsyncAPIResponse( raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), client=async_client, @@ -156,7 +156,7 @@ async def test_async_response_parse_custom_model(async_client: AsyncReplicateCli assert obj.bar == 2 -def test_response_parse_annotated_type(client: ReplicateClient) -> None: +def test_response_parse_annotated_type(client: Replicate) -> None: response = APIResponse( raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), client=client, @@ -173,7 +173,7 @@ def test_response_parse_annotated_type(client: ReplicateClient) -> None: assert obj.bar == 2 -async def test_async_response_parse_annotated_type(async_client: AsyncReplicateClient) -> None: +async def test_async_response_parse_annotated_type(async_client: AsyncReplicate) -> None: response = AsyncAPIResponse( raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), client=async_client, @@ -201,7 +201,7 @@ async def test_async_response_parse_annotated_type(async_client: AsyncReplicateC ("FalSe", False), ], ) -def test_response_parse_bool(client: ReplicateClient, content: str, expected: bool) -> None: +def test_response_parse_bool(client: Replicate, content: str, expected: bool) -> None: response = APIResponse( raw=httpx.Response(200, content=content), client=client, @@ -226,7 +226,7 @@ def test_response_parse_bool(client: ReplicateClient, content: str, expected: bo ("FalSe", False), ], ) -async def test_async_response_parse_bool(client: AsyncReplicateClient, content: str, expected: bool) -> None: +async def test_async_response_parse_bool(client: AsyncReplicate, content: str, expected: bool) -> None: response = AsyncAPIResponse( raw=httpx.Response(200, content=content), client=client, @@ -245,7 +245,7 @@ class OtherModel(BaseModel): @pytest.mark.parametrize("client", [False], indirect=True) # loose validation -def test_response_parse_expect_model_union_non_json_content(client: ReplicateClient) -> None: +def test_response_parse_expect_model_union_non_json_content(client: Replicate) -> None: response = APIResponse( raw=httpx.Response(200, content=b"foo", headers={"Content-Type": "application/text"}), client=client, @@ -262,7 +262,7 @@ def test_response_parse_expect_model_union_non_json_content(client: ReplicateCli @pytest.mark.asyncio @pytest.mark.parametrize("async_client", [False], indirect=True) # loose validation -async def test_async_response_parse_expect_model_union_non_json_content(async_client: AsyncReplicateClient) -> None: +async def test_async_response_parse_expect_model_union_non_json_content(async_client: AsyncReplicate) -> None: response = AsyncAPIResponse( raw=httpx.Response(200, content=b"foo", headers={"Content-Type": "application/text"}), client=async_client, diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 89e05bb..8c3344f 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -5,13 +5,13 @@ import httpx import pytest -from replicate import ReplicateClient, AsyncReplicateClient +from replicate import Replicate, AsyncReplicate from replicate._streaming import Stream, AsyncStream, ServerSentEvent @pytest.mark.asyncio @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) -async def test_basic(sync: bool, client: ReplicateClient, async_client: AsyncReplicateClient) -> None: +async def test_basic(sync: bool, client: Replicate, async_client: AsyncReplicate) -> None: def body() -> Iterator[bytes]: yield b"event: completion\n" yield b'data: {"foo":true}\n' @@ -28,7 +28,7 @@ def body() -> Iterator[bytes]: @pytest.mark.asyncio @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) -async def test_data_missing_event(sync: bool, client: ReplicateClient, async_client: AsyncReplicateClient) -> None: +async def test_data_missing_event(sync: bool, client: Replicate, async_client: AsyncReplicate) -> None: def body() -> Iterator[bytes]: yield b'data: {"foo":true}\n' yield b"\n" @@ -44,7 +44,7 @@ def body() -> Iterator[bytes]: @pytest.mark.asyncio @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) -async def test_event_missing_data(sync: bool, client: ReplicateClient, async_client: AsyncReplicateClient) -> None: +async def test_event_missing_data(sync: bool, client: Replicate, async_client: AsyncReplicate) -> None: def body() -> Iterator[bytes]: yield b"event: ping\n" yield b"\n" @@ -60,7 +60,7 @@ def body() -> Iterator[bytes]: @pytest.mark.asyncio @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) -async def test_multiple_events(sync: bool, client: ReplicateClient, async_client: AsyncReplicateClient) -> None: +async def test_multiple_events(sync: bool, client: Replicate, async_client: AsyncReplicate) -> None: def body() -> Iterator[bytes]: yield b"event: ping\n" yield b"\n" @@ -82,9 +82,7 @@ def body() -> Iterator[bytes]: @pytest.mark.asyncio @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) -async def test_multiple_events_with_data( - sync: bool, client: ReplicateClient, async_client: AsyncReplicateClient -) -> None: +async def test_multiple_events_with_data(sync: bool, client: Replicate, async_client: AsyncReplicate) -> None: def body() -> Iterator[bytes]: yield b"event: ping\n" yield b'data: {"foo":true}\n' @@ -108,9 +106,7 @@ def body() -> Iterator[bytes]: @pytest.mark.asyncio @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) -async def test_multiple_data_lines_with_empty_line( - sync: bool, client: ReplicateClient, async_client: AsyncReplicateClient -) -> None: +async def test_multiple_data_lines_with_empty_line(sync: bool, client: Replicate, async_client: AsyncReplicate) -> None: def body() -> Iterator[bytes]: yield b"event: ping\n" yield b"data: {\n" @@ -132,9 +128,7 @@ def body() -> Iterator[bytes]: @pytest.mark.asyncio @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) -async def test_data_json_escaped_double_new_line( - sync: bool, client: ReplicateClient, async_client: AsyncReplicateClient -) -> None: +async def test_data_json_escaped_double_new_line(sync: bool, client: Replicate, async_client: AsyncReplicate) -> None: def body() -> Iterator[bytes]: yield b"event: ping\n" yield b'data: {"foo": "my long\\n\\ncontent"}' @@ -151,7 +145,7 @@ def body() -> Iterator[bytes]: @pytest.mark.asyncio @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) -async def test_multiple_data_lines(sync: bool, client: ReplicateClient, async_client: AsyncReplicateClient) -> None: +async def test_multiple_data_lines(sync: bool, client: Replicate, async_client: AsyncReplicate) -> None: def body() -> Iterator[bytes]: yield b"event: ping\n" yield b"data: {\n" @@ -171,8 +165,8 @@ def body() -> Iterator[bytes]: @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) async def test_special_new_line_character( sync: bool, - client: ReplicateClient, - async_client: AsyncReplicateClient, + client: Replicate, + async_client: AsyncReplicate, ) -> None: def body() -> Iterator[bytes]: yield b'data: {"content":" culpa"}\n' @@ -202,8 +196,8 @@ def body() -> Iterator[bytes]: @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) async def test_multi_byte_character_multiple_chunks( sync: bool, - client: ReplicateClient, - async_client: AsyncReplicateClient, + client: Replicate, + async_client: AsyncReplicate, ) -> None: def body() -> Iterator[bytes]: yield b'data: {"content":"' @@ -243,8 +237,8 @@ def make_event_iterator( content: Iterator[bytes], *, sync: bool, - client: ReplicateClient, - async_client: AsyncReplicateClient, + client: Replicate, + async_client: AsyncReplicate, ) -> Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]: if sync: return Stream(cast_to=object, client=client, response=httpx.Response(200, content=content))._iter_events()