diff --git a/shiny/ui/_chat_normalize.py b/shiny/ui/_chat_normalize.py index 85da7ed9c..2c8612878 100644 --- a/shiny/ui/_chat_normalize.py +++ b/shiny/ui/_chat_normalize.py @@ -57,11 +57,15 @@ def can_normalize_chunk(self, chunk: Any) -> bool: class DictNormalizer(BaseMessageNormalizer): def normalize(self, message: Any) -> ChatMessage: - x = self._check_dict(message) + x = cast("dict[str, Any]", message) + if "content" not in x: + raise ValueError("Message must have 'content' key") return ChatMessage(content=x["content"], role=x.get("role", "assistant")) def normalize_chunk(self, chunk: Any) -> ChatMessage: - x = self._check_dict(chunk) + x = cast("dict[str, Any]", chunk) + if "content" not in x: + raise ValueError("Message must have 'content' key") return ChatMessage(content=x["content"], role=x.get("role", "assistant")) def can_normalize(self, message: Any) -> bool: @@ -70,14 +74,6 @@ def can_normalize(self, message: Any) -> bool: def can_normalize_chunk(self, chunk: Any) -> bool: return isinstance(chunk, dict) - @staticmethod - def _check_dict(x: Any) -> "dict[str, Any]": - if "content" not in x: - raise ValueError("Message must have 'content' key") - if "role" in x and x["role"] not in ["assistant", "system"]: - raise ValueError("Role must be 'assistant' or 'system") - return x - class LangChainNormalizer(BaseMessageNormalizer): def normalize(self, message: Any) -> ChatMessage: diff --git a/tests/playwright/shiny/components/chat/append_user_msg/app.py b/tests/playwright/shiny/components/chat/append_user_msg/app.py new file mode 100644 index 000000000..835ce7204 --- /dev/null +++ b/tests/playwright/shiny/components/chat/append_user_msg/app.py @@ -0,0 +1,18 @@ +from shiny import reactive +from shiny.express import render, ui + +chat = ui.Chat(id="chat") +chat.ui() + + +@reactive.effect +async def _(): + await chat.append_message({"content": "A user message", "role": "user"}) + + +"chat.messages():" + + +@render.code +def message_state(): + return str(chat.messages()) diff --git a/tests/playwright/shiny/components/chat/append_user_msg/test_chat_append_user_msg.py b/tests/playwright/shiny/components/chat/append_user_msg/test_chat_append_user_msg.py new file mode 100644 index 000000000..1d4e614e6 --- /dev/null +++ b/tests/playwright/shiny/components/chat/append_user_msg/test_chat_append_user_msg.py @@ -0,0 +1,19 @@ +from playwright.sync_api import Page, expect + +from shiny.playwright import controller +from shiny.run import ShinyAppProc + + +def test_validate_chat_append_user_message(page: Page, local_app: ShinyAppProc) -> None: + page.goto(local_app.url) + + chat = controller.Chat(page, "chat") + + # Verify starting state + expect(chat.loc).to_be_visible() + chat.expect_latest_message("A user message") + + # Verify that the message state is as expected + message_state = controller.OutputCode(page, "message_state") + message_state_expected = ({"content": "A user message", "role": "user"},) + message_state.expect_value(str(message_state_expected))