Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,18 @@ For models that support streaming (particularly language models), you can use `r
import replicate

for event in replicate.stream(
"meta/meta-llama-3-70b-instruct",
"anthropic/claude-4-sonnet",
input={
"prompt": "Please write a haiku about llamas.",
"prompt": "Give me a recipe for tasty smashed avocado on sourdough toast.",
"max_tokens": 8192,
"system_prompt": "You are a helpful assistant",
},
):
print(str(event), end="")
```

The `stream()` method creates a prediction and returns an iterator that yields output chunks as strings. This is useful for language models where you want to display output as it's generated rather than waiting for the entire response.

## Async usage

Simply import `AsyncReplicate` instead of `Replicate` and use `await` with each API call:
Expand Down Expand Up @@ -172,7 +176,11 @@ async def main():

# Stream a model's output
async for event in replicate.stream(
"meta/meta-llama-3-70b-instruct", input={"prompt": "Write a haiku about coding"}
"anthropic/claude-4-sonnet",
input={
"prompt": "Write a haiku about coding",
"system_prompt": "You are a helpful assistant",
},
):
print(str(event), end="")

Expand Down
3 changes: 2 additions & 1 deletion src/replicate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
if not __name.startswith("__"):
try:
# Skip symbols that are imported later from _module_client
if __name in ("run", "use"):
if __name in ("run", "use", "stream"):
continue
__locals[__name].__module__ = "replicate"
except (TypeError, AttributeError):
Expand Down Expand Up @@ -253,6 +253,7 @@ def _reset_client() -> None: # type: ignore[reportUnusedFunction]
use as use,
files as files,
models as models,
stream as stream,
account as account,
hardware as hardware,
webhooks as webhooks,
Expand Down
61 changes: 61 additions & 0 deletions src/replicate/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,36 @@ def use(
# TODO: Fix mypy overload matching for streaming parameter
return _use(self, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return]

def stream(
self,
ref: Union[Model, Version, ModelVersionIdentifier, str],
*,
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
**params: Unpack[PredictionCreateParamsWithoutVersion],
) -> Iterator[str]:
"""
Stream output from a model prediction.

Example:
```python
for event in client.stream(
"meta/meta-llama-3-70b-instruct",
input={"prompt": "Write a haiku about coding"},
):
print(str(event), end="")
```

See `replicate.lib._predictions_stream.stream` for full documentation.
"""
from .lib._predictions_stream import stream

return stream(
self,
ref,
file_encoding_strategy=file_encoding_strategy,
**params,
)

def copy(
self,
*,
Expand Down Expand Up @@ -695,6 +725,37 @@ def use(
# TODO: Fix mypy overload matching for streaming parameter
return _use(self, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return]

async def stream(
self,
ref: Union[Model, Version, ModelVersionIdentifier, str],
*,
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
**params: Unpack[PredictionCreateParamsWithoutVersion],
) -> AsyncIterator[str]:
"""
Stream output from a model prediction asynchronously.

Example:
```python
async for event in client.stream(
"meta/meta-llama-3-70b-instruct",
input={"prompt": "Write a haiku about coding"},
):
print(str(event), end="")
```

See `replicate.lib._predictions_stream.async_stream` for full documentation.
"""
from .lib._predictions_stream import async_stream

async for chunk in async_stream(
self,
ref,
file_encoding_strategy=file_encoding_strategy,
**params,
):
yield chunk

def copy(
self,
*,
Expand Down
5 changes: 5 additions & 0 deletions src/replicate/_module_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __load__(self) -> PredictionsResource:
__client: Replicate = cast(Replicate, {})
run = __client.run
use = __client.use
stream = __client.stream
else:

def _run(*args, **kwargs):
Expand All @@ -100,8 +101,12 @@ def _use(ref, *, hint=None, streaming=False, use_async=False, **kwargs):

return use(Replicate, ref, hint=hint, streaming=streaming, **kwargs)

def _stream(*args, **kwargs):
return _load_client().stream(*args, **kwargs)

run = _run
use = _use
stream = _stream

files: FilesResource = FilesResourceProxy().__as_proxied__()
models: ModelsResource = ModelsResourceProxy().__as_proxied__()
Expand Down
163 changes: 163 additions & 0 deletions src/replicate/lib/_predictions_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Tuple, Union, Iterator, Optional
from collections.abc import AsyncIterator
from typing_extensions import Unpack

from replicate.lib._files import FileEncodingStrategy
from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion

from ..types import PredictionCreateParams
from ._models import Model, Version, ModelVersionIdentifier, resolve_reference

if TYPE_CHECKING:
from .._client import Replicate, AsyncReplicate

_STREAM_DOCSTRING = """
Stream output from a model prediction.

This creates a prediction and returns an iterator that yields output chunks
as strings as they become available from the streaming API.

Args:
ref: Reference to the model or version to run. Can be:
- A string containing a version ID
- A string with owner/name format (e.g. "replicate/hello-world")
- A string with owner/name:version format
- A Model instance
- A Version instance
- A ModelVersionIdentifier dictionary
file_encoding_strategy: Strategy for encoding file inputs
**params: Additional parameters including the required "input" dictionary

Yields:
str: Output chunks from the model as they become available

Raises:
ValueError: If the reference format is invalid
ReplicateError: If the prediction fails or streaming is not available
"""


def _resolve_reference(
ref: Union[Model, Version, ModelVersionIdentifier, str],
) -> Tuple[Optional[Version], Optional[str], Optional[str], Optional[str]]:
"""Resolve a model reference to its components, with fallback for plain version IDs."""
try:
return resolve_reference(ref)
except ValueError:
# If resolution fails, treat it as a version ID if it's a string
if isinstance(ref, str):
return None, None, None, ref
else:
raise


def stream(
client: "Replicate",
ref: Union[Model, Version, ModelVersionIdentifier, str],
*,
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
**params: Unpack[PredictionCreateParamsWithoutVersion],
) -> Iterator[str]:
__doc__ = _STREAM_DOCSTRING
_version, owner, name, version_id = _resolve_reference(ref)

# Create prediction
if version_id is not None:
params_with_version: PredictionCreateParams = {**params, "version": version_id}
prediction = client.predictions.create(file_encoding_strategy=file_encoding_strategy, **params_with_version)
elif owner and name:
prediction = client.models.predictions.create(
file_encoding_strategy=file_encoding_strategy, model_owner=owner, model_name=name, **params
)
else:
if isinstance(ref, str):
params_with_version = {**params, "version": ref}
prediction = client.predictions.create(file_encoding_strategy=file_encoding_strategy, **params_with_version)
else:
raise ValueError(
f"Invalid reference format: {ref}. Expected a model name ('owner/name'), "
"a version ID, a Model object, a Version object, or a ModelVersionIdentifier."
)

# Check if streaming URL is available
if not prediction.urls or not prediction.urls.stream:
raise ValueError("Model does not support streaming. The prediction URLs do not include a stream endpoint.")

stream_url = prediction.urls.stream

with client._client.stream(
"GET",
stream_url,
headers={
"Accept": "text/event-stream",
"Cache-Control": "no-store",
},
timeout=None, # No timeout for streaming
) as response:
response.raise_for_status()

# Parse SSE events and yield output chunks
decoder = client._make_sse_decoder()
for sse in decoder.iter_bytes(response.iter_bytes()):
# The SSE data contains the output chunks
if sse.data:
yield sse.data


async def async_stream(
client: "AsyncReplicate",
ref: Union[Model, Version, ModelVersionIdentifier, str],
*,
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
**params: Unpack[PredictionCreateParamsWithoutVersion],
) -> AsyncIterator[str]:
__doc__ = _STREAM_DOCSTRING
_version, owner, name, version_id = _resolve_reference(ref)

# Create prediction
if version_id is not None:
params_with_version: PredictionCreateParams = {**params, "version": version_id}
prediction = await client.predictions.create(
file_encoding_strategy=file_encoding_strategy, **params_with_version
)
elif owner and name:
prediction = await client.models.predictions.create(
file_encoding_strategy=file_encoding_strategy, model_owner=owner, model_name=name, **params
)
else:
if isinstance(ref, str):
params_with_version = {**params, "version": ref}
prediction = await client.predictions.create(
file_encoding_strategy=file_encoding_strategy, **params_with_version
)
else:
raise ValueError(
f"Invalid reference format: {ref}. Expected a model name ('owner/name'), "
"a version ID, a Model object, a Version object, or a ModelVersionIdentifier."
)

# Check if streaming URL is available
if not prediction.urls or not prediction.urls.stream:
raise ValueError("Model does not support streaming. The prediction URLs do not include a stream endpoint.")

stream_url = prediction.urls.stream

async with client._client.stream(
"GET",
stream_url,
headers={
"Accept": "text/event-stream",
"Cache-Control": "no-store",
},
timeout=None, # No timeout for streaming
) as response:
response.raise_for_status()

# Parse SSE events and yield output chunks
decoder = client._make_sse_decoder()
async for sse in decoder.aiter_bytes(response.aiter_bytes()):
# The SSE data contains the output chunks
if sse.data:
yield sse.data
Loading