Skip to content

Commit

Permalink
Avoid unexpected error when stream chat doesn't yield (#13422)
Browse files Browse the repository at this point in the history
Fix nonyielding stream chat bug

Co-authored-by: Logan Markewich <logan.markewich@live.com>
  • Loading branch information
joelrorseth and logan-markewich committed May 14, 2024
1 parent 662e0f6 commit 6011233
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Any, List, Optional

from llama_index.core.bridge.pydantic import BaseModel
from llama_index.core.base.llms.types import (
ChatMessage,
Expand Down Expand Up @@ -138,7 +137,7 @@ class LLMChatInProgressEvent(BaseEvent):
Args:
messages (List[ChatMessage]): List of chat messages.
response (ChatResponse): Chat response currently beiung streamed.
response (ChatResponse): Chat response currently being streamed.
"""

messages: List[ChatMessage]
Expand All @@ -155,11 +154,11 @@ class LLMChatEndEvent(BaseEvent):
Args:
messages (List[ChatMessage]): List of chat messages.
response (ChatResponse): Chat response.
response (Optional[ChatResponse]): Last chat response.
"""

messages: List[ChatMessage]
response: ChatResponse
response: Optional[ChatResponse]

@classmethod
def class_name(cls):
Expand Down
4 changes: 2 additions & 2 deletions llama-index-core/llama_index/core/llms/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ async def wrapped_gen() -> ChatResponseAsyncGen:
dispatcher.event(
LLMChatEndEvent(
messages=messages,
response=x,
response=last_response,
span_id=span_id,
)
)
Expand Down Expand Up @@ -173,7 +173,7 @@ def wrapped_gen() -> ChatResponseGen:
dispatcher.event(
LLMChatEndEvent(
messages=messages,
response=x,
response=last_response,
span_id=span_id,
)
)
Expand Down
12 changes: 10 additions & 2 deletions llama-index-core/llama_index/core/llms/mock.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Any, Callable, Optional, Sequence

from llama_index.core.base.llms.types import (
ChatMessage,
ChatResponseGen,
CompletionResponse,
CompletionResponseGen,
LLMMetadata,
)
from llama_index.core.callbacks import CallbackManager
from llama_index.core.llms.callbacks import llm_completion_callback
from llama_index.core.llms.callbacks import llm_chat_callback, llm_completion_callback
from llama_index.core.llms.custom import CustomLLM
from llama_index.core.types import PydanticProgramMode

Expand Down Expand Up @@ -76,3 +76,11 @@ def gen_response(max_tokens: int) -> CompletionResponseGen:
)

return gen_response(self.max_tokens) if self.max_tokens else gen_prompt()


class MockLLMWithNonyieldingChatStream(MockLLM):
@llm_chat_callback()
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
yield from []
26 changes: 26 additions & 0 deletions llama-index-core/tests/llms/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import pytest
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.llms.llm import LLM
from llama_index.core.llms.mock import MockLLM
from llama_index.core.llms.mock import MockLLMWithNonyieldingChatStream


@pytest.fixture()
def nonyielding_llm() -> LLM:
return MockLLMWithNonyieldingChatStream()


@pytest.fixture()
Expand All @@ -13,6 +20,25 @@ def prompt() -> str:
return "test prompt"


def test_llm_stream_chat_handles_nonyielding_stream(
nonyielding_llm: LLM, prompt: str
) -> None:
response = nonyielding_llm.stream_chat([ChatMessage(role="user", content=prompt)])
for _ in response:
pass


@pytest.mark.asyncio()
async def test_llm_astream_chat_handles_nonyielding_stream(
nonyielding_llm: LLM, prompt: str
) -> None:
response = await nonyielding_llm.astream_chat(
[ChatMessage(role="user", content=prompt)]
)
async for _ in response:
pass


def test_llm_complete_prompt_arg(llm: LLM, prompt: str) -> None:
res = llm.complete(prompt)
expected_res_text = prompt
Expand Down

0 comments on commit 6011233

Please sign in to comment.