diff --git a/.gitignore b/.gitignore index 740f3993c..a06bad8c9 100644 --- a/.gitignore +++ b/.gitignore @@ -268,6 +268,9 @@ renv.lock # Planning documents (local only) docs/plans/ +# Screenshot capture script (local only) +pkg-py/docs/_screenshots/ + # Playwright MCP .playwright-mcp/ diff --git a/CLAUDE.md b/CLAUDE.md index 4170fa304..98873f00e 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -69,13 +69,15 @@ make py-build make py-docs ``` -Before finishing your implementation or committing any code, you should run: +Before committing any Python code, you must run all three checks and confirm they pass: ```bash uv run ruff check --fix pkg-py --config pyproject.toml +make py-check-types +make py-check-tests ``` -To get help with making sure code adheres to project standards. +Do not commit or push until all three pass. ### R Package diff --git a/pkg-py/CHANGELOG.md b/pkg-py/CHANGELOG.md index a02fe5f68..ff0e3f084 100644 --- a/pkg-py/CHANGELOG.md +++ b/pkg-py/CHANGELOG.md @@ -9,16 +9,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### New features +* Added a `"visualize"` tool that lets the LLM create inline Altair charts from natural language requests using [ggsql](https://github.com/posit-dev/ggsql) — a SQL extension for declarative data visualization. Include it via `tools=("query", "visualize")` (or alongside `"update"`). Charts render inline in the chat with fullscreen support, a "Show Query" toggle, and Save as PNG/SVG. Install the optional dependencies with `pip install querychat[viz]`. (#219) + +* The `querychat_query` tool now accepts an optional `collapsed` parameter. When `collapsed=True`, the result card starts collapsed so preparatory or exploratory queries don't clutter the conversation. The LLM is guided to use this automatically when running queries before a visualization. + +* Added support for Snowflake Semantic Views. When connected to Snowflake (via SQLAlchemy or Ibis), querychat automatically discovers available Semantic Views and includes their definitions in the system prompt. This helps the LLM generate correct queries using the `SEMANTIC_VIEW()` table function with certified business metrics and dimensions. (#200) + * `QueryChat()` now supports deferred chat client initialization. Pass `client=` to `server()` to provide a session-scoped chat client, enabling use cases where API credentials are only available at session time (e.g., Posit Connect managed OAuth tokens). When no `client` is specified anywhere, querychat resolves a sensible default from the `QUERYCHAT_CLIENT` environment variable (or `"openai"`). (#205) ### Improvements * When a custom `prompt_template` is provided that doesn't contain Mustache references to `{{schema}}`, the expensive `get_schema()` call is now skipped entirely. This allows users with large databases to avoid slow startup by providing their own prompt that includes schema information inline (or omits it). (#208) -### New features - -* Added support for Snowflake Semantic Views. When connected to Snowflake (via SQLAlchemy or Ibis), querychat automatically discovers available Semantic Views and includes their definitions in the system prompt. This helps the LLM generate correct queries using the `SEMANTIC_VIEW()` table function with certified business metrics and dimensions. (#200) - ## [0.5.1] - 2026-01-23 ### New features diff --git a/pkg-py/docs/_quarto.yml b/pkg-py/docs/_quarto.yml index df2576e49..927859ba4 100644 --- a/pkg-py/docs/_quarto.yml +++ b/pkg-py/docs/_quarto.yml @@ -50,6 +50,7 @@ website: - models.qmd - data-sources.qmd - context.qmd + - visualize.qmd - section: "Build custom apps" contents: - build-intro.qmd @@ -114,6 +115,8 @@ quartodoc: signature_name: short - name: tools.tool_reset_dashboard signature_name: short + - name: tools.tool_visualize + signature_name: short filters: - "interlinks" diff --git a/pkg-py/docs/build.qmd b/pkg-py/docs/build.qmd index 009f6cfd0..460d36cc7 100644 --- a/pkg-py/docs/build.qmd +++ b/pkg-py/docs/build.qmd @@ -31,6 +31,14 @@ from querychat.data import titanic qc = QueryChat(titanic(), "titanic") ``` +::: {.callout-tip} +### Visualization support + +querychat supports an optional visualization tool that lets the LLM create inline charts. +Enable it by including `"visualize"` in the `tools` parameter. +See [Visualizations](visualize.qmd) for details. +::: + ::: {.callout-note collapse="true"} ## Quick start with `.app()` diff --git a/pkg-py/docs/images/viz-bar-chart.png b/pkg-py/docs/images/viz-bar-chart.png new file mode 100644 index 000000000..0a7033651 Binary files /dev/null and b/pkg-py/docs/images/viz-bar-chart.png differ diff --git a/pkg-py/docs/images/viz-fullscreen.png b/pkg-py/docs/images/viz-fullscreen.png new file mode 100644 index 000000000..fbefca3fb Binary files /dev/null and b/pkg-py/docs/images/viz-fullscreen.png differ diff --git a/pkg-py/docs/images/viz-scatter.png b/pkg-py/docs/images/viz-scatter.png new file mode 100644 index 000000000..db25bfe2a Binary files /dev/null and b/pkg-py/docs/images/viz-scatter.png differ diff --git a/pkg-py/docs/images/viz-show-query.png b/pkg-py/docs/images/viz-show-query.png new file mode 100644 index 000000000..fc9ae6384 Binary files /dev/null and b/pkg-py/docs/images/viz-show-query.png differ diff --git a/pkg-py/docs/index.qmd b/pkg-py/docs/index.qmd index 88b0da5c5..8d119ae3b 100644 --- a/pkg-py/docs/index.qmd +++ b/pkg-py/docs/index.qmd @@ -75,6 +75,11 @@ querychat can also handle more general questions about the data that require cal ![](/images/quickstart-summary.png){fig-alt="Screenshot of the querychat's app with a summary statistic inlined in the chat." class="lightbox shadow rounded mb-3"} +querychat can also create visualizations, powered by [ggsql](https://ggsql.org/) and [Altair](https://altair-viz.github.io/). +With the [visualization tool](visualize.qmd) enabled, ask for a chart and it appears inline in the conversation: + +![](/images/viz-bar-chart.png){fig-alt="Screenshot of querychat with an inline bar chart showing survival rate by passenger class." class="lightbox shadow rounded mb-3"} + ## Web frameworks While the examples above use [Shiny](https://shiny.posit.co/py/), querychat also supports [Streamlit](https://streamlit.io/), [Gradio](https://gradio.app/), and [Dash](https://dash.plotly.com/). Each framework has its own `QueryChat` class under the relevant sub-module, but the methods and properties are mostly consistent across all of them. diff --git a/pkg-py/docs/tools.qmd b/pkg-py/docs/tools.qmd index e438e1bde..44301f1d4 100644 --- a/pkg-py/docs/tools.qmd +++ b/pkg-py/docs/tools.qmd @@ -6,7 +6,7 @@ querychat combines [tool calling](https://posit-dev.github.io/chatlas/get-starte One important thing to understand generally about querychat's tools is they are Python functions, and that execution happens on _your machine_, not on the LLM provider's side. In other words, the SQL queries generated by the LLM are executed locally in the Python process running the app. -querychat provides the LLM access two tool groups: +querychat provides the LLM access to three tool groups: 1. **Data updating** - Filter and sort data (without sending results to the LLM). 2. **Data analysis** - Calculate summaries and return results for interpretation by the LLM. @@ -52,6 +52,40 @@ app = qc.app() ![](/images/quickstart-summary.png){fig-alt="Screenshot of the querychat's app with a summary statistic inlined in the chat." class="lightbox shadow rounded mb-3"} +## Data visualization + +When a user asks for a chart or visualization, the LLM generates a [ggsql](https://ggsql.org/) query — standard SQL extended with a `VISUALISE` clause — and requests a call to the `visualize` tool. +This tool: + +1. Executes the SQL portion of the query +2. Renders the `VISUALISE` clause as an Altair chart +3. Displays the chart inline in the chat + +Unlike the data updating tools, visualization queries don't affect the dashboard filter. +They query the full dataset independently, and each call produces a new inline chart message in the chat. + +The inline chart includes controls for fullscreen viewing, saving as PNG/SVG, and a "Show Query" toggle that reveals the underlying ggsql code. + +To use the visualization tool, first install the `viz` extras: + +```bash +pip install "querychat[viz]" +``` + +Then include `"visualize"` in the `tools` parameter (it is not enabled by default): + +```{.python filename="viz-app.py"} +from querychat import QueryChat +from querychat.data import titanic + +qc = QueryChat(titanic(), "titanic", tools=("query", "update", "visualize")) +app = qc.app() +``` + +![](/images/viz-scatter.png){fig-alt="Screenshot of querychat with an inline scatter plot." class="lightbox shadow rounded mb-3"} + +See [Visualizations](visualize.qmd) for more details. + ## View the source If you'd like to better understand how the tools work and how the LLM is prompted to use them, check out the following resources: @@ -65,3 +99,4 @@ If you'd like to better understand how the tools work and how the LLM is prompte - [`prompts/tool-update-dashboard.md`](https://github.com/posit-dev/querychat/blob/main/pkg-py/src/querychat/prompts/tool-update-dashboard.md) - [`prompts/tool-reset-dashboard.md`](https://github.com/posit-dev/querychat/blob/main/pkg-py/src/querychat/prompts/tool-reset-dashboard.md) - [`prompts/tool-query.md`](https://github.com/posit-dev/querychat/blob/main/pkg-py/src/querychat/prompts/tool-query.md) +- [`prompts/tool-visualize.md`](https://github.com/posit-dev/querychat/blob/main/pkg-py/src/querychat/prompts/tool-visualize.md) diff --git a/pkg-py/docs/visualize.qmd b/pkg-py/docs/visualize.qmd new file mode 100644 index 000000000..e8d720b68 --- /dev/null +++ b/pkg-py/docs/visualize.qmd @@ -0,0 +1,104 @@ +--- +title: Visualizations +lightbox: true +--- + +querychat can create charts inline in the chat. +When you ask a question that benefits from a visualization, the LLM writes a query using [ggsql](https://ggsql.org/) — a SQL-like visualization grammar — and renders an [Altair](https://altair-viz.github.io/) chart directly in the conversation. + +## Getting started + +Visualization requires two steps: + +1. **Install the `viz` extras:** + + ```bash + pip install "querychat[viz]" + ``` + +2. **Include `"visualize"` in the `tools` parameter:** + + ```{.python filename="app.py"} + from querychat import QueryChat + from querychat.data import titanic + + qc = QueryChat(titanic(), "titanic", tools=("query", "update", "visualize")) + app = qc.app() + ``` + +Ask something like "Show me survival rate by passenger class as a bar chart" and querychat will generate and display the chart inline: + +![](/images/viz-bar-chart.png){fig-alt="Bar chart showing survival rate by passenger class." class="lightbox shadow rounded mb-3"} + +## Choosing tools + +The `tools` parameter controls which capabilities the LLM has access to. +By default, only `"query"` and `"update"` are enabled — visualization must be opted into explicitly. + +To enable only query and visualization (no dashboard filtering): + +```{.python} +qc = QueryChat(titanic(), "titanic", tools=("query", "visualize")) +``` + +See [Tools](tools.qmd) for a full reference on available tools and what each one does. + +## Custom apps + +The example below shows a minimal custom Shiny app using only the `"query"` and `"visualize"` tools. +It omits `"update"` to focus entirely on data analysis and visualization rather than dashboard filtering: + +```{.python filename="app.py"} +{{< include /../examples/10-viz-app.py >}} +``` + +## What you can ask for + +querychat can generate a wide range of chart types. +Some example prompts: + +- "Show me a bar chart of survival rate by passenger class" +- "Scatter plot of age vs fare, colored by survival" +- "Line chart of average fare over time" +- "Histogram of passenger ages" +- "Facet survival rate by class and sex" + +The LLM chooses an appropriate chart type based on your question, but you can always be specific. +If you ask for a bar chart, you'll get a bar chart. + +![](/images/viz-scatter.png){fig-alt="Scatter plot of age vs fare colored by survival status." class="lightbox shadow rounded mb-3"} + +::: {.callout-tip} +If you don't like the chart, ask the LLM to adjust it — for example, "make the dots bigger" or "use a log scale on the y-axis". +::: + +## Chart controls + +Each chart has controls in its footer: + +**Fullscreen** — Click the expand icon to view the chart in fullscreen mode. + +![](/images/viz-fullscreen.png){fig-alt="A chart displayed in fullscreen mode." class="lightbox shadow rounded mb-3"} + +**Save** — Download the chart as a PNG or SVG file. + +**Show Query** — Expand the footer to see the ggsql query used to generate the chart. + +![](/images/viz-show-query.png){fig-alt="A chart with the Show Query footer expanded, showing the ggsql query." class="lightbox shadow rounded mb-3"} + +## How it works + +1. **The LLM generates a ggsql query** — a SQL-like grammar that describes both data transformation and visual encoding in a single statement. +2. **The SQL is executed** — querychat runs the data portion of the query against your data source locally. +3. **The VISUALISE clause is rendered** — the result is passed to Altair, which produces a Vega-Lite chart specification. +4. **The chart appears inline** — the chart is streamed back into the conversation as an interactive widget. + +Note that visualization queries are independent of any active dashboard filter set by the `update` tool. +They always run against the full dataset. + +Learn more about the ggsql grammar at [ggsql.org](https://ggsql.org/). + +## See also + +- [Tools](tools.qmd) — Understand what querychat can do under the hood +- [Provide context](context.qmd) — Help the LLM understand your data better diff --git a/pkg-py/examples/10-viz-app.py b/pkg-py/examples/10-viz-app.py new file mode 100644 index 000000000..fe9ef6dc8 --- /dev/null +++ b/pkg-py/examples/10-viz-app.py @@ -0,0 +1,17 @@ +from querychat.express import QueryChat +from querychat.data import titanic + +from shiny.express import ui, app_opts + +# Omits "update" tool — this demo focuses on query + visualization only +qc = QueryChat( + titanic(), + "titanic", + tools=("query", "visualize") +) + +qc.ui() + +ui.page_opts(fillable=True, title="QueryChat Visualization Demo") + +app_opts(bookmark_store="url") diff --git a/pkg-py/src/querychat/_icons.py b/pkg-py/src/querychat/_icons.py index 2b7683da0..fc484c9c0 100644 --- a/pkg-py/src/querychat/_icons.py +++ b/pkg-py/src/querychat/_icons.py @@ -2,19 +2,35 @@ from shiny import ui -ICON_NAMES = Literal["arrow-counterclockwise", "funnel-fill", "terminal-fill", "table"] +ICON_NAMES = Literal[ + "arrow-counterclockwise", + "bar-chart-fill", + "chevron-down", + "download", + "funnel-fill", + "graph-up", + "terminal-fill", + "table", +] -def bs_icon(name: ICON_NAMES) -> ui.HTML: +def bs_icon(name: ICON_NAMES, cls: str = "") -> ui.HTML: """Get Bootstrap icon SVG by name.""" if name not in BS_ICONS: raise ValueError(f"Unknown Bootstrap icon: {name}") - return ui.HTML(BS_ICONS[name]) + svg = BS_ICONS[name] + if cls: + svg = svg.replace('class="', f'class="{cls} ', 1) + return ui.HTML(svg) BS_ICONS = { "arrow-counterclockwise": '', + "bar-chart-fill": '', + "chevron-down": '', + "download": '', "funnel-fill": '', + "graph-up": '', "terminal-fill": '', "table": '', } diff --git a/pkg-py/src/querychat/_querychat_base.py b/pkg-py/src/querychat/_querychat_base.py index d3bf29e26..ffa325aed 100644 --- a/pkg-py/src/querychat/_querychat_base.py +++ b/pkg-py/src/querychat/_querychat_base.py @@ -23,11 +23,13 @@ from ._shiny_module import GREETING_PROMPT from ._system_prompt import QueryChatSystemPrompt from ._utils import MISSING, MISSING_TYPE, is_ibis_table +from ._viz_utils import has_viz_deps, has_viz_tool from .tools import ( UpdateDashboardData, tool_query, tool_reset_dashboard, tool_update_dashboard, + tool_visualize, ) if TYPE_CHECKING: @@ -35,8 +37,10 @@ from narwhals.stable.v1.typing import IntoFrame -TOOL_GROUPS = Literal["update", "query"] + from ._viz_tools import VisualizeData +TOOL_GROUPS = Literal["update", "query", "visualize"] +DEFAULT_TOOLS: tuple[TOOL_GROUPS, ...] = ("update", "query") class QueryChatBase(Generic[IntoFrameT]): """ @@ -58,7 +62,7 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -72,7 +76,7 @@ def __init__( "Table name must begin with a letter and contain only letters, numbers, and underscores", ) - self.tools = normalize_tools(tools, default=("update", "query")) + self.tools = normalize_tools(tools, default=DEFAULT_TOOLS) self.greeting = greeting.read_text() if isinstance(greeting, Path) else greeting # Store init parameters for deferred system prompt building @@ -128,6 +132,7 @@ def _create_session_client( tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None | MISSING_TYPE = MISSING, update_dashboard: Callable[[UpdateDashboardData], None] | None = None, reset_dashboard: Callable[[], None] | None = None, + visualize: Callable[[VisualizeData], None] | None = None, ) -> chatlas.Chat: """Create a fresh, fully-configured Chat.""" spec = self._client_spec if isinstance(client_spec, MISSING_TYPE) else client_spec @@ -152,6 +157,10 @@ def _create_session_client( if "query" in resolved_tools: chat.register_tool(tool_query(data_source)) + if "visualize" in resolved_tools: + viz_fn = visualize or (lambda _: None) + chat.register_tool(tool_visualize(data_source, viz_fn)) + return chat def client( @@ -160,6 +169,7 @@ def client( tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None | MISSING_TYPE = MISSING, update_dashboard: Callable[[UpdateDashboardData], None] | None = None, reset_dashboard: Callable[[], None] | None = None, + visualize: Callable[[VisualizeData], None] | None = None, ) -> chatlas.Chat: """ Create a chat client with registered tools. @@ -167,11 +177,14 @@ def client( Parameters ---------- tools - Which tools to include: `"update"`, `"query"`, or both. + Which tools to include: `"update"`, `"query"`, `"visualize"`, + or a combination. update_dashboard Callback when update_dashboard tool succeeds. reset_dashboard Callback when reset_dashboard tool is invoked. + visualize + Callback when visualize tool succeeds. Returns ------- @@ -184,6 +197,7 @@ def client( tools=tools, update_dashboard=update_dashboard, reset_dashboard=reset_dashboard, + visualize=visualize, ) def generate_greeting(self, *, echo: Literal["none", "output"] = "none") -> str: @@ -293,14 +307,24 @@ def create_client(client: str | chatlas.Chat | None) -> chatlas.Chat: def normalize_tools( tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None | MISSING_TYPE, default: tuple[TOOL_GROUPS, ...] | None, + *, + check_deps: bool = True, ) -> tuple[TOOL_GROUPS, ...] | None: if tools is None or tools == (): - return None + result = None elif isinstance(tools, MISSING_TYPE): - return default + result = default elif isinstance(tools, str): - return (tools,) + result = (tools,) elif isinstance(tools, tuple): - return tools + result = tools else: - return tuple(tools) + result = tuple(tools) + if not check_deps: + return result + if has_viz_tool(result) and not has_viz_deps(): + raise ImportError( + "Visualization tools require ggsql, altair, shinywidgets, and " + "vl-convert-python. Install them with: pip install querychat[viz]" + ) + return result diff --git a/pkg-py/src/querychat/_querychat_core.py b/pkg-py/src/querychat/_querychat_core.py index af0685e01..1dd132631 100644 --- a/pkg-py/src/querychat/_querychat_core.py +++ b/pkg-py/src/querychat/_querychat_core.py @@ -165,6 +165,8 @@ def format_tool_result(result: ContentToolResult) -> str: return str(result) + + def format_query_error(e: Exception) -> str: """Format a query error with helpful guidance.""" error_msg = str(e).lower() diff --git a/pkg-py/src/querychat/_shiny.py b/pkg-py/src/querychat/_shiny.py index c25b923fc..f915e79f4 100644 --- a/pkg-py/src/querychat/_shiny.py +++ b/pkg-py/src/querychat/_shiny.py @@ -10,9 +10,10 @@ from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui from ._icons import bs_icon -from ._querychat_base import TOOL_GROUPS, QueryChatBase +from ._querychat_base import DEFAULT_TOOLS, TOOL_GROUPS, QueryChatBase from ._shiny_module import ServerValues, mod_server, mod_ui from ._utils import MISSING, MISSING_TYPE, as_narwhals +from ._viz_utils import has_viz_tool if TYPE_CHECKING: from pathlib import Path @@ -97,10 +98,11 @@ class QueryChat(QueryChatBase[IntoFrameT]): tools Which querychat tools to include in the chat client by default. Can be: - A single tool string: `"update"` or `"query"` - - A tuple of tools: `("update", "query")` + - A tuple of tools: `("update", "query", "visualize")` - `None` or `()` to disable all tools - Default is `("update", "query")` (both tools enabled). + Default is `("update", "query")`. The visualization tool (`"visualize"`) + can be opted into by including it in the tuple. Set to `"update"` to prevent the LLM from accessing data values, only allowing dashboard filtering without answering questions. @@ -156,7 +158,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -172,7 +174,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -188,7 +190,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -204,7 +206,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -219,7 +221,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -245,7 +247,7 @@ def app( """ Quickly chat with a dataset. - Creates a Shiny app with a chat sidebar and data table view -- providing a + Creates a Shiny app with a chat sidebar and data view -- providing a quick-and-easy way to start chatting with your data. Parameters @@ -301,6 +303,7 @@ def app_server(input: Inputs, output: Outputs, session: Session): greeting=self.greeting, client=self._create_session_client, enable_bookmarking=enable_bookmarking, + tools=self.tools, ) @render.text @@ -399,7 +402,7 @@ def ui(self, *, id: Optional[str] = None, **kwargs): A UI component. """ - return mod_ui(id or self.id, **kwargs) + return mod_ui(id or self.id, preload_viz=has_viz_tool(self.tools), **kwargs) def server( self, @@ -506,6 +509,7 @@ def create_session_client(**kwargs) -> chatlas.Chat: greeting=self.greeting, client=create_session_client, enable_bookmarking=enable_bookmarking, + tools=self.tools, ) @@ -616,6 +620,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -632,6 +637,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -648,6 +654,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -664,6 +671,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -680,6 +688,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -695,6 +704,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -714,6 +724,7 @@ def __init__( table_name, greeting=greeting, client=client, + tools=tools, data_description=data_description, categorical_threshold=categorical_threshold, extra_instructions=extra_instructions, @@ -743,6 +754,7 @@ def __init__( greeting=self.greeting, client=self._create_session_client, enable_bookmarking=enable, + tools=self.tools, ) def sidebar( @@ -804,7 +816,7 @@ def ui(self, *, id: Optional[str] = None, **kwargs): A UI component. """ - return mod_ui(id or self.id, **kwargs) + return mod_ui(id or self.id, preload_viz=has_viz_tool(self.tools), **kwargs) def df(self) -> IntoFrameT: """ diff --git a/pkg-py/src/querychat/_shiny_module.py b/pkg-py/src/querychat/_shiny_module.py index 4264285bd..7b568afa9 100644 --- a/pkg-py/src/querychat/_shiny_module.py +++ b/pkg-py/src/querychat/_shiny_module.py @@ -1,10 +1,9 @@ from __future__ import annotations -import copy import warnings from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Generic, Union +from typing import TYPE_CHECKING, Generic, TypedDict, Union import chatlas import shinychat @@ -13,7 +12,9 @@ from shiny import module, reactive, ui from ._querychat_core import GREETING_PROMPT -from .tools import tool_query, tool_reset_dashboard, tool_update_dashboard +from ._viz_altair_widget import AltairWidget +from ._viz_ggsql import execute_ggsql +from ._viz_utils import has_viz_tool, preload_viz_deps_server, preload_viz_deps_ui if TYPE_CHECKING: from collections.abc import Callable @@ -23,6 +24,8 @@ from shiny import Inputs, Outputs, Session from ._datasource import DataSource + from ._querychat_base import TOOL_GROUPS + from ._viz_tools import VisualizeData from .types import UpdateDashboardData ReactiveString = reactive.Value[str] @@ -30,11 +33,31 @@ ReactiveStringOrNone = reactive.Value[Union[str, None]] """A reactive string (or None) value.""" + +class VizWidgetEntry(TypedDict): + """A bookmarked visualization widget: enough state to re-register on restore.""" + + widget_id: str + ggsql: str + + CHAT_ID = "chat" +class _DeferredStubChatClient: + """Placeholder chat client for deferred stub sessions.""" + + def __getattr__(self, _name: str): + raise RuntimeError( + "Chat client is unavailable during stub session before data_source is set." + ) + + +ServerClient = chatlas.Chat | _DeferredStubChatClient + + @module.ui -def mod_ui(**kwargs): +def mod_ui(*, preload_viz: bool = False, **kwargs): css_path = Path(__file__).parent / "static" / "css" / "styles.css" js_path = Path(__file__).parent / "static" / "js" / "querychat.js" @@ -47,6 +70,7 @@ def mod_ui(**kwargs): ui.include_js(js_path), ), tag, + preload_viz_deps_ui() if preload_viz else None, ) @@ -76,18 +100,17 @@ class ServerValues(Generic[IntoFrameT]): `.title()`, or set it with `.title.set("...")`. Returns `None` if no title has been set. client - The session-specific chat client instance. This is a deep copy of the - base client configured for this specific session, containing the chat - history and tool registrations for this session only. This may be - `None` during stub sessions when the client depends on deferred, - session-scoped state. + Session chat client value. + For real sessions this is a `chatlas.Chat` created by the client + factory. For deferred stub sessions (where `data_source` is not set + yet), this is a placeholder client that raises when accessed. """ df: Callable[[], IntoFrameT] sql: ReactiveStringOrNone title: ReactiveStringOrNone - client: chatlas.Chat | None + client: ServerClient @module.server @@ -98,14 +121,39 @@ def mod_server( *, data_source: DataSource[IntoFrameT] | None, greeting: str | None, - client: chatlas.Chat | Callable, + client: Callable[..., chatlas.Chat], enable_bookmarking: bool, + tools: tuple[TOOL_GROUPS, ...] | None = None, ) -> ServerValues[IntoFrameT]: # Reactive values to store state sql = ReactiveStringOrNone(None) title = ReactiveStringOrNone(None) has_greeted = reactive.value[bool](False) # noqa: FBT003 + if not callable(client): + raise TypeError("mod_server() requires a callable client factory.") + + def update_dashboard(data: UpdateDashboardData): + sql.set(data["query"]) + title.set(data["title"]) + + def reset_dashboard(): + sql.set(None) + title.set(None) + + viz_widgets: list[VizWidgetEntry] = [] + + def on_visualize(data: VisualizeData): + viz_widgets.append({"widget_id": data["widget_id"], "ggsql": data["ggsql"]}) + + def build_chat_client() -> chatlas.Chat: + return client( + update_dashboard=update_dashboard, + reset_dashboard=reset_dashboard, + visualize=on_visualize, + tools=tools, + ) + # Short-circuit for stub sessions (e.g. 1st run of an Express app) # data_source may be None during stub session for deferred pattern if session.is_stub_session(): @@ -113,11 +161,15 @@ def mod_server( def _stub_df(): raise RuntimeError("RuntimeError: No current reactive context") + stub_client = ( + _DeferredStubChatClient() if data_source is None else build_chat_client() + ) + return ServerValues( df=_stub_df, sql=sql, title=title, - client=client if isinstance(client, chatlas.Chat) else None, + client=stub_client, ) # Real session requires data_source @@ -127,27 +179,11 @@ def _stub_df(): "Set it via the data_source property before users connect." ) - def update_dashboard(data: UpdateDashboardData): - sql.set(data["query"]) - title.set(data["title"]) - - def reset_dashboard(): - sql.set(None) - title.set(None) - - # Set up the chat object for this session - # Support both a callable that creates a client and legacy instance pattern - if callable(client) and not isinstance(client, chatlas.Chat): - chat = client( - update_dashboard=update_dashboard, reset_dashboard=reset_dashboard - ) - else: - # Legacy pattern: client is Chat instance - chat = copy.deepcopy(client) + # Build the session-specific chat client through QueryChat.client(...). + chat = build_chat_client() - chat.register_tool(tool_update_dashboard(data_source, update_dashboard)) - chat.register_tool(tool_query(data_source)) - chat.register_tool(tool_reset_dashboard(reset_dashboard)) + if has_viz_tool(tools): + preload_viz_deps_server() # Execute query when SQL changes @reactive.calc @@ -211,6 +247,8 @@ def _on_bookmark(x: BookmarkState) -> None: vals["querychat_sql"] = sql.get() vals["querychat_title"] = title.get() vals["querychat_has_greeted"] = has_greeted.get() + if viz_widgets: + vals["querychat_viz_widgets"] = viz_widgets @session.bookmark.on_restore def _on_restore(x: RestoreState) -> None: @@ -221,9 +259,44 @@ def _on_restore(x: RestoreState) -> None: title.set(vals["querychat_title"]) if "querychat_has_greeted" in vals: has_greeted.set(vals["querychat_has_greeted"]) + if "querychat_viz_widgets" in vals: + restored = restore_viz_widgets( + data_source, vals["querychat_viz_widgets"] + ) + viz_widgets[:] = restored return ServerValues(df=filtered_df, sql=sql, title=title, client=chat) class GreetWarning(Warning): """Warning raised when no greeting is provided to QueryChat.""" + + +def restore_viz_widgets( + data_source: DataSource[IntoFrameT], + saved_widgets: list[VizWidgetEntry], +) -> list[VizWidgetEntry]: + """Re-execute ggsql queries, register widgets, and return restored entries.""" + from ggsql import validate + from shinywidgets import register_widget + + restored: list[VizWidgetEntry] = [] + + for entry in saved_widgets: + widget_id = entry["widget_id"] + ggsql_str = entry["ggsql"] + try: + validated = validate(ggsql_str) + spec = execute_ggsql(data_source, validated) + altair_widget = AltairWidget.from_ggsql(spec, widget_id=widget_id) + register_widget(widget_id, altair_widget.widget) + restored.append(entry) + except Exception: + # If a query fails on restore (e.g. data changed), skip it. + # The placeholder will remain empty but the rest of the chat restores. + warnings.warn( + f"Failed to restore visualization widget '{widget_id}' on bookmark restore.", + stacklevel=2, + ) + + return restored diff --git a/pkg-py/src/querychat/_system_prompt.py b/pkg-py/src/querychat/_system_prompt.py index 5a8445e93..0a57a70ba 100644 --- a/pkg-py/src/querychat/_system_prompt.py +++ b/pkg-py/src/querychat/_system_prompt.py @@ -6,6 +6,8 @@ import chevron +from ._viz_utils import has_viz_tool + _SCHEMA_TAG_RE = re.compile(r"\{\{[{#^/]?\s*schema\b") if TYPE_CHECKING: @@ -83,7 +85,14 @@ def render(self, tools: tuple[TOOL_GROUPS, ...] | None) -> str: "extra_instructions": self.extra_instructions, "has_tool_update": "update" in tools if tools else False, "has_tool_query": "query" in tools if tools else False, + "has_tool_visualize": has_viz_tool(tools), "include_query_guidelines": len(tools or ()) > 0, } - return chevron.render(self.template, context) + prompts_dir = str(Path(__file__).parent / "prompts") + return chevron.render( + self.template, + context, + partials_path=prompts_dir, + partials_ext="md", + ) diff --git a/pkg-py/src/querychat/_utils.py b/pkg-py/src/querychat/_utils.py index 555e8e376..1c8f9f31b 100644 --- a/pkg-py/src/querychat/_utils.py +++ b/pkg-py/src/querychat/_utils.py @@ -4,8 +4,10 @@ import re import warnings from contextlib import contextmanager +from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, Optional, overload +import chevron import narwhals.stable.v1 as nw from great_tables import GT @@ -14,6 +16,50 @@ import ibis import pandas as pd + import polars as pl + from narwhals.stable.v1.typing import IntoFrame + + +_SCHEMA_DUMP_PATTERN = re.compile( + r"^\s*[\{\[]|'additionalProperties'|\"additionalProperties\"", +) + + +def truncate_error(error_msg: str, max_chars: int = 500) -> str: + if len(error_msg) <= max_chars: + return error_msg + + lines = error_msg.split("\n") + meaningful: list[str] = [] + truncated_by_schema = False + for line in lines: + if not line.strip(): + truncated_by_schema = True + break + if _SCHEMA_DUMP_PATTERN.search(line): + truncated_by_schema = True + break + meaningful.append(line) + + if truncated_by_schema and meaningful: + prefix = "\n".join(meaningful) + if len(prefix) > max_chars: + cut = prefix[:max_chars] + last_space = cut.rfind(" ") + if last_space > max_chars // 2: + cut = cut[:last_space] + prefix = cut + return prefix.rstrip() + "\n\n(error truncated)" + + # No schema markers found (or nothing before them) — apply hard cap if needed + if len(error_msg) <= max_chars: + return error_msg + + truncated = error_msg[:max_chars] + last_space = truncated.rfind(" ") + if last_space > max_chars // 2: + truncated = truncated[:last_space] + return truncated.rstrip() + "\n\n(error truncated)" class MISSING_TYPE: # noqa: N801 @@ -171,14 +217,18 @@ def get_tool_details_setting() -> Optional[Literal["expanded", "collapsed", "def return setting_lower -def querychat_tool_starts_open(action: Literal["update", "query", "reset"]) -> bool: +def querychat_tool_starts_open( + action: Literal[ + "update", "query", "reset", "visualize" + ], +) -> bool: """ Determine whether a tool card should be open based on action and setting. Parameters ---------- action : str - The action type ('update', 'query', or 'reset') + The action type ('update', 'query', 'reset', or 'visualize') Returns ------- @@ -290,3 +340,15 @@ def df_to_html(df, maxrows: int = 5) -> str: table_html += f"\n\n*(Showing {maxrows} of {nrow_full} rows)*\n" return table_html + + +def to_polars(data: IntoFrame) -> pl.DataFrame: + """Convert any narwhals-compatible frame to a polars DataFrame.""" + return as_narwhals(data).to_polars() + + +def read_prompt_template(filename: str, **kwargs: object) -> str: + """Read and interpolate a prompt template file.""" + template_path = Path(__file__).parent / "prompts" / filename + template = template_path.read_text() + return chevron.render(template, kwargs) diff --git a/pkg-py/src/querychat/_viz_altair_widget.py b/pkg-py/src/querychat/_viz_altair_widget.py new file mode 100644 index 000000000..00d40347d --- /dev/null +++ b/pkg-py/src/querychat/_viz_altair_widget.py @@ -0,0 +1,187 @@ +"""Altair chart wrapper for responsive display in Shiny.""" + +from __future__ import annotations + +import copy +from typing import TYPE_CHECKING, Any, cast +from uuid import uuid4 + +from shiny.session import get_current_session + +from shiny import reactive + +if TYPE_CHECKING: + import altair as alt + import ggsql + +class AltairWidget: + """ + An Altair chart wrapped in ``alt.JupyterChart`` for display in Shiny. + + Always produces a ``JupyterChart`` so that ``shinywidgets`` receives + a consistent widget type and doesn't call ``chart.properties(width=...)`` + (which fails on compound specs). + + Simple charts use native ``width/height: "container"`` sizing. + Compound charts (facet, concat) get calculated cell dimensions + that are reactively updated when the output container resizes. + """ + + widget: alt.JupyterChart + widget_id: str + + def __init__( + self, + chart: alt.TopLevelMixin, + *, + widget_id: str | None = None, + ) -> None: + import altair as alt + + is_compound = isinstance( + chart, + (alt.FacetChart, alt.ConcatChart, alt.HConcatChart, alt.VConcatChart), + ) + + # Workaround: Vega-Lite's width/height: "container" doesn't work for + # compound specs (facet, concat, etc.), so we inject pixel dimensions + # and reconstruct the chart. Remove this branch when ggsql handles it + # natively: https://github.com/posit-dev/ggsql/issues/238 + if is_compound: + chart = fit_chart_to_container( + chart, DEFAULT_COMPOUND_WIDTH, DEFAULT_COMPOUND_HEIGHT + ) + else: + chart = chart.properties(width="container", height="container") + + self.widget = alt.JupyterChart(chart) + self.widget_id = widget_id or f"querychat_viz_{uuid4().hex[:8]}" + + # Reactively update compound cell sizes when the container resizes. + # Also part of the compound sizing workaround (issue #238). + if is_compound: + self._setup_reactive_sizing(self.widget, self.widget_id) + + @classmethod + def from_ggsql( + cls, spec: ggsql.Spec, *, widget_id: str | None = None + ) -> AltairWidget: + from ggsql import VegaLiteWriter + + writer = VegaLiteWriter() + return cls(writer.render_chart(spec), widget_id=widget_id) + + @staticmethod + def _setup_reactive_sizing(widget: alt.JupyterChart, widget_id: str) -> None: + session = get_current_session() + if session is None: + return + + @reactive.effect + def _sizing_effect(): + width = session.clientdata.output_width(widget_id) + height = session.clientdata.output_height(widget_id) + if width is None or height is None: + return + chart = widget.chart + if chart is None: + return + chart = cast("alt.Chart", chart) + chart2 = fit_chart_to_container(chart, int(width), int(height)) + # Must set widget.spec (a new dict) rather than widget.chart, + # because traitlets won't fire change events when the same + # chart object is assigned back after in-place mutation. + widget.spec = chart2.to_dict() + + # Clean up the effect when the session ends to avoid memory leaks + session.on_ended(_sizing_effect.destroy) + + +# --------------------------------------------------------------------------- +# Compound chart sizing helpers +# +# Vega-Lite's `width/height: "container"` doesn't work for compound specs +# (facet, concat, etc.), so we manually inject cell dimensions. Ideally ggsql +# will handle this natively: https://github.com/posit-dev/ggsql/issues/238 +# --------------------------------------------------------------------------- + +DEFAULT_COMPOUND_WIDTH = 900 +DEFAULT_COMPOUND_HEIGHT = 450 + +LEGEND_CHANNELS = frozenset( + {"color", "fill", "stroke", "shape", "size", "opacity"} +) +LEGEND_WIDTH = 120 # approximate space for a right-side legend + + +def fit_chart_to_container( + chart: alt.TopLevelMixin, + container_width: int, + container_height: int, +) -> alt.TopLevelMixin: + """ + Return a copy of ``chart`` with cell ``width``/``height`` set. + + The original chart is never mutated. + + For faceted charts, divides the container width by the number of columns. + For hconcat/concat, divides by the number of sub-specs. + For vconcat, each sub-spec gets the full width. + + Subtracts padding estimates so the rendered cells fill the container, + including space for legends when present. + """ + import altair as alt + + chart = copy.deepcopy(chart) + + # Approximate padding; will be replaced when ggsql handles compound sizing + # natively (https://github.com/posit-dev/ggsql/issues/238). + padding_x = 80 # y-axis labels + title padding + padding_y = 120 # facet headers, x-axis labels + title, bottom padding + if has_legend(chart.to_dict()): + padding_x += LEGEND_WIDTH + usable_w = max(container_width - padding_x, 100) + usable_h = max(container_height - padding_y, 100) + + if isinstance(chart, alt.FacetChart): + ncol = chart.columns if isinstance(chart.columns, int) else 1 + cell_w = usable_w // max(ncol, 1) + chart.spec.width = cell_w + chart.spec.height = usable_h + elif isinstance(chart, alt.HConcatChart): + cell_w = usable_w // max(len(chart.hconcat), 1) + for sub in chart.hconcat: + sub.width = cell_w + sub.height = usable_h + elif isinstance(chart, alt.ConcatChart): + ncol = chart.columns if isinstance(chart.columns, int) else len(chart.concat) + cell_w = usable_w // max(ncol, 1) + for sub in chart.concat: + sub.width = cell_w + sub.height = usable_h + elif isinstance(chart, alt.VConcatChart): + cell_h = usable_h // max(len(chart.vconcat), 1) + for sub in chart.vconcat: + sub.width = usable_w + sub.height = cell_h + + return chart + + +def has_legend(vl: dict[str, object]) -> bool: + """Check if any encoding in the VL spec uses a legend-producing channel with a field.""" + specs: list[dict[str, Any]] = [] + if "spec" in vl: + specs.append(vl["spec"]) # type: ignore[arg-type] + for key in ("hconcat", "vconcat", "concat"): + if key in vl: + specs.extend(vl[key]) # type: ignore[arg-type] + + for spec in specs: + for layer in spec.get("layer", [spec]): # type: ignore[union-attr] + enc = layer.get("encoding", {}) # type: ignore[union-attr] + for ch in LEGEND_CHANNELS: + if ch in enc and "field" in enc[ch]: # type: ignore[operator] + return True + return False diff --git a/pkg-py/src/querychat/_viz_ggsql.py b/pkg-py/src/querychat/_viz_ggsql.py new file mode 100644 index 000000000..b1a950363 --- /dev/null +++ b/pkg-py/src/querychat/_viz_ggsql.py @@ -0,0 +1,106 @@ +"""Helpers for ggsql integration.""" + +from __future__ import annotations + +import re +from typing import TYPE_CHECKING + +from ._utils import to_polars + +if TYPE_CHECKING: + import ggsql + + from ._datasource import DataSource + + +def execute_ggsql(data_source: DataSource, validated: ggsql.Validated) -> ggsql.Spec: + """ + Execute a pre-validated ggsql query against a DataSource, returning a Spec. + + Executes the SQL portion through DataSource (preserving database pushdown), + then feeds the result into a ggsql DuckDBReader to produce a Spec. + + Parameters + ---------- + data_source + The querychat DataSource to execute the SQL portion against. + validated + A pre-validated ggsql query (from ``ggsql.validate()``). + + Returns + ------- + ggsql.Spec + The writer-independent plot specification. + + """ + from ggsql import DuckDBReader + + visual = validated.visual() + if has_layer_level_source(visual): + # Short term, querychat only supports visual layers that can be replayed + # from one final SQL result. Long term, the cleaner fix is likely to use + # ggsql's native remote-reader execution path (for example via ODBC-backed + # Readers) instead of reconstructing multi-relation scope here. + raise ValueError( + "Layer-specific sources are not currently supported in querychat visual " + "queries. Rewrite the query so that all layers come from the final SQL " + "result." + ) + + pl_df = to_polars(data_source.execute_query(validated.sql())) + # Snowflake (and some other backends) uppercase unquoted identifiers, + # but the LLM writes lowercase aliases in the VISUALISE clause. + # DuckDB is case-insensitive, so lowercasing here lets both match. + pl_df.columns = [c.lower() for c in pl_df.columns] + + reader = DuckDBReader("duckdb://memory") + table = extract_visualise_table(visual) + + if table is not None: + # VISUALISE [mappings] FROM — register data under the + # referenced table name and execute the visual part directly. + name = table[1:-1] if table.startswith('"') and table.endswith('"') else table + reader.register(name, pl_df) + return reader.execute(visual) + else: + # SELECT ... VISUALISE — no FROM in VISUALISE clause, so register + # under a synthetic name and prepend a SELECT. + reader.register("_data", pl_df) + return reader.execute(f"SELECT * FROM _data {visual}") + + +def extract_visualise_table(visual: str) -> str | None: + """ + Extract the table name from ``VISUALISE ... FROM
`` if present. + + This reimplements a small part of ggsql's parsing because the current + Python bindings do not expose the top-level VISUALISE source directly. + """ + draw_pos = re.search(r"\bDRAW\b", visual, re.IGNORECASE) + vis_clause = visual[: draw_pos.start()] if draw_pos else visual + m = re.search(r'\bFROM\s+("[^"]+?"|\S+)', vis_clause, re.IGNORECASE) + return m.group(1) if m else None + + +def has_layer_level_source(visual: str) -> bool: + """ + Return ``True`` when a DRAW clause defines its own ``FROM ``. + + Querychat currently replays the VISUALISE portion against a single local + relation, so layer-specific sources cannot be preserved reliably. + """ + clauses = re.split( + r"(?=\b(?:DRAW|SCALE|PROJECT|FACET|PLACE|LABEL|THEME)\b)", + visual, + flags=re.IGNORECASE, + ) + for clause in clauses: + if not re.match(r"^\s*DRAW\b", clause, re.IGNORECASE): + continue + if re.search( + r'\bMAPPING\b[\s\S]*?\bFROM\s+("[^"]+?"|\S+)', + clause, + re.IGNORECASE, + ): + return True + return False diff --git a/pkg-py/src/querychat/_viz_tools.py b/pkg-py/src/querychat/_viz_tools.py new file mode 100644 index 000000000..d273cf354 --- /dev/null +++ b/pkg-py/src/querychat/_viz_tools.py @@ -0,0 +1,328 @@ +"""Visualization tool definitions for querychat.""" + +from __future__ import annotations + +import base64 +import copy +import io +from typing import TYPE_CHECKING, Any, TypedDict +from uuid import uuid4 + +from chatlas import ContentToolResult, Tool, content_image_url +from htmltools import HTMLDependency, TagList, tags +from shinychat.types import ToolResultDisplay + +from shiny import ui + +from .__version import __version__ +from ._icons import bs_icon +from ._utils import querychat_tool_starts_open, read_prompt_template, truncate_error +from ._viz_altair_widget import AltairWidget, fit_chart_to_container +from ._viz_ggsql import execute_ggsql + +if TYPE_CHECKING: + from collections.abc import Callable + + import altair as alt + from ipywidgets.widgets.widget import Widget + + from ._datasource import DataSource + + +class VisualizeData(TypedDict): + """ + Data passed to visualize callback. + + This TypedDict defines the structure of data passed to the + `tool_visualize` callback function when the LLM creates an + exploratory visualization from a ggsql query. + + Attributes + ---------- + ggsql + The full ggsql query string (SQL + VISUALISE). + title + A descriptive title for the visualization. + widget_id + The unique widget ID used to register the visualization with shinywidgets. + + """ + + ggsql: str + title: str + widget_id: str + + +def tool_visualize( + data_source: DataSource, + update_fn: Callable[[VisualizeData], None], +) -> Tool: + """ + Create a tool that executes a ggsql query and renders the visualization. + + Parameters + ---------- + data_source + The data source to query against + update_fn + Callback function to call with VisualizeData when visualization succeeds + + Returns + ------- + Tool + A tool that can be registered with chatlas + + """ + impl = visualize_impl(data_source, update_fn) + impl.__doc__ = read_prompt_template( + "tool-visualize.md", + db_type=data_source.get_db_type(), + ) + + return Tool.from_func( + impl, + name="querychat_visualize", + annotations={"title": "Query Visualization"}, + ) + + +class VisualizeResult(ContentToolResult): + """Tool result that registers an ipywidget and embeds it inline via shinywidgets.""" + + def __init__( + self, + widget_id: str, + widget: Widget, + ggsql_str: str, + title: str, + png_bytes: bytes | None = None, + **kwargs: Any, + ): + from shinywidgets import output_widget, register_widget + + register_widget(widget_id, widget) + + title_display = f" with title '{title}'" if title else "" + text = f"Chart displayed{title_display}." + + if png_bytes is not None: + png_b64 = base64.b64encode(png_bytes).decode("ascii") + value = [ + text, + content_image_url(f"data:image/png;base64,{png_b64}"), + ] + else: + value = text + + footer = build_viz_footer(ggsql_str, title, widget_id) + + widget_html = output_widget(widget_id, fill=True, fillable=True) + widget_html.add_class("querychat-viz-container") + widget_html.append(viz_dep()) + + extra = { + "display": ToolResultDisplay( + html=widget_html, + title=title or "Query Visualization", + show_request=False, + open=querychat_tool_starts_open("visualize"), + full_screen=True, + icon=bs_icon("graph-up"), + footer=footer, + ), + } + + super().__init__(value=value, model_format="as_is", extra=extra, **kwargs) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def visualize_impl( + data_source: DataSource, + update_fn: Callable[[VisualizeData], None], +) -> Callable[[str, str], ContentToolResult]: + """Create the visualize implementation function.""" + from ggsql import VegaLiteWriter, validate + + def visualize( + ggsql: str, + title: str, + ) -> ContentToolResult: + """Execute a ggsql query and render the visualization.""" + markdown = f"```sql\n{ggsql}\n```" + + try: + validated = validate(ggsql) + if not validated.has_visual(): + # When VISUALISE contains SQL expressions (e.g., CAST()), + # ggsql silently treats the entire query as plain SQL: + # valid()=True, has_visual()=False, no errors. This + # heuristic catches that case so we can guide the LLM. + # Remove when ggsql reports this as a parse error: + # https://github.com/posit-dev/ggsql/issues/256 + has_keyword = ( + "VISUALISE" in ggsql.upper() or "VISUALIZE" in ggsql.upper() + ) + if has_keyword: + raise ValueError( + "VISUALISE clause was not recognized. " + "VISUALISE and MAPPING accept column names only — " + "no SQL expressions, CAST(), or functions. " + "Move all data transformations to the SELECT clause, " + "then reference the resulting column by name in VISUALISE." + ) + raise ValueError( + "Query must include a VISUALISE clause. " + "Use querychat_query for queries without visualization." + ) + + spec = execute_ggsql(data_source, validated) + + raw_chart = VegaLiteWriter().render_chart(spec) + altair_widget = AltairWidget(copy.deepcopy(raw_chart)) + + try: + png_bytes = render_chart_to_png(raw_chart) + except Exception: + png_bytes = None + + update_fn( + {"ggsql": ggsql, "title": title, "widget_id": altair_widget.widget_id} + ) + + return VisualizeResult( + widget_id=altair_widget.widget_id, + widget=altair_widget.widget, + ggsql_str=ggsql, + title=title, + png_bytes=png_bytes, + ) + + except Exception as e: + error_msg = truncate_error(str(e)) + markdown += f"\n\n> Error: {error_msg}" + return ContentToolResult(value=markdown, error=Exception(error_msg)) + + return visualize + + +PNG_WIDTH = 500 +PNG_HEIGHT = 300 + + +def render_chart_to_png(chart: alt.TopLevelMixin) -> bytes: + """Render an Altair chart to PNG bytes at a fixed size for LLM feedback.""" + import altair as alt + + chart = copy.deepcopy(chart) + is_compound = isinstance( + chart, + (alt.FacetChart, alt.ConcatChart, alt.HConcatChart, alt.VConcatChart), + ) + if is_compound: + chart = fit_chart_to_container(chart, PNG_WIDTH, PNG_HEIGHT) + else: + chart = chart.properties(width=PNG_WIDTH, height=PNG_HEIGHT) + + buf = io.BytesIO() + chart.save(buf, format="png", scale_factor=1) + return buf.getvalue() + + +def viz_dep() -> HTMLDependency: + """HTMLDependency for viz-specific CSS and JS assets.""" + return HTMLDependency( + "querychat-viz", + __version__, + source={ + "package": "querychat", + "subdir": "static", + }, + stylesheet=[{"href": "css/viz.css"}], + script=[{"src": "js/viz.js"}], + ) + + +def build_viz_footer( + ggsql_str: str, + title: str, + widget_id: str, +) -> TagList: + """Build footer HTML for visualization tool results.""" + footer_id = f"querychat_footer_{uuid4().hex[:8]}" + query_section_id = f"{footer_id}_query" + code_editor_id = f"{footer_id}_code" + + # Read-only code editor for query display + code_editor = ui.input_code_editor( + id=code_editor_id, + value=ggsql_str, + language="ggsql", + read_only=True, + line_numbers=False, + height="auto", + theme_dark="github-dark", + ) + + # Query section (hidden by default) + query_section = tags.div( + {"class": "querychat-query-section", "id": query_section_id}, + code_editor, + ) + + # Footer buttons row + buttons_row = tags.div( + {"class": "querychat-footer-buttons"}, + # Left: Show Query toggle + tags.div( + {"class": "querychat-footer-left"}, + tags.button( + { + "class": "querychat-show-query-btn", + "data-target": query_section_id, + }, + tags.span({"class": "querychat-query-chevron"}, "\u25b6"), + tags.span({"class": "querychat-query-label"}, "Show Query"), + ), + ), + # Right: Save dropdown + tags.div( + {"class": "querychat-footer-right"}, + tags.div( + {"class": "querychat-save-dropdown"}, + tags.button( + { + "class": "querychat-save-btn", + "data-widget-id": widget_id, + }, + bs_icon("download", cls="querychat-icon"), + "Save", + bs_icon("chevron-down", cls="querychat-dropdown-chevron"), + ), + tags.div( + {"class": "querychat-save-menu"}, + tags.button( + { + "class": "querychat-save-png-btn", + "data-widget-id": widget_id, + "data-title": title, + }, + "Save as PNG", + ), + tags.button( + { + "class": "querychat-save-svg-btn", + "data-widget-id": widget_id, + "data-title": title, + }, + "Save as SVG", + ), + ), + ), + ), + ) + + return TagList(buttons_row, query_section) diff --git a/pkg-py/src/querychat/_viz_utils.py b/pkg-py/src/querychat/_viz_utils.py new file mode 100644 index 000000000..57aae7a50 --- /dev/null +++ b/pkg-py/src/querychat/_viz_utils.py @@ -0,0 +1,67 @@ +"""Shared visualization utilities.""" + +from __future__ import annotations + +import importlib.util + +from htmltools import HTMLDependency, tags + +from .__version import __version__ + + +def has_viz_tool(tools: tuple[str, ...] | None) -> bool: + """Check if visualize is among the configured tools.""" + return tools is not None and "visualize" in tools + + +def has_viz_deps() -> bool: + """Check whether visualization dependencies (ggsql, altair, shinywidgets, vl-convert-python) are installed.""" + return all( + importlib.util.find_spec(pkg) is not None + for pkg in ("ggsql", "altair", "shinywidgets", "vl_convert") + ) + + +PRELOAD_WIDGET_ID = "__querychat_preload_viz__" + + +def preload_viz_deps_ui(): + """Return a hidden widget output that triggers eager JS dependency loading.""" + from shinywidgets import output_widget + + return tags.div( + output_widget(PRELOAD_WIDGET_ID), + viz_preload_dep(), + class_="querychat-viz-preload", + hidden="", + aria_hidden="true", + style="position:absolute; left:-9999px; width:1px; height:1px;", + ) + + +def viz_preload_dep() -> HTMLDependency: + """HTMLDependency for viz preload-specific JS.""" + return HTMLDependency( + "querychat-viz-preload", + __version__, + source={ + "package": "querychat", + "subdir": "static", + }, + script=[{"src": "js/viz-preload.js"}], + ) + + +def preload_viz_deps_server() -> None: + """Register a minimal Altair widget to trigger full JS dependency loading.""" + from shinywidgets import register_widget + + register_widget(PRELOAD_WIDGET_ID, mock_altair_widget()) + + +def mock_altair_widget(): + """Create a minimal Altair JupyterChart suitable for preloading JS dependencies.""" + import altair as alt + + chart = alt.Chart({"values": [{"x": 0}]}).mark_point().encode(x="x:Q") + return alt.JupyterChart(chart) diff --git a/pkg-py/src/querychat/prompts/ggsql-syntax.md b/pkg-py/src/querychat/prompts/ggsql-syntax.md new file mode 100644 index 000000000..e98b90e56 --- /dev/null +++ b/pkg-py/src/querychat/prompts/ggsql-syntax.md @@ -0,0 +1,561 @@ +## ggsql Syntax Reference + +### Quick Reference + +```sql +[WITH cte AS (...), ...] +[SELECT columns FROM table WHERE conditions] +VISUALISE [mappings] [FROM source] +DRAW geom_type + [MAPPING col AS aesthetic, ... FROM source] + [REMAPPING stat AS aesthetic, ...] + [SETTING param => value, ...] + [FILTER sql_condition] + [PARTITION BY col, ...] + [ORDER BY col [ASC|DESC], ...] +[SCALE [TYPE] aesthetic [FROM ...] [TO ...] [VIA ...] [SETTING ...] [RENAMING ...]] +[PROJECT [aesthetics] TO coord_system [SETTING ...]] +[FACET var | row_var BY col_var [SETTING free => 'x'|'y'|('x','y'), ncol => N, nrow => N]] +[PLACE geom_type SETTING param => value, ...] +[LABEL x => '...', y => '...', ...] +``` + +### VISUALISE Clause + +Entry point for visualization. Marks where SQL ends and visualization begins. Mappings in VISUALISE and MAPPING accept **column names only** — no SQL expressions, functions, or casts. All data transformations must happen in the SELECT clause. + +```sql +-- After SELECT (most common) +SELECT date, revenue, region FROM sales +VISUALISE date AS x, revenue AS y, region AS color +DRAW line + +-- Shorthand with FROM (auto-generates SELECT * FROM) +VISUALISE FROM sales +DRAW bar MAPPING region AS x, total AS y +``` + +### Mapping Styles + +| Style | Syntax | Use When | +|-------|--------|----------| +| Explicit | `date AS x` | Column name differs from aesthetic | +| Implicit | `x` | Column name equals aesthetic name | +| Wildcard | `*` | Map all matching columns automatically | +| Literal | `'string' AS color` | Use a literal value (for legend labels in multi-layer plots) | +| Null | `null AS color` | Suppress an inherited global mapping for this layer | + +### DRAW Clause (Layers) + +Multiple DRAW clauses create layered visualizations. + +```sql +DRAW geom_type + [MAPPING col AS aesthetic, ... FROM source] + [REMAPPING stat AS aesthetic, ...] + [SETTING param => value, ...] + [FILTER sql_condition] + [PARTITION BY col, ...] + [ORDER BY col [ASC|DESC], ...] +``` + +**Geom types:** + +| Category | Types | +|----------|-------| +| Basic | `point`, `line`, `path`, `bar`, `area`, `tile`, `polygon`, `ribbon` | +| Statistical | `histogram`, `density`, `smooth`, `boxplot`, `violin` | +| Annotation | `text`, `label`, `segment`, `arrow`, `rule`, `rect`, `errorbar` | + +- `path` is like `line` but preserves data order instead of sorting by x. +- `tile` draws rectangles for heatmaps or range indicators. Map `x`/`y` for center (defaults to width/height of 1), or use `xmin`/`xmax`/`ymin`/`ymax` for explicit bounds. +- `smooth` fits a trendline to data. Settings: `method` (`'nw'` default for kernel regression, `'ols'` for linear, `'tls'` for total least squares), `bandwidth`, `adjust`, `kernel`. +- `text` (or `label`) renders text labels. Map `label` for the text content. Settings: `format` (template string for label formatting), `offset` (pixel offset as `(x, y)`). Labels containing `\n` are automatically split into multiple lines. +- `arrow` draws arrows between two points. Requires `x`, `y`, `xend`, `yend` aesthetics. +- `rule` draws full-span reference lines. Map a value to `y` for a horizontal line or `x` for a vertical line. Optionally map `slope` to create diagonal reference lines: `y = a + slope * x` (when `y` is mapped) or `x = a + slope * y` (when `x` is mapped). +- `rect` draws rectangles. Pick 2 per axis from center (`x`/`y`), min (`xmin`/`ymin`), max (`xmax`/`ymax`), `width`, `height`. Or just map center (defaults to width/height of 1). +- `errorbar` displays interval marks. Requires `x`, `ymin`, `ymax`. Settings: `width` (hinge width in points, default 10; `null` to hide hinges). +- `line` and `path` support continuously varying `linewidth`, `stroke`, and `opacity` aesthetics within groups. + +**Aesthetics (MAPPING):** + +| Category | Aesthetics | +|----------|------------| +| Position | `x`, `y`, `xmin`, `xmax`, `ymin`, `ymax`, `xend`, `yend` | +| Color | `color`/`colour`, `fill`, `stroke`, `opacity` | +| Size/Shape | `size`, `shape`, `linewidth`, `linetype`, `width`, `height` | +| Text | `label`, `typeface`, `fontweight`, `italic`, `fontsize`, `hjust`, `vjust`, `rotation` | +| Aggregation | `weight` (for histogram/bar/density/violin) | +| Rule | `slope` (for diagonal `rule` lines) | + +**PARTITION BY** groups data without visual encoding (useful for separate lines per group without color): + +```sql +DRAW line PARTITION BY category +``` + +**ORDER BY** controls row ordering within a layer: + +```sql +DRAW line ORDER BY date ASC +``` + +### PLACE Clause (Annotations) + +`PLACE` creates annotation layers with literal values only — no data mappings. Use it for reference lines, text labels, and other fixed annotations. All aesthetics are set via `SETTING` and bypass scaling. + +```sql +PLACE geom_type SETTING param => value, ... +``` + +**Examples:** +```sql +-- Horizontal reference line +PLACE rule SETTING y => 100 + +-- Vertical reference line +PLACE rule SETTING x => '2024-06-01' + +-- Multiple reference lines (array values) +PLACE rule SETTING y => (50, 75, 100) + +-- Text annotation +PLACE text SETTING x => 10, y => 50, label => 'Threshold' + +-- Diagonal reference line (y = -1 + 0.4 * x) +PLACE rule SETTING slope => 0.4, y => -1 +``` + +`PLACE` supports any geom type but is most useful with `rule`, `text`, `segment`, and `tile`. Use `PLACE` for fixed annotation values known at query time; use `DRAW` with `MAPPING` when values come from data columns. Unlike `DRAW`, `PLACE` has no `MAPPING`, `FILTER`, `PARTITION BY`, or `ORDER BY` sub-clauses. Array values in PLACE SETTING are recycled into multiple rows only for supported aesthetics; geom parameters (like `offset` on `text`) are passed through as-is. + +### Statistical Layers and REMAPPING + +Some layers compute statistics. Use REMAPPING to access computed values: + +| Layer | Computed Stats | Default Remapping | +|-------|---------------|-------------------| +| `bar` (y unmapped) | `count`, `proportion` | `count AS y` | +| `histogram` | `count`, `density` | `count AS y` | +| `density` | `density`, `intensity` | `density AS y` | +| `violin` | `density`, `intensity` | `density AS offset` | +| `smooth` | `intensity` | `intensity AS y` | +| `boxplot` | `value`, `type` | `value AS y` | + +`boxplot` displays box-and-whisker plots. Settings: `outliers` (`true` default — show outlier points), `coef` (`1.5` default — whisker fence coefficient), `width` (`0.9` default — box width, 0–1). + +`smooth` fits a trendline to data. Settings: `method` (`'nw'` or `'nadaraya-watson'` default kernel regression, `'ols'` for OLS linear, `'tls'` for total least squares). NW-only settings: `bandwidth` (numeric), `adjust` (multiplier, default 1), `kernel` (`'gaussian'` default, `'epanechnikov'`, `'triangular'`, `'rectangular'`, `'uniform'`, `'biweight'`, `'quartic'`, `'cosine'`). + +`density` computes a KDE from a continuous `x`. Settings: `bandwidth` (numeric), `adjust` (multiplier, default 1), `kernel` (`'gaussian'` default, `'epanechnikov'`, `'triangular'`, `'rectangular'`, `'uniform'`, `'biweight'`, `'quartic'`, `'cosine'`). Use `REMAPPING intensity AS y` to show unnormalized density that reflects group size differences. Use `SETTING position => 'stack'` for stacked densities. + +`violin` displays mirrored KDE curves for groups. Requires both `x` (categorical) and `y` (continuous). Accepts the same bandwidth/adjust/kernel settings as density. Use `REMAPPING intensity AS offset` to reflect group size differences. Additional settings: `side` (`'both'` default, `'left'`/`'bottom'`, `'right'`/`'top'` — for half-violin/ridgeline plots), `width` (any value > 0; values > 1 enable ridgeline-style overlapping). + +**Examples:** + +```sql +-- Density histogram (instead of count) +VISUALISE FROM products +DRAW histogram MAPPING price AS x REMAPPING density AS y + +-- Bar showing proportion +VISUALISE FROM sales +DRAW bar MAPPING region AS x REMAPPING proportion AS y + +-- Overlay histogram and density on the same scale +VISUALISE FROM measurements +DRAW histogram MAPPING value AS x SETTING opacity => 0.5 +DRAW density MAPPING value AS x REMAPPING intensity AS y SETTING opacity => 0.5 + +-- Violin plot +SELECT department, salary FROM employees +VISUALISE department AS x, salary AS y +DRAW violin +``` + +### SCALE Clause + +Configures how data maps to visual properties. All sub-clauses are optional; type and transform are auto-detected from data when omitted. + +```sql +SCALE [TYPE] aesthetic [FROM range] [TO output] [VIA transform] [SETTING prop => value, ...] [RENAMING ...] +``` + +**Type identifiers** (optional — auto-detected if omitted): + +| Type | Description | +|------|-------------| +| `CONTINUOUS` | Numeric data on a continuous axis | +| `DISCRETE` | Categorical/nominal data | +| `BINNED` | Pre-bucketed data | +| `ORDINAL` | Ordered categories with interpolated output | +| `IDENTITY` | Data values are already visual values (e.g., literal hex colors) | + +**Important — integer columns used as categories:** When an integer column represents categories (e.g., a 0/1 `survived` column), ggsql will treat it as continuous by default. This causes errors when mapping to `fill`, `color`, `shape`, or using it in `FACET`. Two fixes: +- **Preferred:** Cast to string in the SELECT clause: `SELECT CAST(survived AS VARCHAR) AS survived ...`, then map the column by name in VISUALISE: `survived AS fill` +- **Alternative:** Declare the scale: `SCALE DISCRETE fill` or `SCALE fill VIA bool` + +**FROM** — input domain: +```sql +SCALE x FROM (0, 100) -- explicit min and max +SCALE x FROM (0, null) -- explicit min, auto max +SCALE DISCRETE x FROM ('A', 'B', 'C') -- explicit category order +``` + +**TO** — output range or palette: +```sql +SCALE color TO sequential -- default continuous palette (derived from navia) +SCALE color TO viridis -- other continuous: viridis, plasma, inferno, magma, cividis, navia, batlow +SCALE color TO vik -- diverging: vik, rdbu, rdylbu, spectral, brbg, berlin, roma +SCALE DISCRETE color TO ggsql10 -- discrete (default: ggsql10): tableau10, category10, set1, set2, set3, dark2, paired, kelly +SCALE color TO ('red', 'blue') -- explicit color array +SCALE size TO (1, 10) -- numeric output range +``` + +**VIA** — transformation: +```sql +SCALE x VIA date -- date axis (auto-detected from Date columns) +SCALE x VIA datetime -- datetime axis +SCALE y VIA log10 -- base-10 logarithm +SCALE y VIA sqrt -- square root +``` + +| Category | Transforms | +|----------|------------| +| Logarithmic | `log10`, `log2`, `log` (natural) | +| Power | `sqrt`, `square` | +| Exponential | `exp`, `exp2`, `exp10` | +| Other | `asinh`, `pseudo_log` | +| Temporal | `date`, `datetime`, `time` | +| Type coercion | `integer`, `string`, `bool` | + +**SETTING** — additional properties: +```sql +SCALE x SETTING breaks => 5 -- number of tick marks +SCALE x SETTING breaks => '2 months' -- interval-based breaks +SCALE x SETTING expand => 0.05 -- expand scale range by 5% +SCALE x SETTING reverse => true -- reverse direction +SCALE y FROM (0, 100) SETTING oob => 'squish' -- squish out-of-bounds values to range boundary +``` + +`oob` (out-of-bounds) controls data outside the scale range: `'keep'` (default for x/y), `'censor'` (remove, default for other aesthetics), `'squish'` (clamp to boundary). + +**RENAMING** — custom axis/legend labels: +```sql +SCALE DISCRETE x RENAMING 'A' => 'Alpha', 'B' => 'Beta' +SCALE CONTINUOUS x RENAMING * => '{} units' -- template for all labels +SCALE x VIA date RENAMING * => '{:time %b %Y}' -- date label formatting +``` + +### Date/Time Axes + +Temporal transforms are auto-detected from column data types, including after `DATE_TRUNC`. + +**Break intervals:** +```sql +SCALE x SETTING breaks => 'month' -- one break per month +SCALE x SETTING breaks => '2 weeks' -- every 2 weeks +SCALE x SETTING breaks => '3 months' -- quarterly +SCALE x SETTING breaks => 'year' -- yearly +``` + +Valid units: `day`, `week`, `month`, `year` (for date); also `hour`, `minute`, `second` (for datetime/time). + +**Date label formatting** (strftime syntax): +```sql +SCALE x VIA date RENAMING * => '{:time %b %Y}' -- "Jan 2024" +SCALE x VIA date RENAMING * => '{:time %B %d, %Y}' -- "January 15, 2024" +SCALE x VIA date RENAMING * => '{:time %b %d}' -- "Jan 15" +``` + +### PROJECT Clause + +Sets coordinate system. Use `PROJECT ... TO` to specify coordinates. + +**Coordinate systems:** `cartesian` (default), `polar`. + +**Polar aesthetics:** In polar coordinates, positional aesthetics use `angle` and `radius` (instead of `x` and `y`). Variants `anglemin`, `anglemax`, `angleend`, `radiusmin`, `radiusmax`, `radiusend` are also available. Typically you map to `x`/`y` and let `PROJECT TO polar` handle the conversion, but you can use `angle`/`radius` explicitly when needed. + +```sql +PROJECT TO cartesian -- explicit default (usually omitted) +PROJECT y, x TO cartesian -- flip axes (maps y to horizontal, x to vertical) +PROJECT TO polar -- pie/radial charts +PROJECT TO polar SETTING start => 90 -- start at 3 o'clock +PROJECT TO polar SETTING inner => 0.5 -- donut chart (50% hole) +PROJECT TO polar SETTING start => -90, end => 90 -- half-circle gauge +``` + +**Cartesian settings:** +- `clip` — clip out-of-bounds data (default `true`) +- `ratio` — enforce aspect ratio between axes + +**Polar settings:** +- `start` — starting angle in degrees (0 = 12 o'clock, 90 = 3 o'clock) +- `end` — ending angle in degrees (default: start + 360; use for partial arcs/gauges) +- `inner` — inner radius as proportion 0–1 (0 = full pie, 0.5 = donut with 50% hole) +- `clip` — clip out-of-bounds data (default `true`) + +**Axis flipping:** To create horizontal bar charts or flip axes, use `PROJECT y, x TO cartesian`. This maps anything on `y` to the horizontal axis and `x` to the vertical axis. + +### FACET Clause + +Creates small multiples (subplots by category). + +```sql +FACET category -- Single variable, wrapped layout +FACET row_var BY col_var -- Grid layout (rows x columns) +FACET category SETTING free => 'y' -- Independent y-axes +FACET category SETTING free => ('x', 'y') -- Independent both axes +FACET category SETTING ncol => 4 -- Control number of columns +FACET category SETTING nrow => 2 -- Control number of rows (mutually exclusive with ncol) +``` + +Custom strip labels via SCALE: +```sql +FACET region +SCALE panel RENAMING 'N' => 'North', 'S' => 'South' +``` + +Filter to specific panels via SCALE FROM: +```sql +FACET island +SCALE panel FROM ('Biscoe', 'Dream') +``` + +### LABEL Clause + +Use LABEL for axis labels, subtitles, and captions. Do NOT use `LABEL title => ...` — the tool's `title` parameter handles chart titles. Set a label to `null` to suppress it. + +Available labels: any aesthetic name (`x`, `y`, `fill`, `color`, etc.), `subtitle`, `caption`. + +```sql +LABEL x => 'X Axis Label', y => 'Y Axis Label' +LABEL x => null -- suppress x-axis label +LABEL subtitle => 'Q4 2024 data', caption => 'Source: internal database' +``` + +## Complete Examples + +**Line chart with multiple series:** +```sql +SELECT date, revenue, region FROM sales WHERE year = 2024 +VISUALISE date AS x, revenue AS y, region AS color +DRAW line +SCALE x VIA date +LABEL x => 'Date', y => 'Revenue ($)' +``` + +**Bar chart (auto-count):** +```sql +VISUALISE FROM products +DRAW bar MAPPING category AS x +``` + +**Horizontal bar chart:** +```sql +SELECT region, COUNT(*) as n FROM sales GROUP BY region +VISUALISE region AS y, n AS x +DRAW bar +PROJECT y, x TO cartesian +``` + +**Scatter plot with trend line:** +```sql +SELECT mpg, hp, cylinders FROM cars +VISUALISE mpg AS x, hp AS y +DRAW point MAPPING cylinders AS color +DRAW smooth +``` + +**Histogram with density overlay:** +```sql +VISUALISE FROM measurements +DRAW histogram MAPPING value AS x SETTING bins => 20, opacity => 0.5 +DRAW density MAPPING value AS x REMAPPING intensity AS y SETTING opacity => 0.5 +``` + +**Density plot with groups:** +```sql +VISUALISE FROM measurements +DRAW density MAPPING value AS x, category AS color SETTING opacity => 0.7 +``` + +**Heatmap with tile:** +```sql +SELECT day, month, temperature FROM weather +VISUALISE day AS x, month AS y, temperature AS color +DRAW tile +``` + +**Threshold reference lines (using PLACE):** +```sql +SELECT date, temperature FROM sensor_data +VISUALISE date AS x, temperature AS y +DRAW line +PLACE rule SETTING y => 100, stroke => 'red', linetype => 'dashed' +LABEL y => 'Temperature (F)' +``` + +**Faceted chart:** +```sql +SELECT month, sales, region FROM data +VISUALISE month AS x, sales AS y +DRAW line +DRAW point +FACET region +SCALE x VIA date +``` + +**CTE with aggregation and date formatting:** +```sql +WITH monthly AS ( + SELECT DATE_TRUNC('month', order_date) as month, SUM(amount) as total + FROM orders GROUP BY 1 +) +VISUALISE month AS x, total AS y FROM monthly +DRAW line +DRAW point +SCALE x VIA date SETTING breaks => 'month' RENAMING * => '{:time %b %Y}' +LABEL y => 'Revenue ($)' +``` + +**Ribbon / confidence band:** +```sql +WITH daily AS ( + SELECT DATE_TRUNC('day', timestamp) as day, + AVG(temperature) as avg_temp, + MIN(temperature) as min_temp, + MAX(temperature) as max_temp + FROM sensor_data + GROUP BY DATE_TRUNC('day', timestamp) +) +VISUALISE day AS x FROM daily +DRAW ribbon MAPPING min_temp AS ymin, max_temp AS ymax SETTING opacity => 0.3 +DRAW line MAPPING avg_temp AS y +SCALE x VIA date +LABEL y => 'Temperature' +``` + +**Text labels on bars:** +```sql +SELECT region, COUNT(*) AS n FROM sales GROUP BY region +VISUALISE region AS x, n AS y +DRAW bar +DRAW text MAPPING n AS label SETTING offset => (0, -11), fill => 'white' +``` + +**Lollipop chart:** +```sql +SELECT ROUND(bill_dep) AS bill_dep, COUNT(*) AS n FROM penguins GROUP BY 1 +VISUALISE bill_dep AS x, n AS y +DRAW segment MAPPING 0 AS yend +DRAW point +``` + +**Ridgeline / joy plot:** +```sql +VISUALISE temp AS x, month AS y FROM weather +DRAW violin SETTING width => 4, side => 'top' +SCALE ORDINAL y +``` + +**Donut chart:** +```sql +VISUALISE FROM products +DRAW bar MAPPING category AS fill +PROJECT TO polar SETTING inner => 0.5 +``` + +## Important Notes + +1. **Numeric columns as categories**: Integer columns representing categories (e.g., 0/1 `survived`) are treated as continuous by default, causing errors with `fill`, `color`, `shape`, and `FACET`. Fix by casting in SQL or declaring the scale: + ```sql + -- WRONG: integer fill without discrete scale — causes validation error + SELECT sex, survived FROM titanic + VISUALISE sex AS x, survived AS fill + DRAW bar + + -- CORRECT: cast to string in SQL (preferred) + SELECT sex, CAST(survived AS VARCHAR) AS survived FROM titanic + VISUALISE sex AS x, survived AS fill + DRAW bar + + -- ALSO CORRECT: declare the scale as discrete + SELECT sex, survived FROM titanic + VISUALISE sex AS x, survived AS fill + DRAW bar + SCALE DISCRETE fill + ``` +2. **Do not mix `VISUALISE FROM` with a preceding `SELECT`**: `VISUALISE FROM table` is shorthand that auto-generates `SELECT * FROM table`. If you already have a `SELECT`, use `SELECT ... VISUALISE` instead: + ```sql + -- WRONG: VISUALISE FROM after SELECT + SELECT * FROM titanic + VISUALISE FROM titanic + DRAW bar MAPPING class AS x + + -- CORRECT: use VISUALISE (without FROM) after SELECT + SELECT * FROM titanic + VISUALISE class AS x + DRAW bar + + -- ALSO CORRECT: use VISUALISE FROM without any SELECT + VISUALISE FROM titanic + DRAW bar MAPPING class AS x + ``` +3. **In querychat, all layers must come from the final SQL result**: Do not use layer-specific `FROM source` inside `DRAW ... MAPPING ...` clauses. If you need raw data and a summary in one chart, put both into one final relation and distinguish layers with a column such as `layer_type`: + ```sql + WITH raw AS ( + SELECT + date, + amount, + region, + 'raw' AS layer_type + FROM sales + ), + summary AS ( + SELECT + date, + AVG(amount) AS amount, + region, + 'summary' AS layer_type + FROM sales + GROUP BY date, region + ), + combined AS ( + SELECT * FROM raw + UNION ALL + SELECT * FROM summary + ) + SELECT * FROM combined + VISUALISE date AS x, amount AS y + DRAW point MAPPING region AS color FILTER layer_type = 'raw' + DRAW line MAPPING region AS color FILTER layer_type = 'summary' + ``` +4. **String values use single quotes**: In SETTING, LABEL, and RENAMING clauses, always use single quotes for string values. Double quotes cause parse errors. +5. **Column casing in VISUALISE**: DuckDB lowercases unquoted column names in query results, and VISUALISE validates column references **case-sensitively**. If your source table has uppercase column names (e.g., from Snowflake), you **must** alias them to lowercase in the SELECT clause: + ```sql + -- WRONG: VISUALISE references uppercase name, but DuckDB lowercases it in results + SELECT ROOM_TYPE, COUNT(*) AS listings FROM airbnb + VISUALISE ROOM_TYPE AS x, listings AS y + DRAW bar + + -- CORRECT: Alias to lowercase, then reference the alias + SELECT ROOM_TYPE AS room_type, COUNT(*) AS listings FROM airbnb + VISUALISE room_type AS x, listings AS y + DRAW bar + ``` + As a general rule, always use lowercase column names and aliases in both SELECT and VISUALISE clauses. +6. **Charts vs Tables**: For visualizations use VISUALISE with DRAW. For tabular data use plain SQL without VISUALISE. +7. **Statistical layers**: When using `histogram`, `bar` (without y), `density`, `smooth`, `violin`, or `boxplot`, the layer computes statistics. Use REMAPPING to access `density`, `intensity`, `proportion`, etc. +8. **No trailing commas**: SETTING, LABEL, MAPPING, and RENAMING clauses must not end with a trailing comma. A comma after the last item causes a parse error. + ```sql + -- WRONG: trailing comma after the last label + LABEL x => 'Gender', y => 'Count', + + -- CORRECT + LABEL x => 'Gender', y => 'Count' + ``` +9. **Bar position adjustments**: Bars stack automatically when `fill` is mapped. Use `SETTING position => 'dodge'` for side-by-side bars, or `position => 'stack', total => 1` for proportional (100%) stacking: + ```sql + DRAW bar MAPPING category AS x, subcategory AS fill -- stacked (default) + DRAW bar MAPPING category AS x, subcategory AS fill SETTING position => 'dodge' -- side-by-side + DRAW bar MAPPING category AS x, subcategory AS fill SETTING position => 'stack', total => 1 -- proportional + ``` diff --git a/pkg-py/src/querychat/prompts/prompt.md b/pkg-py/src/querychat/prompts/prompt.md index 8c6ff97bc..2876bcf00 100644 --- a/pkg-py/src/querychat/prompts/prompt.md +++ b/pkg-py/src/querychat/prompts/prompt.md @@ -1,4 +1,4 @@ -You are a data dashboard chatbot that operates in a sidebar interface. Your role is to help users interact with their data through filtering, sorting, and answering questions. +You are a data dashboard chatbot that operates in a sidebar interface. Your role is to help users interact with their data through filtering, sorting, and answering questions.{{#has_tool_visualize}} You can also help them explore data visually.{{/has_tool_visualize}} You have access to a {{db_type}} SQL database with the following schema: @@ -118,11 +118,95 @@ Response: "The average revenue is $X." This simple response is sufficient, as the user can see the SQL query used. {{/has_tool_query}} +{{#has_tool_visualize}} +### Visualizing Data + +You can create visualizations using the `querychat_visualize` tool, which uses ggsql — a SQL extension for declarative data visualization. Write a ggsql query (SQL with a VISUALISE clause), and the tool executes the SQL, renders the VISUALISE clause as an Altair chart, and displays it inline in the chat. + +#### Visualization best practices + +The database schema in this prompt includes column names, types, and summary statistics. {{#has_tool_query}}If that context isn't sufficient for a confident visualization — e.g., you're unsure about value distributions, need to check for NULLs, or want to gauge row counts before choosing a chart type — use the `querychat_query` tool to inspect the data before visualizing. Always pass `collapsed=True` for these preparatory queries so the chart remains the focal point of the response.{{/has_tool_query}} + +Follow the principles below to produce clear, interpretable charts. + +#### Axis labels must be readable + +When the x-axis contains categorical labels (names, categories, long strings), prefer flipping axes with `PROJECT y, x TO cartesian` so labels read naturally left-to-right. Short numeric or date labels on the x-axis are fine horizontal — this applies specifically to text categories. + +#### Always include axis labels with units + +Charts should be interpretable without reading the surrounding prose. Always include axis labels that describe what is shown, including units when applicable (e.g., `LABEL y => 'Revenue ($M)'`, not just `LABEL y => 'Revenue'`). + +#### Maximize data-ink ratio + +Every visual element should serve a purpose: + +- Don't map columns to aesthetics (color, size, shape) unless the distinction is meaningful to the user's question. A single-series bar chart doesn't need color. +- When using color for categories, keep to 7 or fewer distinct values. Beyond that, consider filtering to the most important categories or using facets instead. +- Avoid dual-encoding the same variable (e.g., mapping the same column to both x-position and color) unless it genuinely aids interpretation. + +#### Avoid overplotting + +When a dataset has many rows, plotting one mark per row creates clutter that obscures patterns. Before generating a query, consider the row count and data characteristics visible in the schema. + +**For large datasets (hundreds+ rows):** + +- **Aggregate first**: Use `GROUP BY` with `COUNT`, `AVG`, `SUM`, or other aggregates to reduce to meaningful summaries before visualizing. +- **Choose chart types that summarize naturally**: histograms for distributions, boxplots for group comparisons, line charts for trends over time. + +**For two numeric variables with many rows:** + +Bin in SQL and use `DRAW tile` to create a heatmap: + +```sql +WITH binned AS ( + SELECT ROUND(x_col / 5) * 5 AS x_bin, + ROUND(y_col / 5) * 5 AS y_bin, + COUNT(*) AS n + FROM large_table + GROUP BY x_bin, y_bin +) +SELECT * FROM binned +VISUALISE x_bin AS x, y_bin AS y, n AS fill +DRAW tile +SCALE fill TO viridis +``` + +**If individual points matter** (e.g., outlier detection): use `SETTING opacity` to reveal density through overlap. + +#### Choose chart types based on the data relationship + +Match the chart type to what the user is trying to understand: + +- **Comparison across categories**: bar chart (`DRAW bar`, with `PROJECT y, x TO cartesian` for long labels). Order bars by value, not alphabetically. +- **Trend over time**: line chart (`DRAW line`). Use `SCALE x VIA date` for date columns. +- **Distribution of a single variable**: histogram (`DRAW histogram`) or density (`DRAW density`). +- **Relationship between two numeric variables**: scatter plot (`DRAW point`), but prefer aggregation or heatmap if the dataset is large. +- **Part-of-whole**: stacked bar chart (map subcategory to `fill`). Avoid pie charts — position along a common scale is easier to decode than angle. + +#### ggsql syntax reference + + +{{> ggsql-syntax}} + +{{#has_tool_query}} + +**Avoid redundant expanded results.** If you run a preparatory query before visualizing, or if both a table and chart would show the same data, always pass `collapsed=True` on the query so the user sees the chart prominently, not a duplicate table above it. The user can still expand the table if they want the exact values. +{{/has_tool_query}} +{{/has_tool_visualize}} +{{^has_tool_visualize}} +### Visualization Requests + +You cannot create charts or visualizations. If users ask for a plot, chart, or visual representation of the data, explain that visualization is not currently enabled.{{#has_tool_query}} Offer to answer their question with a tabular query instead.{{/has_tool_query}} Suggest that the developer can enable visualization by installing `querychat[viz]` and adding `"visualize"` to the `tools` parameter. + +{{/has_tool_visualize}} {{^has_tool_query}} +{{^has_tool_visualize}} ### Questions About Data You cannot query or analyze the data. If users ask questions about data values, statistics, or calculations (e.g., "What is the average ____?" or "How many ____ are there?"), explain that you're not able to run queries on this data. Do not attempt to answer based on your own knowledge or assumptions about the data, even if the dataset seems familiar. +{{/has_tool_visualize}} {{/has_tool_query}} ### Providing Suggestions for Next Steps @@ -146,9 +230,16 @@ You might want to explore the advanced features **Nested lists:** ```md +{{#has_tool_query}} * Analyze the data * What's the average …? * How many …? +{{/has_tool_query}} +{{#has_tool_visualize}} +* Visualize the data + * Show a bar chart of … + * Plot the trend of … over time +{{/has_tool_visualize}} * Filter and sort * Show records from the year … * Sort the ____ by ____ … @@ -185,6 +276,7 @@ You might want to explore the advanced features - **Ask for clarification** if any request is unclear or ambiguous - **Be concise** due to the constrained interface - **Only answer data questions using your tools** - never use prior knowledge or assumptions about the data, even if the dataset seems familiar +- **Be skeptical of your own interpretations** - when describing chart results or data patterns, encourage the user to verify findings rather than presenting analytical conclusions as fact - **Use Markdown tables** for any tabular or structured data in your responses {{#extra_instructions}} diff --git a/pkg-py/src/querychat/prompts/tool-query.md b/pkg-py/src/querychat/prompts/tool-query.md index 0fcdec4b3..246cc90ee 100644 --- a/pkg-py/src/querychat/prompts/tool-query.md +++ b/pkg-py/src/querychat/prompts/tool-query.md @@ -25,6 +25,8 @@ Parameters ---------- query : A valid {{db_type}} SQL SELECT statement. Must follow the database schema provided in the system prompt. Use clear column aliases (e.g., 'AVG(price) AS avg_price') and include SQL comments for complex logic. Subqueries and CTEs are encouraged for readability. +collapsed : + Optional (default: false). Set to true for exploratory or preparatory queries (e.g., inspecting data before visualization, checking row counts, previewing column values) whose results aren't the primary answer. When true, the result card starts collapsed so it doesn't clutter the conversation. _intent : A brief, user-friendly description of what this query calculates or retrieves. diff --git a/pkg-py/src/querychat/prompts/tool-visualize.md b/pkg-py/src/querychat/prompts/tool-visualize.md new file mode 100644 index 000000000..89e46d20b --- /dev/null +++ b/pkg-py/src/querychat/prompts/tool-visualize.md @@ -0,0 +1,23 @@ +Create a data visualization + +Render a ggsql query (SQL with a VISUALISE clause) as an Altair chart displayed inline in the chat. + +**When to use:** Call this tool when the user's question involves comparisons, distributions, or trends — even for small result sets, a chart is often clearer than a table.{{#has_tool_query}} For single-value answers (averages, counts, totals, specific lookups) or when the user needs exact values, use `querychat_query` instead.{{/has_tool_query}} + +**Key constraints:** + +- All data transformations must happen in the SELECT clause — VISUALISE and MAPPING accept column names only, not SQL expressions or functions +- Do NOT include `LABEL title => ...` in the query — use the `title` parameter instead +- If a visualization fails, read the error message carefully and retry with a corrected query. Common fixes: correcting column names, adding `SCALE DISCRETE` for integer categories, using single quotes for strings, moving SQL expressions out of VISUALISE into the SELECT clause.{{#has_tool_query}} If the error persists, fall back to `querychat_query` for a tabular answer.{{/has_tool_query}} + +Parameters +---------- +ggsql : + A full ggsql query. Must include a VISUALISE clause and at least one DRAW clause. The SELECT portion uses {{db_type}} SQL; VISUALISE and MAPPING accept column names only, not expressions. Do NOT include `LABEL title => ...` in the query — use the `title` parameter instead. +title : + A brief, user-friendly title for this visualization. This is displayed as the card header above the chart. + +Returns +------- +: + If successful, a static image of the rendered plot. If not, an error message. diff --git a/pkg-py/src/querychat/static/css/viz.css b/pkg-py/src/querychat/static/css/viz.css new file mode 100644 index 000000000..1b5812bc1 --- /dev/null +++ b/pkg-py/src/querychat/static/css/viz.css @@ -0,0 +1,141 @@ +/* Hide Vega's built-in action dropdown (we have our own save button) */ +.querychat-viz-container details:has(> .vega-actions) { + display: none !important; +} + +/* ---- Visualization container ---- */ + +.querychat-viz-container { + aspect-ratio: 4 / 2; + width: 100%; +} + +/* In full-screen mode, let the chart fill the available space */ +.bslib-full-screen-container .querychat-viz-container { + aspect-ratio: unset; +} + +/* ---- Visualization footer ---- */ + +.querychat-footer-buttons { + display: flex; + justify-content: space-between; + align-items: center; +} + +.querychat-footer-left, +.querychat-footer-right { + display: flex; + align-items: center; + gap: 4px; +} + +.querychat-show-query-btn, +.querychat-save-btn { + display: inline-flex; + align-items: center; + gap: 4px; + padding: 2px 8px; + height: 28px; + border: none; + border-radius: var(--bs-border-radius, 4px); + background: transparent; + color: var(--bs-secondary-color, #6c757d); + font-size: 0.75rem; + cursor: pointer; + white-space: nowrap; +} + +.querychat-show-query-btn:hover, +.querychat-save-btn:hover { + color: var(--bs-body-color, #212529); + background-color: rgba(var(--bs-emphasis-color-rgb, 0, 0, 0), 0.05); +} + +.querychat-query-chevron { + font-size: 0.625rem; + transition: transform 150ms; + display: inline-block; +} + +.querychat-query-chevron--expanded { + transform: rotate(90deg); +} + +.querychat-icon { + width: 14px; + height: 14px; +} + +.querychat-dropdown-chevron { + width: 12px; + height: 12px; + margin-left: 2px; +} + +.querychat-save-dropdown { + position: relative; +} + +.querychat-save-menu { + display: none; + position: absolute; + right: 0; + bottom: 100%; + margin-bottom: 4px; + z-index: 20; + background: var(--bs-body-bg, #fff); + border: 1px solid var(--bs-border-color, #dee2e6); + border-radius: var(--bs-border-radius, 4px); + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.15); + padding: 4px 0; + min-width: 120px; +} + +.querychat-save-menu--visible { + display: block; +} + +.querychat-save-menu button { + display: block; + width: 100%; + padding: 6px 12px; + border: none; + background: transparent; + color: var(--bs-body-color, #212529); + font-size: 0.75rem; + text-align: left; + cursor: pointer; +} + +.querychat-save-menu button:hover { + background-color: rgba(var(--bs-emphasis-color-rgb, 0, 0, 0), 0.05); +} + +.querychat-query-section { + display: none; + position: relative; + border-top: 1px solid var(--bs-border-color, #dee2e6); + margin: 8px -16px -8px; +} + +.querychat-query-section--visible { + display: block; +} + + +/* shinychat sets max-height:500px on all cards, which is too small for viz+editor */ +.shiny-tool-card:has(.querychat-viz-container) { + max-height: 700px; + overflow: hidden; +} + +.querychat-query-section bslib-code-editor .code-editor { + margin: 1em; +} + +.querychat-query-section bslib-code-editor .prism-code-editor { + background-color: var(--bs-light, #f8f8f8); + max-height: 200px; + overflow-y: auto; +} \ No newline at end of file diff --git a/pkg-py/src/querychat/static/js/viz-preload.js b/pkg-py/src/querychat/static/js/viz-preload.js new file mode 100644 index 000000000..c6ef5bf2e --- /dev/null +++ b/pkg-py/src/querychat/static/js/viz-preload.js @@ -0,0 +1,52 @@ +(() => { + // In Shiny apps, reveal the first `.querychat-viz-preload` element that appears + // and then stop watching the DOM. This is a one-time, page-level initialization: + // if a preload element already exists at startup, reveal it immediately; otherwise + // observe DOM mutations until one is added, then reveal it and disconnect. + + if (!window.Shiny || window.__querychatVizPreloaded) return; + + let preloadObserver; + + const stopVizPreloadObserver = () => { + preloadObserver?.disconnect(); + preloadObserver = undefined; + }; + + const findVizPreload = (node) => { + if (!(node instanceof Element)) return null; + return node.matches(".querychat-viz-preload") + ? node + : node.querySelector(".querychat-viz-preload"); + }; + + const revealVizPreload = (root) => { + if (!root?.isConnected || window.__querychatVizPreloaded) return false; + + window.__querychatVizPreloaded = true; + root.hidden = false; + stopVizPreloadObserver(); + return true; + }; + + const processVizPreloads = (node) => { + const preloadRoot = findVizPreload(node); + if (!preloadRoot) return false; + return revealVizPreload(preloadRoot); + }; + + if (processVizPreloads(document.documentElement)) return; + + preloadObserver = new MutationObserver((mutations) => { + for (const mutation of mutations) { + for (const node of mutation.addedNodes) { + if (processVizPreloads(node)) return; + } + } + }); + + preloadObserver.observe(document.documentElement, { + childList: true, + subtree: true, + }); +})(); diff --git a/pkg-py/src/querychat/static/js/viz.js b/pkg-py/src/querychat/static/js/viz.js new file mode 100644 index 000000000..a04475173 --- /dev/null +++ b/pkg-py/src/querychat/static/js/viz.js @@ -0,0 +1,129 @@ +// Helper: find a native vega-embed action link inside a widget container. +// vega-embed renders a hidden
with tags for "Save as SVG", +// "Save as PNG", etc. We find them by matching the download attribute suffix. +// +// Why not use the Vega View API (view.toSVG(), view.toImageURL()) directly? +// Altair renders charts via its anywidget ESM, which calls vegaEmbed() and +// stores the resulting View in a closure — it's never exposed on the DOM or +// any accessible object. vega-embed v7 also doesn't set __vega_embed__ on +// the element. The only code with access to the View is vega-embed's own +// action handlers, so we delegate to them. +function findVegaAction(container, extension) { + return container.querySelector( + '.vega-actions a[download$=".' + extension + '"]' + ); +} + +// Helper: find a widget container by its base ID. +// Shiny module namespacing may prefix the ID (e.g. "mod-querychat_viz_abc"), +// so we match elements whose ID ends with the base widget ID. +function findWidgetContainer(widgetId) { + return document.getElementById(widgetId) + || document.querySelector('[id$="' + CSS.escape(widgetId) + '"]'); +} + +// Helper: trigger a vega-embed export action link. +// vega-embed attaches an async mousedown handler that calls +// view.toImageURL() and sets the link's href to a data URL. +// We dispatch mousedown, then use a MutationObserver to detect +// when href changes from "#" to a data URL, and click the link. +function triggerVegaAction(link, filename) { + link.download = filename; + + // If href is already a data URL (unlikely but possible), click immediately. + if (link.href && link.href !== "#" && !link.href.endsWith("#")) { + link.click(); + return; + } + + var observer = new MutationObserver(function () { + if (link.href && link.href !== "#" && !link.href.endsWith("#")) { + observer.disconnect(); + clearTimeout(timeout); + link.click(); + } + }); + + observer.observe(link, { attributes: true, attributeFilter: ["href"] }); + + var timeout = setTimeout(function () { + observer.disconnect(); + console.error("Timed out waiting for vega-embed to generate image"); + }, 5000); + + link.dispatchEvent(new MouseEvent("mousedown", { bubbles: true })); +} + +function closeAllSaveMenus() { + document.querySelectorAll(".querychat-save-menu--visible").forEach(function (menu) { + menu.classList.remove("querychat-save-menu--visible"); + }); +} + +function handleShowQuery(event, btn) { + event.stopPropagation(); + var targetId = btn.dataset.target; + var section = document.getElementById(targetId); + if (!section) return; + var isVisible = section.classList.toggle("querychat-query-section--visible"); + var label = btn.querySelector(".querychat-query-label"); + var chevron = btn.querySelector(".querychat-query-chevron"); + if (label) label.textContent = isVisible ? "Hide Query" : "Show Query"; + if (chevron) chevron.classList.toggle("querychat-query-chevron--expanded", isVisible); +} + +function handleSaveToggle(event, btn) { + event.stopPropagation(); + var menu = btn.parentElement.querySelector(".querychat-save-menu"); + if (menu) menu.classList.toggle("querychat-save-menu--visible"); +} + +function handleSaveExport(event, btn, extension) { + event.stopPropagation(); + var widgetId = btn.dataset.widgetId; + var title = btn.dataset.title || "chart"; + var menu = btn.closest(".querychat-save-menu"); + if (menu) menu.classList.remove("querychat-save-menu--visible"); + + var container = findWidgetContainer(widgetId); + if (!container) return; + var link = findVegaAction(container, extension); + if (!link) return; + triggerVegaAction(link, title + "." + extension); +} + +function handleCopy(event, btn) { + event.stopPropagation(); + var query = btn.dataset.query; + if (!query) return; + navigator.clipboard.writeText(query).then(function () { + var original = btn.textContent; + btn.textContent = "Copied!"; + setTimeout(function () { btn.textContent = original; }, 2000); + }).catch(function (err) { + console.error("Failed to copy:", err); + }); +} + +// Single delegated click handler for all querychat viz footer buttons. +window.addEventListener("click", function (event) { + var target = event.target; + + var btn = target.closest(".querychat-show-query-btn"); + if (btn) { handleShowQuery(event, btn); return; } + + btn = target.closest(".querychat-save-png-btn"); + if (btn) { handleSaveExport(event, btn, "png"); return; } + + btn = target.closest(".querychat-save-svg-btn"); + if (btn) { handleSaveExport(event, btn, "svg"); return; } + + btn = target.closest(".querychat-copy-btn"); + if (btn) { handleCopy(event, btn); return; } + + btn = target.closest(".querychat-save-btn"); + if (btn) { handleSaveToggle(event, btn); return; } + + // Click outside any button — close open save menus + closeAllSaveMenus(); +}); diff --git a/pkg-py/src/querychat/tools.py b/pkg-py/src/querychat/tools.py index 67ea453f5..48a17b5cc 100644 --- a/pkg-py/src/querychat/tools.py +++ b/pkg-py/src/querychat/tools.py @@ -1,14 +1,26 @@ from __future__ import annotations -from pathlib import Path from typing import TYPE_CHECKING, Any, Protocol, TypedDict, runtime_checkable -import chevron from chatlas import ContentToolResult, Tool from shinychat.types import ToolResultDisplay from ._icons import bs_icon -from ._utils import as_narwhals, df_to_html, querychat_tool_starts_open +from ._utils import ( + as_narwhals, + df_to_html, + querychat_tool_starts_open, + read_prompt_template, + truncate_error, +) +from ._viz_tools import tool_visualize + +__all__ = [ + "tool_query", + "tool_reset_dashboard", + "tool_update_dashboard", + "tool_visualize", +] if TYPE_CHECKING: from collections.abc import Callable @@ -69,13 +81,6 @@ def log_update(data: UpdateDashboardData): title: str -def _read_prompt_template(filename: str, **kwargs) -> str: - """Read and interpolate a prompt template file.""" - template_path = Path(__file__).parent / "prompts" / filename - template = template_path.read_text() - return chevron.render(template, kwargs) - - def _update_dashboard_impl( data_source: DataSource, update_fn: Callable[[UpdateDashboardData], None], @@ -103,9 +108,9 @@ def update_dashboard(query: str, title: str) -> ContentToolResult: update_fn({"query": query, "title": title}) except Exception as e: - error = str(e) + error = truncate_error(str(e)) markdown += f"\n\n> Error: {error}" - return ContentToolResult(value=markdown, error=e) + return ContentToolResult(value=markdown, error=Exception(error)) # Return ContentToolResult with display metadata return ContentToolResult( @@ -146,7 +151,7 @@ def tool_update_dashboard( """ impl = _update_dashboard_impl(data_source, update_fn) - description = _read_prompt_template( + description = read_prompt_template( "tool-update-dashboard.md", db_type=data_source.get_db_type(), ) @@ -212,7 +217,7 @@ def tool_reset_dashboard( """ impl = _reset_dashboard_impl(reset_fn) - description = _read_prompt_template("tool-reset-dashboard.md") + description = read_prompt_template("tool-reset-dashboard.md") impl.__doc__ = description return Tool.from_func( @@ -222,10 +227,14 @@ def tool_reset_dashboard( ) -def _query_impl(data_source: DataSource) -> Callable[[str, str], ContentToolResult]: +def _query_impl(data_source: DataSource) -> Callable[..., ContentToolResult]: """Create the implementation function for querying data.""" - def query(query: str, _intent: str = "") -> ContentToolResult: + def query( + query: str, + collapsed: bool | None = None, # noqa: FBT001 (LLM tool parameter) + _intent: str = "", + ) -> ContentToolResult: error = None markdown = f"```sql\n{query}\n```" value = None @@ -239,9 +248,9 @@ def query(query: str, _intent: str = "") -> ContentToolResult: markdown += "\n\n" + str(tbl_html) except Exception as e: - error = str(e) + error = truncate_error(str(e)) markdown += f"\n\n> Error: {error}" - return ContentToolResult(value=markdown, error=e) + return ContentToolResult(value=markdown, error=Exception(error)) # Return ContentToolResult with display metadata return ContentToolResult( @@ -250,7 +259,9 @@ def query(query: str, _intent: str = "") -> ContentToolResult: "display": ToolResultDisplay( markdown=markdown, show_request=False, - open=querychat_tool_starts_open("query"), + open=(not collapsed) + if collapsed is not None + else querychat_tool_starts_open("query"), icon=bs_icon("table"), ), }, @@ -276,7 +287,7 @@ def tool_query(data_source: DataSource) -> Tool: """ impl = _query_impl(data_source) - description = _read_prompt_template( + description = read_prompt_template( "tool-query.md", db_type=data_source.get_db_type() ) impl.__doc__ = description diff --git a/pkg-py/src/querychat/types/__init__.py b/pkg-py/src/querychat/types/__init__.py index f9a8163df..88b598326 100644 --- a/pkg-py/src/querychat/types/__init__.py +++ b/pkg-py/src/querychat/types/__init__.py @@ -9,6 +9,7 @@ from .._querychat_core import AppStateDict from .._shiny_module import ServerValues from .._utils import UnsafeQueryError +from .._viz_tools import VisualizeData, VisualizeResult from ..tools import UpdateDashboardData __all__ = ( @@ -22,4 +23,6 @@ "ServerValues", "UnsafeQueryError", "UpdateDashboardData", + "VisualizeData", + "VisualizeResult", ) diff --git a/pkg-py/tests/conftest.py b/pkg-py/tests/conftest.py new file mode 100644 index 000000000..95d586937 --- /dev/null +++ b/pkg-py/tests/conftest.py @@ -0,0 +1,32 @@ +"""Shared pytest fixtures for querychat unit tests.""" + +import polars as pl +import pytest + + +def _ggsql_render_works() -> bool: + """Check if ggsql.render_altair() is functional (build can be broken in some envs).""" + try: + import ggsql + + df = pl.DataFrame({"x": [1, 2], "y": [3, 4]}) + result = ggsql.render_altair(df, "VISUALISE x, y DRAW point") + spec = result.to_dict() + return "$schema" in spec + except (ValueError, ImportError): + return False + + +_ggsql_available = _ggsql_render_works() + + +def pytest_collection_modifyitems(config, items): + """Auto-skip tests marked with @pytest.mark.ggsql when ggsql is broken.""" + if _ggsql_available: + return + skip = pytest.mark.skip( + reason="ggsql.render_altair() not functional (build environment issue)" + ) + for item in items: + if "ggsql" in item.keywords: + item.add_marker(skip) diff --git a/pkg-py/tests/playwright/apps/viz_bookmark_app.py b/pkg-py/tests/playwright/apps/viz_bookmark_app.py new file mode 100644 index 000000000..17c678797 --- /dev/null +++ b/pkg-py/tests/playwright/apps/viz_bookmark_app.py @@ -0,0 +1,25 @@ +"""Test app for viz bookmark restore: uses server-side bookmarking to avoid URL length limits.""" + +from querychat import QueryChat +from querychat.data import titanic + +from shiny import App, ui + +qc = QueryChat( + titanic(), + "titanic", + tools=("query", "visualize"), +) + + +def app_ui(request): + return ui.page_fillable( + qc.ui(), + ) + + +def server(input, output, session): + qc.server(enable_bookmarking=True) + + +app = App(app_ui, server, bookmark_store="server") diff --git a/pkg-py/tests/playwright/conftest.py b/pkg-py/tests/playwright/conftest.py index 6febfd4e8..961af01f3 100644 --- a/pkg-py/tests/playwright/conftest.py +++ b/pkg-py/tests/playwright/conftest.py @@ -592,3 +592,31 @@ def dash_cleanup(_thread, server): yield url finally: _stop_dash_server(server) + + +@pytest.fixture(scope="module") +def app_10_viz() -> Generator[str, None, None]: + """Start the 10-viz-app.py Shiny server for testing.""" + app_path = str(EXAMPLES_DIR / "10-viz-app.py") + + def start_factory(): + port = _find_free_port() + url = f"http://localhost:{port}" + return url, lambda: _start_shiny_app_threaded(app_path, port) + + def shiny_cleanup(_thread, server): + _stop_shiny_server(server) + + url, _thread, server = _start_server_with_retry( + start_factory, shiny_cleanup, timeout=30.0 + ) + try: + yield url + finally: + _stop_shiny_server(server) + + +@pytest.fixture +def chat_10_viz(page: Page) -> ChatControllerType: + """Create a ChatController for the 10-viz-app chat component.""" + return _create_chat_controller(page, "titanic") diff --git a/pkg-py/tests/playwright/test_10_viz_inline.py b/pkg-py/tests/playwright/test_10_viz_inline.py new file mode 100644 index 000000000..6bc746668 --- /dev/null +++ b/pkg-py/tests/playwright/test_10_viz_inline.py @@ -0,0 +1,119 @@ +""" +Playwright tests for inline visualization and fullscreen behavior. + +These tests verify that: +1. The visualize tool renders Altair charts inline in tool result cards +2. The fullscreen toggle button appears on visualization tool results +3. Fullscreen mode works (expand and collapse via button and Escape key) +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +from playwright.sync_api import expect + +if TYPE_CHECKING: + from playwright.sync_api import Page + from shinychat.playwright import ChatController + + +class TestInlineVisualization: + """Tests for inline chart rendering in tool result cards.""" + + @pytest.fixture(autouse=True) + def setup( + self, page: Page, app_10_viz: str, chat_10_viz: ChatController + ) -> None: + """Navigate to the viz app before each test.""" + page.goto(app_10_viz) + page.wait_for_selector("shiny-chat-container", timeout=30000) + self.page = page + self.chat = chat_10_viz + + def test_viz_tool_renders_inline_chart(self) -> None: + """VIZ-INLINE: Visualization tool result contains an inline chart widget.""" + self.chat.set_user_input( + "Create a scatter plot of age vs fare for the titanic passengers" + ) + self.chat.send_user_input(method="click") + + # Wait for a tool result card with full-screen attribute (viz results have it) + tool_card = self.page.locator(".shiny-tool-result:has(.tool-fullscreen-toggle)") + expect(tool_card).to_be_visible(timeout=90000) + + # The card should contain the viz container (Altair chart via shinywidgets) + viz_container = tool_card.locator(".querychat-viz-container") + expect(viz_container).to_be_visible(timeout=10000) + + def test_fullscreen_button_visible_on_viz_card(self) -> None: + """VIZ-FS-BTN: Fullscreen toggle button appears on visualization cards.""" + self.chat.set_user_input( + "Make a bar chart showing count of passengers by class" + ) + self.chat.send_user_input(method="click") + + # Wait for viz tool result + tool_card = self.page.locator(".shiny-tool-result:has(.tool-fullscreen-toggle)") + expect(tool_card).to_be_visible(timeout=90000) + + # Fullscreen toggle should be visible + fs_button = tool_card.locator(".tool-fullscreen-toggle") + expect(fs_button).to_be_visible() + + def test_fullscreen_toggle_expands_card(self) -> None: + """VIZ-FS-EXPAND: Clicking fullscreen button expands the card.""" + self.chat.set_user_input( + "Plot a histogram of passenger ages from the titanic data" + ) + self.chat.send_user_input(method="click") + + # Wait for viz tool result + tool_result = self.page.locator(".shiny-tool-result:has(.tool-fullscreen-toggle)") + expect(tool_result).to_be_visible(timeout=90000) + + # Click fullscreen toggle + fs_button = tool_result.locator(".tool-fullscreen-toggle") + fs_button.click() + + # The .shiny-tool-card inside should now have fullscreen attribute + card = tool_result.locator(".shiny-tool-card[fullscreen]") + expect(card).to_be_visible() + + def test_escape_closes_fullscreen(self) -> None: + """VIZ-FS-ESC: Pressing Escape closes fullscreen mode.""" + self.chat.set_user_input( + "Create a visualization of survival rate by passenger class" + ) + self.chat.send_user_input(method="click") + + # Wait for viz tool result + tool_result = self.page.locator(".shiny-tool-result:has(.tool-fullscreen-toggle)") + expect(tool_result).to_be_visible(timeout=90000) + + # Enter fullscreen + fs_button = tool_result.locator(".tool-fullscreen-toggle") + fs_button.click() + + card = tool_result.locator(".shiny-tool-card[fullscreen]") + expect(card).to_be_visible() + + # Press Escape + self.page.keyboard.press("Escape") + + # Fullscreen should be removed + expect(card).not_to_be_visible() + + def test_non_viz_tool_results_have_no_fullscreen(self) -> None: + """VIZ-NO-FS: Non-visualization tool results don't have fullscreen.""" + self.chat.set_user_input("Show me passengers who survived") + self.chat.send_user_input(method="click") + + # Wait for a tool result (any) + tool_result = self.page.locator(".shiny-tool-result").first + expect(tool_result).to_be_visible(timeout=90000) + + # Non-viz tool results should NOT have fullscreen toggle + fs_results = self.page.locator(".shiny-tool-result:has(.tool-fullscreen-toggle)") + expect(fs_results).to_have_count(0) diff --git a/pkg-py/tests/playwright/test_11_viz_footer.py b/pkg-py/tests/playwright/test_11_viz_footer.py new file mode 100644 index 000000000..2cd586952 --- /dev/null +++ b/pkg-py/tests/playwright/test_11_viz_footer.py @@ -0,0 +1,154 @@ +""" +Playwright tests for visualization footer interactions (Show Query, Save dropdown). + +These tests verify the client-side JS behavior in viz.js: +1. Show Query toggle reveals/hides the query section +2. Save dropdown opens/closes on click +3. Clicking outside the Save dropdown closes it +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +from playwright.sync_api import expect + +if TYPE_CHECKING: + from playwright.sync_api import Page + from shinychat.playwright import ChatController + + +VIZ_PROMPT = "Use the visualize tool to create a scatter plot of age vs fare" +TOOL_RESULT_TIMEOUT = 90_000 + + +@pytest.fixture(autouse=True) +def _send_viz_prompt( + page: Page, app_10_viz: str, chat_10_viz: ChatController +) -> None: + """Navigate to the viz app and trigger a visualization before each test.""" + page.goto(app_10_viz) + page.wait_for_selector("shiny-chat-container", timeout=30_000) + + chat_10_viz.set_user_input(VIZ_PROMPT) + chat_10_viz.send_user_input(method="click") + + # Wait for the viz tool result card with fullscreen support + page.locator(".shiny-tool-result:has(.tool-fullscreen-toggle)").wait_for( + state="visible", timeout=TOOL_RESULT_TIMEOUT + ) + # Wait for the footer buttons to appear inside the card + page.locator(".querychat-footer-buttons").wait_for( + state="visible", timeout=10_000 + ) + + +class TestShowQueryToggle: + """Tests for the Show Query / Hide Query toggle button.""" + + def test_query_section_hidden_by_default(self, page: Page) -> None: + """The query section should be hidden initially.""" + section = page.locator(".querychat-query-section") + expect(section).to_be_attached() + expect(section).not_to_be_visible() + + def test_click_show_query_reveals_section(self, page: Page) -> None: + """Clicking 'Show Query' should reveal the query section.""" + btn = page.locator(".querychat-show-query-btn") + btn.click() + + section = page.locator(".querychat-query-section--visible") + expect(section).to_be_visible() + + def test_label_changes_to_hide_query(self, page: Page) -> None: + """After clicking, the label should change to 'Hide Query'.""" + btn = page.locator(".querychat-show-query-btn") + label = btn.locator(".querychat-query-label") + + expect(label).to_have_text("Show Query") + btn.click() + expect(label).to_have_text("Hide Query") + + def test_chevron_rotates_on_expand(self, page: Page) -> None: + """The chevron should get the --expanded class when query is shown.""" + btn = page.locator(".querychat-show-query-btn") + chevron = btn.locator(".querychat-query-chevron") + + expect(chevron).not_to_have_class("querychat-query-chevron--expanded") + btn.click() + expect(chevron).to_have_class("querychat-query-chevron querychat-query-chevron--expanded") + + def test_toggle_hides_section_again(self, page: Page) -> None: + """Clicking the button a second time should hide the query section.""" + btn = page.locator(".querychat-show-query-btn") + btn.click() # show + btn.click() # hide + + section = page.locator(".querychat-query-section") + expect(section).not_to_have_class("querychat-query-section--visible") + + label = btn.locator(".querychat-query-label") + expect(label).to_have_text("Show Query") + + def test_query_section_contains_code(self, page: Page) -> None: + """The revealed query section should contain the ggsql code.""" + btn = page.locator(".querychat-show-query-btn") + btn.click() + + section = page.locator(".querychat-query-section--visible") + expect(section).to_be_visible() + + # The code editor should contain VISUALISE (ggsql keyword) + code = section.locator(".code-editor") + expect(code).to_be_visible() + + +class TestSaveDropdown: + """Tests for the Save button dropdown menu.""" + + def test_save_menu_hidden_by_default(self, page: Page) -> None: + """The save dropdown menu should be hidden initially.""" + menu = page.locator(".querychat-save-menu") + expect(menu).to_be_attached() + expect(menu).not_to_be_visible() + + def test_click_save_opens_menu(self, page: Page) -> None: + """Clicking the Save button should reveal the dropdown menu.""" + btn = page.locator(".querychat-save-btn") + btn.click() + + menu = page.locator(".querychat-save-menu--visible") + expect(menu).to_be_visible() + + def test_menu_has_png_and_svg_options(self, page: Page) -> None: + """The save menu should contain 'Save as PNG' and 'Save as SVG' options.""" + btn = page.locator(".querychat-save-btn") + btn.click() + + menu = page.locator(".querychat-save-menu--visible") + expect(menu.locator(".querychat-save-png-btn")).to_be_visible() + expect(menu.locator(".querychat-save-svg-btn")).to_be_visible() + + def test_click_outside_closes_menu(self, page: Page) -> None: + """Clicking outside the dropdown should close it.""" + btn = page.locator(".querychat-save-btn") + btn.click() + + menu = page.locator(".querychat-save-menu") + expect(menu).to_have_class("querychat-save-menu querychat-save-menu--visible") + + # Click somewhere else on the page body + page.locator("body").click(position={"x": 10, "y": 10}) + + expect(menu).not_to_have_class("querychat-save-menu--visible") + + def test_toggle_save_menu(self, page: Page) -> None: + """Clicking Save twice should open then close the menu.""" + btn = page.locator(".querychat-save-btn") + btn.click() + menu = page.locator(".querychat-save-menu") + expect(menu).to_have_class("querychat-save-menu querychat-save-menu--visible") + + btn.click() + expect(menu).not_to_have_class("querychat-save-menu--visible") diff --git a/pkg-py/tests/playwright/test_12_viz_bookmark.py b/pkg-py/tests/playwright/test_12_viz_bookmark.py new file mode 100644 index 000000000..7683b4355 --- /dev/null +++ b/pkg-py/tests/playwright/test_12_viz_bookmark.py @@ -0,0 +1,136 @@ +""" +Playwright tests for visualization bookmark restore behavior. + +These tests verify that when a user creates a visualization and then +restores from a bookmark URL, the chart widget is properly re-rendered +(not just the empty HTML shell). +""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +import pytest +from playwright.sync_api import expect + +if TYPE_CHECKING: + from collections.abc import Generator + + from playwright.sync_api import BrowserContext, Page + from shinychat.playwright import ChatController as ChatControllerType + +import sys + +# conftest.py is not importable directly; add the test directory to sys.path +sys.path.insert(0, str(Path(__file__).parent)) +from conftest import ( + _create_chat_controller, + _find_free_port, + _start_server_with_retry, + _start_shiny_app_threaded, + _stop_shiny_server, +) + +VIZ_PROMPT = "Use the visualize tool to create a scatter plot of age vs fare" +TOOL_RESULT_TIMEOUT = 90_000 +APPS_DIR = Path(__file__).parent / "apps" + + +@pytest.fixture(scope="module") +def app_viz_bookmark() -> Generator[str, None, None]: + """Start the viz bookmark test app with server-side bookmarking.""" + app_path = str(APPS_DIR / "viz_bookmark_app.py") + + def start_factory(): + port = _find_free_port() + url = f"http://localhost:{port}" + return url, lambda: _start_shiny_app_threaded(app_path, port) + + def shiny_cleanup(_thread, server): + _stop_shiny_server(server) + + url, _thread, server = _start_server_with_retry( + start_factory, shiny_cleanup, timeout=30.0 + ) + try: + yield url + finally: + _stop_shiny_server(server) + + +@pytest.fixture +def chat_viz_bookmark(page: Page) -> ChatControllerType: + return _create_chat_controller(page, "titanic") + + +class TestVizBookmarkRestore: + """Tests for visualization restoration from bookmark URLs.""" + + @pytest.fixture(autouse=True) + def setup( + self, page: Page, app_viz_bookmark: str, chat_viz_bookmark: ChatControllerType + ) -> None: + """Navigate to the viz app and create a viz before each test.""" + self.app_url = app_viz_bookmark + self.page = page + self.chat = chat_viz_bookmark + + page.goto(app_viz_bookmark) + page.wait_for_selector("shiny-chat-container", timeout=30_000) + + # Wait for the greeting bookmark URL to be set first + # (bookmark_on="response" auto-bookmarks after greeting) + page.wait_for_function( + "() => window.location.search.includes('_state_id_=')", + timeout=30_000, + ) + self.greeting_url = page.url + + # Create a visualization + chat_viz_bookmark.set_user_input(VIZ_PROMPT) + chat_viz_bookmark.send_user_input(method="click") + + # Wait for the viz tool result to fully render + page.locator(".shiny-tool-result:has(.tool-fullscreen-toggle)").wait_for( + state="visible", timeout=TOOL_RESULT_TIMEOUT + ) + page.locator(".querychat-footer-buttons").wait_for( + state="visible", timeout=10_000 + ) + + def _wait_for_viz_bookmark_url(self) -> str: + """Wait for the URL to update from the greeting bookmark to the viz bookmark.""" + greeting_search = self.greeting_url.split("?", 1)[1] if "?" in self.greeting_url else "" + self.page.wait_for_function( + "(greetingSearch) => window.location.search.includes('_state_id_=') && window.location.search !== '?' + greetingSearch", + arg=greeting_search, + timeout=30_000, + ) + return self.page.url + + def test_bookmark_url_updates_after_viz(self) -> None: + """BOOKMARK-VIZ-URL: URL should update with new state ID after viz is created.""" + url = self._wait_for_viz_bookmark_url() + assert url != self.greeting_url, "URL should have changed after viz bookmarking" + + def test_viz_widget_renders_on_bookmark_restore(self, context: BrowserContext) -> None: + """BOOKMARK-VIZ-RESTORE: Restored bookmark should re-render the chart widget, not just the HTML shell.""" + bookmark_url = self._wait_for_viz_bookmark_url() + + # Open the bookmark URL in a new page (new session) + new_page = context.new_page() + new_page.goto(bookmark_url) + new_page.wait_for_selector("shiny-chat-container", timeout=30_000) + + # The viz container HTML should be restored (shinychat restores message HTML) + viz_container = new_page.locator(".querychat-viz-container") + expect(viz_container).to_be_visible(timeout=30_000) + + # The critical check: the widget should actually render a chart, + # not just be an empty output_widget div. A rendered Vega-Lite chart + # will have a canvas or SVG inside a .vega-embed container. + chart_element = viz_container.locator("canvas, svg, .vega-embed") + expect(chart_element.first).to_be_visible(timeout=30_000) + + new_page.close() diff --git a/pkg-py/tests/test_deferred_shiny.py b/pkg-py/tests/test_deferred_shiny.py index 39899a772..96ba29656 100644 --- a/pkg-py/tests/test_deferred_shiny.py +++ b/pkg-py/tests/test_deferred_shiny.py @@ -2,6 +2,7 @@ import os +import chatlas import pandas as pd import pytest from chatlas import ChatOpenAI @@ -95,8 +96,9 @@ def spy_create_client(client_spec): with session_context(ExpressStubSession()): vals = qc.server(data_source=sample_df, client=override_client) - assert vals.client is None - assert recorded_specs == [] + assert isinstance(vals.client, chatlas.Chat) + assert len(recorded_specs) == 1 + assert recorded_specs[0] is override_client assert qc._client_spec is init_client def test_multiple_server_overrides_do_not_leak_into_shared_state(self, sample_df): diff --git a/pkg-py/tests/test_ggsql.py b/pkg-py/tests/test_ggsql.py new file mode 100644 index 000000000..d49a821e7 --- /dev/null +++ b/pkg-py/tests/test_ggsql.py @@ -0,0 +1,234 @@ +"""Tests for ggsql integration helpers.""" + +import ggsql +import narwhals.stable.v1 as nw +import polars as pl +import pytest +from querychat._datasource import DataFrameSource +from querychat._viz_altair_widget import AltairWidget +from querychat._viz_ggsql import ( + execute_ggsql, + extract_visualise_table, + has_layer_level_source, +) + + +class TestExtractVisualiseTable: + """Tests for extract_visualise_table() parsing.""" + + def test_bare_identifier(self): + assert extract_visualise_table("VISUALISE x, y FROM mytable DRAW point") == "mytable" + + def test_quoted_identifier(self): + assert ( + extract_visualise_table('VISUALISE x FROM "my table" DRAW point') + == '"my table"' + ) + + def test_no_from_returns_none(self): + assert extract_visualise_table("VISUALISE x, y DRAW point") is None + + def test_ignores_draw_level_from(self): + visual = "VISUALISE x, y DRAW bar MAPPING z AS fill FROM summary" + assert extract_visualise_table(visual) is None + + +class TestHasLayerLevelSource: + def test_detects_draw_level_from(self): + visual = "VISUALISE x, y DRAW bar MAPPING z AS fill FROM summary" + assert has_layer_level_source(visual) + + def test_ignores_visualise_from(self): + visual = "VISUALISE x, y FROM sales DRAW point MAPPING z AS color" + assert not has_layer_level_source(visual) + + def test_ignores_scale_from(self): + visual = "VISUALISE x, y DRAW point MAPPING z AS color SCALE x FROM [0, 10]" + assert not has_layer_level_source(visual) + + +class TestGgsqlValidate: + """Tests for ggsql.validate() usage (split SQL and VISUALISE).""" + + def test_splits_query_with_visualise(self): + query = "SELECT x, y FROM data VISUALISE x, y DRAW point" + validated = ggsql.validate(query) + assert validated.sql() == "SELECT x, y FROM data" + assert validated.visual() == "VISUALISE x, y DRAW point" + assert validated.has_visual() + + def test_returns_empty_viz_without_visualise(self): + query = "SELECT x, y FROM data" + validated = ggsql.validate(query) + assert validated.sql() == "SELECT x, y FROM data" + assert validated.visual() == "" + assert not validated.has_visual() + + def test_handles_complex_query(self): + query = """ + SELECT date, SUM(revenue) as total + FROM sales + GROUP BY date + VISUALISE date AS x, total AS y + DRAW line + LABEL title => 'Revenue Over Time' + """ + validated = ggsql.validate(query) + assert "SELECT date, SUM(revenue)" in validated.sql() + assert "GROUP BY date" in validated.sql() + assert "VISUALISE date AS x" in validated.visual() + assert "LABEL title" in validated.visual() + + + +@pytest.fixture(autouse=True) +def _allow_widget_outside_session(monkeypatch): + """Allow JupyterChart (an ipywidget) to be constructed without a Shiny session.""" + from ipywidgets.widgets.widget import Widget + + monkeypatch.setattr(Widget, "_widget_construction_callback", lambda _w: None) + + +class TestAltairWidget: + @pytest.mark.ggsql + def test_produces_jupyter_chart(self): + import altair as alt + import ggsql + + reader = ggsql.DuckDBReader("duckdb://memory") + df = pl.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}) + reader.register("data", df) + spec = reader.execute("SELECT * FROM data VISUALISE x, y DRAW point") + altair_widget = AltairWidget.from_ggsql(spec) + assert isinstance(altair_widget.widget, alt.JupyterChart) + result = altair_widget.widget.chart.to_dict() + assert "$schema" in result + assert "vega-lite" in result["$schema"] + + +class TestExecuteGgsql: + @pytest.mark.ggsql + def test_full_pipeline(self): + nw_df = nw.from_native(pl.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})) + ds = DataFrameSource(nw_df, "test_data") + query = "SELECT * FROM test_data VISUALISE x, y DRAW point" + spec = execute_ggsql(ds, ggsql.validate(query)) + altair_widget = AltairWidget.from_ggsql(spec) + result = altair_widget.widget.chart.to_dict() + assert "$schema" in result + + @pytest.mark.ggsql + def test_with_filtered_query(self): + nw_df = nw.from_native( + pl.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]}) + ) + ds = DataFrameSource(nw_df, "test_data") + query = "SELECT * FROM test_data WHERE x > 2 VISUALISE x, y DRAW point" + spec = execute_ggsql(ds, ggsql.validate(query)) + assert spec.metadata()["rows"] == 3 + + @pytest.mark.ggsql + def test_spec_has_visual(self): + nw_df = nw.from_native(pl.DataFrame({"x": [1, 2], "y": [3, 4]})) + ds = DataFrameSource(nw_df, "test_data") + query = "SELECT * FROM test_data VISUALISE x, y DRAW point" + spec = execute_ggsql(ds, ggsql.validate(query)) + assert "VISUALISE" in spec.visual() + + @pytest.mark.ggsql + def test_visualise_from_path(self): + nw_df = nw.from_native(pl.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})) + ds = DataFrameSource(nw_df, "test_data") + query = "VISUALISE x, y FROM test_data DRAW point" + spec = execute_ggsql(ds, ggsql.validate(query)) + assert spec.metadata()["rows"] == 3 + assert "VISUALISE" in spec.visual() + + @pytest.mark.ggsql + def test_with_pandas_dataframe(self): + import pandas as pd + + nw_df = nw.from_native(pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})) + ds = DataFrameSource(nw_df, "test_data") + query = "SELECT * FROM test_data VISUALISE x, y DRAW point" + spec = execute_ggsql(ds, ggsql.validate(query)) + altair_widget = AltairWidget.from_ggsql(spec) + result = altair_widget.widget.chart.to_dict() + assert "$schema" in result + + @pytest.mark.ggsql + def test_rejects_layer_level_from_sources_with_clear_error(self): + nw_df = nw.from_native( + pl.DataFrame( + { + "date": ["2024-01", "2024-01", "2024-02", "2024-02"], + "region": ["north", "south", "north", "south"], + "amount": [10, 20, 30, 40], + } + ) + ) + ds = DataFrameSource(nw_df, "sales") + query = """ + WITH summary AS ( + SELECT region, SUM(amount) AS total + FROM sales + GROUP BY region + ) + SELECT * + FROM sales + VISUALISE date AS x, amount AS y + DRAW line + DRAW bar MAPPING region AS x, total AS y FROM summary + """ + + with pytest.raises( + ValueError, + match="Layer-specific sources are not currently supported", + ): + execute_ggsql(ds, ggsql.validate(query)) + + @pytest.mark.ggsql + def test_supports_single_relation_raw_plus_summary_overlay(self): + nw_df = nw.from_native( + pl.DataFrame( + { + "x": [1, 1, 2, 2], + "y": [10, 20, 30, 40], + "category": ["a", "b", "a", "b"], + } + ) + ) + ds = DataFrameSource(nw_df, "sales") + query = """ + WITH raw AS ( + SELECT + x, + y, + category, + 'raw' AS layer_type + FROM sales + ), + summary AS ( + SELECT + x, + AVG(y) AS y, + category, + 'summary' AS layer_type + FROM sales + GROUP BY x, category + ), + combined AS ( + SELECT * FROM raw + UNION ALL + SELECT * FROM summary + ) + SELECT * + FROM combined + VISUALISE x AS x, y AS y + DRAW point MAPPING category AS color FILTER layer_type = 'raw' + DRAW line MAPPING category AS color FILTER layer_type = 'summary' + """ + + spec = execute_ggsql(ds, ggsql.validate(query)) + assert spec.metadata()["rows"] == 4 + assert "VISUALISE" in spec.visual() diff --git a/pkg-py/tests/test_shiny_viz_regressions.py b/pkg-py/tests/test_shiny_viz_regressions.py new file mode 100644 index 000000000..ab3e2babe --- /dev/null +++ b/pkg-py/tests/test_shiny_viz_regressions.py @@ -0,0 +1,387 @@ +"""Regression tests for Shiny ggsql tool wiring and bookmark restore.""" + +import inspect +import os +from types import SimpleNamespace +from unittest.mock import patch + +import chatlas +import pytest +from querychat import QueryChat +from querychat._shiny import QueryChatExpress +from querychat._shiny_module import mod_server +from querychat.data import tips + +from shiny import reactive + + +@pytest.fixture(autouse=True) +def set_dummy_api_key(): + old_api_key = os.environ.get("OPENAI_API_KEY") + os.environ["OPENAI_API_KEY"] = "sk-dummy-api-key-for-testing" + yield + if old_api_key is not None: + os.environ["OPENAI_API_KEY"] = old_api_key + else: + del os.environ["OPENAI_API_KEY"] + + +@pytest.fixture +def sample_df(): + return tips() + + +def _identity(fn): + return fn + + +def _event(*_args, **_kwargs): + def wrapper(fn): + return fn + + return wrapper + + +def _raw_mod_server(): + return inspect.getclosurevars(mod_server).nonlocals["fn"] + + +class DummyBookmark: + def on_bookmark(self, fn): + self.bookmark_fn = fn + return fn + + def on_restore(self, fn): + self.restore_fn = fn + return fn + + +class DummySession: + def __init__(self): + self.bookmark = DummyBookmark() + + def is_stub_session(self): + return False + + +class DummyStubSession(DummySession): + def is_stub_session(self): + return True + + +class DummyChatUi: + def __init__(self, *_args, **_kwargs): + pass + + def on_user_submit(self, fn): + return fn + + async def append_message_stream(self, _stream): + return None + + async def append_message(self, _message): + return None + + def enable_bookmarking(self, _chat): + return None + + +class DummyProvider(chatlas.Provider): + def __init__(self, *, name, model): + super().__init__(name=name, model=model) + + def list_models(self): + return [] + + def chat_perform(self, *, stream, turns, tools, data_model, kwargs): + return () if stream else SimpleNamespace() + + async def chat_perform_async( + self, *, stream, turns, tools, data_model, kwargs + ): + return () if stream else SimpleNamespace() + + def stream_content(self, chunk): + return None + + def stream_text(self, chunk): + return None + + def stream_merge_chunks(self, completion, chunk): + return completion or {} + + def stream_turn(self, completion, has_data_model): + return SimpleNamespace() + + def value_turn(self, completion, has_data_model): + return SimpleNamespace() + + def value_tokens(self, completion): + return (0, 0, 0) + + def token_count(self, *args, tools, data_model): + return 0 + + async def token_count_async(self, *args, tools, data_model): + return 0 + + def translate_model_params(self, params): + return params + + def supported_model_params(self): + return set() + + +def test_app_passes_callable_client_to_mod_server(sample_df): + qc = QueryChat(sample_df, "tips", tools=("query", "visualize")) + app = qc.app() + captured = {} + + def fake_mod_server(*args, **kwargs): + captured.update(kwargs) + vals = SimpleNamespace() + vals.title = lambda: None + vals.sql = lambda: None + vals.df = list + vals.title.set = lambda _value: None + vals.sql.set = lambda _value: None + return vals + + with ( + patch("querychat._shiny.mod_server", fake_mod_server), + patch("querychat._shiny.render.text", _identity), + patch("querychat._shiny.render.ui", _identity), + patch("querychat._shiny.render.data_frame", _identity), + patch("querychat._shiny.reactive.effect", _identity), + patch("querychat._shiny.reactive.event", _event), + patch("querychat._shiny.req", lambda value: value), + patch("querychat._shiny.output_markdown_stream", lambda *a, **k: None), + ): + app.server( + SimpleNamespace(reset_query=lambda: None), + SimpleNamespace(), + SimpleNamespace(), + ) + + assert callable(captured["client"]) + assert not isinstance(captured["client"], chatlas.Chat) + + +def test_express_passes_callable_client_to_mod_server(sample_df, monkeypatch): + captured = {} + + class CurrentSession: + pass + + monkeypatch.setattr("querychat._shiny.get_current_session", lambda: CurrentSession()) + monkeypatch.setattr( + "querychat._shiny.mod_server", + lambda *args, **kwargs: captured.update(kwargs) or SimpleNamespace(), + ) + + QueryChatExpress( + sample_df, + "tips", + tools=("query", "visualize"), + enable_bookmarking=False, + ) + + assert callable(captured["client"]) + assert not isinstance(captured["client"], chatlas.Chat) + + +def test_server_passes_callable_client_to_mod_server(sample_df, monkeypatch): + qc = QueryChat(sample_df, "tips", tools=("query", "visualize")) + captured = {} + + class CurrentSession: + pass + + monkeypatch.setattr("querychat._shiny.get_current_session", lambda: CurrentSession()) + monkeypatch.setattr( + "querychat._shiny.mod_server", + lambda *args, **kwargs: captured.update(kwargs) or SimpleNamespace(), + ) + + qc.server(enable_bookmarking=False) + + assert callable(captured["client"]) + assert not isinstance(captured["client"], chatlas.Chat) + + +def test_mod_server_rejects_raw_chat_instance(sample_df): + qc = QueryChat(sample_df, "tips", tools=("query", "visualize")) + raw_chat = chatlas.Chat(provider=DummyProvider(name="dummy", model="dummy")) + + with ( + patch("querychat._shiny_module.preload_viz_deps_server", lambda: None), + patch("querychat._shiny_module.shinychat.Chat", DummyChatUi), + pytest.raises(TypeError, match="callable"), + ): + _raw_mod_server()( + SimpleNamespace(chat_update=lambda: None), + SimpleNamespace(), + DummySession(), + data_source=qc.data_source, + greeting=qc.greeting, + client=raw_chat, + enable_bookmarking=False, + tools=qc.tools, + ) + + +def test_mod_server_stub_session_deferred_client_factory_does_not_raise(): + qc = QueryChat(None, "users") + + vals = _raw_mod_server()( + SimpleNamespace(chat_update=lambda: None), + SimpleNamespace(), + DummyStubSession(), + data_source=None, + greeting=qc.greeting, + client=qc.client, + enable_bookmarking=False, + tools=qc.tools, + ) + + with pytest.raises(RuntimeError, match="unavailable during stub session"): + _ = vals.client.stream_async + + +def test_callable_mod_server_passes_visualize_callback_and_tools(sample_df): + qc = QueryChat(sample_df, "tips", tools=("query", "visualize")) + captured = {} + + def client_factory(**kwargs): + captured.update(kwargs) + return qc.client(**kwargs) + + with ( + patch("querychat._shiny_module.preload_viz_deps_server", lambda: None), + patch("querychat._shiny_module.shinychat.Chat", DummyChatUi), + ): + _raw_mod_server()( + SimpleNamespace(chat_update=lambda: None), + SimpleNamespace(), + DummySession(), + data_source=qc.data_source, + greeting=qc.greeting, + client=client_factory, + enable_bookmarking=False, + tools=qc.tools, + ) + + assert captured["tools"] == ("query", "visualize") + assert callable(captured["visualize"]) + assert callable(captured["update_dashboard"]) + assert callable(captured["reset_dashboard"]) + + +def test_mod_server_preloads_viz_for_each_real_session_instance(sample_df): + qc = QueryChat(sample_df, "tips", tools=("query", "visualize")) + session = DummySession() + preload_calls = [] + + with ( + patch( + "querychat._shiny_module.preload_viz_deps_server", + lambda: preload_calls.append("called"), + ), + patch("querychat._shiny_module.shinychat.Chat", DummyChatUi), + ): + _raw_mod_server()( + SimpleNamespace(chat_update=lambda: None), + SimpleNamespace(), + session, + data_source=qc.data_source, + greeting=qc.greeting, + client=qc.client, + enable_bookmarking=False, + tools=qc.tools, + ) + _raw_mod_server()( + SimpleNamespace(chat_update=lambda: None), + SimpleNamespace(), + session, + data_source=qc.data_source, + greeting=qc.greeting, + client=qc.client, + enable_bookmarking=False, + tools=qc.tools, + ) + + assert preload_calls == ["called", "called"] + + +def test_mod_server_stub_session_does_not_preload_viz(sample_df): + qc = QueryChat(sample_df, "tips", tools=("query", "visualize")) + preload_calls = [] + + with ( + patch( + "querychat._shiny_module.preload_viz_deps_server", + lambda: preload_calls.append("called"), + ), + patch("querychat._shiny_module.shinychat.Chat", DummyChatUi), + ): + _raw_mod_server()( + SimpleNamespace(chat_update=lambda: None), + SimpleNamespace(), + DummyStubSession(), + data_source=qc.data_source, + greeting=qc.greeting, + client=qc.client, + enable_bookmarking=False, + tools=qc.tools, + ) + + assert preload_calls == [] + + +def test_restored_viz_widgets_survive_second_bookmark_cycle(sample_df): + qc = QueryChat(sample_df, "tips", tools=("query", "visualize")) + callbacks = {} + session = DummySession() + + def client_factory(**kwargs): + callbacks.update(kwargs) + return qc.client(**kwargs) + + with ( + patch("querychat._shiny_module.preload_viz_deps_server", lambda: None), + patch("querychat._shiny_module.shinychat.Chat", DummyChatUi), + patch( + "querychat._shiny_module.restore_viz_widgets", + lambda _data_source, saved_widgets: list(saved_widgets), + ), + ): + _raw_mod_server()( + SimpleNamespace(chat_update=lambda: None), + SimpleNamespace(), + session, + data_source=qc.data_source, + greeting=qc.greeting, + client=client_factory, + enable_bookmarking=True, + tools=qc.tools, + ) + saved = [ + { + "widget_id": "querychat_viz_1", + "ggsql": "SELECT 1 VISUALISE 1 AS x DRAW point", + } + ] + callbacks["visualize"](saved[0]) + + first_bookmark = SimpleNamespace(values={}) + with reactive.isolate(): + session.bookmark.bookmark_fn(first_bookmark) + assert first_bookmark.values["querychat_viz_widgets"] == saved + + with reactive.isolate(): + session.bookmark.restore_fn(SimpleNamespace(values=first_bookmark.values)) + + second_bookmark = SimpleNamespace(values={}) + with reactive.isolate(): + session.bookmark.bookmark_fn(second_bookmark) + assert second_bookmark.values["querychat_viz_widgets"] == saved diff --git a/pkg-py/tests/test_system_prompt.py b/pkg-py/tests/test_system_prompt.py index 64b64c9b7..976362045 100644 --- a/pkg-py/tests/test_system_prompt.py +++ b/pkg-py/tests/test_system_prompt.py @@ -298,3 +298,109 @@ def test_schema_computed_for_conditional_section(self, sample_data_source): ) assert prompt.schema != "" + + +class TestVizPromptConditionals: + """Tests for visualization-related conditional rendering in the real prompt.""" + + def test_graceful_recovery_fallback_excluded_without_query_tool( + self, sample_data_source + ): + """ + When only visualize is enabled (no query tool), the fallback + to querychat_query should not appear in the rendered prompt. + """ + from pathlib import Path + + template_path = ( + Path(__file__).parent.parent + / "src" + / "querychat" + / "prompts" + / "prompt.md" + ) + prompt = QueryChatSystemPrompt( + prompt_template=template_path, + data_source=sample_data_source, + ) + + rendered = prompt.render(tools=("update", "visualize")) + + assert "fall back to" not in rendered + + def test_collapsed_guidance_included_with_both_tools( + self, sample_data_source + ): + """ + When both query and visualize are enabled, the collapsed query + guidance should appear in the system prompt. + """ + from pathlib import Path + + template_path = ( + Path(__file__).parent.parent + / "src" + / "querychat" + / "prompts" + / "prompt.md" + ) + prompt = QueryChatSystemPrompt( + prompt_template=template_path, + data_source=sample_data_source, + ) + + rendered = prompt.render(tools=("update", "query", "visualize")) + + assert "Avoid redundant expanded results" in rendered + + def test_viz_only_has_no_cannot_query_message(self, sample_data_source): + """ + When only visualize is enabled (no query tool), the rendered prompt + should NOT contain "cannot query or analyze" and SHOULD contain + "Visualizing Data". + """ + from pathlib import Path + + template_path = ( + Path(__file__).parent.parent + / "src" + / "querychat" + / "prompts" + / "prompt.md" + ) + prompt = QueryChatSystemPrompt( + prompt_template=template_path, + data_source=sample_data_source, + ) + + rendered = prompt.render(tools=("visualize",)) + + assert "cannot query or analyze" not in rendered + assert "Visualizing Data" in rendered + + def test_collapsed_guidance_only_with_both_tools(self, sample_data_source): + """ + The "Avoid redundant expanded results" guidance should only appear + when both query and visualize are enabled. + """ + from pathlib import Path + + template_path = ( + Path(__file__).parent.parent + / "src" + / "querychat" + / "prompts" + / "prompt.md" + ) + prompt = QueryChatSystemPrompt( + prompt_template=template_path, + data_source=sample_data_source, + ) + + rendered_both = prompt.render(tools=("query", "visualize")) + rendered_query_only = prompt.render(tools=("query",)) + rendered_viz_only = prompt.render(tools=("visualize",)) + + assert "Avoid redundant expanded results" in rendered_both + assert "Avoid redundant expanded results" not in rendered_query_only + assert "Avoid redundant expanded results" not in rendered_viz_only diff --git a/pkg-py/tests/test_tools.py b/pkg-py/tests/test_tools.py index 682f259cf..887cef548 100644 --- a/pkg-py/tests/test_tools.py +++ b/pkg-py/tests/test_tools.py @@ -2,7 +2,52 @@ import warnings +import narwhals.stable.v1 as nw +import pandas as pd +import pytest +from querychat._datasource import DataFrameSource from querychat._utils import querychat_tool_starts_open +from querychat.tools import _query_impl + + +@pytest.fixture +def data_source(): + df = nw.from_native(pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})) + return DataFrameSource(df, "test_table") + + +class TestQueryCollapsedParameter: + """Tests for the query tool's collapsed parameter.""" + + def test_collapsed_true_sets_open_false(self, data_source, monkeypatch): + monkeypatch.delenv("QUERYCHAT_TOOL_DETAILS", raising=False) + query_fn = _query_impl(data_source) + result = query_fn("SELECT * FROM test_table", collapsed=True) + assert result.extra["display"].open is False + + def test_collapsed_false_sets_open_true(self, data_source, monkeypatch): + monkeypatch.delenv("QUERYCHAT_TOOL_DETAILS", raising=False) + query_fn = _query_impl(data_source) + result = query_fn("SELECT * FROM test_table", collapsed=False) + assert result.extra["display"].open is True + + def test_collapsed_none_falls_back_to_default(self, data_source, monkeypatch): + monkeypatch.delenv("QUERYCHAT_TOOL_DETAILS", raising=False) + query_fn = _query_impl(data_source) + result = query_fn("SELECT * FROM test_table") + assert result.extra["display"].open is True # default for query + + def test_collapsed_overrides_env_expanded(self, data_source, monkeypatch): + monkeypatch.setenv("QUERYCHAT_TOOL_DETAILS", "expanded") + query_fn = _query_impl(data_source) + result = query_fn("SELECT * FROM test_table", collapsed=True) + assert result.extra["display"].open is False + + def test_collapsed_overrides_env_collapsed(self, data_source, monkeypatch): + monkeypatch.setenv("QUERYCHAT_TOOL_DETAILS", "collapsed") + query_fn = _query_impl(data_source) + result = query_fn("SELECT * FROM test_table", collapsed=False) + assert result.extra["display"].open is True def test_querychat_tool_starts_open_default_behavior(monkeypatch): @@ -12,6 +57,7 @@ def test_querychat_tool_starts_open_default_behavior(monkeypatch): assert querychat_tool_starts_open("query") is True assert querychat_tool_starts_open("update") is True assert querychat_tool_starts_open("reset") is False + assert querychat_tool_starts_open("visualize") is True def test_querychat_tool_starts_open_expanded(monkeypatch): @@ -21,6 +67,7 @@ def test_querychat_tool_starts_open_expanded(monkeypatch): assert querychat_tool_starts_open("query") is True assert querychat_tool_starts_open("update") is True assert querychat_tool_starts_open("reset") is True + assert querychat_tool_starts_open("visualize") is True def test_querychat_tool_starts_open_collapsed(monkeypatch): @@ -30,6 +77,7 @@ def test_querychat_tool_starts_open_collapsed(monkeypatch): assert querychat_tool_starts_open("query") is False assert querychat_tool_starts_open("update") is False assert querychat_tool_starts_open("reset") is False + assert querychat_tool_starts_open("visualize") is False def test_querychat_tool_starts_open_default_setting(monkeypatch): @@ -39,6 +87,7 @@ def test_querychat_tool_starts_open_default_setting(monkeypatch): assert querychat_tool_starts_open("query") is True assert querychat_tool_starts_open("update") is True assert querychat_tool_starts_open("reset") is False + assert querychat_tool_starts_open("visualize") is True def test_querychat_tool_starts_open_case_insensitive(monkeypatch): diff --git a/pkg-py/tests/test_truncate_error.py b/pkg-py/tests/test_truncate_error.py new file mode 100644 index 000000000..57b2db169 --- /dev/null +++ b/pkg-py/tests/test_truncate_error.py @@ -0,0 +1,52 @@ +"""Tests for truncate_error.""" + +from querychat._utils import truncate_error + + +class TestTruncateError: + def test_short_message_unchanged(self): + msg = "Column 'foo' not found" + assert truncate_error(msg) == msg + + def test_empty_string(self): + assert truncate_error("") == "" + + def test_short_message_with_blank_line_unchanged(self): + msg = "line1\n\nline2" + assert truncate_error(msg) == msg + + def test_truncates_at_blank_line(self): + msg = "Something went wrong\n\n" + "x" * 500 + result = truncate_error(msg) + assert result == "Something went wrong\n\n(error truncated)" + + def test_truncates_at_schema_dump_line(self): + msg = "Bad property\nFailed validating 'additionalProperties' in schema[0]:\n" + "x" * 500 + result = truncate_error(msg) + assert "Bad property" in result + assert "(error truncated)" in result + assert "{'additionalProperties'" not in result + + def test_hard_cap_on_long_single_line(self): + msg = "x " * 300 # 600 chars, single line, no schema markers + result = truncate_error(msg, max_chars=500) + assert len(result) <= 500 + len("\n\n(error truncated)") + assert result.endswith("\n\n(error truncated)") + + def test_hard_cap_cuts_on_word_boundary(self): + msg = "word " * 200 + result = truncate_error(msg, max_chars=100) + assert not result.split("\n\n(error truncated)")[0].endswith(" w") + + def test_preserves_first_line_of_altair_error(self): + first_line = "Additional properties are not allowed ('offset' was unexpected)" + schema_dump = "\n\nFailed validating 'additionalProperties' in schema[0]['properties']['encoding']:\n {'additionalProperties': False,\n 'properties': {'angle': " + "x" * 10000 + msg = first_line + schema_dump + result = truncate_error(msg) + assert result.startswith(first_line) + assert len(result) < 600 + + def test_custom_max_chars(self): + msg = "a" * 200 + result = truncate_error(msg, max_chars=100) + assert len(result) <= 100 + len("\n\n(error truncated)") diff --git a/pkg-py/tests/test_viz_footer.py b/pkg-py/tests/test_viz_footer.py new file mode 100644 index 000000000..7051fec43 --- /dev/null +++ b/pkg-py/tests/test_viz_footer.py @@ -0,0 +1,109 @@ +""" +Tests for visualization footer (Save dropdown, Show Query). + +The footer HTML (containing Save dropdown and Show Query toggle) is built by +_build_viz_footer() and passed as the `footer` parameter to ToolResultDisplay. +shinychat renders this in the card footer area. +""" + +from unittest.mock import MagicMock + +import narwhals.stable.v1 as nw +import polars as pl +import pytest +from htmltools import TagList, tags +from querychat._datasource import DataFrameSource + + +@pytest.fixture +def sample_df(): + return pl.DataFrame( + {"x": [1, 2, 3, 4, 5], "y": [10, 20, 15, 25, 30]} + ) + + +@pytest.fixture +def data_source(sample_df): + nw_df = nw.from_native(sample_df) + return DataFrameSource(nw_df, "test_data") + + +def _mock_output_widget(widget_id, **kwargs): + return tags.div(id=widget_id) + + +@pytest.fixture(autouse=True) +def _patch_deps(monkeypatch): + monkeypatch.setattr( + "shinywidgets.register_widget", lambda _widget_id, _chart: None + ) + monkeypatch.setattr("shinywidgets.output_widget", _mock_output_widget) + + mock_spec = MagicMock() + mock_spec.metadata.return_value = {"rows": 5, "columns": ["x", "y"]} + mock_chart = MagicMock() + mock_chart.properties.return_value = mock_chart + + mock_altair_widget = MagicMock() + mock_altair_widget.widget = mock_chart + mock_altair_widget.widget_id = "querychat_viz_test1234" + mock_altair_widget.is_compound = False + + monkeypatch.setattr( + "querychat._viz_ggsql.execute_ggsql", lambda _ds, _q: mock_spec + ) + monkeypatch.setattr( + "querychat._viz_altair_widget.AltairWidget.from_ggsql", + staticmethod(lambda _spec: mock_altair_widget), + ) + + import ggsql + from querychat import _viz_tools + + mock_raw_chart = MagicMock() + mock_vl_writer = MagicMock() + mock_vl_writer.render_chart.return_value = mock_raw_chart + monkeypatch.setattr(ggsql, "VegaLiteWriter", lambda: mock_vl_writer) + monkeypatch.setattr( + _viz_tools, "render_chart_to_png", lambda _chart: b"\x89PNG\r\n\x1a\n" + ) + + +class TestVizFooterIcons: + """Verify Bootstrap icons used in viz footer are defined in _icons.py.""" + + def test_download_icon_exists(self): + from querychat._icons import bs_icon + + html = str(bs_icon("download")) + assert "svg" in html + assert "bi-download" in html + + def test_chevron_down_icon_exists(self): + from querychat._icons import bs_icon + + html = str(bs_icon("chevron-down")) + assert "svg" in html + assert "bi-chevron-down" in html + + def test_cls_parameter_injects_class(self): + from querychat._icons import bs_icon + + html = str(bs_icon("download", cls="querychat-icon")) + assert "querychat-icon" in html + + +class TestVizPreloadMarkup: + def test_preload_markup_has_no_inline_script(self): + from querychat._viz_utils import PRELOAD_WIDGET_ID, preload_viz_deps_ui + + rendered = TagList(preload_viz_deps_ui()).render() + preload_dep = next( + dep for dep in rendered["dependencies"] if dep.name == "querychat-viz-preload" + ) + + assert PRELOAD_WIDGET_ID in rendered["html"] + assert "querychat-viz-preload" in rendered["html"] + assert "hidden" in rendered["html"] + assert "=1.5.1", - "shinychat>=0.2.8", + "shiny @ git+https://github.com/posit-dev/py-shiny.git", + "shinychat @ git+https://github.com/posit-dev/shinychat.git", "htmltools", "chatlas>=0.13.2", "narwhals>=2.2.0", @@ -48,6 +48,8 @@ ibis = ["ibis-framework>=9.0.0", "pandas"] # pandas required for ibis .execute( streamlit = ["streamlit>=1.30"] gradio = ["gradio>=6.0"] dash = ["dash-ag-grid>=31.0", "dash[async]>=3.1", "dash-bootstrap-components>=2.0", "pandas"] +# Visualization with ggsql +viz = ["ggsql>=0.2.4", "altair>=6.0", "shinywidgets>=0.8.0", "vl-convert-python>=1.9.0"] [project.urls] Homepage = "https://github.com/posit-dev/querychat" # TODO update when we have docs @@ -55,6 +57,15 @@ Repository = "https://github.com/posit-dev/querychat" Issues = "https://github.com/posit-dev/querychat/issues" Source = "https://github.com/posit-dev/querychat/tree/main/pkg-py" +[tool.uv] +# Restrict lock-file resolution to platforms we actually target in CI. +# Without this, uv may resolve dependency versions whose wheels aren't +# available on all platforms (e.g. non-x86_64 Linux), causing CI failures. +required-environments = [ + "sys_platform == 'linux' and platform_machine == 'x86_64'", + "sys_platform == 'darwin'", +] + [tool.hatch.metadata] allow-direct-references = true @@ -76,7 +87,7 @@ git_describe_command = "git describe --dirty --tags --long --match 'py/v*'" version-file = "pkg-py/src/querychat/__version.py" [dependency-groups] -dev = ["ruff>=0.6.5", "pyright>=1.1.401", "tox-uv>=1.11.4", "pytest>=8.4.0", "polars>=1.0.0", "pyarrow>=14.0.0", "ibis-framework[duckdb]>=9.0.0"] +dev = ["ruff>=0.6.5", "pyright>=1.1.401", "tox-uv>=1.11.4", "pytest>=8.4.0", "polars>=1.0.0", "pyarrow>=14.0.0", "ibis-framework[duckdb]>=9.0.0", "ggsql>=0.2.4", "altair>=6.0", "shinywidgets>=0.8.0", "vl-convert-python>=1.9.0"] docs = ["quartodoc>=0.11.1", "griffe<2", "nbformat", "nbclient", "ipykernel"] examples = [ "openai", @@ -214,13 +225,14 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" # disable S101 (flagging asserts) for tests [tool.ruff.lint.per-file-ignores] -"pkg-py/tests/*.py" = ["S101", "PLR2004"] # Allow assert and magic numbers in tests +"pkg-py/tests/*.py" = ["S101", "PLR2004", "ARG", "PLW0108"] # Allow assert, magic numbers, unused args, and unnecessary lambdas in tests "pkg-py/tests/playwright/*.py" = ["S101", "PLR2004", "S310", "S603", "S607", "PERF203"] # Test fixtures launch subprocesses "pkg-py/examples/tests/*.py" = ["S101", "PLR2004"] # Allow assert and magic numbers in tests "pkg-py/src/querychat/_dash.py" = ["E402"] # Backwards-compat aliases at end of file "pkg-py/src/querychat/_gradio.py" = ["E402"] # Backwards-compat aliases at end of file "pkg-py/src/querychat/_streamlit.py" = ["E402"] # Backwards-compat aliases at end of file "pkg-py/src/querychat/types/__init__.py" = ["A005"] # Deliberately shadows stdlib types module +"pkg-py/docs/_screenshots/*.py" = ["S310", "PLR2004", "PERF203"] # Dev utility scripts [tool.ruff.format] quote-style = "double" @@ -230,6 +242,9 @@ line-ending = "auto" docstring-code-format = true docstring-code-line-length = "dynamic" +[tool.pytest.ini_options] +markers = ["ggsql: requires working ggsql.render_altair()"] + [tool.pyright] include = ["pkg-py/src/querychat"]