diff --git a/pkg-py/CHANGELOG.md b/pkg-py/CHANGELOG.md new file mode 100644 index 00000000..d3b45dc0 --- /dev/null +++ b/pkg-py/CHANGELOG.md @@ -0,0 +1,24 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [UNRELEASED] + +* `querychat.init()` now accepts a `client` argument, replacing the previous `create_chat_callback` argument. (#60) + + The `client` can be: + + * a `chatlas.Chat` object, + * a function that returns a `chatlas.Chat` object, + * or a provider-model string, e.g. `"openai/gpt-4.1"`, to be passed to `chatlas.ChatAuto()`. + + If `client` is not provided, querychat will use the `QUERYCHAT_CLIENT` environment variable, which should be a provider-model string. If the envvar is not set, querychat uses OpenAI with the default model from `chatlas.ChatOpenAI()`. + + +## [0.1.0] - 2025-05-24 + +This first release of the `querychat` package. + diff --git a/pkg-py/src/querychat/_utils.py b/pkg-py/src/querychat/_utils.py new file mode 100644 index 00000000..8d82a52e --- /dev/null +++ b/pkg-py/src/querychat/_utils.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import os +from contextlib import contextmanager +from typing import Optional + + +@contextmanager +def temp_env_vars(env_vars: dict[str, Optional[str]]): + """ + Temporarily set environment variables and restore them when exiting. + + Parameters + ---------- + env_vars : Dict[str, str] + Dictionary of environment variable names to values to set temporarily + + Example + ------- + with temp_env_vars({"FOO": "bar", "BAZ": "qux"}): + # FOO and BAZ are set to "bar" and "qux" + do_something() + # FOO and BAZ are restored to their original values (or unset if they weren't set) + + """ + original_values: dict[str, Optional[str]] = {} + for key in env_vars: + original_values[key] = os.environ.get(key) + + for key, value in env_vars.items(): + if value is None: + # If value is None, remove the variable + os.environ.pop(key, None) + else: + # Otherwise set the variable to the specified value + os.environ[key] = value + + try: + yield + finally: + # Restore original values + for key, original_value in original_values.items(): + if original_value is None: + # Variable wasn't set originally, so remove it + os.environ.pop(key, None) + else: + # Restore original value + os.environ[key] = original_value diff --git a/pkg-py/src/querychat/querychat.py b/pkg-py/src/querychat/querychat.py index 001ffee7..c9e70482 100644 --- a/pkg-py/src/querychat/querychat.py +++ b/pkg-py/src/querychat/querychat.py @@ -1,9 +1,11 @@ from __future__ import annotations +import copy +import os import re import sys +import warnings from dataclasses import dataclass -from functools import partial from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, Union @@ -13,6 +15,8 @@ import sqlalchemy from shiny import Inputs, Outputs, Session, module, reactive, ui +from ._utils import temp_env_vars + if TYPE_CHECKING: import pandas as pd from narwhals.typing import IntoFrame @@ -33,7 +37,7 @@ class QueryChatConfig: data_source: DataSource system_prompt: str greeting: Optional[str] - create_chat_callback: CreateChatCallback + client: chatlas.Chat class QueryChat: @@ -233,6 +237,74 @@ def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: return table_html + rows_notice +def _get_client_from_env() -> Optional[str]: + """Get client configuration from environment variable.""" + env_client = os.getenv("QUERYCHAT_CLIENT", "") + if not env_client: + return None + return env_client + + +def _create_client_from_string(client_str: str) -> chatlas.Chat: + """Create a chatlas.Chat client from a provider-model string.""" + provider, model = ( + client_str.split("/", 1) if "/" in client_str else (client_str, None) + ) + # We unset chatlas's envvars so we can listen to querychat's envvars instead + with temp_env_vars( + { + "CHATLAS_CHAT_PROVIDER": provider, + "CHATLAS_CHAT_MODEL": model, + "CHATLAS_CHAT_ARGS": os.environ["QUERYCHAT_CLIENT_ARGS"], + }, + ): + return chatlas.ChatAuto(provider="openai") + + +def _resolve_querychat_client( + client: Optional[Union[chatlas.Chat, CreateChatCallback, str]] = None, +) -> chatlas.Chat: + """ + Resolve the client argument into a chatlas.Chat object. + + Parameters + ---------- + client : chatlas.Chat, CreateChatCallback, str, or None + The client to resolve. Can be: + - A chatlas.Chat object (returned as-is) + - A function that returns a chatlas.Chat object + - A provider-model string (e.g., "openai/gpt-4.1") + - None (fall back to environment variable or default) + + Returns + ------- + chatlas.Chat + A resolved chatlas.Chat object + + """ + if client is None: + client = _get_client_from_env() + + if client is None: + # Default to OpenAI with using chatlas's default model + return chatlas.ChatOpenAI() + + if callable(client) and not isinstance(client, chatlas.Chat): + # Backcompat: support the old create_chat_callback style, using an empty + # system prompt as a placeholder. + client = client(system_prompt="") + + if isinstance(client, str): + client = _create_client_from_string(client) + + if not isinstance(client, chatlas.Chat): + raise TypeError( + "client must be a chatlas.Chat object or function that returns one", + ) + + return client + + def init( data_source: IntoFrame | sqlalchemy.Engine, table_name: str, @@ -242,6 +314,7 @@ def init( extra_instructions: Optional[str | Path] = None, prompt_template: Optional[str | Path] = None, system_prompt_override: Optional[str] = None, + client: Optional[Union[chatlas.Chat, CreateChatCallback, str]] = None, create_chat_callback: Optional[CreateChatCallback] = None, ) -> QueryChatConfig: """ @@ -283,8 +356,18 @@ def init( A custom system prompt to use instead of the default. If provided, `data_description`, `extra_instructions`, and `prompt_template` will be silently ignored. + client : chatlas.Chat, CreateChatCallback, str, optional + A `chatlas.Chat` object, a string to be passed to `chatlas.ChatAuto()` + describing the model to use (e.g. `"openai/gpt-4.1"`), or a function + that creates a chat client. If using a function, the function should + accept a `system_prompt` argument and return a `chatlas.Chat` object. + + If `client` is not provided, querychat consults the `QUERYCHAT_CLIENT` + environment variable, which can be set to a provider-model string. If no + option is provided, querychat defaults to using + `chatlas.ChatOpenAI(model="gpt-4.1")`. create_chat_callback : CreateChatCallback, optional - A function that creates a chat object + **Deprecated.** Use the `client` argument instead. Returns ------- @@ -292,6 +375,22 @@ def init( A QueryChatConfig object that can be passed to server() """ + # Handle deprecated create_chat_callback argument + if create_chat_callback is not None: + warnings.warn( + "The 'create_chat_callback' parameter is deprecated. Use 'client' instead.", + DeprecationWarning, + stacklevel=2, + ) + if client is not None: + raise ValueError( + "You cannot pass both `create_chat_callback` and `client` to `init()`.", + ) + client = create_chat_callback + + # Resolve the client + resolved_client = _resolve_querychat_client(client) + # Validate table name (must begin with letter, contain only letters, numbers, underscores) if not re.match(r"^[a-zA-Z][a-zA-Z0-9_]*$", table_name): raise ValueError( @@ -330,17 +429,11 @@ def init( prompt_template=prompt_template, ) - # Default chat function if none provided - create_chat_callback = create_chat_callback or partial( - chatlas.ChatOpenAI, - model="gpt-4.1", - ) - return QueryChatConfig( data_source=data_source_obj, system_prompt=system_prompt_, greeting=greeting_str, - create_chat_callback=create_chat_callback, + client=resolved_client, ) @@ -441,7 +534,7 @@ def _(): data_source = querychat_config.data_source system_prompt = querychat_config.system_prompt greeting = querychat_config.greeting - create_chat_callback = querychat_config.create_chat_callback + client = querychat_config.client # Reactive values to store state current_title = reactive.value[Union[str, None]](None) @@ -517,17 +610,13 @@ async def query(query: str): chat_ui = ui.Chat("chat") - # Initialize the chat with the system prompt - # This is a placeholder - actual implementation would depend on chatlas - chat = create_chat_callback(system_prompt=system_prompt) + # Set up the chat object for this session + chat = copy.deepcopy(client) + chat.set_turns([]) + chat.system_prompt = system_prompt chat.register_tool(update_dashboard) chat.register_tool(query) - # Register tools with the chat - # This is a placeholder - actual implementation would depend on chatlas - # chat.register_tool("update_dashboard", update_dashboard) - # chat.register_tool("query", query) - # Add greeting if provided if greeting and any(len(g) > 0 for g in greeting.split("\n")): # Display greeting in chat UI diff --git a/pkg-r/DESCRIPTION b/pkg-r/DESCRIPTION index cacfb127..f5471586 100644 --- a/pkg-r/DESCRIPTION +++ b/pkg-r/DESCRIPTION @@ -1,6 +1,6 @@ Package: querychat Title: Filter and Query Data Frames in 'shiny' Using an LLM Chat Interface -Version: 0.0.0.9000 +Version: 0.0.1.9000 Authors@R: c( person("Joe", "Cheng", , "joe@posit.co", role = c("aut", "cre")), person("Posit Software, PBC", role = c("cph", "fnd")) @@ -18,19 +18,22 @@ Imports: bslib, DBI, duckdb, - ellmer, + ellmer (>= 0.3.0), htmltools, + lifecycle, purrr, rlang, shiny, shinychat (>= 0.2.0), whisker, xtable -Suggests: +Suggests: DT, + R6, RSQLite, shinytest2, - testthat (>= 3.0.0) + testthat (>= 3.0.0), + withr Config/testthat/edition: 3 Encoding: UTF-8 Roxygen: list(markdown = TRUE) diff --git a/pkg-r/NAMESPACE b/pkg-r/NAMESPACE index 8c75247d..8fce3055 100644 --- a/pkg-r/NAMESPACE +++ b/pkg-r/NAMESPACE @@ -20,3 +20,4 @@ export(querychat_server) export(querychat_sidebar) export(querychat_ui) export(test_query) +importFrom(lifecycle,deprecated) diff --git a/pkg-r/NEWS.md b/pkg-r/NEWS.md index 81a3e390..d44c08d2 100644 --- a/pkg-r/NEWS.md +++ b/pkg-r/NEWS.md @@ -3,3 +3,17 @@ * Initial CRAN submission. * Added `prompt_template` support for `querychat_system_prompt()`. (Thank you, @oacar! #37, #45) + +* `querychat_init()` now accepts a `client`, replacing the previous `create_chat_func` argument. (#60) + + The `client` can be: + + * an `ellmer::Chat` object, + * a function that returns an `ellmer::Chat` object, + * or a provider-model string, e.g. `"openai/gpt-4.1"`, to be passed to `ellmer::chat()`. + + If `client` is not provided, querychat will use + + * the `querychat.client` R option, which can be any of the above options, + * the `QUERYCHAT_CLIENT` environment variable, which should be a provider-model string, + * or the default model from `ellmer::chat_openai()`. diff --git a/pkg-r/R/querychat-package.R b/pkg-r/R/querychat-package.R new file mode 100644 index 00000000..425b3c1c --- /dev/null +++ b/pkg-r/R/querychat-package.R @@ -0,0 +1,7 @@ +#' @keywords internal +"_PACKAGE" + +## usethis namespace: start +#' @importFrom lifecycle deprecated +## usethis namespace: end +NULL diff --git a/pkg-r/R/querychat.R b/pkg-r/R/querychat.R index 5232cc79..cae09e7a 100644 --- a/pkg-r/R/querychat.R +++ b/pkg-r/R/querychat.R @@ -13,8 +13,19 @@ #' @param data_description A string containing a data description for the chat model. We have found #' that formatting the data description as a markdown bulleted list works best. #' @param extra_instructions A string containing extra instructions for the chat model. -#' @param create_chat_func A function that takes a system prompt and returns a -#' chat object. The default uses `ellmer::chat_openai()`. +#' @param client An `ellmer::Chat` object, a string to be passed to +#' [ellmer::chat()] describing the model to use (e.g. `"openai/gpt-4o"`), or a +#' function that creates a chat client. When using a function, the function +#' should take `system_prompt` as an argument and return an `ellmer::Chat` +#' object. +#' +#' If `client` is not provided, querychat consults the `querychat.client` R +#' option, which can be any of the described options, or the +#' `QUERYCHAT_CLIENT` environment variable, which can be set to a a +#' provider-model string. If no option is provided, querychat defaults to +#' using [ellmer::chat_openai()]. +#' @param create_chat_func `r lifecycle::badge('deprecated')`. Use the `client` +#' argument instead. #' @param system_prompt A string containing the system prompt for the chat model. #' The default generates a generic prompt, which you can enhance via the `data_description` and #' `extra_instructions` arguments. @@ -31,11 +42,26 @@ querychat_init <- function( greeting = NULL, data_description = NULL, extra_instructions = NULL, - create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o"), + create_chat_func = deprecated(), system_prompt = NULL, - auto_close_data_source = TRUE + auto_close_data_source = TRUE, + client = NULL ) { - force(create_chat_func) + if (lifecycle::is_present(create_chat_func)) { + lifecycle::deprecate_warn( + "0.0.1", + "querychat_init(create_chat_func=)", + "querychat_init(client =)" + ) + if (!is.null(client)) { + rlang::abort( + "You cannot pass both `create_chat_func` and `client` to `querychat_init()`." + ) + } + client <- create_chat_func + } + + client <- querychat_client(client) # If the user passes a data.frame to data_source, create a correct data source for them if (inherits(data_source, "data.frame")) { @@ -70,10 +96,9 @@ querychat_init <- function( ) } - # Validate system prompt and create_chat_func + # Validate system prompt stopifnot( - "system_prompt must be a string" = is.character(system_prompt), - "create_chat_func must be a function" = is.function(create_chat_func) + "system_prompt must be a string" = is.character(system_prompt) ) if (!is.null(greeting)) { @@ -90,7 +115,7 @@ querychat_init <- function( data_source = data_source, system_prompt = system_prompt, greeting = greeting, - create_chat_func = create_chat_func + client = client ), class = "querychat_config" ) @@ -154,7 +179,7 @@ querychat_server <- function(id, querychat_config) { data_source <- querychat_config[["data_source"]] system_prompt <- querychat_config[["system_prompt"]] greeting <- querychat_config[["greeting"]] - create_chat_func <- querychat_config[["create_chat_func"]] + client <- querychat_config[["client"]] current_title <- shiny::reactiveVal(NULL) current_query <- shiny::reactiveVal("") @@ -223,7 +248,9 @@ querychat_server <- function(id, querychat_config) { # Preload the conversation with the system prompt. These are instructions for # the chat model, and must not be shown to the end user. - chat <- create_chat_func(system_prompt = system_prompt) + chat <- client$clone() + chat$set_turns(list()) + chat$set_system_prompt(system_prompt) chat$register_tool(ellmer::tool( update_dashboard, "Modifies the data presented in the data dashboard, based on the given SQL query, and also updates the title.", @@ -263,7 +290,7 @@ querychat_server <- function(id, querychat_config) { # Add user message to the chat history shinychat::chat_append( "chat", - chat$stream_async(input$chat_user_input) + chat$stream_async(input$chat_user_input, stream = "content") ) }) diff --git a/pkg-r/R/querychat_client.R b/pkg-r/R/querychat_client.R new file mode 100644 index 00000000..98741d8a --- /dev/null +++ b/pkg-r/R/querychat_client.R @@ -0,0 +1,42 @@ +querychat_client <- function(client = NULL) { + if (is.null(client)) { + client <- querychat_client_option() + } + + if (is.null(client)) { + # Use OpenAI with ellmer's default model + return(ellmer::chat_openai()) + } + + if (rlang::is_function(client)) { + # `client` as a function was the first interface we supported and expected + # `system_prompt` as an argument. This avoids breaking existing code. + client <- client(system_prompt = NULL) + } + + if (rlang::is_string(client)) { + client <- ellmer::chat(client) + } + + if (!inherits(client, "Chat")) { + rlang::abort( + "`client` must be an {ellmer} Chat object or a function that returns one.", + ) + } + + client +} + +querychat_client_option <- function() { + opt <- getOption("querychat.client", NULL) + if (!is.null(opt)) { + return(opt) + } + + env <- Sys.getenv("QUERYCHAT_CLIENT", "") + if (nzchar(env)) { + return(env) + } + + NULL +} diff --git a/pkg-r/README.md b/pkg-r/README.md index e73ce98a..ff75d6ec 100644 --- a/pkg-r/README.md +++ b/pkg-r/README.md @@ -220,35 +220,34 @@ You can also put these instructions in a separate file and use `readLines()` to ### Use a different LLM provider -By default, querychat uses GPT-4o via the OpenAI API. If you want to use a different model, you can provide a `create_chat_func` function that takes a `system_prompt` parameter, and returns an Ellmer chat object. A convenient way to do this is with `purrr::partial`: - -```r -library(ellmer) - -# Option 1: Define a function -my_chat_func <- function(system_prompt) { - return( - chat_claude( - model="claude-3-7-sonnet-latest", - system_prompt=system_prompt - ) - ) -} -``` +By default, querychat uses OpenAI with the default model chosen by `ellmer::chat_openai()`. If you want to use a different model, you can provide an ellmer chat object to the `client` argument of `querychat_init()`. ```r library(ellmer) library(purrr) -# Create data source first mtcars_source <- querychat_data_source(mtcars, tbl_name = "cars") -# Option 2: Use partial querychat_config <- querychat_init( data_source = mtcars_source, - create_chat_func = purrr::partial(ellmer::chat_claude, model = "claude-3-7-sonnet-latest") + client = ellmer::chat_anthropic(model = "claude-3-7-sonnet-latest") ) ``` This would use Claude 3.7 Sonnet instead, which would require you to provide an API key. -See the [instructions from Ellmer](https://ellmer.tidyverse.org/reference/chat_claude.html) for more information on how to authenticate with different providers. \ No newline at end of file +See the [instructions from Ellmer](https://ellmer.tidyverse.org/reference/chat_anthropic.html) for more information on how to authenticate with different providers. + +Alternatively, you can use a provider-model string, which will be passed to `ellmer::chat()`: + +```r +querychat_config <- querychat_init( + data_source = mtcars_source, + client = "anthropic/claude-3-7-sonnet-latest" +) +``` + +Or you can set the `querychat.client` R option to a chat object or provider-model string, which will be used as the default client for all querychat apps in your session: + +```r +option(querychat.client = "anthropic/claude-3-7-sonnet-latest") +``` diff --git a/pkg-r/man/figures/lifecycle-deprecated.svg b/pkg-r/man/figures/lifecycle-deprecated.svg new file mode 100644 index 00000000..b61c57c3 --- /dev/null +++ b/pkg-r/man/figures/lifecycle-deprecated.svg @@ -0,0 +1,21 @@ + + lifecycle: deprecated + + + + + + + + + + + + + + + lifecycle + + deprecated + + diff --git a/pkg-r/man/figures/lifecycle-experimental.svg b/pkg-r/man/figures/lifecycle-experimental.svg new file mode 100644 index 00000000..5d88fc2c --- /dev/null +++ b/pkg-r/man/figures/lifecycle-experimental.svg @@ -0,0 +1,21 @@ + + lifecycle: experimental + + + + + + + + + + + + + + + lifecycle + + experimental + + diff --git a/pkg-r/man/figures/lifecycle-stable.svg b/pkg-r/man/figures/lifecycle-stable.svg new file mode 100644 index 00000000..9bf21e76 --- /dev/null +++ b/pkg-r/man/figures/lifecycle-stable.svg @@ -0,0 +1,29 @@ + + lifecycle: stable + + + + + + + + + + + + + + + + lifecycle + + + + stable + + + diff --git a/pkg-r/man/figures/lifecycle-superseded.svg b/pkg-r/man/figures/lifecycle-superseded.svg new file mode 100644 index 00000000..db8d757f --- /dev/null +++ b/pkg-r/man/figures/lifecycle-superseded.svg @@ -0,0 +1,21 @@ + + lifecycle: superseded + + + + + + + + + + + + + + + lifecycle + + superseded + + diff --git a/pkg-r/man/querychat-package.Rd b/pkg-r/man/querychat-package.Rd new file mode 100644 index 00000000..1cae4261 --- /dev/null +++ b/pkg-r/man/querychat-package.Rd @@ -0,0 +1,28 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/querychat-package.R +\docType{package} +\name{querychat-package} +\alias{querychat} +\alias{querychat-package} +\title{querychat: Filter and Query Data Frames in 'shiny' Using an LLM Chat Interface} +\description{ +Adds an LLM-powered chatbot to your 'shiny' app, that can turn your users' natural language questions into SQL queries that run against your data, and return the result as a reactive dataframe. Use it to drive reactive calculations, visualizations, downloads, etc. +} +\seealso{ +Useful links: +\itemize{ + \item \url{https://posit-dev.github.io/querychat/pkg-r} + \item \url{https://posit-dev.github.io/querychat} +} + +} +\author{ +\strong{Maintainer}: Joe Cheng \email{joe@posit.co} + +Other contributors: +\itemize{ + \item Posit Software, PBC [copyright holder, funder] +} + +} +\keyword{internal} diff --git a/pkg-r/man/querychat_init.Rd b/pkg-r/man/querychat_init.Rd index 618d8532..bd1a43b9 100644 --- a/pkg-r/man/querychat_init.Rd +++ b/pkg-r/man/querychat_init.Rd @@ -9,9 +9,10 @@ querychat_init( greeting = NULL, data_description = NULL, extra_instructions = NULL, - create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o"), + create_chat_func = deprecated(), system_prompt = NULL, - auto_close_data_source = TRUE + auto_close_data_source = TRUE, + client = NULL ) } \arguments{ @@ -31,8 +32,8 @@ that formatting the data description as a markdown bulleted list works best.} \item{extra_instructions}{A string containing extra instructions for the chat model.} -\item{create_chat_func}{A function that takes a system prompt and returns a -chat object. The default uses \code{ellmer::chat_openai()}.} +\item{create_chat_func}{\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#deprecated}{\figure{lifecycle-deprecated.svg}{options: alt='[Deprecated]'}}}{\strong{[Deprecated]}}. Use the \code{client} +argument instead.} \item{system_prompt}{A string containing the system prompt for the chat model. The default generates a generic prompt, which you can enhance via the \code{data_description} and @@ -40,6 +41,18 @@ The default generates a generic prompt, which you can enhance via the \code{data \item{auto_close_data_source}{Should the data source connection be automatically closed when the shiny app stops? Defaults to TRUE.} + +\item{client}{An \code{ellmer::Chat} object, a string to be passed to +\code{\link[ellmer:chat-any]{ellmer::chat()}} describing the model to use (e.g. \code{"openai/gpt-4o"}), or a +function that creates a chat client. When using a function, the function +should take \code{system_prompt} as an argument and return an \code{ellmer::Chat} +object. + +If \code{client} is not provided, querychat consults the \code{querychat.client} R +option, which can be any of the described options, or the +\code{QUERYCHAT_CLIENT} environment variable, which can be set to a a +provider-model string. If no option is provided, querychat defaults to +using \code{\link[ellmer:chat_openai]{ellmer::chat_openai()}}.} } \value{ An object that can be passed to \code{querychat_server()} as the diff --git a/pkg-r/tests/testthat/apps/basic/app.R b/pkg-r/tests/testthat/apps/basic/app.R new file mode 100644 index 00000000..7caf0253 --- /dev/null +++ b/pkg-r/tests/testthat/apps/basic/app.R @@ -0,0 +1,67 @@ +library(shiny) +library(bslib, warn.conflicts = FALSE) +library(querychat) +library(DBI) +library(RSQLite) + +# Mock chat function to avoid LLM API calls +MockChat <- R6::R6Class( + "MockChat", + inherit = asNamespace("ellmer")[["Chat"]], + public = list( + stream_async = function(message, ...) { + "Welcome! This is a mock response for testing." + } + ) +) + +# Create test database +temp_db <- tempfile(fileext = ".db") +conn <- dbConnect(RSQLite::SQLite(), temp_db) +dbWriteTable(conn, "iris", iris, overwrite = TRUE) +dbDisconnect(conn) + +# Setup database source +db_conn <- dbConnect(RSQLite::SQLite(), temp_db) +iris_source <- querychat_data_source(db_conn, "iris") + +# Configure querychat with mock +querychat_config <- querychat_init( + data_source = iris_source, + greeting = "Welcome to the test app!", + client = MockChat$new(ellmer::Provider("test", "test", "test")) +) + +ui <- page_sidebar( + title = "Test Database App", + sidebar = querychat_sidebar("chat"), + h2("Data"), + DT::DTOutput("data_table"), + h3("SQL Query"), + verbatimTextOutput("sql_query") +) + +server <- function(input, output, session) { + chat <- querychat_server("chat", querychat_config) + + output$data_table <- DT::renderDT( + { + chat$df() + }, + options = list(pageLength = 5) + ) + + output$sql_query <- renderText({ + query <- chat$sql() + if (query == "") "No filter applied" else query + }) + + session$onSessionEnded(function() { + if (DBI::dbIsValid(db_conn)) { + DBI::dbDisconnect(db_conn) + } + unlink(temp_db) + }) +} + +shinyApp(ui = ui, server = server) diff --git a/pkg-r/tests/testthat/test-data-source.R b/pkg-r/tests/testthat/test-data-source.R index a957aae9..a4e324ce 100644 --- a/pkg-r/tests/testthat/test-data-source.R +++ b/pkg-r/tests/testthat/test-data-source.R @@ -14,19 +14,19 @@ test_that("querychat_data_source.data.frame creates proper S3 object", { # Test with explicit table name source <- querychat_data_source(test_df, table_name = "test_table") + withr::defer(cleanup_source(source)) + expect_s3_class(source, "data_frame_source") expect_s3_class(source, "querychat_data_source") expect_equal(source$table_name, "test_table") expect_true(inherits(source$conn, "DBIConnection")) - - # Clean up - cleanup_source(source) }) test_that("querychat_data_source.DBIConnection creates proper S3 object", { # Create temporary SQLite database - temp_db <- tempfile(fileext = ".db") + temp_db <- withr::local_tempfile(fileext = ".db") conn <- dbConnect(RSQLite::SQLite(), temp_db) + withr::defer(dbDisconnect(conn)) # Create test table test_data <- data.frame( @@ -44,10 +44,6 @@ test_that("querychat_data_source.DBIConnection creates proper S3 object", { expect_s3_class(db_source, "querychat_data_source") expect_equal(db_source$table_name, "users") expect_equal(db_source$categorical_threshold, 20) - - # Clean up - dbDisconnect(conn) - unlink(temp_db) }) test_that("get_schema methods return proper schema", { @@ -60,6 +56,8 @@ test_that("get_schema methods return proper schema", { ) df_source <- querychat_data_source(test_df, table_name = "test_table") + withr::defer(cleanup_source(df_source)) + schema <- get_schema(df_source) expect_type(schema, "character") expect_match(schema, "Table: test_table") @@ -72,8 +70,10 @@ test_that("get_schema methods return proper schema", { expect_match(schema, "- id \\(INTEGER\\)\\n Range: 1 to 5") # Test with DBI source - temp_db <- tempfile(fileext = ".db") + temp_db <- withr::local_tempfile(fileext = ".db") conn <- dbConnect(RSQLite::SQLite(), temp_db) + withr::defer(dbDisconnect(conn)) + dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) dbi_source <- querychat_data_source(conn, "test_table") @@ -85,11 +85,6 @@ test_that("get_schema methods return proper schema", { # Test min/max values in DBI source schema - specifically for the id column expect_match(schema, "- id \\(INTEGER\\)\\n Range: 1 to 5") - - # Clean up - cleanup_source(df_source) - dbDisconnect(conn) - unlink(temp_db) }) test_that("execute_query works for both source types", { @@ -101,6 +96,7 @@ test_that("execute_query works for both source types", { ) df_source <- querychat_data_source(test_df, table_name = "test_table") + withr::defer(cleanup_source(df_source)) result <- execute_query( df_source, "SELECT * FROM test_table WHERE value > 25" @@ -109,8 +105,9 @@ test_that("execute_query works for both source types", { expect_equal(nrow(result), 3) # Should return 3 rows (30, 40, 50) # Test with DBI source - temp_db <- tempfile(fileext = ".db") + temp_db <- withr::local_tempfile(fileext = ".db") conn <- dbConnect(RSQLite::SQLite(), temp_db) + withr::defer(dbDisconnect(conn)) dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) dbi_source <- querychat_data_source(conn, "test_table") @@ -120,11 +117,6 @@ test_that("execute_query works for both source types", { ) expect_s3_class(result, "data.frame") expect_equal(nrow(result), 3) # Should return 3 rows (30, 40, 50) - - # Clean up - cleanup_source(df_source) - dbDisconnect(conn) - unlink(temp_db) }) test_that("execute_query works with empty/null queries", { @@ -136,6 +128,7 @@ test_that("execute_query works with empty/null queries", { ) df_source <- querychat_data_source(test_df, table_name = "test_table") + withr::defer(cleanup_source(df_source)) # Test with NULL query result_null <- execute_query(df_source, NULL) @@ -150,8 +143,10 @@ test_that("execute_query works with empty/null queries", { expect_equal(ncol(result_empty), 2) # Should return all columns # Test with DBI source - temp_db <- tempfile(fileext = ".db") + temp_db <- withr::local_tempfile(fileext = ".db") conn <- dbConnect(RSQLite::SQLite(), temp_db) + withr::defer(dbDisconnect(conn)) + dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) dbi_source <- querychat_data_source(conn, "test_table") @@ -167,11 +162,6 @@ test_that("execute_query works with empty/null queries", { expect_s3_class(result_empty, "data.frame") expect_equal(nrow(result_empty), 5) # Should return all rows expect_equal(ncol(result_empty), 2) # Should return all columns - - # Clean up - cleanup_source(df_source) - dbDisconnect(conn) - unlink(temp_db) }) @@ -185,6 +175,7 @@ test_that("get_schema correctly reports min/max values for numeric columns", { ) df_source <- querychat_data_source(test_df, table_name = "test_metrics") + withr::defer(cleanup_source(df_source)) schema <- get_schema(df_source) # Check that each numeric column has the correct min/max values @@ -192,9 +183,6 @@ test_that("get_schema correctly reports min/max values for numeric columns", { expect_match(schema, "- score \\(FLOAT\\)\\n Range: 10\\.5 to 30\\.1") # Note: In the test output, count was detected as FLOAT rather than INTEGER expect_match(schema, "- count \\(FLOAT\\)\\n Range: 50 to 200") - - # Clean up - cleanup_source(df_source) }) test_that("create_system_prompt generates appropriate system prompt", { @@ -205,6 +193,8 @@ test_that("create_system_prompt generates appropriate system prompt", { ) df_source <- querychat_data_source(test_df, table_name = "test_table") + withr::defer(cleanup_source(df_source)) + prompt <- create_system_prompt( df_source, data_description = "A test dataframe" @@ -213,9 +203,6 @@ test_that("create_system_prompt generates appropriate system prompt", { expect_true(nchar(prompt) > 0) expect_match(prompt, "A test dataframe") expect_match(prompt, "Table: test_table") - - # Clean up - cleanup_source(df_source) }) test_that("querychat_init automatically handles data.frame inputs", { @@ -224,18 +211,18 @@ test_that("querychat_init automatically handles data.frame inputs", { # Should work with data frame and auto-convert it config <- querychat_init(data_source = test_df, greeting = "Test greeting") + withr::defer(cleanup_source(config$data_source)) + expect_s3_class(config, "querychat_config") expect_s3_class(config$data_source, "querychat_data_source") expect_s3_class(config$data_source, "data_frame_source") # Should work with proper data source too df_source <- querychat_data_source(test_df, table_name = "test_table") + withr::defer(cleanup_source(df_source)) + config <- querychat_init(data_source = df_source, greeting = "Test greeting") expect_s3_class(config, "querychat_config") - - # Clean up - cleanup_source(df_source) - cleanup_source(config$data_source) }) test_that("querychat_init works with both source types", { @@ -248,14 +235,19 @@ test_that("querychat_init works with both source types", { # Create data source and test with querychat_init df_source <- querychat_data_source(test_df, table_name = "test_source") + withr::defer(cleanup_source(df_source)) + config <- querychat_init(data_source = df_source, greeting = "Test greeting") + expect_s3_class(config, "querychat_config") expect_s3_class(config$data_source, "data_frame_source") expect_equal(config$data_source$table_name, "test_source") # Test with database connection - temp_db <- tempfile(fileext = ".db") + temp_db <- withr::local_tempfile(fileext = ".db") conn <- dbConnect(RSQLite::SQLite(), temp_db) + withr::defer(dbDisconnect(conn)) + dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) dbi_source <- querychat_data_source(conn, "test_table") @@ -263,9 +255,4 @@ test_that("querychat_init works with both source types", { expect_s3_class(config, "querychat_config") expect_s3_class(config$data_source, "dbi_source") expect_equal(config$data_source$table_name, "test_table") - - # Clean up - cleanup_source(df_source) - dbDisconnect(conn) - unlink(temp_db) }) diff --git a/pkg-r/tests/testthat/test-db-type.R b/pkg-r/tests/testthat/test-db-type.R index e10967d8..72181823 100644 --- a/pkg-r/tests/testthat/test-db-type.R +++ b/pkg-r/tests/testthat/test-db-type.R @@ -13,17 +13,14 @@ test_that("get_db_type returns correct type for dbi_source with SQLite", { skip_if_not_installed("RSQLite") # Create a SQLite database source - temp_db <- tempfile(fileext = ".db") + temp_db <- withr::local_tempfile(fileext = ".db") conn <- DBI::dbConnect(RSQLite::SQLite(), temp_db) + withr::defer(DBI::dbDisconnect(conn)) DBI::dbWriteTable(conn, "test_table", data.frame(x = 1:5, y = letters[1:5])) db_source <- querychat_data_source(conn, "test_table") # Test that get_db_type returns the correct database type expect_equal(get_db_type(db_source), "SQLite") - - # Clean up - DBI::dbDisconnect(conn) - unlink(temp_db) }) test_that("get_db_type is correctly used in create_system_prompt", { diff --git a/pkg-r/tests/testthat/test-querychat-server.R b/pkg-r/tests/testthat/test-querychat-server.R index a44cfb08..3e6421f8 100644 --- a/pkg-r/tests/testthat/test-querychat-server.R +++ b/pkg-r/tests/testthat/test-querychat-server.R @@ -5,8 +5,9 @@ library(querychat) test_that("database source query functionality", { # Create temporary SQLite database - temp_db <- tempfile(fileext = ".db") + temp_db <- withr::local_tempfile(fileext = ".db") conn <- dbConnect(RSQLite::SQLite(), temp_db) + withr::defer(dbDisconnect(conn)) # Create test table test_data <- data.frame( @@ -39,8 +40,4 @@ test_that("database source query functionality", { "SELECT * FROM users ORDER BY age DESC" ) expect_equal(ordered_result$name[1], "Charlie") # Oldest first - - # Clean up - dbDisconnect(conn) - unlink(temp_db) }) diff --git a/pkg-r/tests/testthat/test-shiny-app.R b/pkg-r/tests/testthat/test-shiny-app.R index 8f762840..ff183c7e 100644 --- a/pkg-r/tests/testthat/test-shiny-app.R +++ b/pkg-r/tests/testthat/test-shiny-app.R @@ -6,84 +6,17 @@ test_that("app database example loads without errors", { skip_if_not_installed("shinytest2") # Create a simplified test app with mocked ellmer - test_app_file <- tempfile(fileext = ".R") + test_app_dir <- withr::local_tempdir() + test_app_file <- file.path(test_app_dir, "app.R") + dir.create(dirname(test_app_file), showWarnings = FALSE) - test_app_content <- ' -library(shiny) -library(bslib) -library(querychat) -library(DBI) -library(RSQLite) - -# Mock chat function to avoid LLM API calls -mock_chat_func <- function(system_prompt) { - list( - register_tool = function(tool) invisible(NULL), - stream_async = function(message) { - "Welcome! This is a mock response for testing." - } - ) -} - -# Create test database -temp_db <- tempfile(fileext = ".db") -conn <- dbConnect(RSQLite::SQLite(), temp_db) -dbWriteTable(conn, "iris", iris, overwrite = TRUE) -dbDisconnect(conn) - -# Setup database source -db_conn <- dbConnect(RSQLite::SQLite(), temp_db) -iris_source <- querychat_data_source(db_conn, "iris") - -# Configure querychat with mock -querychat_config <- querychat_init( - data_source = iris_source, - greeting = "Welcome to the test app!", - create_chat_func = mock_chat_func -) - -ui <- page_sidebar( - title = "Test Database App", - sidebar = querychat_sidebar("chat"), - h2("Data"), - DT::DTOutput("data_table"), - h3("SQL Query"), - verbatimTextOutput("sql_query") -) - -server <- function(input, output, session) { - chat <- querychat_server("chat", querychat_config) - - output$data_table <- DT::renderDT({ - chat$df() - }, options = list(pageLength = 5)) - - output$sql_query <- renderText({ - query <- chat$sql() - if (query == "") "No filter applied" else query - }) - - session$onSessionEnded(function() { - if (DBI::dbIsValid(db_conn)) { - DBI::dbDisconnect(db_conn) - } - unlink(temp_db) - }) -} - -shinyApp(ui = ui, server = server) -' - - writeLines(test_app_content, test_app_file) + file.copy(test_path("apps/basic/app.R"), test_app_file) # Test that the app can be loaded without immediate errors expect_no_error({ # Try to parse and evaluate the app code source(test_app_file, local = TRUE) }) - - # Clean up - unlink(test_app_file) }) test_that("database reactive functionality works correctly", { @@ -93,28 +26,25 @@ test_that("database reactive functionality works correctly", { library(RSQLite) # Create test database - temp_db <- tempfile(fileext = ".db") + temp_db <- withr::local_tempfile(fileext = ".db") conn <- dbConnect(RSQLite::SQLite(), temp_db) dbWriteTable(conn, "iris", iris, overwrite = TRUE) dbDisconnect(conn) # Test database source creation db_conn <- dbConnect(RSQLite::SQLite(), temp_db) + withr::defer(dbDisconnect(db_conn)) + iris_source <- querychat_data_source(db_conn, "iris") # Mock chat function - mock_chat_func <- function(system_prompt) { - list( - register_tool = function(tool) invisible(NULL), - stream_async = function(message) "Mock response" - ) - } + mock_client <- ellmer::chat_openai(api_key = "boop") # Test querychat_init with database source config <- querychat_init( data_source = iris_source, greeting = "Test greeting", - create_chat_func = mock_chat_func + client = mock_client ) expect_s3_class(config$data_source, "dbi_source") @@ -135,8 +65,4 @@ test_that("database reactive functionality works correctly", { expect_equal(nrow(query_result), 50) expect_equal(ncol(query_result), 2) expect_true(all(c("Sepal.Length", "Sepal.Width") %in% names(query_result))) - - # Clean up - dbDisconnect(db_conn) - unlink(temp_db) }) diff --git a/pkg-r/tests/testthat/test-sql-comments.R b/pkg-r/tests/testthat/test-sql-comments.R index e7553ad1..39a0415c 100644 --- a/pkg-r/tests/testthat/test-sql-comments.R +++ b/pkg-r/tests/testthat/test-sql-comments.R @@ -13,11 +13,12 @@ test_that("execute_query handles SQL with inline comments", { # Create data source df_source <- querychat_data_source(test_df, table_name = "test_table") + withr::defer(cleanup_source(df_source)) # Test with inline comments inline_comment_query <- " SELECT id, value -- This is a comment - FROM test_table + FROM test_table WHERE value > 25 -- Filter for higher values " @@ -39,9 +40,6 @@ test_that("execute_query handles SQL with inline comments", { expect_s3_class(result, "data.frame") expect_equal(nrow(result), 3) expect_equal(ncol(result), 2) - - # Clean up - cleanup_source(df_source) }) test_that("execute_query handles SQL with multiline comments", { @@ -54,10 +52,11 @@ test_that("execute_query handles SQL with multiline comments", { # Create data source df_source <- querychat_data_source(test_df, table_name = "test_table") + withr::defer(cleanup_source(df_source)) # Test with multiline comments multiline_comment_query <- " - /* + /* * This is a multiline comment * that spans multiple lines */ @@ -85,9 +84,6 @@ test_that("execute_query handles SQL with multiline comments", { expect_s3_class(result, "data.frame") expect_equal(nrow(result), 3) expect_equal(ncol(result), 2) - - # Clean up - cleanup_source(df_source) }) test_that("execute_query handles SQL with trailing semicolons", { @@ -100,6 +96,7 @@ test_that("execute_query handles SQL with trailing semicolons", { # Create data source df_source <- querychat_data_source(test_df, table_name = "test_table") + withr::defer(cleanup_source(df_source)) # Test with trailing semicolon query_with_semicolon <- " @@ -124,9 +121,6 @@ test_that("execute_query handles SQL with trailing semicolons", { expect_s3_class(result, "data.frame") expect_equal(nrow(result), 3) expect_equal(ncol(result), 2) - - # Clean up - cleanup_source(df_source) }) test_that("execute_query handles SQL with mixed comments and semicolons", { @@ -139,18 +133,19 @@ test_that("execute_query handles SQL with mixed comments and semicolons", { # Create data source df_source <- querychat_data_source(test_df, table_name = "test_table") + withr::defer(cleanup_source(df_source)) # Test with a mix of comment styles and semicolons complex_query <- " - /* + /* * This is a complex query with different comment styles */ - SELECT + SELECT id, -- This is the ID column value /* Value column */ - FROM + FROM test_table -- Our test table - WHERE + WHERE /* Only get higher values */ value > 25; -- End of query " @@ -174,9 +169,6 @@ test_that("execute_query handles SQL with mixed comments and semicolons", { expect_s3_class(result, "data.frame") expect_equal(nrow(result), 3) expect_equal(ncol(result), 2) - - # Clean up - cleanup_source(df_source) }) test_that("execute_query handles SQL with unusual whitespace patterns", { @@ -189,23 +181,21 @@ test_that("execute_query handles SQL with unusual whitespace patterns", { # Create data source df_source <- querychat_data_source(test_df, table_name = "test_table") + withr::defer(cleanup_source(df_source)) # Test with unusual whitespace patterns (which LLMs might generate) unusual_whitespace_query <- " - - SELECT id, value - - FROM test_table - + + SELECT id, value + + FROM test_table + WHERE value>25 - + " result <- execute_query(df_source, unusual_whitespace_query) expect_s3_class(result, "data.frame") expect_equal(nrow(result), 3) expect_equal(ncol(result), 2) - - # Clean up - cleanup_source(df_source) }) diff --git a/pkg-r/tests/testthat/test-test-query.R b/pkg-r/tests/testthat/test-test-query.R index ceac04e5..0df278ab 100644 --- a/pkg-r/tests/testthat/test-test-query.R +++ b/pkg-r/tests/testthat/test-test-query.R @@ -13,11 +13,13 @@ test_that("test_query.dbi_source correctly retrieves one row of data", { ) # Setup DBI source - temp_db <- tempfile(fileext = ".db") + temp_db <- withr::local_tempfile(fileext = ".db") conn <- dbConnect(RSQLite::SQLite(), temp_db) + dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) dbi_source <- querychat_data_source(conn, "test_table") + withr::defer(cleanup_source(dbi_source)) # Test basic query - should only return one row result <- test_query(dbi_source, "SELECT * FROM test_table") @@ -44,16 +46,13 @@ test_that("test_query.dbi_source correctly retrieves one row of data", { result <- test_query(dbi_source, "SELECT * FROM test_table WHERE value > 100") expect_s3_class(result, "data.frame") expect_equal(nrow(result), 0) # Should return empty data frame - - # Clean up - cleanup_source(dbi_source) - unlink(temp_db) }) test_that("test_query.dbi_source handles errors correctly", { # Setup DBI source - temp_db <- tempfile(fileext = ".db") + temp_db <- withr::local_tempfile(fileext = ".db") conn <- dbConnect(RSQLite::SQLite(), temp_db) + withr::defer(dbDisconnect(conn)) # Create a test table test_df <- data.frame( @@ -64,6 +63,7 @@ test_that("test_query.dbi_source handles errors correctly", { dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) dbi_source <- querychat_data_source(conn, "test_table") + withr::defer(cleanup_source(dbi_source), priority = "last") # Test with invalid SQL expect_error(test_query(dbi_source, "SELECT * WRONG SYNTAX")) @@ -76,10 +76,6 @@ test_that("test_query.dbi_source handles errors correctly", { dbi_source, "SELECT non_existent_column FROM test_table" )) - - # Clean up - cleanup_source(dbi_source) - unlink(temp_db) }) test_that("test_query.dbi_source works with different data types", { @@ -94,11 +90,14 @@ test_that("test_query.dbi_source works with different data types", { ) # Setup DBI source - temp_db <- tempfile(fileext = ".db") + temp_db <- withr::local_tempfile(fileext = ".db") conn <- dbConnect(RSQLite::SQLite(), temp_db) + withr::defer(dbDisconnect(conn)) + dbWriteTable(conn, "types_table", test_df, overwrite = TRUE) dbi_source <- querychat_data_source(conn, "types_table") + withr::defer(cleanup_source(dbi_source), priority = "last") # Test query with different column types result <- test_query(dbi_source, "SELECT * FROM types_table") @@ -108,8 +107,4 @@ test_that("test_query.dbi_source works with different data types", { expect_type(result$num_col, "double") expect_type(result$int_col, "integer") expect_type(result$bool_col, "integer") # SQLite stores booleans as integers - - # Clean up - cleanup_source(dbi_source) - unlink(temp_db) }) diff --git a/pyproject.toml b/pyproject.toml index 4f530228..47c076b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,10 @@ include = ["pkg-py/src/querychat", "pkg-py/LICENSE", "pkg-py/README.md"] [dependency-groups] dev = ["ruff>=0.6.5", "pyright>=1.1.401", "tox-uv>=1.11.4", "pytest>=8.4.0"] docs = ["quartodoc>=0.11.1"] -examples = ["seaborn", "openai"] +examples = [ + "openai", + "seaborn", +] [tool.ruff] src = ["pkg-py/src/querychat"]