-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Streaming #1874
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Streaming #1874
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
7816ba3
dspy.streamify
CyrusNuevoDia d8bc33c
Update docs
CyrusNuevoDia 48a2376
Merge branch 'main' into streaming
CyrusNuevoDia ab8b28e
Fix ruff lint error
CyrusNuevoDia 7c251a9
Bring back send_stream to settings
CyrusNuevoDia 00c6e76
Improve doc
CyrusNuevoDia c02a7fc
Bring back request_cache setting
CyrusNuevoDia 707e5a3
sse => streaming_response
CyrusNuevoDia e852b4c
Simplify dsp.utils.settings diff
CyrusNuevoDia d4064e1
Add load/dump to LRUCache + drop callable request params
CyrusNuevoDia 92f137f
ujson => pickle for dump/load
CyrusNuevoDia 070101f
Stream fix
dbczumar a573ed0
Merge conflict
dbczumar 10c7bbd
test streaming
dbczumar dfa0e3b
fix
dbczumar 68aca74
fix
dbczumar ca928fd
Streaming works
dbczumar 962283d
Fix
dbczumar 93fe33b
fix
dbczumar a234c7b
no ignore change
dbczumar eb7e792
fix
dbczumar 7772a0d
Simple init
dbczumar 4232dc4
Simple init
dbczumar File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| import copy | ||
| import threading | ||
| from contextlib import contextmanager | ||
|
|
||
| from dspy.dsp.utils.utils import dotdict | ||
|
|
||
| DEFAULT_CONFIG = dotdict( | ||
|
|
@@ -17,6 +18,7 @@ | |
| backoff_time=10, | ||
| callbacks=[], | ||
| async_max_workers=8, | ||
| send_stream=None, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the only meaningful change in the file - everything else is a linter adjustment |
||
| ) | ||
|
|
||
| # Global base configuration and owner tracking | ||
|
|
@@ -26,20 +28,22 @@ | |
| # Global lock for settings configuration | ||
| global_lock = threading.Lock() | ||
|
|
||
|
|
||
| class ThreadLocalOverrides(threading.local): | ||
| def __init__(self): | ||
| self.overrides = dotdict() | ||
|
|
||
|
|
||
| thread_local_overrides = ThreadLocalOverrides() | ||
|
|
||
|
|
||
| class Settings: | ||
| """ | ||
| A singleton class for DSPy configuration settings. | ||
| Thread-safe global configuration. | ||
| Thread-safe global configuration. | ||
| - 'configure' can be called by only one 'owner' thread (the first thread that calls it). | ||
| - Other threads see the configured global values from 'main_thread_config'. | ||
| - 'context' sets thread-local overrides. These overrides propagate to threads spawned | ||
| - 'context' sets thread-local overrides. These overrides propagate to threads spawned | ||
| inside that context block, when (and only when!) using a ParallelExecutor that copies overrides. | ||
|
|
||
| 1. Only one unique thread (which can be any thread!) can call dspy.configure. | ||
|
|
@@ -61,7 +65,7 @@ def lock(self): | |
| return global_lock | ||
|
|
||
| def __getattr__(self, name): | ||
| overrides = getattr(thread_local_overrides, 'overrides', dotdict()) | ||
| overrides = getattr(thread_local_overrides, "overrides", dotdict()) | ||
| if name in overrides: | ||
| return overrides[name] | ||
| elif name in main_thread_config: | ||
|
|
@@ -70,7 +74,7 @@ def __getattr__(self, name): | |
| raise AttributeError(f"'Settings' object has no attribute '{name}'") | ||
|
|
||
| def __setattr__(self, name, value): | ||
| if name in ('_instance',): | ||
| if name in ("_instance",): | ||
| super().__setattr__(name, value) | ||
| else: | ||
| self.configure(**{name: value}) | ||
|
|
@@ -82,7 +86,7 @@ def __setitem__(self, key, value): | |
| self.__setattr__(key, value) | ||
|
|
||
| def __contains__(self, key): | ||
| overrides = getattr(thread_local_overrides, 'overrides', dotdict()) | ||
| overrides = getattr(thread_local_overrides, "overrides", dotdict()) | ||
| return key in overrides or key in main_thread_config | ||
|
|
||
| def get(self, key, default=None): | ||
|
|
@@ -92,7 +96,7 @@ def get(self, key, default=None): | |
| return default | ||
|
|
||
| def copy(self): | ||
| overrides = getattr(thread_local_overrides, 'overrides', dotdict()) | ||
| overrides = getattr(thread_local_overrides, "overrides", dotdict()) | ||
| return dotdict({**main_thread_config, **overrides}) | ||
|
|
||
| @property | ||
|
|
@@ -122,7 +126,7 @@ def context(self, **kwargs): | |
| If threads are spawned inside this block using ParallelExecutor, they will inherit these overrides. | ||
| """ | ||
|
|
||
| original_overrides = getattr(thread_local_overrides, 'overrides', dotdict()).copy() | ||
| original_overrides = getattr(thread_local_overrides, "overrides", dotdict()).copy() | ||
| new_overrides = dotdict({**main_thread_config, **original_overrides, **kwargs}) | ||
| thread_local_overrides.overrides = new_overrides | ||
|
|
||
|
|
@@ -132,7 +136,7 @@ def context(self, **kwargs): | |
| thread_local_overrides.overrides = original_overrides | ||
|
|
||
| def __repr__(self): | ||
| overrides = getattr(thread_local_overrides, 'overrides', dotdict()) | ||
| overrides = getattr(thread_local_overrides, "overrides", dotdict()) | ||
| combined_config = {**main_thread_config, **overrides} | ||
| return repr(combined_config) | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| from asyncio import iscoroutinefunction | ||
| from typing import Any, AsyncGenerator, Awaitable, Callable | ||
|
|
||
| import litellm | ||
| import ujson | ||
| from anyio import create_memory_object_stream, create_task_group | ||
| from anyio.streams.memory import MemoryObjectSendStream | ||
|
|
||
| from dspy.primitives.prediction import Prediction | ||
| from dspy.primitives.program import Module | ||
| from dspy.utils.asyncify import asyncify | ||
|
|
||
|
|
||
| def streamify(program: Module) -> Callable[[Any, Any], Awaitable[Any]]: | ||
| """ | ||
| Wrap a DSPy program so that it streams its outputs incrementally, rather than returning them | ||
| all at once. | ||
| Args: | ||
| program: The DSPy program to wrap with streaming functionality. | ||
| Returns: | ||
| A function that takes the same arguments as the original program, but returns an async | ||
| generator that yields the program's outputs incrementally. | ||
| Example: | ||
| >>> class TestSignature(dspy.Signature): | ||
| >>> input_text: str = dspy.InputField() | ||
| >>> output_text: str = dspy.OutputField() | ||
| >>> | ||
| >>> # Create the program and wrap it with streaming functionality | ||
| >>> program = dspy.streamify(dspy.Predict(TestSignature)) | ||
| >>> | ||
| >>> # Use the program with streaming output | ||
| >>> async def use_streaming(): | ||
| >>> output_stream = program(input_text="Test") | ||
| >>> async for value in output_stream: | ||
| >>> print(value) # Print each streamed value incrementally | ||
| """ | ||
| import dspy | ||
|
|
||
| if not iscoroutinefunction(program): | ||
| program = asyncify(program) | ||
|
|
||
| async def generator(args, kwargs, stream: MemoryObjectSendStream): | ||
| with dspy.settings.context(send_stream=stream): | ||
| prediction = await program(*args, **kwargs) | ||
|
|
||
| await stream.send(prediction) | ||
|
|
||
| async def streamer(*args, **kwargs): | ||
| send_stream, receive_stream = create_memory_object_stream(16) | ||
| async with create_task_group() as tg, send_stream, receive_stream: | ||
| tg.start_soon(generator, args, kwargs, send_stream) | ||
|
|
||
| async for value in receive_stream: | ||
| yield value | ||
| if isinstance(value, Prediction): | ||
| return | ||
|
|
||
| return streamer | ||
|
|
||
|
|
||
| async def streaming_response(streamer: AsyncGenerator) -> AsyncGenerator: | ||
| """ | ||
| Convert a DSPy program output stream to an OpenAI-compatible output stream that can be | ||
| used by a service as an API response to a streaming request. | ||
| Args: | ||
| streamer: An async generator that yields values from a DSPy program output stream. | ||
| Returns: | ||
| An async generator that yields OpenAI-compatible streaming response chunks. | ||
| """ | ||
| async for value in streamer: | ||
| if isinstance(value, Prediction): | ||
| data = {"prediction": {k: v for k, v in value.items(include_dspy=False)}} | ||
CyrusNuevoDia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| yield f"data: {ujson.dumps(data)}\n\n" | ||
| elif isinstance(value, litellm.ModelResponse): | ||
| data = {"chunk": value.json()} | ||
| yield f"data: {ujson.dumps(data)}\n\n" | ||
| elif isinstance(value, str) and value.startswith("data:"): | ||
| # The chunk value is an OpenAI-compatible streaming chunk value, | ||
| # e.g. "data: {"finish_reason": "stop", "index": 0, "is_finished": True, ...}", | ||
| # so yield it directly | ||
| yield value | ||
CyrusNuevoDia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| else: | ||
| raise ValueError(f"Unknown chunk value type: {value}") | ||
| yield "data: [DONE]\n\n" | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.