-
Notifications
You must be signed in to change notification settings - Fork 2.2k
feat: Add async streaming support in HuggingFaceLocalChatGenerator
#9405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add async streaming support in HuggingFaceLocalChatGenerator
#9405
Conversation
HuggingFaceLocalChatGenerator
Pull Request Test Coverage Report for Build 15588007009Warning: This coverage report may be inaccurate.This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.
Details
💛 - Coveralls |
@@ -566,7 +568,7 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments | |||
) | |||
|
|||
# Set up streaming handler | |||
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words) | |||
generation_kwargs["streamer"] = AsyncHFTokenStreamingHandler(tokenizer, streaming_callback, stop_words) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make mypy happy could we add an assert here asserting that streaming_callback
is of type AsyncStreamingCallbackT
?
or update AsyncHFTokenStreamingHandler
such that the type hint for stream_handler
is StreamingCallbackT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've left a comment below!
@@ -608,3 +616,52 @@ def test_to_dict_with_toolset(self, model_info_mock, mock_pipeline_tokenizer, to | |||
}, | |||
} | |||
assert data["init_parameters"]["tools"] == expected_tools_data | |||
|
|||
@pytest.mark.asyncio | |||
async def test_run_async_with_streaming_callback(self, model_info_mock): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we also add an integration test for this? So an async version of test_live_run
with streaming?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would also simplify this, removing asyncio.Event
usage and reuse an already available mock.
Something like:
@pytest.mark.asyncio
async def test_run_async_with_streaming_callback(self, model_info_mock, mock_pipeline_with_tokenizer):
streaming_chunks = []
async def streaming_callback(chunk: StreamingChunk) -> None:
streaming_chunks.append(chunk)
# Create a mock that simulates streaming behavior
def mock_pipeline_call(*args, **kwargs):
streamer = kwargs.get("streamer")
if streamer:
# Simulate streaming chunks
streamer.on_finalized_text("Berlin", stream_end=False)
streamer.on_finalized_text(" is cool", stream_end=True)
return [{"generated_text": "Berlin is cool"}]
# Setup the mock pipeline with streaming simulation
mock_pipeline_with_tokenizer.side_effect = mock_pipeline_call
generator = HuggingFaceLocalChatGenerator(model="test-model", streaming_callback=streaming_callback)
generator.pipeline = mock_pipeline_with_tokenizer
messages = [ChatMessage.from_user("Test message")]
response = await generator.run_async(messages)
# Verify streaming chunks were collected
assert len(streaming_chunks) == 2
assert streaming_chunks[0].content == "Berlin"
assert streaming_chunks[1].content == " is cool\n"
# Verify the final response
assert isinstance(response, dict)
assert "replies" in response
assert len(response["replies"]) == 1
assert isinstance(response["replies"][0], ChatMessage)
assert response["replies"][0].text == "Berlin is cool"
WDYT?
- fix breaking tests - added component_info to AsyncHFTokenStreamingHandler
@sjrl: added a live integration test
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added some comments for possible improvements. Let me know if they are clear enough!
for r_index, reply in enumerate(replies) | ||
] | ||
# Remove stop words from replies if present | ||
for stop_word in stop_words or []: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about adding a more explicit check here? (can apply also on line 427):
if stop_words:
for stop_word in stop_words:
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
generation_kwargs["streamer"] = HFTokenStreamingHandler( | ||
tokenizer, streaming_callback, stop_words, component_info | ||
) | ||
assert asyncio.iscoroutinefunction(streaming_callback), "Streaming callback must be asynchronous" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use select_streaming_callback
utility here? (we used it in other generators)
You can get it from:
from haystack.dataclasses.streaming_chunk import select_streaming_callback
so we can avoid assert
usage!
# Clean up the queue processor | ||
queue_processor.cancel() | ||
with suppress(asyncio.CancelledError): | ||
await queue_processor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This cleanup logic can be a bit more robust: we can add a short timeout so we can ensure queue is drained:
finally:
try:
await asyncio.wait_for(queue_processor, timeout=0.1)
except asyncio.TimeoutError:
queue_processor.cancel()
with suppress(asyncio.CancelledError):
await queue_processor
WDYT?
@pytest.mark.slow | ||
@pytest.mark.flaky(reruns=3, reruns_delay=10) | ||
@pytest.mark.asyncio | ||
async def test_live_run_async_with_streaming(self, monkeypatch): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that this test is a bit over-engineered. What about something simpler like the following? No need to use e.g. asyncio.Event
to check when the streaming is done.
@pytest.mark.integration
@pytest.mark.slow
@pytest.mark.flaky(reruns=3, reruns_delay=10)
@pytest.mark.asyncio
async def test_live_run_async_with_streaming(self, monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False)
streaming_chunks = []
async def streaming_callback(chunk: StreamingChunk) -> None:
streaming_chunks.append(chunk)
llm = HuggingFaceLocalChatGenerator(
model="Qwen/Qwen2.5-0.5B-Instruct",
generation_kwargs={"max_new_tokens": 50},
streaming_callback=streaming_callback,
)
llm.warm_up()
response = await llm.run_async(
messages=[ChatMessage.from_user("Please create a summary about the following topic: Capital of France")]
)
# Verify that the response is not None
assert len(streaming_chunks) > 0
assert "replies" in response
assert isinstance(response["replies"][0], ChatMessage)
assert response["replies"][0].text is not None
# Verify that the response contains the word "Paris"
assert "Paris" in response["replies"][0].text
# Verify streaming chunks contain actual content
total_streamed_content = "".join(chunk.content for chunk in streaming_chunks)
assert len(total_streamed_content.strip()) > 0
assert "Paris" in total_streamed_content
WDYT?
@@ -608,3 +616,52 @@ def test_to_dict_with_toolset(self, model_info_mock, mock_pipeline_tokenizer, to | |||
}, | |||
} | |||
assert data["init_parameters"]["tools"] == expected_tools_data | |||
|
|||
@pytest.mark.asyncio | |||
async def test_run_async_with_streaming_callback(self, model_info_mock): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would also simplify this, removing asyncio.Event
usage and reuse an already available mock.
Something like:
@pytest.mark.asyncio
async def test_run_async_with_streaming_callback(self, model_info_mock, mock_pipeline_with_tokenizer):
streaming_chunks = []
async def streaming_callback(chunk: StreamingChunk) -> None:
streaming_chunks.append(chunk)
# Create a mock that simulates streaming behavior
def mock_pipeline_call(*args, **kwargs):
streamer = kwargs.get("streamer")
if streamer:
# Simulate streaming chunks
streamer.on_finalized_text("Berlin", stream_end=False)
streamer.on_finalized_text(" is cool", stream_end=True)
return [{"generated_text": "Berlin is cool"}]
# Setup the mock pipeline with streaming simulation
mock_pipeline_with_tokenizer.side_effect = mock_pipeline_call
generator = HuggingFaceLocalChatGenerator(model="test-model", streaming_callback=streaming_callback)
generator.pipeline = mock_pipeline_with_tokenizer
messages = [ChatMessage.from_user("Test message")]
response = await generator.run_async(messages)
# Verify streaming chunks were collected
assert len(streaming_chunks) == 2
assert streaming_chunks[0].content == "Berlin"
assert streaming_chunks[1].content == " is cool\n"
# Verify the final response
assert isinstance(response, dict)
assert "replies" in response
assert len(response["replies"]) == 1
assert isinstance(response["replies"][0], ChatMessage)
assert response["replies"][0].text == "Berlin is cool"
WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - we can address other minor nits later!
Related Issues
HuggingFaceLocalChatGenerator
add an async version ofHFTokenStreamHandler
and update type signature for async streaming callback #9391Proposed Changes:
How did you test it?
Notes for the reviewer
Checklist
fix:
,feat:
,build:
,chore:
,ci:
,docs:
,style:
,refactor:
,perf:
,test:
and added!
in case the PR includes breaking changes.