diff --git a/README.md b/README.md index 59d01b5..314defa 100644 --- a/README.md +++ b/README.md @@ -110,9 +110,24 @@ output = replicate.run("...", input={...}, wait=False) When `wait=False`, the method returns immediately after creating the prediction, and you'll need to poll for the result manually. +## Streaming output + +For models that support streaming (particularly language models), use `replicate.use()` with `streaming=True` to stream the output response as it's generated: + +```python +import replicate + +claude = replicate.use("anthropic/claude-4.5-sonnet", streaming=True) + +for event in claude(input={"prompt": "Please write a haiku about streaming Python."}): + print(str(event), end="") +``` + +> **Note:** The [legacy `replicate.stream()` method](https://github.com/replicate/replicate-python/blob/d2956ff9c3e26ef434bc839cc5c87a50c49dfe20/README.md#run-a-model-and-stream-its-output) is also available for backwards compatibility with the v1 SDK, but is deprecated and will be removed in a future version. + ## Async usage -Simply import `AsyncReplicate` instead of `Replicate` and use `await` with each API call: +To use the Replicate client asynchronously, import `AsyncReplicate` instead of `Replicate` and use `await` with each API call: ```python import os @@ -136,34 +151,6 @@ asyncio.run(main()) Functionality between the synchronous and asynchronous clients is otherwise identical. -### Async run() and stream() - -The async client also supports `run()` and `stream()` methods: - -```python -import asyncio -from replicate import AsyncReplicate - -replicate = AsyncReplicate() - - -async def main(): - # Run a model - output = await replicate.run( - "black-forest-labs/flux-schnell", input={"prompt": "astronaut riding a rocket like a horse"} - ) - print(output) - - # Stream a model's output - async for event in replicate.stream( - "meta/meta-llama-3-70b-instruct", input={"prompt": "Write a haiku about coding"} - ): - print(str(event), end="") - - -asyncio.run(main()) -``` - ### With aiohttp By default, the async client uses `httpx` for HTTP requests. However, for improved concurrency performance you may also use `aiohttp` as the HTTP backend. diff --git a/src/replicate/__init__.py b/src/replicate/__init__.py index 1cfff56..e91608a 100644 --- a/src/replicate/__init__.py +++ b/src/replicate/__init__.py @@ -109,7 +109,7 @@ if not __name.startswith("__"): try: # Skip symbols that are imported later from _module_client - if __name in ("run", "use"): + if __name in ("run", "use", "stream"): continue __locals[__name].__module__ = "replicate" except (TypeError, AttributeError): @@ -253,6 +253,7 @@ def _reset_client() -> None: # type: ignore[reportUnusedFunction] use as use, files as files, models as models, + stream as stream, # pyright: ignore[reportDeprecated] account as account, hardware as hardware, webhooks as webhooks, diff --git a/src/replicate/_client.py b/src/replicate/_client.py index 390a552..b7bdb01 100644 --- a/src/replicate/_client.py +++ b/src/replicate/_client.py @@ -16,7 +16,7 @@ AsyncIterator, overload, ) -from typing_extensions import Self, Unpack, ParamSpec, override +from typing_extensions import Self, Unpack, ParamSpec, override, deprecated import httpx @@ -320,6 +320,38 @@ def use( # TODO: Fix mypy overload matching for streaming parameter return _use(self, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return] + @deprecated("replicate.stream() is deprecated. Use replicate.use() with streaming=True instead") + def stream( + self, + ref: str, + *, + input: dict[str, Any], + ) -> Iterator[str]: + """ + Run a model and stream its output (deprecated). + + .. deprecated:: + Use :meth:`use` with ``streaming=True`` instead: + + .. code-block:: python + + model = replicate.use("anthropic/claude-4.5-sonnet", streaming=True) + for event in model(input={"prompt": "Hello"}): + print(str(event), end="") + + Args: + ref: Reference to the model to run. Can be a string with owner/name format + (e.g., "anthropic/claude-4.5-sonnet") or owner/name:version format. + input: Dictionary of input parameters for the model. The required keys depend + on the specific model being run. + + Returns: + An iterator that yields output strings as they are generated by the model + """ + from .lib._stream import stream as _stream # pyright: ignore[reportDeprecated] + + return _stream(type(self), ref, input=input) # type: ignore[return-value, arg-type] + def copy( self, *, @@ -695,6 +727,38 @@ def use( # TODO: Fix mypy overload matching for streaming parameter return _use(self, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return] + @deprecated("replicate.stream() is deprecated. Use replicate.use() with streaming=True instead") + async def stream( + self, + ref: str, + *, + input: dict[str, Any], + ) -> AsyncIterator[str]: + """ + Run a model and stream its output asynchronously (deprecated). + + .. deprecated:: + Use :meth:`use` with ``streaming=True`` instead: + + .. code-block:: python + + model = replicate.use("anthropic/claude-4.5-sonnet", streaming=True) + async for event in model(input={"prompt": "Hello"}): + print(str(event), end="") + + Args: + ref: Reference to the model to run. Can be a string with owner/name format + (e.g., "anthropic/claude-4.5-sonnet") or owner/name:version format. + input: Dictionary of input parameters for the model. The required keys depend + on the specific model being run. + + Returns: + An async iterator that yields output strings as they are generated by the model + """ + from .lib._stream import stream as _stream # pyright: ignore[reportDeprecated] + + return _stream(type(self), ref, input=input) # type: ignore[return-value, arg-type] + def copy( self, *, diff --git a/src/replicate/_module_client.py b/src/replicate/_module_client.py index a3e8ab4..0f19a61 100644 --- a/src/replicate/_module_client.py +++ b/src/replicate/_module_client.py @@ -82,6 +82,7 @@ def __load__(self) -> PredictionsResource: __client: Replicate = cast(Replicate, {}) run = __client.run use = __client.use + stream = __client.stream # pyright: ignore[reportDeprecated] else: def _run(*args, **kwargs): @@ -100,8 +101,21 @@ def _use(ref, *, hint=None, streaming=False, use_async=False, **kwargs): return use(Replicate, ref, hint=hint, streaming=streaming, **kwargs) + def _stream(ref, *, input, use_async=False): + from .lib._stream import stream + + if use_async: + from ._client import AsyncReplicate + + return stream(AsyncReplicate, ref, input=input) + + from ._client import Replicate + + return stream(Replicate, ref, input=input) + run = _run use = _use + stream = _stream files: FilesResource = FilesResourceProxy().__as_proxied__() models: ModelsResource = ModelsResourceProxy().__as_proxied__() diff --git a/src/replicate/lib/_stream.py b/src/replicate/lib/_stream.py new file mode 100644 index 0000000..852c2fd --- /dev/null +++ b/src/replicate/lib/_stream.py @@ -0,0 +1,95 @@ +""" +Deprecated streaming functionality for backwards compatibility with v1 SDK. + +This module provides the stream() function which wraps replicate.use() with streaming=True. +""" + +from __future__ import annotations + +import warnings +from typing import Any, Dict, Type, Union, Iterator, AsyncIterator, overload +from typing_extensions import deprecated + +from .._client import Client, AsyncClient +from ._predictions_use import use + +__all__ = ["stream"] + + +def _format_deprecation_message(ref: str, input: Dict[str, Any]) -> str: + """Format the deprecation message with a working example.""" + # Format the input dict for display + input_str = "{\n" + for key, value in input.items(): + if isinstance(value, str): + input_str += f' "{key}": "{value}",\n' + else: + input_str += f' "{key}": {value},\n' + input_str += " }" + + return ( + f"replicate.stream() is deprecated and will be removed in a future version. " + f"Use replicate.use() with streaming=True instead:\n\n" + f' model = replicate.use("{ref}", streaming=True)\n' + f" for event in model(input={input_str}):\n" + f' print(str(event), end="")\n' + ) + + +@overload +def stream( + client: Type[Client], + ref: str, + *, + input: Dict[str, Any], +) -> Iterator[str]: ... + + +@overload +def stream( + client: Type[AsyncClient], + ref: str, + *, + input: Dict[str, Any], +) -> AsyncIterator[str]: ... + + +@deprecated("replicate.stream() is deprecated. Use replicate.use() with streaming=True instead") +def stream( + client: Union[Type[Client], Type[AsyncClient]], + ref: str, + *, + input: Dict[str, Any], +) -> Union[Iterator[str], AsyncIterator[str]]: + """ + Run a model and stream its output (deprecated). + + This function is deprecated. Use replicate.use() with streaming=True instead: + + model = replicate.use("anthropic/claude-4.5-sonnet", streaming=True) + for event in model(input={"prompt": "Hello"}): + print(str(event), end="") + + Args: + client: The Replicate client class (Client or AsyncClient) + ref: Reference to the model to run. Can be a string with owner/name format + (e.g., "anthropic/claude-4.5-sonnet") or owner/name:version format. + input: Dictionary of input parameters for the model. The required keys depend + on the specific model being run. + + Returns: + An iterator (or async iterator) that yields output strings as they are + generated by the model + """ + # Log deprecation warning with helpful migration example + warnings.warn( + _format_deprecation_message(ref, input), + DeprecationWarning, + stacklevel=2, + ) + + # Use the existing use() function with streaming=True + model = use(client, ref, streaming=True) # type: ignore[var-annotated] # pyright: ignore[reportUnknownVariableType] + + # Call the model with the input + return model(**input) # type: ignore[return-value] diff --git a/tests/lib/test_stream.py b/tests/lib/test_stream.py new file mode 100644 index 0000000..5317b1a --- /dev/null +++ b/tests/lib/test_stream.py @@ -0,0 +1,239 @@ +"""Tests for the deprecated stream() function.""" + +from __future__ import annotations + +import warnings +from unittest.mock import Mock, patch + +import pytest + +import replicate +from replicate import Replicate, AsyncReplicate + + +def test_stream_shows_deprecation_warning(): + """Test that stream() shows a deprecation warning.""" + with patch("replicate.lib._stream.use") as mock_use: + # Create a mock function that returns an iterator + mock_function = Mock() + mock_function.return_value = iter(["Hello", " ", "world"]) + mock_use.return_value = mock_function + + # Call stream and capture warnings + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + client = Replicate(bearer_token="test-token") + _ = list( # pyright: ignore[reportDeprecated] + client.stream( # pyright: ignore[reportDeprecated] + "anthropic/claude-4.5-sonnet", + input={"prompt": "Hello"}, + ) + ) + + # Check that deprecation warnings were raised + assert len(w) > 0 + deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] + assert len(deprecation_warnings) > 0 + + # Verify the @deprecated decorator message appears + # (there may be multiple warnings from different decorator levels) + messages = [str(warning.message) for warning in deprecation_warnings] + expected_message = "replicate.stream() is deprecated. Use replicate.use() with streaming=True instead" + assert expected_message in messages + + +def test_stream_calls_use_with_streaming_true(): + """Test that stream() internally calls use() with streaming=True.""" + with patch("replicate.lib._stream.use") as mock_use: + # Create a mock function that returns an iterator + mock_function = Mock() + mock_function.return_value = iter(["Hello", " ", "world"]) + mock_use.return_value = mock_function + + client = Replicate(bearer_token="test-token") + + # Suppress deprecation warnings for this test + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + _ = list( # pyright: ignore[reportDeprecated] + client.stream( # pyright: ignore[reportDeprecated] + "anthropic/claude-4.5-sonnet", + input={"prompt": "Hello"}, + ) + ) + + # Verify use() was called with streaming=True + mock_use.assert_called_once() + call_args = mock_use.call_args + assert call_args.kwargs["streaming"] is True + assert call_args.args[1] == "anthropic/claude-4.5-sonnet" + + # Verify the mock function was called with the input + mock_function.assert_called_once_with(prompt="Hello") + + +def test_stream_returns_iterator(): + """Test that stream() returns an iterator of strings.""" + with patch("replicate.lib._stream.use") as mock_use: + # Create a mock function that returns an iterator + mock_function = Mock() + mock_function.return_value = iter(["Hello", " ", "world", "!"]) + mock_use.return_value = mock_function + + client = Replicate(bearer_token="test-token") + + # Suppress deprecation warnings for this test + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + result = client.stream( # pyright: ignore[reportDeprecated] + "anthropic/claude-4.5-sonnet", + input={"prompt": "Say hello"}, + ) + + # Verify we get an iterator + assert hasattr(result, "__iter__") + + # Verify the content + output = list(result) + assert output == ["Hello", " ", "world", "!"] + + +def test_stream_works_same_as_use_with_streaming(): + """Test that stream() produces the same output as use() with streaming=True.""" + with patch("replicate.lib._stream.use") as mock_stream_use, \ + patch("replicate.lib._predictions_use.use") as mock_predictions_use: + # Create a mock function that returns an iterator + mock_function = Mock() + expected_output = ["Test", " ", "output"] + mock_function.return_value = iter(expected_output.copy()) + mock_stream_use.return_value = mock_function + mock_predictions_use.return_value = mock_function + + client = Replicate(bearer_token="test-token") + + # Get output from stream() + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + stream_output = list( + client.stream( # pyright: ignore[reportDeprecated] + "test-model", + input={"prompt": "test"}, + ) + ) + + # Reset the mock + mock_function.return_value = iter(expected_output.copy()) + + # Get output from use() with streaming=True + model = client.use("test-model", streaming=True) # pyright: ignore[reportUnknownVariableType] + use_output = list(model(prompt="test")) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] + + # Verify they produce the same output + assert stream_output == use_output + + +def test_module_level_stream_function(): + """Test that the module-level replicate.stream() function works.""" + with patch("replicate.lib._stream.use") as mock_use: + # Create a mock function that returns an iterator + mock_function = Mock() + mock_function.return_value = iter(["a", "b", "c"]) + mock_use.return_value = mock_function + + # Suppress deprecation warnings for this test + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + result = list( + replicate.stream( # pyright: ignore[reportDeprecated] + "test-model", + input={"prompt": "test"}, + ) + ) + + # Verify we got the expected output + assert result == ["a", "b", "c"] + + # Verify use() was called with streaming=True + mock_use.assert_called_once() + assert mock_use.call_args.kwargs["streaming"] is True + + +@pytest.mark.asyncio +async def test_async_stream_shows_deprecation_warning(): + """Test that async stream() shows a deprecation warning.""" + with patch("replicate.lib._stream.use") as mock_use: + # Create a mock async function that returns an async iterator + async def async_gen(): + yield "Hello" + yield " " + yield "world" + + mock_function = Mock() + mock_function.return_value = async_gen() + mock_use.return_value = mock_function + + # Call stream and capture warnings + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + client = AsyncReplicate(bearer_token="test-token") + result = [] + async for item in await client.stream( # pyright: ignore[reportDeprecated] + "anthropic/claude-4.5-sonnet", + input={"prompt": "Hello"}, + ): + result.append(item) # pyright: ignore[reportUnknownMemberType] + + # Check that deprecation warnings were raised + assert len(w) > 0 + deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] + assert len(deprecation_warnings) > 0 + + # Verify the @deprecated decorator message appears + messages = [str(warning.message) for warning in deprecation_warnings] + expected_message = "replicate.stream() is deprecated. Use replicate.use() with streaming=True instead" + assert expected_message in messages + + +def test_deprecation_message_includes_example(): + """Test that the detailed deprecation message includes a helpful example.""" + with patch("replicate.lib._stream.use") as mock_use: + mock_function = Mock() + mock_function.return_value = iter([]) + mock_use.return_value = mock_function + + client = Replicate(bearer_token="test-token") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + list( + client.stream( # pyright: ignore[reportDeprecated] + "anthropic/claude-4.5-sonnet", + input={"prompt": "Hello", "max_tokens": 100}, + ) + ) + + # Find the detailed warning from _format_deprecation_message + # (should be one of the warnings, as there are multiple deprecation warnings) + detailed_message = None + for warning in w: + msg = str(warning.message) + if "will be removed in a future version" in msg: + detailed_message = msg + break + + assert detailed_message is not None, "Expected detailed deprecation message not found" + + # Verify the complete detailed message format + expected_message = ( + "replicate.stream() is deprecated and will be removed in a future version. " + "Use replicate.use() with streaming=True instead:\n\n" + ' model = replicate.use("anthropic/claude-4.5-sonnet", streaming=True)\n' + " for event in model(input={\n" + ' "prompt": "Hello",\n' + ' "max_tokens": 100,\n' + " }):\n" + ' print(str(event), end="")\n' + ) + assert detailed_message == expected_message