diff --git a/docs/docs/tutorials/deployment/index.md b/docs/docs/tutorials/deployment/index.md index 6c6cd3d048..5117a3bdeb 100644 --- a/docs/docs/tutorials/deployment/index.md +++ b/docs/docs/tutorials/deployment/index.md @@ -41,9 +41,11 @@ class Question(BaseModel): # Configure your language model and 'asyncify' your DSPy program. lm = dspy.LM("openai/gpt-4o-mini") dspy.settings.configure(lm=lm, async_max_workers=4) # default is 8 -dspy_program = dspy.ChainOfThought("question -> answer") -dspy_program = dspy.asyncify(dspy_program) +dspy_program = dspy.asyncify(dspy.ChainOfThought("question -> answer")) +streaming_dspy_program = dspy.streamify(dspy_program) + +# Define an endpoint (no streaming) @app.post("/predict") async def predict(question: Question): try: @@ -54,14 +56,45 @@ async def predict(question: Question): } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + +# Define an endpoint (streaming) +from fastapi.responses import StreamingResponse + +@app.post("/predict/stream") +async def stream(question: Question): + async def generate(): + async for value in streaming_dspy_program(question=question.text): + if isinstance(value, dspy.Prediction): + data = {"prediction": value.labels().toDict()} + elif isinstance(value, litellm.ModelResponse): + data = {"chunk": value.json()} + yield f"data: {ujson.dumps(data)}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse(generate(), media_type="text/event-stream") + +# Since you're often going to want to stream the result of a DSPy program as server-sent events, +# we've included a helper function for that, which is equivalent to the code above. + +from dspy.utils.streaming import streaming_response + +@app.post("/predict/stream") +async def stream(question: Question): + stream = streaming_dspy_program(question=question.text) + return StreamingResponse(streaming_response(stream), media_type="text/event-stream") ``` In the code above, we call `dspy.asyncify` to convert the dspy program to run in async mode for high-throughput FastAPI -deployments. Currently, this runs the dspy program in a -separate thread and awaits its result. By default, the limit of spawned threads is 8. Think of this like a worker pool. +deployments. Currently, this runs the dspy program in a separate thread and awaits its result. + +By default, the limit of spawned threads is 8. Think of this like a worker pool. If you have 8 in-flight programs and call it once more, the 9th call will wait until one of the 8 returns. You can configure the async capacity using the new `async_max_workers` setting. +We also use `dspy.streamify` to convert the dspy program to a streaming mode. This is useful when you want to stream +the intermediate outputs (i.e. O1-style reasoning) to the client before the final prediction is ready. This uses +asyncify under the hood and inherits the execution semantics. + Write your code to a file, e.g., `fastapi_dspy.py`. Then you can serve the app with: ```bash diff --git a/dspy/__init__.py b/dspy/__init__.py index 6f4428525e..fea48caca8 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -12,6 +12,7 @@ from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging from dspy.utils.asyncify import asyncify from dspy.utils.saving import load +from dspy.utils.streaming import streamify from dspy.dsp.utils.settings import settings diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index f70259e537..66cd7ea2cb 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -5,14 +5,17 @@ import uuid from datetime import datetime from hashlib import sha256 -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, cast import litellm import pydantic import ujson +from anyio.streams.memory import MemoryObjectSendStream +from asyncer import syncify from cachetools import LRUCache, cached from litellm import RetryPolicy +import dspy from dspy.adapters.base import Adapter from dspy.clients.openai import OpenAIProvider from dspy.clients.provider import Provider, TrainingJob @@ -309,16 +312,41 @@ def cached_litellm_completion(request: Dict[str, Any], num_retries: int): def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}): - return litellm.completion( - cache=cache, + retry_kwargs = dict( retry_policy=_get_litellm_retry_policy(num_retries), # In LiteLLM version 1.55.3 (the first version that supports retry_policy as an argument # to completion()), the default value of max_retries is non-zero for certain providers, and # max_retries is stacked on top of the retry_policy. To avoid this, we set max_retries=0 max_retries=0, - **request, ) + stream = dspy.settings.send_stream + if stream is None: + return litellm.completion( + cache=cache, + **retry_kwargs, + **request, + ) + + # The stream is already opened, and will be closed by the caller. + stream = cast(MemoryObjectSendStream, stream) + + @syncify + async def stream_completion(): + response = await litellm.acompletion( + cache=cache, + stream=True, + **retry_kwargs, + **request, + ) + chunks = [] + async for chunk in response: + chunks.append(chunk) + await stream.send(chunk) + return litellm.stream_chunk_builder(chunks) + + return stream_completion() + @request_cache(maxsize=None) def cached_litellm_text_completion(request: Dict[str, Any], num_retries: int): diff --git a/dspy/dsp/utils/settings.py b/dspy/dsp/utils/settings.py index 79295ea029..2f3b619aea 100644 --- a/dspy/dsp/utils/settings.py +++ b/dspy/dsp/utils/settings.py @@ -1,6 +1,7 @@ import copy import threading from contextlib import contextmanager + from dspy.dsp.utils.utils import dotdict DEFAULT_CONFIG = dotdict( @@ -17,6 +18,7 @@ backoff_time=10, callbacks=[], async_max_workers=8, + send_stream=None, ) # Global base configuration and owner tracking @@ -26,20 +28,22 @@ # Global lock for settings configuration global_lock = threading.Lock() + class ThreadLocalOverrides(threading.local): def __init__(self): self.overrides = dotdict() + thread_local_overrides = ThreadLocalOverrides() class Settings: """ A singleton class for DSPy configuration settings. - Thread-safe global configuration. + Thread-safe global configuration. - 'configure' can be called by only one 'owner' thread (the first thread that calls it). - Other threads see the configured global values from 'main_thread_config'. - - 'context' sets thread-local overrides. These overrides propagate to threads spawned + - 'context' sets thread-local overrides. These overrides propagate to threads spawned inside that context block, when (and only when!) using a ParallelExecutor that copies overrides. 1. Only one unique thread (which can be any thread!) can call dspy.configure. @@ -61,7 +65,7 @@ def lock(self): return global_lock def __getattr__(self, name): - overrides = getattr(thread_local_overrides, 'overrides', dotdict()) + overrides = getattr(thread_local_overrides, "overrides", dotdict()) if name in overrides: return overrides[name] elif name in main_thread_config: @@ -70,7 +74,7 @@ def __getattr__(self, name): raise AttributeError(f"'Settings' object has no attribute '{name}'") def __setattr__(self, name, value): - if name in ('_instance',): + if name in ("_instance",): super().__setattr__(name, value) else: self.configure(**{name: value}) @@ -82,7 +86,7 @@ def __setitem__(self, key, value): self.__setattr__(key, value) def __contains__(self, key): - overrides = getattr(thread_local_overrides, 'overrides', dotdict()) + overrides = getattr(thread_local_overrides, "overrides", dotdict()) return key in overrides or key in main_thread_config def get(self, key, default=None): @@ -92,7 +96,7 @@ def get(self, key, default=None): return default def copy(self): - overrides = getattr(thread_local_overrides, 'overrides', dotdict()) + overrides = getattr(thread_local_overrides, "overrides", dotdict()) return dotdict({**main_thread_config, **overrides}) @property @@ -122,7 +126,7 @@ def context(self, **kwargs): If threads are spawned inside this block using ParallelExecutor, they will inherit these overrides. """ - original_overrides = getattr(thread_local_overrides, 'overrides', dotdict()).copy() + original_overrides = getattr(thread_local_overrides, "overrides", dotdict()).copy() new_overrides = dotdict({**main_thread_config, **original_overrides, **kwargs}) thread_local_overrides.overrides = new_overrides @@ -132,7 +136,7 @@ def context(self, **kwargs): thread_local_overrides.overrides = original_overrides def __repr__(self): - overrides = getattr(thread_local_overrides, 'overrides', dotdict()) + overrides = getattr(thread_local_overrides, "overrides", dotdict()) combined_config = {**main_thread_config, **overrides} return repr(combined_config) diff --git a/dspy/utils/streaming.py b/dspy/utils/streaming.py new file mode 100644 index 0000000000..bf1f4f5bd6 --- /dev/null +++ b/dspy/utils/streaming.py @@ -0,0 +1,87 @@ +from asyncio import iscoroutinefunction +from typing import Any, AsyncGenerator, Awaitable, Callable + +import litellm +import ujson +from anyio import create_memory_object_stream, create_task_group +from anyio.streams.memory import MemoryObjectSendStream + +from dspy.primitives.prediction import Prediction +from dspy.primitives.program import Module +from dspy.utils.asyncify import asyncify + + +def streamify(program: Module) -> Callable[[Any, Any], Awaitable[Any]]: + """ + Wrap a DSPy program so that it streams its outputs incrementally, rather than returning them + all at once. + + Args: + program: The DSPy program to wrap with streaming functionality. + Returns: + A function that takes the same arguments as the original program, but returns an async + generator that yields the program's outputs incrementally. + + Example: + >>> class TestSignature(dspy.Signature): + >>> input_text: str = dspy.InputField() + >>> output_text: str = dspy.OutputField() + >>> + >>> # Create the program and wrap it with streaming functionality + >>> program = dspy.streamify(dspy.Predict(TestSignature)) + >>> + >>> # Use the program with streaming output + >>> async def use_streaming(): + >>> output_stream = program(input_text="Test") + >>> async for value in output_stream: + >>> print(value) # Print each streamed value incrementally + """ + import dspy + + if not iscoroutinefunction(program): + program = asyncify(program) + + async def generator(args, kwargs, stream: MemoryObjectSendStream): + with dspy.settings.context(send_stream=stream): + prediction = await program(*args, **kwargs) + + await stream.send(prediction) + + async def streamer(*args, **kwargs): + send_stream, receive_stream = create_memory_object_stream(16) + async with create_task_group() as tg, send_stream, receive_stream: + tg.start_soon(generator, args, kwargs, send_stream) + + async for value in receive_stream: + yield value + if isinstance(value, Prediction): + return + + return streamer + + +async def streaming_response(streamer: AsyncGenerator) -> AsyncGenerator: + """ + Convert a DSPy program output stream to an OpenAI-compatible output stream that can be + used by a service as an API response to a streaming request. + + Args: + streamer: An async generator that yields values from a DSPy program output stream. + Returns: + An async generator that yields OpenAI-compatible streaming response chunks. + """ + async for value in streamer: + if isinstance(value, Prediction): + data = {"prediction": {k: v for k, v in value.items(include_dspy=False)}} + yield f"data: {ujson.dumps(data)}\n\n" + elif isinstance(value, litellm.ModelResponse): + data = {"chunk": value.json()} + yield f"data: {ujson.dumps(data)}\n\n" + elif isinstance(value, str) and value.startswith("data:"): + # The chunk value is an OpenAI-compatible streaming chunk value, + # e.g. "data: {"finish_reason": "stop", "index": 0, "is_finished": True, ...}", + # so yield it directly + yield value + else: + raise ValueError(f"Unknown chunk value type: {value}") + yield "data: [DONE]\n\n" diff --git a/tests/test_utils/server/litellm_server.py b/tests/test_utils/server/litellm_server.py index 18cb5c5e44..419ea45782 100644 --- a/tests/test_utils/server/litellm_server.py +++ b/tests/test_utils/server/litellm_server.py @@ -1,8 +1,11 @@ import json import os +import time +from typing import AsyncIterator, Iterator import litellm from litellm import CustomLLM +from litellm.types.utils import GenericStreamingChunk LITELLM_TEST_SERVER_LOG_FILE_PATH_ENV_VAR = "LITELLM_TEST_SERVER_LOG_FILE_PATH" @@ -16,6 +19,28 @@ async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse: _append_request_to_log_file(kwargs) return _get_mock_llm_response(kwargs) + def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]: + generic_streaming_chunk: GenericStreamingChunk = { + "finish_reason": "stop", + "index": 0, + "is_finished": True, + "text": '{"output_text": "Hello!"}', + "tool_use": None, + "usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0}, + } + return generic_streaming_chunk # type: ignore + + async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: + generic_streaming_chunk: GenericStreamingChunk = { + "finish_reason": "stop", + "index": 0, + "is_finished": True, + "text": '{"output_text": "Hello!"}', + "tool_use": None, + "usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0}, + } + yield generic_streaming_chunk + def _get_mock_llm_response(request_kwargs): _throw_exception_based_on_content_if_applicable(request_kwargs) diff --git a/tests/utils/test_streaming.py b/tests/utils/test_streaming.py new file mode 100644 index 0000000000..dbaa8be663 --- /dev/null +++ b/tests/utils/test_streaming.py @@ -0,0 +1,60 @@ +import pytest + +import dspy +from dspy.utils.streaming import streaming_response +from tests.test_utils.server import litellm_test_server + + +@pytest.mark.asyncio +async def test_streamify_yields_expected_response_chunks(litellm_test_server): + api_base, _ = litellm_test_server + lm = dspy.LM( + model="openai/dspy-test-model", + api_base=api_base, + api_key="fakekey", + ) + with dspy.context(lm=lm): + + class TestSignature(dspy.Signature): + input_text: str = dspy.InputField() + output_text: str = dspy.OutputField() + + program = dspy.streamify(dspy.Predict(TestSignature)) + output_stream1 = program(input_text="Test") + output_chunks1 = [chunk async for chunk in output_stream1] + assert len(output_chunks1) > 1 + last_chunk1 = output_chunks1[-1] + assert isinstance(last_chunk1, dspy.Prediction) + assert last_chunk1.output_text == "Hello!" + + output_stream2 = program(input_text="Test") + output_chunks2 = [chunk async for chunk in output_stream2] + # Since the input is cached, only one chunk should be + # yielded containing the prediction + assert len(output_chunks2) == 1 + last_chunk2 = output_chunks2[-1] + assert isinstance(last_chunk2, dspy.Prediction) + assert last_chunk2.output_text == "Hello!" + + +@pytest.mark.asyncio +async def test_streaming_response_yields_expected_response_chunks(litellm_test_server): + api_base, _ = litellm_test_server + lm = dspy.LM( + model="openai/dspy-test-model", + api_base=api_base, + api_key="fakekey", + ) + with dspy.context(lm=lm): + + class TestSignature(dspy.Signature): + input_text: str = dspy.InputField() + output_text: str = dspy.OutputField() + + program = dspy.streamify(dspy.Predict(TestSignature)) + output_stream_from_program = streaming_response(program(input_text="Test")) + output_stream_for_server_response = streaming_response(output_stream_from_program) + output_chunks = [chunk async for chunk in output_stream_for_server_response] + assert all(chunk.startswith("data: ") for chunk in output_chunks) + assert 'data: {"prediction":{"output_text":"Hello!"}}\n\n' in output_chunks + assert output_chunks[-1] == "data: [DONE]\n\n"