In [1]:
# | default_exp fasthtml.chat

In [2]:
# | export
from torch_snippets import AD
from fasthtml.common import *

In [3]:
from torch_snippets import randint
from fasthtml.jupyter import *
from fasthtml.common import *

render_ft()

app = FastHTML(exts="ws")
rt = app.route

prt = 10979
server = JupyUvi(app, port=prt, log_level="info")
A("visit site", href=f"http://localhost:{prt}")

INFO:     Started server process [5085]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:10979 (Press CTRL+C to quit)
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:10979 (Press CTRL+C to quit)


<div>
<a href="http://localhost:10979">visit site</a><script>if (window.htmx) htmx.process(document.body)</script></div>


INFO:     127.0.0.1:51328 - "GET / HTTP/1.1" 200 OK
INFO:     127.0.0.1:51328 - "POST /chat/add HTTP/1.1" 200 OK
INFO:     127.0.0.1:51328 - "POST /chat/add HTTP/1.1" 200 OK
INFO:     127.0.0.1:51328 - "GET /chat/result?cid=chat-18d63b1d&ph_id=ph-6ed95d7c HTTP/1.1" 200 OK
INFO:     127.0.0.1:51328 - "GET /chat/result?cid=chat-18d63b1d&ph_id=ph-6ed95d7c HTTP/1.1" 200 OK


In [4]:
# | export
from typing import List, Dict, Union, Optional, Callable, Awaitable, Any

import json
import uuid
from pathlib import Path
import inspect
import asyncio

In [5]:
# | export

style_template = """
background: {bg};
text-align: {align};
padding: 10px;
"""

styles = AD(
    css={
        "user": style_template.format(bg="#EEE6CA", align="right"),
        "assistant": style_template.format(bg="#E5BEB5", align="left"),
    },
    emojis={"user": "🗣️", "assistant": "🐕‍🦺"},
)


def Card(data: dict):
    role = data["role"]
    content = data.get("content", "")
    pending = data.get("pending", False)

    style = styles.css[role]
    emoji = styles.emojis[role]
    if pending:
        spinner = Span(cls="spinner")
        card = Div(f"{emoji} {role}: ", spinner, style=style)
    else:
        card = Div(f"{emoji} {role}: {content}", style=style)

    return card


def Cards(data: List[dict], **kwargs):
    cards = [Card(d) for d in data]
    return Div(*cards, **kwargs)

In [None]:
# | export
# Registry to map component instances by id
_CHAT_REGISTRY: Dict[str, "ChatInterface"] = {}


class ChatInterface:
    """Simple chat viewer+persistence using FastHTML Cards with optional input to add messages.

    - Accepts messages as list[str] alternating user/assistant, or list[dict] with
      keys {"role","content"}.
    - Renders with `Cards`.
    - When a router (`rt`) is available, the component auto-registers POST endpoints
      to add a message and manage async responses.
    - Form is horizontally centered; role select can be optionally hidden.
    - Optionally accepts a responder callable: (str) -> str | Awaitable[str]. When provided and a user
      message is added via the route, an assistant placeholder appears immediately and will be
      replaced with the computed reply.
    """

    _route_registered: bool = False
    _route_path: str = "/chat/add"
    _result_path: str = "/chat/result"  # polls to replace a pending placeholder

    def __init__(
        self,
        messages: Union[List[str], List[Dict[str, str]], None] = None,
        _id: Optional[str] = None,
        register: bool = True,
        router=None,  # optional explicit router; falls back to global `rt`
        hide_role_select: bool = False,
        responder: Optional[
            Callable[[str], Union[str, Awaitable[str]]]
        ] = None,  # user-provided responder
    ):
        self.messages: List[Dict[str, str]] = self._normalize(messages or [])
        self._id = _id or f"chat-{uuid.uuid4().hex[:8]}"
        self._list_id = f"chatlist-{self._id}"
        self._next_role = "user"  # default for hidden role when select is hidden
        self.hide_role_select = hide_role_select
        self.responder = responder
        # Track pending replies: placeholder_id -> asyncio.Task
        self._pending: Dict[str, asyncio.Task] = {}
        _CHAT_REGISTRY[self._id] = self
        if register:
            self.register_route(router)

    # --- routing ---
    @classmethod
    def register_route(cls, router=None):
        if cls._route_registered:
            return
        if router is None:
            router = globals().get("rt", None)
        if router is None:
            return

        async def _chat_add(role: str, content: str, cid: str):
            ci = _CHAT_REGISTRY.get(cid)
            if not ci:
                return P("Invalid chat id", cls="error")
            # Add the posted message first
            first_msg = ci.add_message(role, content)
            if not first_msg:
                return Span()
            first_card = Card(first_msg)

            # If user message and we have a responder, append a placeholder and start a task
            if role == "user" and ci.responder:
                ph_id = f"ph-{uuid.uuid4().hex[:8]}"
                # Insert a pending assistant message into state
                ci.messages.append({
                    "role": "assistant",
                    "content": "",
                    "pending": True,
                    "id": ph_id,
                })
                # A self-replacing placeholder that polls for result and swaps itself when ready
                placeholder = Div(
                    Card({"role": "assistant", "pending": True}),
                    id=ph_id,
                    hx_get=cls._result_path,
                    hx_vals={"cid": cid, "ph_id": ph_id},
                    hx_trigger="load, every 1s",
                    hx_swap="outerHTML",
                )

                async def run_and_store():
                    try:
                        res = ci.responder(content)
                        reply_text = await res if inspect.isawaitable(res) else res
                    except Exception as e:
                        reply_text = f"Responder error: {e}"
                    # Replace the placeholder in state
                    for m in ci.messages:
                        if m.get("id") == ph_id:
                            m.pop("pending", None)
                            m["content"] = str(reply_text)
                            break
                    # Mark done
                    ci._pending.pop(ph_id, None)

                ci._pending[ph_id] = asyncio.create_task(run_and_store())
                # Return both user card and placeholder wrapped as siblings
                return Div(first_card, placeholder, style="display: contents")

            # Otherwise just return the single card
            return first_card

        async def _chat_result(cid: str, ph_id: str):
            ci = _CHAT_REGISTRY.get(cid)
            if not ci:
                return P("Invalid chat id", cls="error")
            task = ci._pending.get(ph_id)
            if task and not task.done():
                # Not ready yet; return nothing to keep polling
                return Span()
            # Task done or missing; render the resolved assistant card content
            msg = next((m for m in ci.messages if m.get("id") == ph_id), None)
            if not msg:
                return Span()
            # Return replacement content for the placeholder div
            return Card({"role": "assistant", "content": msg.get("content", "")})

        router(cls._route_path)(_chat_add)
        router(cls._result_path)(_chat_result)
        cls._route_registered = True

    # --- data helpers ---
    @staticmethod
    def _normalize(
        messages: Union[List[str], List[Dict[str, str]]],
    ) -> List[Dict[str, str]]:
        if not messages:
            return []
        first = messages[0]
        if isinstance(first, str):
            roles = ("user", "assistant")
            return [
                {"role": roles[i % 2], "content": m} for i, m in enumerate(messages)
            ]
        # assume list of dicts; validate minimally
        out: List[Dict[str, str]] = []
        for m in messages:
            if not isinstance(m, dict) or "role" not in m or "content" not in m:
                raise ValueError(
                    "Each message must be a str or dict with 'role' and 'content'"
                )
            r = m["role"]
            if r not in ("user", "assistant"):
                raise ValueError("role must be 'user' or 'assistant'")
            out.append({"role": r, "content": str(m["content"])})
        return out

    def add_message(self, role: str, content: str) -> Optional[Dict[str, str]]:
        content = (content or "").strip()
        if not content and role != "assistant":
            return None
        if role not in ("user", "assistant"):
            role = "user"
        msg: Dict[str, Any] = {"role": role, "content": content}
        self.messages.append(msg)
        return msg

    # --- rendering ---
    def __ft__(self):
        # Try to register at render time too, in case `rt` became available later
        self.register_route()

        # Chat list container
        chat_list = Div(
            *[Card(d) for d in self.messages],
            id=self._list_id,
            cls="chat-cards",
        )

        # Only show form if the routes have been registered
        if ChatInterface._route_registered:
            # Role widget: either a Select (with default) or a hidden input
            if self.hide_role_select:
                role_widget = Input(type="hidden", name="role", value=self._next_role)
            else:
                role_widget = Select(
                    Option("user", value="user", selected=True),
                    Option("assistant", value="assistant"),
                    name="role",
                )

            # Reset after request; indicator while waiting
            indicator = Div(
                "…",
                id=f"ind-{self._id}",
                cls="htmx-indicator",
                style="margin-left: 0.5rem;",
            )
            hx_on_val = "htmx:afterRequest: this.reset()"

            form = Form(
                Input(type="hidden", name="cid", value=self._id),
                role_widget,
                Input(name="content", placeholder="Type a message...", required=True),
                Button("Send", type="submit"),
                indicator,
                hx_post=ChatInterface._route_path,
                hx_target=f"#{self._list_id}",
                hx_swap="beforeend",
                hx_on=hx_on_val,
                hx_indicator=f"#ind-{self._id}",
                cls="chat-input",
            )
            # Center the form horizontally
            form_row = Div(
                form,
                style="display:flex; justify-content:center; align-items:center; gap: 0.25rem; margin-top: 0.5rem;",
            )
            return Div(chat_list, form_row, id=self._id)
        else:
            return Div(chat_list, id=self._id)

    # --- persistence ---
    def dump(self, fpath: Union[str, Path]) -> Path:
        p = Path(fpath).expanduser()
        p.write_text(json.dumps(self.messages, ensure_ascii=False, indent=2))
        return p

    @classmethod
    def load(cls, fpath: Union[str, Path]) -> "ChatInterface":
        p = Path(fpath).expanduser()
        msgs = json.loads(p.read_text())
        return cls(msgs)

In [7]:
messages = ["hi", "how are you", "I'm fine", "thanks"]


def echo_responder(text: str) -> str:
    return f"echo: {text}"


c = ChatInterface(messages, hide_role_select=True, responder=echo_responder)
# display the interactive component (chat list + input form)
c

In [8]:
def demo_responder(t: str) -> str:
    import time

    time.sleep(3)  # simulate a long-running task
    return f"user said: {t}"


@rt("/")
def get():
    return ChatInterface(
        [
            "hi",
            "how are you",
            "i am fine, how are you",
            "i am doing well too!",
        ],
        responder=demo_responder,
    )

In [9]:
# (Optional) You can still define other routes here if needed.
# The ChatInterface now handles assistant replies internally via the responder callable.