Skip to content

Commit

Permalink
Follow up to #1453: allow user roles when normalizing a dictionary (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
cpsievert committed Jul 3, 2024
1 parent 36a5308 commit 46f7ee4
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 10 deletions.
16 changes: 6 additions & 10 deletions shiny/ui/_chat_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,15 @@ def can_normalize_chunk(self, chunk: Any) -> bool:

class DictNormalizer(BaseMessageNormalizer):
def normalize(self, message: Any) -> ChatMessage:
x = self._check_dict(message)
x = cast("dict[str, Any]", message)
if "content" not in x:
raise ValueError("Message must have 'content' key")
return ChatMessage(content=x["content"], role=x.get("role", "assistant"))

def normalize_chunk(self, chunk: Any) -> ChatMessage:
x = self._check_dict(chunk)
x = cast("dict[str, Any]", chunk)
if "content" not in x:
raise ValueError("Message must have 'content' key")
return ChatMessage(content=x["content"], role=x.get("role", "assistant"))

def can_normalize(self, message: Any) -> bool:
Expand All @@ -70,14 +74,6 @@ def can_normalize(self, message: Any) -> bool:
def can_normalize_chunk(self, chunk: Any) -> bool:
return isinstance(chunk, dict)

@staticmethod
def _check_dict(x: Any) -> "dict[str, Any]":
if "content" not in x:
raise ValueError("Message must have 'content' key")
if "role" in x and x["role"] not in ["assistant", "system"]:
raise ValueError("Role must be 'assistant' or 'system")
return x


class LangChainNormalizer(BaseMessageNormalizer):
def normalize(self, message: Any) -> ChatMessage:
Expand Down
18 changes: 18 additions & 0 deletions tests/playwright/shiny/components/chat/append_user_msg/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from shiny import reactive
from shiny.express import render, ui

chat = ui.Chat(id="chat")
chat.ui()


@reactive.effect
async def _():
await chat.append_message({"content": "A user message", "role": "user"})


"chat.messages():"


@render.code
def message_state():
return str(chat.messages())
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from playwright.sync_api import Page, expect

from shiny.playwright import controller
from shiny.run import ShinyAppProc


def test_validate_chat_append_user_message(page: Page, local_app: ShinyAppProc) -> None:
page.goto(local_app.url)

chat = controller.Chat(page, "chat")

# Verify starting state
expect(chat.loc).to_be_visible()
chat.expect_latest_message("A user message")

# Verify that the message state is as expected
message_state = controller.OutputCode(page, "message_state")
message_state_expected = ({"content": "A user message", "role": "user"},)
message_state.expect_value(str(message_state_expected))

0 comments on commit 46f7ee4

Please sign in to comment.