From ce20b684a9bf6c15ac9174c316349284ea97f692 Mon Sep 17 00:00:00 2001 From: Zeke Sikelianos Date: Tue, 30 Sep 2025 13:09:48 -0700 Subject: [PATCH 1/8] feat: add support for replicate.stream() This PR adds support for streaming predictions via the `replicate.stream()` method. Changes: - Add `stream()` method to both Replicate and AsyncReplicate clients - Add module-level `stream()` function for convenience - Create new `lib/_predictions_stream.py` module with streaming logic - Add comprehensive tests for sync and async streaming - Update README with documentation and examples using anthropic/claude-4-sonnet The stream method creates a prediction and returns an iterator that yields output chunks as they become available via Server-Sent Events (SSE). This is useful for language models where you want to display output as it's generated. --- README.md | 14 +- src/replicate/__init__.py | 3 +- src/replicate/_client.py | 97 ++++++++++ src/replicate/_module_client.py | 5 + src/replicate/lib/_predictions_stream.py | 188 +++++++++++++++++++ tests/lib/test_stream.py | 222 +++++++++++++++++++++++ 6 files changed, 525 insertions(+), 4 deletions(-) create mode 100644 src/replicate/lib/_predictions_stream.py create mode 100644 tests/lib/test_stream.py diff --git a/README.md b/README.md index d4812d4..85c713c 100644 --- a/README.md +++ b/README.md @@ -118,14 +118,18 @@ For models that support streaming (particularly language models), you can use `r import replicate for event in replicate.stream( - "meta/meta-llama-3-70b-instruct", + "anthropic/claude-4-sonnet", input={ - "prompt": "Please write a haiku about llamas.", + "prompt": "Give me a recipe for tasty smashed avocado on sourdough toast.", + "max_tokens": 8192, + "system_prompt": "You are a helpful assistant", }, ): print(str(event), end="") ``` +The `stream()` method creates a prediction and returns an iterator that yields output chunks as they become available via Server-Sent Events (SSE). This is useful for language models where you want to display output as it's generated rather than waiting for the entire response. + ## Async usage Simply import `AsyncReplicate` instead of `Replicate` and use `await` with each API call: @@ -172,7 +176,11 @@ async def main(): # Stream a model's output async for event in replicate.stream( - "meta/meta-llama-3-70b-instruct", input={"prompt": "Write a haiku about coding"} + "anthropic/claude-4-sonnet", + input={ + "prompt": "Write a haiku about coding", + "system_prompt": "You are a helpful assistant", + }, ): print(str(event), end="") diff --git a/src/replicate/__init__.py b/src/replicate/__init__.py index 1cfff56..2e2f286 100644 --- a/src/replicate/__init__.py +++ b/src/replicate/__init__.py @@ -109,7 +109,7 @@ if not __name.startswith("__"): try: # Skip symbols that are imported later from _module_client - if __name in ("run", "use"): + if __name in ("run", "use", "stream"): continue __locals[__name].__module__ = "replicate" except (TypeError, AttributeError): @@ -253,6 +253,7 @@ def _reset_client() -> None: # type: ignore[reportUnusedFunction] use as use, files as files, models as models, + stream as stream, account as account, hardware as hardware, webhooks as webhooks, diff --git a/src/replicate/_client.py b/src/replicate/_client.py index 390a552..34ffbb0 100644 --- a/src/replicate/_client.py +++ b/src/replicate/_client.py @@ -320,6 +320,54 @@ def use( # TODO: Fix mypy overload matching for streaming parameter return _use(self, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return] + def stream( + self, + ref: Union[Model, Version, ModelVersionIdentifier, str], + *, + file_encoding_strategy: Optional["FileEncodingStrategy"] = None, + **params: Unpack[PredictionCreateParamsWithoutVersion], + ) -> Iterator[str]: + """ + Stream output from a model prediction. + + This creates a prediction and returns an iterator that yields output chunks + as they become available via Server-Sent Events (SSE). + + Args: + ref: Reference to the model or version to run. Can be: + - A string containing a version ID (e.g. "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa") + - A string with owner/name format (e.g. "replicate/hello-world") + - A string with owner/name:version format (e.g. "replicate/hello-world:5c7d5dc6...") + - A Model instance with owner and name attributes + - A Version instance with id attribute + - A ModelVersionIdentifier dictionary with owner, name, and/or version keys + file_encoding_strategy: Strategy for encoding file inputs, options are "base64" or "url" + **params: Additional parameters to pass to the prediction creation endpoint including + the required "input" dictionary with model-specific parameters + + Yields: + str: Output chunks from the model as they become available + + Raises: + ValueError: If the reference format is invalid or model doesn't support streaming + ReplicateError: If the prediction fails + + Example: + for event in replicate.stream( + "meta/meta-llama-3-70b-instruct", + input={"prompt": "Write a haiku about coding"}, + ): + print(str(event), end="") + """ + from .lib._predictions_stream import stream + + return stream( + self, + ref, + file_encoding_strategy=file_encoding_strategy, + **params, + ) + def copy( self, *, @@ -695,6 +743,55 @@ def use( # TODO: Fix mypy overload matching for streaming parameter return _use(self, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return] + async def stream( + self, + ref: Union[Model, Version, ModelVersionIdentifier, str], + *, + file_encoding_strategy: Optional["FileEncodingStrategy"] = None, + **params: Unpack[PredictionCreateParamsWithoutVersion], + ) -> AsyncIterator[str]: + """ + Stream output from a model prediction asynchronously. + + This creates a prediction and returns an async iterator that yields output chunks + as they become available via Server-Sent Events (SSE). + + Args: + ref: Reference to the model or version to run. Can be: + - A string containing a version ID (e.g. "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa") + - A string with owner/name format (e.g. "replicate/hello-world") + - A string with owner/name:version format (e.g. "replicate/hello-world:5c7d5dc6...") + - A Model instance with owner and name attributes + - A Version instance with id attribute + - A ModelVersionIdentifier dictionary with owner, name, and/or version keys + file_encoding_strategy: Strategy for encoding file inputs, options are "base64" or "url" + **params: Additional parameters to pass to the prediction creation endpoint including + the required "input" dictionary with model-specific parameters + + Yields: + str: Output chunks from the model as they become available + + Raises: + ValueError: If the reference format is invalid or model doesn't support streaming + ReplicateError: If the prediction fails + + Example: + async for event in replicate.stream( + "meta/meta-llama-3-70b-instruct", + input={"prompt": "Write a haiku about coding"}, + ): + print(str(event), end="") + """ + from .lib._predictions_stream import async_stream + + async for chunk in async_stream( + self, + ref, + file_encoding_strategy=file_encoding_strategy, + **params, + ): + yield chunk + def copy( self, *, diff --git a/src/replicate/_module_client.py b/src/replicate/_module_client.py index a3e8ab4..6b7a1f7 100644 --- a/src/replicate/_module_client.py +++ b/src/replicate/_module_client.py @@ -82,6 +82,7 @@ def __load__(self) -> PredictionsResource: __client: Replicate = cast(Replicate, {}) run = __client.run use = __client.use + stream = __client.stream else: def _run(*args, **kwargs): @@ -100,8 +101,12 @@ def _use(ref, *, hint=None, streaming=False, use_async=False, **kwargs): return use(Replicate, ref, hint=hint, streaming=streaming, **kwargs) + def _stream(*args, **kwargs): + return _load_client().stream(*args, **kwargs) + run = _run use = _use + stream = _stream files: FilesResource = FilesResourceProxy().__as_proxied__() models: ModelsResource = ModelsResourceProxy().__as_proxied__() diff --git a/src/replicate/lib/_predictions_stream.py b/src/replicate/lib/_predictions_stream.py new file mode 100644 index 0000000..6a44f64 --- /dev/null +++ b/src/replicate/lib/_predictions_stream.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Union, Iterator, Optional +from collections.abc import AsyncIterator +from typing_extensions import Unpack + +from replicate.lib._files import FileEncodingStrategy +from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion + +from ..types import PredictionCreateParams +from ._models import Model, Version, ModelVersionIdentifier, resolve_reference + +if TYPE_CHECKING: + from .._client import Replicate, AsyncReplicate + + +def stream( + client: "Replicate", + ref: Union[Model, Version, ModelVersionIdentifier, str], + *, + file_encoding_strategy: Optional["FileEncodingStrategy"] = None, + **params: Unpack[PredictionCreateParamsWithoutVersion], +) -> Iterator[str]: + """ + Stream output from a model prediction. + + This creates a prediction and returns an iterator that yields output chunks + as they become available via Server-Sent Events (SSE). + + Args: + client: The Replicate client instance + ref: Reference to the model or version to run. Can be: + - A string containing a version ID + - A string with owner/name format (e.g. "replicate/hello-world") + - A string with owner/name:version format + - A Model instance + - A Version instance + - A ModelVersionIdentifier dictionary + file_encoding_strategy: Strategy for encoding file inputs + **params: Additional parameters including the required "input" dictionary + + Yields: + str: Output chunks from the model as they become available + + Raises: + ValueError: If the reference format is invalid + ReplicateError: If the prediction fails or streaming is not available + """ + # Resolve ref to its components + try: + version, owner, name, version_id = resolve_reference(ref) + except ValueError: + # If resolution fails, treat it as a version ID if it's a string + if isinstance(ref, str): + version_id = ref + owner = name = None + else: + raise + + # Create prediction + prediction = None + if version_id is not None: + params_with_version: PredictionCreateParams = {**params, "version": version_id} + prediction = client.predictions.create(file_encoding_strategy=file_encoding_strategy, **params_with_version) + elif owner and name: + prediction = client.models.predictions.create( + file_encoding_strategy=file_encoding_strategy, model_owner=owner, model_name=name, **params + ) + else: + if isinstance(ref, str): + params_with_version = {**params, "version": ref} + prediction = client.predictions.create(file_encoding_strategy=file_encoding_strategy, **params_with_version) + else: + raise ValueError( + f"Invalid reference format: {ref}. Expected a model name ('owner/name'), " + "a version ID, a Model object, a Version object, or a ModelVersionIdentifier." + ) + + # Check if streaming URL is available + if not prediction.urls or not prediction.urls.stream: + raise ValueError("Model does not support streaming. The prediction URLs do not include a stream endpoint.") + + # Make SSE request to the stream URL + stream_url = prediction.urls.stream + + with client._client.stream( + "GET", + stream_url, + headers={ + "Accept": "text/event-stream", + "Cache-Control": "no-store", + }, + timeout=None, # No timeout for streaming + ) as response: + response.raise_for_status() + + # Parse SSE events and yield output chunks + decoder = client._make_sse_decoder() + for sse in decoder.iter_bytes(response.iter_bytes()): + # The SSE data contains the output chunks + if sse.data: + yield sse.data + + +async def async_stream( + client: "AsyncReplicate", + ref: Union[Model, Version, ModelVersionIdentifier, str], + *, + file_encoding_strategy: Optional["FileEncodingStrategy"] = None, + **params: Unpack[PredictionCreateParamsWithoutVersion], +) -> AsyncIterator[str]: + """ + Async stream output from a model prediction. + + This creates a prediction and returns an async iterator that yields output chunks + as they become available via Server-Sent Events (SSE). + + Args: + client: The AsyncReplicate client instance + ref: Reference to the model or version to run + file_encoding_strategy: Strategy for encoding file inputs + **params: Additional parameters including the required "input" dictionary + + Yields: + str: Output chunks from the model as they become available + + Raises: + ValueError: If the reference format is invalid + ReplicateError: If the prediction fails or streaming is not available + """ + # Resolve ref to its components + try: + version, owner, name, version_id = resolve_reference(ref) + except ValueError: + # If resolution fails, treat it as a version ID if it's a string + if isinstance(ref, str): + version_id = ref + owner = name = None + else: + raise + + # Create prediction + prediction = None + if version_id is not None: + params_with_version: PredictionCreateParams = {**params, "version": version_id} + prediction = await client.predictions.create( + file_encoding_strategy=file_encoding_strategy, **params_with_version + ) + elif owner and name: + prediction = await client.models.predictions.create( + file_encoding_strategy=file_encoding_strategy, model_owner=owner, model_name=name, **params + ) + else: + if isinstance(ref, str): + params_with_version = {**params, "version": ref} + prediction = await client.predictions.create( + file_encoding_strategy=file_encoding_strategy, **params_with_version + ) + else: + raise ValueError( + f"Invalid reference format: {ref}. Expected a model name ('owner/name'), " + "a version ID, a Model object, a Version object, or a ModelVersionIdentifier." + ) + + # Check if streaming URL is available + if not prediction.urls or not prediction.urls.stream: + raise ValueError("Model does not support streaming. The prediction URLs do not include a stream endpoint.") + + # Make SSE request to the stream URL + stream_url = prediction.urls.stream + + async with client._client.stream( + "GET", + stream_url, + headers={ + "Accept": "text/event-stream", + "Cache-Control": "no-store", + }, + timeout=None, # No timeout for streaming + ) as response: + response.raise_for_status() + + # Parse SSE events and yield output chunks + decoder = client._make_sse_decoder() + async for sse in decoder.aiter_bytes(response.aiter_bytes()): + # The SSE data contains the output chunks + if sse.data: + yield sse.data diff --git a/tests/lib/test_stream.py b/tests/lib/test_stream.py new file mode 100644 index 0000000..cdd16cd --- /dev/null +++ b/tests/lib/test_stream.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +import os +from typing import Iterator + +import httpx +import pytest + +from replicate import Replicate, AsyncReplicate + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") +bearer_token = "My Bearer Token" + + +def create_mock_prediction_json(stream_url: str | None = None) -> dict: + """Helper to create a complete prediction JSON response""" + prediction = { + "id": "test-prediction-id", + "created_at": "2023-01-01T00:00:00Z", + "data_removed": False, + "input": {"prompt": "Test"}, + "model": "test-model", + "output": None, + "status": "starting", + "version": "test-version-id", + "urls": { + "get": f"{base_url}/predictions/test-prediction-id", + "cancel": f"{base_url}/predictions/test-prediction-id/cancel", + "web": "https://replicate.com/p/test-prediction-id", + }, + } + if stream_url: + prediction["urls"]["stream"] = stream_url + return prediction + + +def test_stream_with_model_owner_name(respx_mock) -> None: + """Test streaming with owner/name format""" + client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) + + # Mock the prediction creation + respx_mock.post(f"{base_url}/models/meta/meta-llama-3-70b-instruct/predictions").mock( + return_value=httpx.Response( + 201, + json=create_mock_prediction_json(stream_url=f"{base_url}/stream/test-prediction-id"), + ) + ) + + # Mock the SSE stream endpoint + def stream_content() -> Iterator[bytes]: + yield b"data: Hello\n\n" + yield b"data: world\n\n" + yield b"data: !\n\n" + + respx_mock.get(f"{base_url}/stream/test-prediction-id").mock( + return_value=httpx.Response( + 200, + headers={"content-type": "text/event-stream"}, + content=stream_content(), + ) + ) + + # Stream the model + output = [] + for chunk in client.stream( + "meta/meta-llama-3-70b-instruct", + input={"prompt": "Say hello"}, + ): + output.append(chunk) + + assert output == ["Hello", " world", "!"] + + +def test_stream_with_version_id(respx_mock) -> None: + """Test streaming with version ID""" + client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) + version_id = "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" + + # Mock the prediction creation + respx_mock.post(f"{base_url}/predictions").mock( + return_value=httpx.Response( + 201, + json=create_mock_prediction_json(stream_url=f"{base_url}/stream/test-prediction-id"), + ) + ) + + # Mock the SSE stream endpoint + def stream_content() -> Iterator[bytes]: + yield b"data: Test\n\n" + yield b"data: output\n\n" + + respx_mock.get(f"{base_url}/stream/test-prediction-id").mock( + return_value=httpx.Response( + 200, + headers={"content-type": "text/event-stream"}, + content=stream_content(), + ) + ) + + # Stream the model + output = [] + for chunk in client.stream( + version_id, + input={"prompt": "Test"}, + ): + output.append(chunk) + + assert output == ["Test", "output"] + + +def test_stream_no_stream_url_raises_error(respx_mock) -> None: + """Test that streaming raises an error when model doesn't support streaming""" + client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) + + # Mock the prediction creation without stream URL + respx_mock.post(f"{base_url}/models/owner/model/predictions").mock( + return_value=httpx.Response( + 201, + json=create_mock_prediction_json(stream_url=None), + ) + ) + + # Try to stream and expect an error + with pytest.raises(ValueError, match="Model does not support streaming"): + for _ in client.stream("owner/model", input={"prompt": "Test"}): + pass + + +@pytest.mark.asyncio +async def test_async_stream_with_model_owner_name(respx_mock) -> None: + """Test async streaming with owner/name format""" + async_client = AsyncReplicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) + + # Mock the prediction creation + respx_mock.post(f"{base_url}/models/meta/meta-llama-3-70b-instruct/predictions").mock( + return_value=httpx.Response( + 201, + json=create_mock_prediction_json(stream_url=f"{base_url}/stream/test-prediction-id"), + ) + ) + + # Mock the SSE stream endpoint + async def stream_content(): + yield b"data: Async\n\n" + yield b"data: test\n\n" + + respx_mock.get(f"{base_url}/stream/test-prediction-id").mock( + return_value=httpx.Response( + 200, + headers={"content-type": "text/event-stream"}, + content=stream_content(), + ) + ) + + # Stream the model + output = [] + async for chunk in async_client.stream( + "meta/meta-llama-3-70b-instruct", + input={"prompt": "Say hello"}, + ): + output.append(chunk) + + assert output == ["Async", " test"] + + +@pytest.mark.asyncio +async def test_async_stream_no_stream_url_raises_error(respx_mock) -> None: + """Test that async streaming raises an error when model doesn't support streaming""" + async_client = AsyncReplicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) + + # Mock the prediction creation without stream URL + respx_mock.post(f"{base_url}/models/owner/model/predictions").mock( + return_value=httpx.Response( + 201, + json=create_mock_prediction_json(stream_url=None), + ) + ) + + # Try to stream and expect an error + with pytest.raises(ValueError, match="Model does not support streaming"): + async for _ in async_client.stream("owner/model", input={"prompt": "Test"}): + pass + + +def test_stream_module_level(respx_mock) -> None: + """Test that module-level stream function works""" + import replicate + + # Set up module level client configuration + replicate.base_url = base_url + replicate.bearer_token = bearer_token + + # Mock the prediction creation + respx_mock.post(f"{base_url}/models/meta/meta-llama-3-70b-instruct/predictions").mock( + return_value=httpx.Response( + 201, + json=create_mock_prediction_json(stream_url=f"{base_url}/stream/test-prediction-id"), + ) + ) + + # Mock the SSE stream endpoint + def stream_content() -> Iterator[bytes]: + yield b"data: Module\n\n" + yield b"data: level\n\n" + + respx_mock.get(f"{base_url}/stream/test-prediction-id").mock( + return_value=httpx.Response( + 200, + headers={"content-type": "text/event-stream"}, + content=stream_content(), + ) + ) + + # Stream using module-level function + output = [] + for chunk in replicate.stream( + "meta/meta-llama-3-70b-instruct", + input={"prompt": "Test"}, + ): + output.append(chunk) + + assert output == ["Module", " level"] From 5e6be60138ec8cc9298765716a8f63c67031e188 Mon Sep 17 00:00:00 2001 From: Zeke Sikelianos Date: Tue, 30 Sep 2025 13:18:26 -0700 Subject: [PATCH 2/8] docs: clarify that stream() yields strings, not SSE events The API uses Server-Sent Events internally, but the Python client yields plain string chunks to the user, not SSE event objects. --- README.md | 2 +- src/replicate/_client.py | 4 ++-- src/replicate/lib/_predictions_stream.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 85c713c..a462b2b 100644 --- a/README.md +++ b/README.md @@ -128,7 +128,7 @@ for event in replicate.stream( print(str(event), end="") ``` -The `stream()` method creates a prediction and returns an iterator that yields output chunks as they become available via Server-Sent Events (SSE). This is useful for language models where you want to display output as it's generated rather than waiting for the entire response. +The `stream()` method creates a prediction and returns an iterator that yields output chunks as strings. This is useful for language models where you want to display output as it's generated rather than waiting for the entire response. ## Async usage diff --git a/src/replicate/_client.py b/src/replicate/_client.py index 34ffbb0..a35e6e1 100644 --- a/src/replicate/_client.py +++ b/src/replicate/_client.py @@ -331,7 +331,7 @@ def stream( Stream output from a model prediction. This creates a prediction and returns an iterator that yields output chunks - as they become available via Server-Sent Events (SSE). + as strings as they become available from the streaming API. Args: ref: Reference to the model or version to run. Can be: @@ -754,7 +754,7 @@ async def stream( Stream output from a model prediction asynchronously. This creates a prediction and returns an async iterator that yields output chunks - as they become available via Server-Sent Events (SSE). + as strings as they become available from the streaming API. Args: ref: Reference to the model or version to run. Can be: diff --git a/src/replicate/lib/_predictions_stream.py b/src/replicate/lib/_predictions_stream.py index 6a44f64..49a9c87 100644 --- a/src/replicate/lib/_predictions_stream.py +++ b/src/replicate/lib/_predictions_stream.py @@ -25,7 +25,7 @@ def stream( Stream output from a model prediction. This creates a prediction and returns an iterator that yields output chunks - as they become available via Server-Sent Events (SSE). + as strings as they become available from the streaming API. Args: client: The Replicate client instance @@ -113,7 +113,7 @@ async def async_stream( Async stream output from a model prediction. This creates a prediction and returns an async iterator that yields output chunks - as they become available via Server-Sent Events (SSE). + as strings as they become available from the streaming API. Args: client: The AsyncReplicate client instance From 9776bfde7ccfb5aab013f62bd425efd586c48fce Mon Sep 17 00:00:00 2001 From: Zeke Sikelianos Date: Tue, 30 Sep 2025 13:35:19 -0700 Subject: [PATCH 3/8] refactor: DRY up duplicate reference resolution logic in stream functions --- src/replicate/lib/_predictions_stream.py | 42 ++++++++++-------------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/src/replicate/lib/_predictions_stream.py b/src/replicate/lib/_predictions_stream.py index 49a9c87..d8fd59c 100644 --- a/src/replicate/lib/_predictions_stream.py +++ b/src/replicate/lib/_predictions_stream.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Union, Iterator, Optional +from typing import TYPE_CHECKING, Tuple, Union, Iterator, Optional from collections.abc import AsyncIterator from typing_extensions import Unpack @@ -14,6 +14,20 @@ from .._client import Replicate, AsyncReplicate +def _resolve_reference( + ref: Union[Model, Version, ModelVersionIdentifier, str], +) -> Tuple[Optional[Version], Optional[str], Optional[str], Optional[str]]: + """Resolve a model reference to its components, with fallback for plain version IDs.""" + try: + return resolve_reference(ref) + except ValueError: + # If resolution fails, treat it as a version ID if it's a string + if isinstance(ref, str): + return None, None, None, ref + else: + raise + + def stream( client: "Replicate", ref: Union[Model, Version, ModelVersionIdentifier, str], @@ -46,19 +60,9 @@ def stream( ValueError: If the reference format is invalid ReplicateError: If the prediction fails or streaming is not available """ - # Resolve ref to its components - try: - version, owner, name, version_id = resolve_reference(ref) - except ValueError: - # If resolution fails, treat it as a version ID if it's a string - if isinstance(ref, str): - version_id = ref - owner = name = None - else: - raise + version, owner, name, version_id = _resolve_reference(ref) # Create prediction - prediction = None if version_id is not None: params_with_version: PredictionCreateParams = {**params, "version": version_id} prediction = client.predictions.create(file_encoding_strategy=file_encoding_strategy, **params_with_version) @@ -80,7 +84,6 @@ def stream( if not prediction.urls or not prediction.urls.stream: raise ValueError("Model does not support streaming. The prediction URLs do not include a stream endpoint.") - # Make SSE request to the stream URL stream_url = prediction.urls.stream with client._client.stream( @@ -128,19 +131,9 @@ async def async_stream( ValueError: If the reference format is invalid ReplicateError: If the prediction fails or streaming is not available """ - # Resolve ref to its components - try: - version, owner, name, version_id = resolve_reference(ref) - except ValueError: - # If resolution fails, treat it as a version ID if it's a string - if isinstance(ref, str): - version_id = ref - owner = name = None - else: - raise + version, owner, name, version_id = _resolve_reference(ref) # Create prediction - prediction = None if version_id is not None: params_with_version: PredictionCreateParams = {**params, "version": version_id} prediction = await client.predictions.create( @@ -166,7 +159,6 @@ async def async_stream( if not prediction.urls or not prediction.urls.stream: raise ValueError("Model does not support streaming. The prediction URLs do not include a stream endpoint.") - # Make SSE request to the stream URL stream_url = prediction.urls.stream async with client._client.stream( From f2d2683ce76ab624633d2c64a71afee9f4396ea8 Mon Sep 17 00:00:00 2001 From: Zeke Sikelianos Date: Tue, 30 Sep 2025 13:44:27 -0700 Subject: [PATCH 4/8] refactor: DRY up duplicate docstrings in stream functions --- src/replicate/_client.py | 56 ++++--------------- src/replicate/lib/_predictions_stream.py | 71 +++++++++--------------- 2 files changed, 37 insertions(+), 90 deletions(-) diff --git a/src/replicate/_client.py b/src/replicate/_client.py index a35e6e1..ab28379 100644 --- a/src/replicate/_client.py +++ b/src/replicate/_client.py @@ -330,34 +330,16 @@ def stream( """ Stream output from a model prediction. - This creates a prediction and returns an iterator that yields output chunks - as strings as they become available from the streaming API. - - Args: - ref: Reference to the model or version to run. Can be: - - A string containing a version ID (e.g. "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa") - - A string with owner/name format (e.g. "replicate/hello-world") - - A string with owner/name:version format (e.g. "replicate/hello-world:5c7d5dc6...") - - A Model instance with owner and name attributes - - A Version instance with id attribute - - A ModelVersionIdentifier dictionary with owner, name, and/or version keys - file_encoding_strategy: Strategy for encoding file inputs, options are "base64" or "url" - **params: Additional parameters to pass to the prediction creation endpoint including - the required "input" dictionary with model-specific parameters - - Yields: - str: Output chunks from the model as they become available - - Raises: - ValueError: If the reference format is invalid or model doesn't support streaming - ReplicateError: If the prediction fails - Example: - for event in replicate.stream( + ```python + for event in client.stream( "meta/meta-llama-3-70b-instruct", input={"prompt": "Write a haiku about coding"}, ): print(str(event), end="") + ``` + + See `replicate.lib._predictions_stream.stream` for full documentation. """ from .lib._predictions_stream import stream @@ -753,34 +735,16 @@ async def stream( """ Stream output from a model prediction asynchronously. - This creates a prediction and returns an async iterator that yields output chunks - as strings as they become available from the streaming API. - - Args: - ref: Reference to the model or version to run. Can be: - - A string containing a version ID (e.g. "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa") - - A string with owner/name format (e.g. "replicate/hello-world") - - A string with owner/name:version format (e.g. "replicate/hello-world:5c7d5dc6...") - - A Model instance with owner and name attributes - - A Version instance with id attribute - - A ModelVersionIdentifier dictionary with owner, name, and/or version keys - file_encoding_strategy: Strategy for encoding file inputs, options are "base64" or "url" - **params: Additional parameters to pass to the prediction creation endpoint including - the required "input" dictionary with model-specific parameters - - Yields: - str: Output chunks from the model as they become available - - Raises: - ValueError: If the reference format is invalid or model doesn't support streaming - ReplicateError: If the prediction fails - Example: - async for event in replicate.stream( + ```python + async for event in client.stream( "meta/meta-llama-3-70b-instruct", input={"prompt": "Write a haiku about coding"}, ): print(str(event), end="") + ``` + + See `replicate.lib._predictions_stream.async_stream` for full documentation. """ from .lib._predictions_stream import async_stream diff --git a/src/replicate/lib/_predictions_stream.py b/src/replicate/lib/_predictions_stream.py index d8fd59c..399a317 100644 --- a/src/replicate/lib/_predictions_stream.py +++ b/src/replicate/lib/_predictions_stream.py @@ -13,6 +13,31 @@ if TYPE_CHECKING: from .._client import Replicate, AsyncReplicate +_STREAM_DOCSTRING = """ +Stream output from a model prediction. + +This creates a prediction and returns an iterator that yields output chunks +as strings as they become available from the streaming API. + +Args: + ref: Reference to the model or version to run. Can be: + - A string containing a version ID + - A string with owner/name format (e.g. "replicate/hello-world") + - A string with owner/name:version format + - A Model instance + - A Version instance + - A ModelVersionIdentifier dictionary + file_encoding_strategy: Strategy for encoding file inputs + **params: Additional parameters including the required "input" dictionary + +Yields: + str: Output chunks from the model as they become available + +Raises: + ValueError: If the reference format is invalid + ReplicateError: If the prediction fails or streaming is not available +""" + def _resolve_reference( ref: Union[Model, Version, ModelVersionIdentifier, str], @@ -35,31 +60,7 @@ def stream( file_encoding_strategy: Optional["FileEncodingStrategy"] = None, **params: Unpack[PredictionCreateParamsWithoutVersion], ) -> Iterator[str]: - """ - Stream output from a model prediction. - - This creates a prediction and returns an iterator that yields output chunks - as strings as they become available from the streaming API. - - Args: - client: The Replicate client instance - ref: Reference to the model or version to run. Can be: - - A string containing a version ID - - A string with owner/name format (e.g. "replicate/hello-world") - - A string with owner/name:version format - - A Model instance - - A Version instance - - A ModelVersionIdentifier dictionary - file_encoding_strategy: Strategy for encoding file inputs - **params: Additional parameters including the required "input" dictionary - - Yields: - str: Output chunks from the model as they become available - - Raises: - ValueError: If the reference format is invalid - ReplicateError: If the prediction fails or streaming is not available - """ + __doc__ = _STREAM_DOCSTRING version, owner, name, version_id = _resolve_reference(ref) # Create prediction @@ -112,25 +113,7 @@ async def async_stream( file_encoding_strategy: Optional["FileEncodingStrategy"] = None, **params: Unpack[PredictionCreateParamsWithoutVersion], ) -> AsyncIterator[str]: - """ - Async stream output from a model prediction. - - This creates a prediction and returns an async iterator that yields output chunks - as strings as they become available from the streaming API. - - Args: - client: The AsyncReplicate client instance - ref: Reference to the model or version to run - file_encoding_strategy: Strategy for encoding file inputs - **params: Additional parameters including the required "input" dictionary - - Yields: - str: Output chunks from the model as they become available - - Raises: - ValueError: If the reference format is invalid - ReplicateError: If the prediction fails or streaming is not available - """ + __doc__ = _STREAM_DOCSTRING version, owner, name, version_id = _resolve_reference(ref) # Create prediction From fe452372ab196d014413ed0ed35c6138e0b53cf8 Mon Sep 17 00:00:00 2001 From: Zeke Sikelianos Date: Tue, 30 Sep 2025 14:00:48 -0700 Subject: [PATCH 5/8] fix: remove unused version variable from stream functions --- src/replicate/lib/_predictions_stream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/replicate/lib/_predictions_stream.py b/src/replicate/lib/_predictions_stream.py index 399a317..396b94c 100644 --- a/src/replicate/lib/_predictions_stream.py +++ b/src/replicate/lib/_predictions_stream.py @@ -61,7 +61,7 @@ def stream( **params: Unpack[PredictionCreateParamsWithoutVersion], ) -> Iterator[str]: __doc__ = _STREAM_DOCSTRING - version, owner, name, version_id = _resolve_reference(ref) + _, owner, name, version_id = _resolve_reference(ref) # Create prediction if version_id is not None: @@ -114,7 +114,7 @@ async def async_stream( **params: Unpack[PredictionCreateParamsWithoutVersion], ) -> AsyncIterator[str]: __doc__ = _STREAM_DOCSTRING - version, owner, name, version_id = _resolve_reference(ref) + _, owner, name, version_id = _resolve_reference(ref) # Create prediction if version_id is not None: From 58ac19fa99b940fa8aa27888e8077ff28095cf13 Mon Sep 17 00:00:00 2001 From: Zeke Sikelianos Date: Tue, 30 Sep 2025 14:04:11 -0700 Subject: [PATCH 6/8] refactor: use _version instead of _ for unused variable --- src/replicate/lib/_predictions_stream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/replicate/lib/_predictions_stream.py b/src/replicate/lib/_predictions_stream.py index 396b94c..d087f65 100644 --- a/src/replicate/lib/_predictions_stream.py +++ b/src/replicate/lib/_predictions_stream.py @@ -61,7 +61,7 @@ def stream( **params: Unpack[PredictionCreateParamsWithoutVersion], ) -> Iterator[str]: __doc__ = _STREAM_DOCSTRING - _, owner, name, version_id = _resolve_reference(ref) + _version, owner, name, version_id = _resolve_reference(ref) # Create prediction if version_id is not None: @@ -114,7 +114,7 @@ async def async_stream( **params: Unpack[PredictionCreateParamsWithoutVersion], ) -> AsyncIterator[str]: __doc__ = _STREAM_DOCSTRING - _, owner, name, version_id = _resolve_reference(ref) + _version, owner, name, version_id = _resolve_reference(ref) # Create prediction if version_id is not None: From 924ab59cff8dbc173dc0baf9b133f79c9bf1d8ce Mon Sep 17 00:00:00 2001 From: Zeke Sikelianos Date: Tue, 30 Sep 2025 14:59:41 -0700 Subject: [PATCH 7/8] fix: add type annotations to test_stream.py for linter --- tests/lib/test_stream.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/lib/test_stream.py b/tests/lib/test_stream.py index cdd16cd..2dcec08 100644 --- a/tests/lib/test_stream.py +++ b/tests/lib/test_stream.py @@ -1,10 +1,11 @@ from __future__ import annotations import os -from typing import Iterator +from typing import Any, Iterator import httpx import pytest +from respx import MockRouter from replicate import Replicate, AsyncReplicate @@ -12,7 +13,7 @@ bearer_token = "My Bearer Token" -def create_mock_prediction_json(stream_url: str | None = None) -> dict: +def create_mock_prediction_json(stream_url: str | None = None) -> dict[str, Any]: """Helper to create a complete prediction JSON response""" prediction = { "id": "test-prediction-id", @@ -34,7 +35,7 @@ def create_mock_prediction_json(stream_url: str | None = None) -> dict: return prediction -def test_stream_with_model_owner_name(respx_mock) -> None: +def test_stream_with_model_owner_name(respx_mock: MockRouter) -> None: """Test streaming with owner/name format""" client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) @@ -71,7 +72,7 @@ def stream_content() -> Iterator[bytes]: assert output == ["Hello", " world", "!"] -def test_stream_with_version_id(respx_mock) -> None: +def test_stream_with_version_id(respx_mock: MockRouter) -> None: """Test streaming with version ID""" client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) version_id = "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" @@ -108,7 +109,7 @@ def stream_content() -> Iterator[bytes]: assert output == ["Test", "output"] -def test_stream_no_stream_url_raises_error(respx_mock) -> None: +def test_stream_no_stream_url_raises_error(respx_mock: MockRouter) -> None: """Test that streaming raises an error when model doesn't support streaming""" client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) @@ -127,7 +128,7 @@ def test_stream_no_stream_url_raises_error(respx_mock) -> None: @pytest.mark.asyncio -async def test_async_stream_with_model_owner_name(respx_mock) -> None: +async def test_async_stream_with_model_owner_name(respx_mock: MockRouter) -> None: """Test async streaming with owner/name format""" async_client = AsyncReplicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) @@ -164,7 +165,7 @@ async def stream_content(): @pytest.mark.asyncio -async def test_async_stream_no_stream_url_raises_error(respx_mock) -> None: +async def test_async_stream_no_stream_url_raises_error(respx_mock: MockRouter) -> None: """Test that async streaming raises an error when model doesn't support streaming""" async_client = AsyncReplicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) @@ -182,7 +183,7 @@ async def test_async_stream_no_stream_url_raises_error(respx_mock) -> None: pass -def test_stream_module_level(respx_mock) -> None: +def test_stream_module_level(respx_mock: MockRouter) -> None: """Test that module-level stream function works""" import replicate From 3333e0f957865dd992dffc37c8e903eec633a1ed Mon Sep 17 00:00:00 2001 From: Zeke Sikelianos Date: Tue, 30 Sep 2025 15:10:20 -0700 Subject: [PATCH 8/8] fix: add type annotations to output lists in stream tests --- tests/lib/test_stream.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/lib/test_stream.py b/tests/lib/test_stream.py index 2dcec08..2d58172 100644 --- a/tests/lib/test_stream.py +++ b/tests/lib/test_stream.py @@ -15,7 +15,7 @@ def create_mock_prediction_json(stream_url: str | None = None) -> dict[str, Any]: """Helper to create a complete prediction JSON response""" - prediction = { + prediction: dict[str, Any] = { "id": "test-prediction-id", "created_at": "2023-01-01T00:00:00Z", "data_removed": False, @@ -31,7 +31,7 @@ def create_mock_prediction_json(stream_url: str | None = None) -> dict[str, Any] }, } if stream_url: - prediction["urls"]["stream"] = stream_url + prediction["urls"]["stream"] = stream_url # type: ignore[index] return prediction @@ -62,7 +62,7 @@ def stream_content() -> Iterator[bytes]: ) # Stream the model - output = [] + output: list[str] = [] for chunk in client.stream( "meta/meta-llama-3-70b-instruct", input={"prompt": "Say hello"}, @@ -99,7 +99,7 @@ def stream_content() -> Iterator[bytes]: ) # Stream the model - output = [] + output: list[str] = [] for chunk in client.stream( version_id, input={"prompt": "Test"}, @@ -154,7 +154,7 @@ async def stream_content(): ) # Stream the model - output = [] + output: list[str] = [] async for chunk in async_client.stream( "meta/meta-llama-3-70b-instruct", input={"prompt": "Say hello"}, @@ -213,7 +213,7 @@ def stream_content() -> Iterator[bytes]: ) # Stream using module-level function - output = [] + output: list[str] = [] for chunk in replicate.stream( "meta/meta-llama-3-70b-instruct", input={"prompt": "Test"},