diff --git a/.release-please-manifest.json b/.release-please-manifest.json
index 4208b5c..6e011e8 100644
--- a/.release-please-manifest.json
+++ b/.release-please-manifest.json
@@ -1,3 +1,3 @@
{
- ".": "0.6.0"
+ ".": "2.0.0-alpha.1"
}
\ No newline at end of file
diff --git a/.stats.yml b/.stats.yml
index 6234e82..d3678f4 100644
--- a/.stats.yml
+++ b/.stats.yml
@@ -1,4 +1,4 @@
configured_endpoints: 35
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-37cd8ea847eb57706035f766ca549d5b4e2111053af0656a2df9a8150421428e.yml
openapi_spec_hash: a3e4d6fd9aff6de0e4b6d8ad28cbbe05
-config_hash: 8e356248f15e5e54d2aecab141f45228
+config_hash: da444f7a7ac6238fa0bdecaa01ffa4c3
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 4cbb647..6787016 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,28 @@
# Changelog
+## 2.0.0-alpha.1 (2025-06-10)
+
+Full Changelog: [v0.6.0...v2.0.0-alpha.1](https://github.com/replicate/replicate-python-stainless/compare/v0.6.0...v2.0.0-alpha.1)
+
+### ⚠ BREAKING CHANGES
+
+* rename package to `replicate`
+
+### Features
+
+* **client:** add follow_redirects request option ([d606061](https://github.com/replicate/replicate-python-stainless/commit/d60606146abbdc778dc33573ccccdf7bedb524e4))
+
+
+### Chores
+
+* **docs:** remove reference to rye shell ([1dfaea4](https://github.com/replicate/replicate-python-stainless/commit/1dfaea4108bee6ea565c48c9f99ed503476fd58f))
+* rename package to `replicate` ([42e30b7](https://github.com/replicate/replicate-python-stainless/commit/42e30b7b0e736fbb39e95ef7744299746a70d1b5))
+
+
+### Documentation
+
+* **internal:** add support for the client config option default_client_example_name to Python ([b320609](https://github.com/replicate/replicate-python-stainless/commit/b3206093c824676a300bfc68da307fab5a0f3718))
+
## 0.6.0 (2025-05-22)
Full Changelog: [v0.5.1...v0.6.0](https://github.com/replicate/replicate-python-stainless/compare/v0.5.1...v0.6.0)
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 6a8b625..532e99b 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -17,8 +17,7 @@ $ rye sync --all-features
You can then run scripts using `rye run python script.py` or by activating the virtual environment:
```sh
-$ rye shell
-# or manually activate - https://docs.python.org/3/library/venv.html#how-venvs-work
+# Activate the virtual environment - https://docs.python.org/3/library/venv.html#how-venvs-work
$ source .venv/bin/activate
# now you can omit the `rye run` prefix
diff --git a/README.md b/README.md
index df59580..52fc747 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
# Replicate Python API library
-[](https://pypi.org/project/replicate-stainless/)
+[](https://pypi.org/project/replicate/)
The Replicate Python library provides convenient access to the Replicate REST API from any Python 3.8+
application. The library includes type definitions for all request params and response fields,
@@ -16,7 +16,7 @@ The REST API documentation can be found on [replicate.com](https://replicate.com
```sh
# install from PyPI
-pip install replicate-stainless
+pip install --pre replicate
```
## Usage
@@ -27,11 +27,11 @@ The full API of this library can be found in [api.md](api.md).
import os
from replicate import Replicate
-client = Replicate(
+replicate = Replicate(
bearer_token=os.environ.get("REPLICATE_API_TOKEN"), # This is the default and can be omitted
)
-prediction = client.predictions.get(
+prediction = replicate.predictions.get(
prediction_id="gm3qorzdhgbfurvjtvhg6dckhu",
)
print(prediction.id)
@@ -51,13 +51,13 @@ import os
import asyncio
from replicate import AsyncReplicate
-client = AsyncReplicate(
+replicate = AsyncReplicate(
bearer_token=os.environ.get("REPLICATE_API_TOKEN"), # This is the default and can be omitted
)
async def main() -> None:
- prediction = await client.predictions.get(
+ prediction = await replicate.predictions.get(
prediction_id="gm3qorzdhgbfurvjtvhg6dckhu",
)
print(prediction.id)
@@ -86,11 +86,11 @@ This library provides auto-paginating iterators with each list response, so you
```python
from replicate import Replicate
-client = Replicate()
+replicate = Replicate()
all_models = []
# Automatically fetches more pages as needed.
-for model in client.models.list():
+for model in replicate.models.list():
# Do something with model here
all_models.append(model)
print(all_models)
@@ -102,13 +102,13 @@ Or, asynchronously:
import asyncio
from replicate import AsyncReplicate
-client = AsyncReplicate()
+replicate = AsyncReplicate()
async def main() -> None:
all_models = []
# Iterate through items across all pages, issuing requests as needed.
- async for model in client.models.list():
+ async for model in replicate.models.list():
all_models.append(model)
print(all_models)
@@ -119,7 +119,7 @@ asyncio.run(main())
Alternatively, you can use the `.has_next_page()`, `.next_page_info()`, or `.get_next_page()` methods for more granular control working with pages:
```python
-first_page = await client.models.list()
+first_page = await replicate.models.list()
if first_page.has_next_page():
print(f"will fetch next page using these details: {first_page.next_page_info()}")
next_page = await first_page.get_next_page()
@@ -131,7 +131,7 @@ if first_page.has_next_page():
Or just work directly with the returned data:
```python
-first_page = await client.models.list()
+first_page = await replicate.models.list()
print(f"next URL: {first_page.next}") # => "next URL: ..."
for model in first_page.results:
@@ -148,9 +148,9 @@ Request parameters that correspond to file uploads can be passed as `bytes`, or
from pathlib import Path
from replicate import Replicate
-client = Replicate()
+replicate = Replicate()
-client.files.create(
+replicate.files.create(
content=Path("/path/to/file"),
)
```
@@ -170,10 +170,10 @@ All errors inherit from `replicate.APIError`.
import replicate
from replicate import Replicate
-client = Replicate()
+replicate = Replicate()
try:
- client.predictions.create(
+ replicate.predictions.create(
input={"text": "Alice"},
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
)
@@ -213,13 +213,13 @@ You can use the `max_retries` option to configure or disable retry settings:
from replicate import Replicate
# Configure the default for all requests:
-client = Replicate(
+replicate = Replicate(
# default is 2
max_retries=0,
)
# Or, configure per-request:
-client.with_options(max_retries=5).predictions.create(
+replicate.with_options(max_retries=5).predictions.create(
input={"text": "Alice"},
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
)
@@ -234,18 +234,18 @@ which accepts a float or an [`httpx.Timeout`](https://www.python-httpx.org/advan
from replicate import Replicate
# Configure the default for all requests:
-client = Replicate(
+replicate = Replicate(
# 20 seconds (default is 1 minute)
timeout=20.0,
)
# More granular control:
-client = Replicate(
+replicate = Replicate(
timeout=httpx.Timeout(60.0, read=5.0, write=10.0, connect=2.0),
)
# Override per-request:
-client.with_options(timeout=5.0).predictions.create(
+replicate.with_options(timeout=5.0).predictions.create(
input={"text": "Alice"},
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
)
@@ -288,8 +288,8 @@ The "raw" Response object can be accessed by prefixing `.with_raw_response.` to
```py
from replicate import Replicate
-client = Replicate()
-response = client.predictions.with_raw_response.create(
+replicate = Replicate()
+response = replicate.predictions.with_raw_response.create(
input={
"text": "Alice"
},
@@ -312,7 +312,7 @@ The above interface eagerly reads the full response body when you make the reque
To stream the response body, use `.with_streaming_response` instead, which requires a context manager and only reads the response body once you call `.read()`, `.text()`, `.json()`, `.iter_bytes()`, `.iter_text()`, `.iter_lines()` or `.parse()`. In the async client, these are async methods.
```python
-with client.predictions.with_streaming_response.create(
+with replicate.predictions.with_streaming_response.create(
input={"text": "Alice"},
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
) as response:
@@ -332,13 +332,13 @@ If you need to access undocumented endpoints, params, or response properties, th
#### Undocumented endpoints
-To make requests to undocumented endpoints, you can make requests using `client.get`, `client.post`, and other
+To make requests to undocumented endpoints, you can make requests using `replicate.get`, `replicate.post`, and other
http verbs. Options on the client will be respected (such as retries) when making this request.
```py
import httpx
-response = client.post(
+response = replicate.post(
"/foo",
cast_to=httpx.Response,
body={"my_param": True},
@@ -370,7 +370,7 @@ You can directly override the [httpx client](https://www.python-httpx.org/api/#c
import httpx
from replicate import Replicate, DefaultHttpxClient
-client = Replicate(
+replicate = Replicate(
# Or use the `REPLICATE_BASE_URL` env var
base_url="http://my.test.server.example.com:8083",
http_client=DefaultHttpxClient(
@@ -383,7 +383,7 @@ client = Replicate(
You can also customize the client on a per-request basis by using `with_options()`:
```python
-client.with_options(http_client=DefaultHttpxClient(...))
+replicate.with_options(http_client=DefaultHttpxClient(...))
```
### Managing HTTP resources
@@ -393,7 +393,7 @@ By default the library closes underlying HTTP connections whenever the client is
```py
from replicate import Replicate
-with Replicate() as client:
+with Replicate() as replicate:
# make requests here
...
diff --git a/api.md b/api.md
index e3acbf5..38f61b7 100644
--- a/api.md
+++ b/api.md
@@ -2,8 +2,8 @@
Methods:
-- client.collections.list() -> None
-- client.collections.get(\*, collection_slug) -> None
+- replicate.collections.list() -> None
+- replicate.collections.get(\*, collection_slug) -> None
# Deployments
@@ -20,17 +20,17 @@ from replicate.types import (
Methods:
-- client.deployments.create(\*\*params) -> DeploymentCreateResponse
-- client.deployments.update(\*, deployment_owner, deployment_name, \*\*params) -> DeploymentUpdateResponse
-- client.deployments.list() -> SyncCursorURLPage[DeploymentListResponse]
-- client.deployments.delete(\*, deployment_owner, deployment_name) -> None
-- client.deployments.get(\*, deployment_owner, deployment_name) -> DeploymentGetResponse
+- replicate.deployments.create(\*\*params) -> DeploymentCreateResponse
+- replicate.deployments.update(\*, deployment_owner, deployment_name, \*\*params) -> DeploymentUpdateResponse
+- replicate.deployments.list() -> SyncCursorURLPage[DeploymentListResponse]
+- replicate.deployments.delete(\*, deployment_owner, deployment_name) -> None
+- replicate.deployments.get(\*, deployment_owner, deployment_name) -> DeploymentGetResponse
## Predictions
Methods:
-- client.deployments.predictions.create(\*, deployment_owner, deployment_name, \*\*params) -> Prediction
+- replicate.deployments.predictions.create(\*, deployment_owner, deployment_name, \*\*params) -> Prediction
# Hardware
@@ -42,7 +42,7 @@ from replicate.types import HardwareListResponse
Methods:
-- client.hardware.list() -> HardwareListResponse
+- replicate.hardware.list() -> HardwareListResponse
# Account
@@ -54,7 +54,7 @@ from replicate.types import AccountGetResponse
Methods:
-- client.account.get() -> AccountGetResponse
+- replicate.account.get() -> AccountGetResponse
# Models
@@ -66,23 +66,23 @@ from replicate.types import ModelListResponse
Methods:
-- client.models.create(\*\*params) -> None
-- client.models.list() -> SyncCursorURLPage[ModelListResponse]
-- client.models.delete(\*, model_owner, model_name) -> None
-- client.models.get(\*, model_owner, model_name) -> None
-- client.models.search(\*\*params) -> None
+- replicate.models.create(\*\*params) -> None
+- replicate.models.list() -> SyncCursorURLPage[ModelListResponse]
+- replicate.models.delete(\*, model_owner, model_name) -> None
+- replicate.models.get(\*, model_owner, model_name) -> None
+- replicate.models.search(\*\*params) -> None
## Examples
Methods:
-- client.models.examples.list(\*, model_owner, model_name) -> None
+- replicate.models.examples.list(\*, model_owner, model_name) -> None
## Predictions
Methods:
-- client.models.predictions.create(\*, model_owner, model_name, \*\*params) -> Prediction
+- replicate.models.predictions.create(\*, model_owner, model_name, \*\*params) -> Prediction
## Readme
@@ -94,15 +94,15 @@ from replicate.types.models import ReadmeGetResponse
Methods:
-- client.models.readme.get(\*, model_owner, model_name) -> str
+- replicate.models.readme.get(\*, model_owner, model_name) -> str
## Versions
Methods:
-- client.models.versions.list(\*, model_owner, model_name) -> None
-- client.models.versions.delete(\*, model_owner, model_name, version_id) -> None
-- client.models.versions.get(\*, model_owner, model_name, version_id) -> None
+- replicate.models.versions.list(\*, model_owner, model_name) -> None
+- replicate.models.versions.delete(\*, model_owner, model_name, version_id) -> None
+- replicate.models.versions.get(\*, model_owner, model_name, version_id) -> None
# Predictions
@@ -114,10 +114,10 @@ from replicate.types import Prediction, PredictionOutput, PredictionRequest
Methods:
-- client.predictions.create(\*\*params) -> Prediction
-- client.predictions.list(\*\*params) -> SyncCursorURLPageWithCreatedFilters[Prediction]
-- client.predictions.cancel(\*, prediction_id) -> Prediction
-- client.predictions.get(\*, prediction_id) -> Prediction
+- replicate.predictions.create(\*\*params) -> Prediction
+- replicate.predictions.list(\*\*params) -> SyncCursorURLPageWithCreatedFilters[Prediction]
+- replicate.predictions.cancel(\*, prediction_id) -> Prediction
+- replicate.predictions.get(\*, prediction_id) -> Prediction
# Trainings
@@ -134,10 +134,10 @@ from replicate.types import (
Methods:
-- client.trainings.create(\*, model_owner, model_name, version_id, \*\*params) -> TrainingCreateResponse
-- client.trainings.list() -> SyncCursorURLPage[TrainingListResponse]
-- client.trainings.cancel(\*, training_id) -> TrainingCancelResponse
-- client.trainings.get(\*, training_id) -> TrainingGetResponse
+- replicate.trainings.create(\*, model_owner, model_name, version_id, \*\*params) -> TrainingCreateResponse
+- replicate.trainings.list() -> SyncCursorURLPage[TrainingListResponse]
+- replicate.trainings.cancel(\*, training_id) -> TrainingCancelResponse
+- replicate.trainings.get(\*, training_id) -> TrainingGetResponse
# Webhooks
@@ -153,7 +153,7 @@ from replicate.types.webhooks.default import SecretGetResponse
Methods:
-- client.webhooks.default.secret.get() -> SecretGetResponse
+- replicate.webhooks.default.secret.get() -> SecretGetResponse
# Files
diff --git a/pyproject.toml b/pyproject.toml
index 198c95f..f9905ba 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
-name = "replicate-stainless"
-version = "0.6.0"
+name = "replicate"
+version = "2.0.0-alpha.1"
description = "The official Python library for the replicate API"
dynamic = ["readme"]
license = "Apache-2.0"
diff --git a/requirements-dev.lock b/requirements-dev.lock
index 86eea12..f29c3c8 100644
--- a/requirements-dev.lock
+++ b/requirements-dev.lock
@@ -14,7 +14,7 @@ annotated-types==0.6.0
# via pydantic
anyio==4.4.0
# via httpx
- # via replicate-stainless
+ # via replicate
argcomplete==3.1.2
# via nox
certifi==2023.7.22
@@ -26,7 +26,7 @@ dirty-equals==0.6.0
distlib==0.3.7
# via virtualenv
distro==1.8.0
- # via replicate-stainless
+ # via replicate
exceptiongroup==1.2.2
# via anyio
# via pytest
@@ -37,7 +37,7 @@ h11==0.14.0
httpcore==1.0.2
# via httpx
httpx==0.28.1
- # via replicate-stainless
+ # via replicate
# via respx
idna==3.4
# via anyio
@@ -64,7 +64,7 @@ platformdirs==3.11.0
pluggy==1.5.0
# via pytest
pydantic==2.10.3
- # via replicate-stainless
+ # via replicate
pydantic-core==2.27.1
# via pydantic
pygments==2.18.0
@@ -86,7 +86,7 @@ six==1.16.0
# via python-dateutil
sniffio==1.3.0
# via anyio
- # via replicate-stainless
+ # via replicate
time-machine==2.9.0
tomli==2.0.2
# via mypy
@@ -97,7 +97,7 @@ typing-extensions==4.12.2
# via pydantic
# via pydantic-core
# via pyright
- # via replicate-stainless
+ # via replicate
virtualenv==20.24.5
# via nox
zipp==3.17.0
diff --git a/requirements.lock b/requirements.lock
index a64eef9..f022008 100644
--- a/requirements.lock
+++ b/requirements.lock
@@ -14,12 +14,12 @@ annotated-types==0.6.0
# via pydantic
anyio==4.4.0
# via httpx
- # via replicate-stainless
+ # via replicate
certifi==2023.7.22
# via httpcore
# via httpx
distro==1.8.0
- # via replicate-stainless
+ # via replicate
exceptiongroup==1.2.2
# via anyio
h11==0.14.0
@@ -27,19 +27,19 @@ h11==0.14.0
httpcore==1.0.2
# via httpx
httpx==0.28.1
- # via replicate-stainless
+ # via replicate
idna==3.4
# via anyio
# via httpx
pydantic==2.10.3
- # via replicate-stainless
+ # via replicate
pydantic-core==2.27.1
# via pydantic
sniffio==1.3.0
# via anyio
- # via replicate-stainless
+ # via replicate
typing-extensions==4.12.2
# via anyio
# via pydantic
# via pydantic-core
- # via replicate-stainless
+ # via replicate
diff --git a/src/replicate/_base_client.py b/src/replicate/_base_client.py
index fec5e9d..131706e 100644
--- a/src/replicate/_base_client.py
+++ b/src/replicate/_base_client.py
@@ -960,6 +960,9 @@ def request(
if self.custom_auth is not None:
kwargs["auth"] = self.custom_auth
+ if options.follow_redirects is not None:
+ kwargs["follow_redirects"] = options.follow_redirects
+
log.debug("Sending HTTP Request: %s %s", request.method, request.url)
response = None
@@ -1474,6 +1477,9 @@ async def request(
if self.custom_auth is not None:
kwargs["auth"] = self.custom_auth
+ if options.follow_redirects is not None:
+ kwargs["follow_redirects"] = options.follow_redirects
+
log.debug("Sending HTTP Request: %s %s", request.method, request.url)
response = None
diff --git a/src/replicate/_models.py b/src/replicate/_models.py
index 798956f..4f21498 100644
--- a/src/replicate/_models.py
+++ b/src/replicate/_models.py
@@ -737,6 +737,7 @@ class FinalRequestOptionsInput(TypedDict, total=False):
idempotency_key: str
json_data: Body
extra_json: AnyMapping
+ follow_redirects: bool
@final
@@ -750,6 +751,7 @@ class FinalRequestOptions(pydantic.BaseModel):
files: Union[HttpxRequestFiles, None] = None
idempotency_key: Union[str, None] = None
post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven()
+ follow_redirects: Union[bool, None] = None
# It should be noted that we cannot use `json` here as that would override
# a BaseModel method in an incompatible fashion.
diff --git a/src/replicate/_types.py b/src/replicate/_types.py
index e461f0b..bb2fc67 100644
--- a/src/replicate/_types.py
+++ b/src/replicate/_types.py
@@ -100,6 +100,7 @@ class RequestOptions(TypedDict, total=False):
params: Query
extra_json: AnyMapping
idempotency_key: str
+ follow_redirects: bool
# Sentinel class used until PEP 0661 is accepted
@@ -215,3 +216,4 @@ class _GenericAlias(Protocol):
class HttpxSendArgs(TypedDict, total=False):
auth: httpx.Auth
+ follow_redirects: bool
diff --git a/src/replicate/_version.py b/src/replicate/_version.py
index b02d187..60265a9 100644
--- a/src/replicate/_version.py
+++ b/src/replicate/_version.py
@@ -1,4 +1,4 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
__title__ = "replicate"
-__version__ = "0.6.0" # x-release-please-version
+__version__ = "2.0.0-alpha.1" # x-release-please-version
diff --git a/src/replicate/lib/_files.py b/src/replicate/lib/_files.py
index c14a944..7c6f485 100644
--- a/src/replicate/lib/_files.py
+++ b/src/replicate/lib/_files.py
@@ -51,9 +51,8 @@ def encode_json(
if file_encoding_strategy == "base64":
return base64_encode_file(obj)
else:
- # todo: support files endpoint
- # return client.files.create(obj).urls["get"]
- raise NotImplementedError("File upload is not supported yet")
+ response = client.files.create(content=obj.read())
+ return response.urls.get
if HAS_NUMPY:
if isinstance(obj, np.integer): # type: ignore
return int(obj)
@@ -91,9 +90,8 @@ async def async_encode_json(
# TODO: This should ideally use an async based file reader path.
return base64_encode_file(obj)
else:
- # todo: support files endpoint
- # return (await client.files.async_create(obj)).urls["get"]
- raise NotImplementedError("File upload is not supported yet")
+ response = await client.files.create(content=obj.read())
+ return response.urls.get
if HAS_NUMPY:
if isinstance(obj, np.integer): # type: ignore
return int(obj)
diff --git a/tests/conftest.py b/tests/conftest.py
index 79342b0..008b8b7 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -37,8 +37,8 @@ def client(request: FixtureRequest) -> Iterator[Replicate]:
if not isinstance(strict, bool):
raise TypeError(f"Unexpected fixture parameter type {type(strict)}, expected {bool}")
- with Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=strict) as client:
- yield client
+ with Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=strict) as replicate:
+ yield replicate
@pytest.fixture(scope="session")
@@ -49,5 +49,5 @@ async def async_client(request: FixtureRequest) -> AsyncIterator[AsyncReplicate]
async with AsyncReplicate(
base_url=base_url, bearer_token=bearer_token, _strict_response_validation=strict
- ) as client:
- yield client
+ ) as replicate:
+ yield replicate
diff --git a/tests/lib/test_run.py b/tests/lib/test_run.py
index 43df10d..168447d 100644
--- a/tests/lib/test_run.py
+++ b/tests/lib/test_run.py
@@ -10,9 +10,11 @@
from respx import MockRouter
from replicate import Replicate, AsyncReplicate
+from replicate._compat import model_dump
from replicate.lib._files import FileOutput, AsyncFileOutput
from replicate._exceptions import ModelError, NotFoundError, BadRequestError
from replicate.lib._models import Model, Version, ModelVersionIdentifier
+from replicate.types.file_create_response import URLs, Checksums, FileCreateResponse
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
bearer_token = "My Bearer Token"
@@ -89,6 +91,16 @@ class TestRun:
# Common model reference format that will work with the new SDK
model_ref = "owner/name:version"
+ file_create_response = FileCreateResponse(
+ id="test_file_id",
+ checksums=Checksums(sha256="test_sha256"),
+ content_type="application/octet-stream",
+ created_at=datetime.datetime.now(),
+ expires_at=datetime.datetime.now() + datetime.timedelta(days=1),
+ metadata={},
+ size=1234,
+ urls=URLs(get="https://api.replicate.com/v1/files/test_file_id"),
+ )
@pytest.mark.respx(base_url=base_url)
def test_run_basic(self, respx_mock: MockRouter) -> None:
@@ -236,6 +248,23 @@ def test_run_with_base64_file(self, respx_mock: MockRouter) -> None:
assert output == "test output"
+ @pytest.mark.respx(base_url=base_url)
+ def test_run_with_file_upload(self, respx_mock: MockRouter) -> None:
+ """Test run with base64 encoded file input."""
+ # Create a simple file-like object
+ file_obj = io.BytesIO(b"test content")
+
+ # Mock the prediction response
+ respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
+ # Mock the file upload endpoint
+ respx_mock.post("/files").mock(
+ return_value=httpx.Response(201, json=model_dump(self.file_create_response, mode="json"))
+ )
+
+ output: Any = self.client.run(self.model_ref, input={"file": file_obj})
+
+ assert output == "test output"
+
def test_run_with_prefer_conflict(self) -> None:
"""Test run with conflicting wait and prefer parameters."""
with pytest.raises(TypeError, match="cannot mix and match prefer and wait"):
@@ -349,6 +378,16 @@ class TestAsyncRun:
# Common model reference format that will work with the new SDK
model_ref = "owner/name:version"
+ file_create_response = FileCreateResponse(
+ id="test_file_id",
+ checksums=Checksums(sha256="test_sha256"),
+ content_type="application/octet-stream",
+ created_at=datetime.datetime.now(),
+ expires_at=datetime.datetime.now() + datetime.timedelta(days=1),
+ metadata={},
+ size=1234,
+ urls=URLs(get="https://api.replicate.com/v1/files/test_file_id"),
+ )
@pytest.mark.respx(base_url=base_url)
async def test_async_run_basic(self, respx_mock: MockRouter) -> None:
@@ -501,6 +540,23 @@ async def test_async_run_with_base64_file(self, respx_mock: MockRouter) -> None:
assert output == "test output"
+ @pytest.mark.respx(base_url=base_url)
+ async def test_async_run_with_file_upload(self, respx_mock: MockRouter) -> None:
+ """Test async run with base64 encoded file input."""
+ # Create a simple file-like object
+ file_obj = io.BytesIO(b"test content")
+
+ # Mock the prediction response
+ respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
+ # Mock the file upload endpoint
+ respx_mock.post("/files").mock(
+ return_value=httpx.Response(201, json=model_dump(self.file_create_response, mode="json"))
+ )
+
+ output: Any = await self.client.run(self.model_ref, input={"file": file_obj})
+
+ assert output == "test output"
+
async def test_async_run_with_prefer_conflict(self) -> None:
"""Test async run with conflicting wait and prefer parameters."""
with pytest.raises(TypeError, match="cannot mix and match prefer and wait"):
diff --git a/tests/test_client.py b/tests/test_client.py
index 89c1b40..6f8a4c1 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -868,6 +868,33 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
assert response.http_request.headers.get("x-stainless-retry-count") == "42"
+ @pytest.mark.respx(base_url=base_url)
+ def test_follow_redirects(self, respx_mock: MockRouter) -> None:
+ # Test that the default follow_redirects=True allows following redirects
+ respx_mock.post("/redirect").mock(
+ return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
+ )
+ respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
+
+ response = self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
+ assert response.status_code == 200
+ assert response.json() == {"status": "ok"}
+
+ @pytest.mark.respx(base_url=base_url)
+ def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
+ # Test that follow_redirects=False prevents following redirects
+ respx_mock.post("/redirect").mock(
+ return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
+ )
+
+ with pytest.raises(APIStatusError) as exc_info:
+ self.client.post(
+ "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
+ )
+
+ assert exc_info.value.response.status_code == 302
+ assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
+
class TestAsyncReplicate:
client = AsyncReplicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True)
@@ -1731,3 +1758,30 @@ async def test_main() -> None:
raise AssertionError("calling get_platform using asyncify resulted in a hung process")
time.sleep(0.1)
+
+ @pytest.mark.respx(base_url=base_url)
+ async def test_follow_redirects(self, respx_mock: MockRouter) -> None:
+ # Test that the default follow_redirects=True allows following redirects
+ respx_mock.post("/redirect").mock(
+ return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
+ )
+ respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
+
+ response = await self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
+ assert response.status_code == 200
+ assert response.json() == {"status": "ok"}
+
+ @pytest.mark.respx(base_url=base_url)
+ async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
+ # Test that follow_redirects=False prevents following redirects
+ respx_mock.post("/redirect").mock(
+ return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
+ )
+
+ with pytest.raises(APIStatusError) as exc_info:
+ await self.client.post(
+ "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
+ )
+
+ assert exc_info.value.response.status_code == 302
+ assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"