From 973a43301e47a89b290be2b5d6b9cdf0ba6fc253 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Thu, 3 Apr 2025 22:47:22 -0700 Subject: [PATCH 01/51] First attempt at genericizing data source --- python-package/examples/app.py | 4 +- python-package/querychat/datasource.py | 207 +++++++++++++++++++++++++ python-package/querychat/querychat.py | 162 +++++-------------- 3 files changed, 251 insertions(+), 122 deletions(-) create mode 100644 python-package/querychat/datasource.py diff --git a/python-package/examples/app.py b/python-package/examples/app.py index 926622ce..5e628f43 100644 --- a/python-package/examples/app.py +++ b/python-package/examples/app.py @@ -4,6 +4,7 @@ from shiny import App, render, ui import querychat +from querychat.datasource import DataFrameSource titanic = load_dataset("titanic") @@ -14,8 +15,7 @@ # 1. Configure querychat querychat_config = querychat.init( - titanic, - "titanic", + DataFrameSource(titanic, "titanic"), greeting=greeting, data_description=data_desc, ) diff --git a/python-package/querychat/datasource.py b/python-package/querychat/datasource.py new file mode 100644 index 00000000..495139ed --- /dev/null +++ b/python-package/querychat/datasource.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +from typing import Protocol +import pandas as pd +import duckdb +import sqlite3 +import narwhals as nw + + +class DataSource(Protocol): + def get_schema(self) -> str: + """Return schema information about the table as a string. + + Returns: + A string containing the schema information in a format suitable for + prompting an LLM about the data structure + """ + ... + + def execute_query(self, query: str) -> pd.DataFrame: + """Execute SQL query and return results as DataFrame. + + Args: + query: SQL query to execute + + Returns: + Query results as a pandas DataFrame + """ + ... + + def get_data(self) -> pd.DataFrame: + """Return the unfiltered data as a DataFrame. + + Returns: + The complete dataset as a pandas DataFrame + """ + ... + + +class DataFrameSource: + """A DataSource implementation that wraps a pandas DataFrame using DuckDB.""" + + def __init__(self, df: pd.DataFrame, table_name: str): + """Initialize with a pandas DataFrame. + + Args: + df: The DataFrame to wrap + table_name: Name of the table in SQL queries + """ + self._conn = duckdb.connect(database=":memory:") + self._df = df + self._table_name = table_name + self._conn.register(table_name, df) + + def get_schema(self, categorical_threshold: int = 10) -> str: + """Generate schema information from DataFrame. + + Args: + table_name: Name to use for the table in schema description + categorical_threshold: Maximum number of unique values for a text column + to be considered categorical + + Returns: + String describing the schema + """ + ndf = nw.from_native(self._df) + + schema = [f"Table: {self._table_name}", "Columns:"] + + for column in ndf.columns: + # Map pandas dtypes to SQL-like types + dtype = ndf[column].dtype + if dtype.is_integer(): + sql_type = "INTEGER" + elif dtype.is_float(): + sql_type = "FLOAT" + elif dtype == nw.Boolean: + sql_type = "BOOLEAN" + elif dtype == nw.Datetime: + sql_type = "TIME" + elif dtype == nw.Date: + sql_type = "DATE" + else: + sql_type = "TEXT" + + column_info = [f"- {column} ({sql_type})"] + + # For TEXT columns, check if they're categorical + if sql_type == "TEXT": + unique_values = ndf[column].drop_nulls().unique() + if unique_values.len() <= categorical_threshold: + categories = unique_values.to_list() + categories_str = ", ".join([f"'{c}'" for c in categories]) + column_info.append(f" Categorical values: {categories_str}") + + # For numeric columns, include range + elif sql_type in ["INTEGER", "FLOAT", "DATE", "TIME"]: + rng = ndf[column].min(), ndf[column].max() + if rng[0] is None and rng[1] is None: + column_info.append(" Range: NULL to NULL") + else: + column_info.append(f" Range: {rng[0]} to {rng[1]}") + + schema.extend(column_info) + + return "\n".join(schema) + + def execute_query(self, query: str) -> pd.DataFrame: + """Execute query using DuckDB. + + Args: + query: SQL query to execute + + Returns: + Query results as pandas DataFrame + """ + return self._conn.execute(query).df() + + def get_data(self) -> pd.DataFrame: + """Return the unfiltered data as a DataFrame. + + Returns: + The complete dataset as a pandas DataFrame + """ + return self._df.copy() + + +class SQLiteSource: + """A DataSource implementation that wraps a SQLite connection.""" + + def __init__(self, conn: sqlite3.Connection, table_name: str): + """Initialize with a SQLite connection. + + Args: + conn: SQLite database connection + """ + self._conn = conn + self._table_name = table_name + + def get_schema(self) -> str: + """Generate schema information from SQLite table. + + Returns: + String describing the schema + """ + # Get column info + cursor = self._conn.execute(f"PRAGMA table_info({self._table_name})") + columns = cursor.fetchall() + + schema = [f"Table: {self._table_name}", "Columns:"] + + for col in columns: + # col format: (cid, name, type, notnull, dflt_value, pk) + column_info = [f"- {col[1]} ({col[2].upper()})"] + + # For numeric columns, try to get range + if col[2].upper() in ["INTEGER", "FLOAT", "REAL", "NUMERIC"]: + try: + cursor = self._conn.execute( + f"SELECT MIN({col[1]}), MAX({col[1]}) FROM {self._table_name}" + ) + min_val, max_val = cursor.fetchone() + if min_val is not None and max_val is not None: + column_info.append(f" Range: {min_val} to {max_val}") + except sqlite3.Error: + pass # Skip range info if query fails + + # For text columns, check if categorical (limited distinct values) + elif col[2].upper() == "TEXT": + try: + cursor = self._conn.execute( + f"SELECT COUNT(DISTINCT {col[1]}) FROM {self._table_name}" + ) + distinct_count = cursor.fetchone()[0] + if distinct_count <= 10: # Use fixed threshold for simplicity + cursor = self._conn.execute( + f"SELECT DISTINCT {col[1]} FROM {self._table_name} " + f"WHERE {col[1]} IS NOT NULL" + ) + values = [str(row[0]) for row in cursor.fetchall()] + values_str = ", ".join([f"'{v}'" for v in values]) + column_info.append(f" Categorical values: {values_str}") + except sqlite3.Error: + pass # Skip categorical info if query fails + + schema.extend(column_info) + + return "\n".join(schema) + + def execute_query(self, query: str) -> pd.DataFrame: + """Execute query using SQLite. + + Args: + query: SQL query to execute + + Returns: + Query results as pandas DataFrame + """ + return pd.read_sql_query(query, self._conn) + + def get_data(self) -> pd.DataFrame: + """Return the unfiltered data as a DataFrame. + + Returns: + The complete dataset as a pandas DataFrame + """ + return pd.read_sql_query(f"SELECT * FROM {self._table_name}", self._conn) diff --git a/python-package/querychat/querychat.py b/python-package/querychat/querychat.py index 4e492fb1..22b2f5ff 100644 --- a/python-package/querychat/querychat.py +++ b/python-package/querychat/querychat.py @@ -15,29 +15,49 @@ import narwhals as nw from narwhals.typing import IntoFrame +from .datasource import DataSource + + +class CreateChatCallback(Protocol): + def __call__(self, system_prompt: str) -> chatlas.Chat: ... + + +class QueryChatConfig: + """ + Configuration class for querychat. + """ + + def __init__( + self, + data_source: DataSource, + system_prompt: str, + greeting: Optional[str], + create_chat_callback: CreateChatCallback, + ): + self.data_source = data_source + self.system_prompt = system_prompt + self.greeting = greeting + self.create_chat_callback = create_chat_callback + def system_prompt( - df: IntoFrame, - table_name: str, + data_source: DataSource, data_description: Optional[str] = None, extra_instructions: Optional[str] = None, - categorical_threshold: int = 10, ) -> str: """ - Create a system prompt for the chat model based on a data frame's + Create a system prompt for the chat model based on a data source's schema and optional additional context and instructions. Args: - df: A DataFrame to generate schema information from - table_name: A string containing the name of the table in SQL queries + data_source: A data source to generate schema information from data_description: Optional description of the data, in plain text or Markdown format extra_instructions: Optional additional instructions for the chat model, in plain text or Markdown format - categorical_threshold: The maximum number of unique values for a text column to be considered categorical Returns: A string containing the system prompt for the chat model """ - schema = df_to_schema(df, table_name, categorical_threshold) + schema = data_source.get_schema() # Read the prompt file prompt_path = os.path.join(os.path.dirname(__file__), "prompt", "prompt.md") @@ -65,62 +85,6 @@ def system_prompt( return prompt_text -def df_to_schema(df: IntoFrame, table_name: str, categorical_threshold: int) -> str: - """ - Convert a DataFrame schema to a string representation for the system prompt. - - Args: - df: The DataFrame to extract schema from - table_name: The name of the table in SQL queries - categorical_threshold: The maximum number of unique values for a text column to be considered categorical - - Returns: - A string containing the schema information - """ - - ndf = nw.from_native(df) - - schema = [f"Table: {table_name}", "Columns:"] - - for column in ndf.columns: - # Map pandas dtypes to SQL-like types - dtype = ndf[column].dtype - if dtype.is_integer(): - sql_type = "INTEGER" - elif dtype.is_float(): - sql_type = "FLOAT" - elif dtype == nw.Boolean: - sql_type = "BOOLEAN" - elif dtype == nw.Datetime: - sql_type = "TIME" - elif dtype == nw.Date: - sql_type = "DATE" - else: - sql_type = "TEXT" - - column_info = [f"- {column} ({sql_type})"] - - # For TEXT columns, check if they're categorical - if sql_type == "TEXT": - unique_values = ndf[column].drop_nulls().unique() - if unique_values.len() <= categorical_threshold: - categories = unique_values.to_list() - categories_str = ", ".join([f"'{c}'" for c in categories]) - column_info.append(f" Categorical values: {categories_str}") - - # For numeric columns, include range - elif sql_type in ["INTEGER", "FLOAT", "DATE", "TIME"]: - rng = ndf[column].min(), ndf[column].max() - if rng[0] is None and rng[1] is None: - column_info.append(" Range: NULL to NULL") - else: - column_info.append(f" Range: {rng[0]} to {rng[1]}") - - schema.extend(column_info) - - return "\n".join(schema) - - def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: """ Convert a DataFrame to an HTML table for display in chat. @@ -149,45 +113,18 @@ def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: return table_html + rows_notice -class CreateChatCallback(Protocol): - def __call__(self, system_prompt: str) -> chatlas.Chat: ... - - -class QueryChatConfig: - """ - Configuration class for querychat. - """ - - def __init__( - self, - df: pd.DataFrame, - conn: duckdb.DuckDBPyConnection, - system_prompt: str, - greeting: Optional[str], - create_chat_callback: CreateChatCallback, - ): - self.df = df - self.conn = conn - self.system_prompt = system_prompt - self.greeting = greeting - self.create_chat_callback = create_chat_callback - - def init( - df: pd.DataFrame, - table_name: str, + data_source: DataSource, greeting: Optional[str] = None, data_description: Optional[str] = None, extra_instructions: Optional[str] = None, create_chat_callback: Optional[CreateChatCallback] = None, system_prompt_override: Optional[str] = None, ) -> QueryChatConfig: - """ - Call this once outside of any server function to initialize querychat. + """Initialize querychat with any compliant data source. Args: - df: A data frame - table_name: A string containing a valid table name for the data frame + data_source: A DataSource implementation that provides schema and query execution greeting: A string in Markdown format, containing the initial message data_description: Description of the data in plain text or Markdown extra_instructions: Additional instructions for the chat model @@ -197,12 +134,6 @@ def init( Returns: A QueryChatConfig object that can be passed to server() """ - # Validate table name (must begin with letter, contain only letters, numbers, underscores) - if not re.match(r"^[a-zA-Z][a-zA-Z0-9_]*$", table_name): - raise ValueError( - "Table name must begin with a letter and contain only letters, numbers, and underscores" - ) - # Process greeting if greeting is None: print( @@ -211,26 +142,18 @@ def init( file=sys.stderr, ) - # Create the system prompt - if system_prompt_override is None: - _system_prompt = system_prompt( - df, table_name, data_description, extra_instructions - ) - else: - _system_prompt = system_prompt_override - - # Set up DuckDB connection and register the data frame - conn = duckdb.connect(database=":memory:") - conn.register(table_name, df) + # Create the system prompt, or use the override + _system_prompt = system_prompt_override or system_prompt( + data_source, data_description, extra_instructions + ) # Default chat function if none provided create_chat_callback = create_chat_callback or partial( - chatlas.ChatOpenAI, model="gpt-4o" + chatlas.ChatOpenAI, model="gpt-4" ) return QueryChatConfig( - df=df, - conn=conn, + data_source=data_source, system_prompt=_system_prompt, greeting=greeting, create_chat_callback=create_chat_callback, @@ -306,8 +229,7 @@ def _(): pass # Extract config parameters - df = querychat_config.df - conn = querychat_config.conn + data_source = querychat_config.data_source system_prompt = querychat_config.system_prompt greeting = querychat_config.greeting create_chat_callback = querychat_config.create_chat_callback @@ -319,9 +241,9 @@ def _(): @reactive.Calc def filtered_df(): if current_query.get() == "": - return df + return data_source.get_data() else: - return conn.execute(current_query.get()).fetch_df() + return data_source.execute_query(current_query.get()) # This would handle appending messages to the chat UI async def append_output(text): @@ -345,7 +267,7 @@ async def update_dashboard(query: str, title: str): try: # Try the query to see if it errors - conn.execute(query) + data_source.execute_query(query) except Exception as e: error_msg = str(e) await append_output(f"> Error: {error_msg}\n\n") @@ -370,7 +292,7 @@ async def query(query: str): await append_output(f"\n```sql\n{query}\n```\n\n") try: - result_df = conn.execute(query).fetch_df() + result_df = data_source.execute_query(query) except Exception as e: error_msg = str(e) await append_output(f"> Error: {error_msg}\n\n") From 8de0ac71d3e687ec66151b7e977ced697f2a590a Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Fri, 4 Apr 2025 08:53:21 -0700 Subject: [PATCH 02/51] Unify prompts by adding chevron Python dependency --- python-package/pyproject.toml | 1 + python-package/querychat/prompt/prompt.md | 6 ++++ python-package/querychat/querychat.py | 39 +++++++---------------- 3 files changed, 18 insertions(+), 28 deletions(-) diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml index c709ee05..dca3b063 100644 --- a/python-package/pyproject.toml +++ b/python-package/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "htmltools", "chatlas", "narwhals", + "chevron", ] [project.urls] diff --git a/python-package/querychat/prompt/prompt.md b/python-package/querychat/prompt/prompt.md index 62d1ea17..154ce0cc 100644 --- a/python-package/querychat/prompt/prompt.md +++ b/python-package/querychat/prompt/prompt.md @@ -10,7 +10,13 @@ You have at your disposal a DuckDB database containing this schema: For security reasons, you may only query this specific table. +{{#data_description}} +Additional helpful info about the data: + + {{data_description}} + +{{/data_description}} There are several tasks you may be asked to do: diff --git a/python-package/querychat/querychat.py b/python-package/querychat/querychat.py index 22b2f5ff..37af66e1 100644 --- a/python-package/querychat/querychat.py +++ b/python-package/querychat/querychat.py @@ -1,19 +1,13 @@ from __future__ import annotations -import sys import os -import re -import pandas as pd -import duckdb -import json +import sys from functools import partial -from typing import List, Dict, Any, Callable, Optional, Union, Protocol +from typing import Any, Dict, Optional, Protocol import chatlas -from htmltools import TagList, tags, HTML -from shiny import module, reactive, ui, Inputs, Outputs, Session -import narwhals as nw -from narwhals.typing import IntoFrame +import chevron +from shiny import Inputs, Outputs, Session, module, reactive, ui from .datasource import DataSource @@ -64,26 +58,15 @@ def system_prompt( with open(prompt_path, "r") as f: prompt_text = f.read() - # Simple template replacement (a more robust template engine could be used) - if data_description: - data_description_section = ( - "Additional helpful info about the data:\n\n" - "\n" - f"{data_description}\n" - "" - ) - else: - data_description_section = "" - - # Replace variables in the template - prompt_text = prompt_text.replace("{{schema}}", schema) - prompt_text = prompt_text.replace("{{data_description}}", data_description_section) - prompt_text = prompt_text.replace( - "{{extra_instructions}}", extra_instructions or "" + return chevron.render( + prompt_text, + { + "schema": schema, + "data_description": data_description, + "extra_instructions": extra_instructions, + }, ) - return prompt_text - def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: """ From 53c7df3ddeda8b07f534a906165b04205ef83b31 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Fri, 18 Apr 2025 14:03:24 -0700 Subject: [PATCH 03/51] Make prompt aware of what engine is being used --- python-package/querychat/datasource.py | 13 ++++++++++--- python-package/querychat/prompt/prompt.md | 10 +++++++--- python-package/querychat/querychat.py | 4 ++-- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/python-package/querychat/datasource.py b/python-package/querychat/datasource.py index 495139ed..e408e4b0 100644 --- a/python-package/querychat/datasource.py +++ b/python-package/querychat/datasource.py @@ -1,13 +1,16 @@ from __future__ import annotations -from typing import Protocol -import pandas as pd -import duckdb import sqlite3 +from typing import ClassVar, Protocol + +import duckdb import narwhals as nw +import pandas as pd class DataSource(Protocol): + db_engine: ClassVar[str] + def get_schema(self) -> str: """Return schema information about the table as a string. @@ -40,6 +43,8 @@ def get_data(self) -> pd.DataFrame: class DataFrameSource: """A DataSource implementation that wraps a pandas DataFrame using DuckDB.""" + db_engine: ClassVar[str] = "DuckDB" + def __init__(self, df: pd.DataFrame, table_name: str): """Initialize with a pandas DataFrame. @@ -128,6 +133,8 @@ def get_data(self) -> pd.DataFrame: class SQLiteSource: """A DataSource implementation that wraps a SQLite connection.""" + db_engine: ClassVar[str] = "SQLite" + def __init__(self, conn: sqlite3.Connection, table_name: str): """Initialize with a SQLite connection. diff --git a/python-package/querychat/prompt/prompt.md b/python-package/querychat/prompt/prompt.md index 154ce0cc..5155ae18 100644 --- a/python-package/querychat/prompt/prompt.md +++ b/python-package/querychat/prompt/prompt.md @@ -4,7 +4,7 @@ It's important that you get clear, unambiguous instructions from the user, so if The user interface in which this conversation is being shown is a narrow sidebar of a dashboard, so keep your answers concise and don't include unnecessary patter, nor additional prompts or offers for further assistance. -You have at your disposal a DuckDB database containing this schema: +You have at your disposal a {{db_engine}} database containing this schema: {{schema}} @@ -25,7 +25,7 @@ There are several tasks you may be asked to do: The user may ask you to perform filtering and sorting operations on the dashboard; if so, your job is to write the appropriate SQL query for this database. Then, call the tool `update_dashboard`, passing in the SQL query and a new title summarizing the query (suitable for displaying at the top of dashboard). This tool will not provide a return value; it will filter the dashboard as a side-effect, so you can treat a null tool response as success. * **Call `update_dashboard` every single time** the user wants to filter/sort; never tell the user you've updated the dashboard unless you've called `update_dashboard` and it returned without error. -* The SQL query must be a **DuckDB SQL** SELECT query. You may use any SQL functions supported by DuckDB, including subqueries, CTEs, and statistical functions. +* The SQL query must be a SELECT query. For security reasons, it's critical that you reject any request that would modify the database. * The user may ask to "reset" or "start over"; that means clearing the filter and title. Do this by calling `update_dashboard({"query": "", "title": ""})`. * Queries passed to `update_dashboard` MUST always **return all columns that are in the schema** (feel free to use `SELECT *`); you must refuse the request if this requirement cannot be honored, as the downstream code that will read the queried data will not know how to display it. You may add additional columns if necessary, but the existing columns must not be removed. * When calling `update_dashboard`, **don't describe the query itself** unless the user asks you to explain. Don't pretend you have access to the resulting data set, as you don't. @@ -80,7 +80,11 @@ Example of question answering: If the user provides a vague help request, like "Help" or "Show me instructions", describe your own capabilities in a helpful way, including examples of questions they can ask. Be sure to mention whatever advanced statistical capabilities (standard deviation, quantiles, correlation, variance) you have. -## DuckDB SQL tips +## SQL tips + +* The SQL engine is {{db_engine}}. + +* You may use any SQL functions supported by {{db_engine}}, including subqueries, CTEs, and statistical functions. * `percentile_cont` and `percentile_disc` are "ordered set" aggregate functions. These functions are specified using the WITHIN GROUP (ORDER BY sort_expression) syntax, and they are converted to an equivalent aggregate function that takes the ordering expression as the first argument. For example, `percentile_cont(fraction) WITHIN GROUP (ORDER BY column [(ASC|DESC)])` is equivalent to `quantile_cont(column, fraction ORDER BY column [(ASC|DESC)])`. diff --git a/python-package/querychat/querychat.py b/python-package/querychat/querychat.py index 37af66e1..fb0e6997 100644 --- a/python-package/querychat/querychat.py +++ b/python-package/querychat/querychat.py @@ -51,7 +51,6 @@ def system_prompt( Returns: A string containing the system prompt for the chat model """ - schema = data_source.get_schema() # Read the prompt file prompt_path = os.path.join(os.path.dirname(__file__), "prompt", "prompt.md") @@ -61,7 +60,8 @@ def system_prompt( return chevron.render( prompt_text, { - "schema": schema, + "db_engine": data_source.db_engine, + "schema": data_source.get_schema(), "data_description": data_description, "extra_instructions": extra_instructions, }, From a2122f22da9233ce6edc3ece5ac12440e5a35f63 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Fri, 18 Apr 2025 14:37:45 -0700 Subject: [PATCH 04/51] Replace SQLite support with SQLAlchemy support --- python-package/pyproject.toml | 5 + python-package/querychat/datasource.py | 133 ++++++++++++++++++------- 2 files changed, 100 insertions(+), 38 deletions(-) diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml index dca3b063..4ca437a2 100644 --- a/python-package/pyproject.toml +++ b/python-package/pyproject.toml @@ -23,6 +23,11 @@ dependencies = [ "chevron", ] +[project.optional-dependencies] +sqlalchemy = [ + "sqlalchemy>=2.0.0", # Using 2.0+ for improved type hints and API +] + [project.urls] Homepage = "https://github.com/posit-dev/querychat" Issues = "https://github.com/posit-dev/querychat/issues" diff --git a/python-package/querychat/datasource.py b/python-package/querychat/datasource.py index e408e4b0..e33711e7 100644 --- a/python-package/querychat/datasource.py +++ b/python-package/querychat/datasource.py @@ -1,11 +1,13 @@ from __future__ import annotations -import sqlite3 from typing import ClassVar, Protocol import duckdb import narwhals as nw import pandas as pd +from sqlalchemy import inspect, text +from sqlalchemy.engine import Engine, Connection +from sqlalchemy.sql import sqltypes class DataSource(Protocol): @@ -130,64 +132,93 @@ def get_data(self) -> pd.DataFrame: return self._df.copy() -class SQLiteSource: - """A DataSource implementation that wraps a SQLite connection.""" +class SQLAlchemySource: + """A DataSource implementation that supports multiple SQL databases via SQLAlchemy. - db_engine: ClassVar[str] = "SQLite" + Supports various databases including PostgreSQL, MySQL, SQLite, Snowflake, and Databricks. + """ - def __init__(self, conn: sqlite3.Connection, table_name: str): - """Initialize with a SQLite connection. + db_engine: ClassVar[str] = "SQLAlchemy" + + def __init__(self, engine: Engine, table_name: str): + """Initialize with a SQLAlchemy engine. Args: - conn: SQLite database connection + engine: SQLAlchemy engine + table_name: Name of the table to query """ - self._conn = conn + self._engine = engine self._table_name = table_name + # Validate table exists + inspector = inspect(self._engine) + if table_name not in inspector.get_table_names(): + raise ValueError(f"Table '{table_name}' not found in database") + def get_schema(self) -> str: - """Generate schema information from SQLite table. + """Generate schema information from database table. Returns: String describing the schema """ - # Get column info - cursor = self._conn.execute(f"PRAGMA table_info({self._table_name})") - columns = cursor.fetchall() + inspector = inspect(self._engine) + columns = inspector.get_columns(self._table_name) schema = [f"Table: {self._table_name}", "Columns:"] for col in columns: - # col format: (cid, name, type, notnull, dflt_value, pk) - column_info = [f"- {col[1]} ({col[2].upper()})"] + # Get SQL type name + sql_type = self._get_sql_type_name(col["type"]) + column_info = [f"- {col['name']} ({sql_type})"] # For numeric columns, try to get range - if col[2].upper() in ["INTEGER", "FLOAT", "REAL", "NUMERIC"]: + if isinstance( + col["type"], + ( + sqltypes.Integer, + sqltypes.Numeric, + sqltypes.Float, + sqltypes.Date, + sqltypes.Time, + sqltypes.DateTime, + sqltypes.BigInteger, + sqltypes.SmallInteger, + # sqltypes.Interval, + ), + ): try: - cursor = self._conn.execute( - f"SELECT MIN({col[1]}), MAX({col[1]}) FROM {self._table_name}" + query = text( + f"SELECT MIN({col['name']}), MAX({col['name']}) FROM {self._table_name}" ) - min_val, max_val = cursor.fetchone() - if min_val is not None and max_val is not None: - column_info.append(f" Range: {min_val} to {max_val}") - except sqlite3.Error: + with self._get_connection() as conn: + result = conn.execute(query).fetchone() + if result and result[0] is not None and result[1] is not None: + column_info.append(f" Range: {result[0]} to {result[1]}") + except Exception: pass # Skip range info if query fails - # For text columns, check if categorical (limited distinct values) - elif col[2].upper() == "TEXT": + # For string/text columns, check if categorical + elif isinstance( + col["type"], (sqltypes.String, sqltypes.Text, sqltypes.Enum) + ): try: - cursor = self._conn.execute( - f"SELECT COUNT(DISTINCT {col[1]}) FROM {self._table_name}" + count_query = text( + f"SELECT COUNT(DISTINCT {col['name']}) FROM {self._table_name}" ) - distinct_count = cursor.fetchone()[0] - if distinct_count <= 10: # Use fixed threshold for simplicity - cursor = self._conn.execute( - f"SELECT DISTINCT {col[1]} FROM {self._table_name} " - f"WHERE {col[1]} IS NOT NULL" - ) - values = [str(row[0]) for row in cursor.fetchall()] - values_str = ", ".join([f"'{v}'" for v in values]) - column_info.append(f" Categorical values: {values_str}") - except sqlite3.Error: + with self._get_connection() as conn: + distinct_count = conn.execute(count_query).scalar() + if distinct_count and distinct_count <= 10: + values_query = text( + f"SELECT DISTINCT {col['name']} FROM {self._table_name} " + f"WHERE {col['name']} IS NOT NULL" + ) + values = [ + str(row[0]) + for row in conn.execute(values_query).fetchall() + ] + values_str = ", ".join([f"'{v}'" for v in values]) + column_info.append(f" Categorical values: {values_str}") + except Exception: pass # Skip categorical info if query fails schema.extend(column_info) @@ -195,7 +226,7 @@ def get_schema(self) -> str: return "\n".join(schema) def execute_query(self, query: str) -> pd.DataFrame: - """Execute query using SQLite. + """Execute SQL query and return results as DataFrame. Args: query: SQL query to execute @@ -203,7 +234,8 @@ def execute_query(self, query: str) -> pd.DataFrame: Returns: Query results as pandas DataFrame """ - return pd.read_sql_query(query, self._conn) + with self._get_connection() as conn: + return pd.read_sql_query(text(query), conn) def get_data(self) -> pd.DataFrame: """Return the unfiltered data as a DataFrame. @@ -211,4 +243,29 @@ def get_data(self) -> pd.DataFrame: Returns: The complete dataset as a pandas DataFrame """ - return pd.read_sql_query(f"SELECT * FROM {self._table_name}", self._conn) + return self.execute_query(f"SELECT * FROM {self._table_name}") + + def _get_sql_type_name(self, type_: sqltypes.TypeEngine) -> str: + """Convert SQLAlchemy type to SQL type name.""" + if isinstance(type_, sqltypes.Integer): + return "INTEGER" + elif isinstance(type_, sqltypes.Float): + return "FLOAT" + elif isinstance(type_, sqltypes.Numeric): + return "NUMERIC" + elif isinstance(type_, sqltypes.Boolean): + return "BOOLEAN" + elif isinstance(type_, sqltypes.DateTime): + return "TIMESTAMP" + elif isinstance(type_, sqltypes.Date): + return "DATE" + elif isinstance(type_, sqltypes.Time): + return "TIME" + elif isinstance(type_, (sqltypes.String, sqltypes.Text)): + return "TEXT" + else: + return type_.__class__.__name__.upper() + + def _get_connection(self) -> Connection: + """Get a connection to use for queries.""" + return self._engine.connect() From a218fb914963a4477598c8f4d0081bae043de286 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Wed, 23 Apr 2025 16:26:58 -0700 Subject: [PATCH 05/51] Don't fail when given table name's case differs from SQLAlchemy Inspector --- python-package/querychat/datasource.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python-package/querychat/datasource.py b/python-package/querychat/datasource.py index e33711e7..1fee9b9c 100644 --- a/python-package/querychat/datasource.py +++ b/python-package/querychat/datasource.py @@ -6,7 +6,7 @@ import narwhals as nw import pandas as pd from sqlalchemy import inspect, text -from sqlalchemy.engine import Engine, Connection +from sqlalchemy.engine import Connection, Engine from sqlalchemy.sql import sqltypes @@ -152,7 +152,7 @@ def __init__(self, engine: Engine, table_name: str): # Validate table exists inspector = inspect(self._engine) - if table_name not in inspector.get_table_names(): + if not inspector.has_table(table_name): raise ValueError(f"Table '{table_name}' not found in database") def get_schema(self) -> str: From dc0814ef6a68575d0bb9624f43596507d769f4e3 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Thu, 1 May 2025 16:58:29 -0400 Subject: [PATCH 06/51] Forgot import --- python-package/querychat/querychat.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python-package/querychat/querychat.py b/python-package/querychat/querychat.py index fb0e6997..ed558362 100644 --- a/python-package/querychat/querychat.py +++ b/python-package/querychat/querychat.py @@ -7,6 +7,7 @@ import chatlas import chevron +import narwhals as nw from shiny import Inputs, Outputs, Session, module, reactive, ui from .datasource import DataSource From 9d95d1d0f47db306c3a422d913cfbcf8c6e0d244 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Mon, 2 Jun 2025 16:12:35 -0700 Subject: [PATCH 07/51] Have server() return proper class with typed methods, instead of dict --- .gitignore | 3 +- python-package/examples/app-database.py | 55 +++++++++ .../examples/{app.py => app-dataframe.py} | 7 +- python-package/querychat/querychat.py | 104 ++++++++++++++++-- 4 files changed, 154 insertions(+), 15 deletions(-) create mode 100644 python-package/examples/app-database.py rename python-package/examples/{app.py => app-dataframe.py} (97%) diff --git a/.gitignore b/.gitignore index 98ab2295..32d0462b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ animation.screenflow/ README_files/ -README.html \ No newline at end of file +README.html +.DS_Store \ No newline at end of file diff --git a/python-package/examples/app-database.py b/python-package/examples/app-database.py new file mode 100644 index 00000000..cfee136e --- /dev/null +++ b/python-package/examples/app-database.py @@ -0,0 +1,55 @@ +import sqlite3 +from pathlib import Path + +import querychat +from querychat.datasource import SQLAlchemySource +from seaborn import load_dataset +from shiny import App, render, ui +from sqlalchemy import create_engine + +# Load titanic data and create SQLite database +db_path = Path(__file__).parent / "titanic.db" +engine = create_engine("sqlite:///" + str(db_path)) +# titanic = load_dataset("titanic") +# titanic.to_sql("titanic", conn, if_exists="replace", index=False) + +with open(Path(__file__).parent / "greeting.md", "r") as f: + greeting = f.read() +with open(Path(__file__).parent / "data_description.md", "r") as f: + data_desc = f.read() + +# 1. Configure querychat +querychat_config = querychat.init( + SQLAlchemySource(engine, "titanic"), + greeting=greeting, + data_description=data_desc, +) + +# Create UI +app_ui = ui.page_sidebar( + # 2. Place the chat component in the sidebar + querychat.sidebar("chat"), + # Main panel with data viewer + ui.card( + ui.output_data_frame("data_table"), + fill=True, + ), + title="querychat with Python (SQLite)", + fillable=True, +) + + +# Define server logic +def server(input, output, session): + # 3. Initialize querychat server with the config from step 1 + chat = querychat.server("chat", querychat_config) + + # 4. Display the filtered dataframe + @render.data_frame + def data_table(): + # Access filtered data via chat.df() reactive + return chat["df"]() + + +# Create Shiny app +app = App(app_ui, server) diff --git a/python-package/examples/app.py b/python-package/examples/app-dataframe.py similarity index 97% rename from python-package/examples/app.py rename to python-package/examples/app-dataframe.py index 5e628f43..13d224fb 100644 --- a/python-package/examples/app.py +++ b/python-package/examples/app-dataframe.py @@ -1,10 +1,9 @@ from pathlib import Path -from seaborn import load_dataset -from shiny import App, render, ui - import querychat from querychat.datasource import DataFrameSource +from seaborn import load_dataset +from shiny import App, render, ui titanic = load_dataset("titanic") @@ -43,7 +42,7 @@ def server(input, output, session): @render.data_frame def data_table(): # Access filtered data via chat.df() reactive - return chat["df"]() + return chat.df() # Create Shiny app diff --git a/python-package/querychat/querychat.py b/python-package/querychat/querychat.py index ed558362..093dec16 100644 --- a/python-package/querychat/querychat.py +++ b/python-package/querychat/querychat.py @@ -3,11 +3,13 @@ import os import sys from functools import partial -from typing import Any, Dict, Optional, Protocol +from typing import Any, Callable, Optional, Protocol import chatlas import chevron import narwhals as nw +import pandas as pd +from narwhals.typing import IntoFrame from shiny import Inputs, Outputs, Session, module, reactive, ui from .datasource import DataSource @@ -35,6 +37,93 @@ def __init__( self.create_chat_callback = create_chat_callback +class QueryChat: + """ + An object representing a query chat session. This is created within a Shiny + server function or Shiny module server function by using + `querychat.server()`. Use this object to bridge the chat interface with the + rest of the Shiny app, for example, by displaying the filtered data. + """ + + def __init__( + self, + chat: chatlas.Chat, + sql: Callable[[], str], + title: Callable[[], str | None], + df: Callable[[], pd.DataFrame], + ): + """ + Initialize a QueryChat object. + + Args: + chat: The chat object for the session + sql: Reactive that returns the current SQL query + title: Reactive that returns the current title + df: Reactive that returns the filtered data frame + """ + self._chat = chat + self._sql = sql + self._title = title + self._df = df + + def chat(self) -> chatlas.Chat: + """ + Get the chat object for this session. + + Returns: + The chat object + """ + return self._chat() + + def sql(self) -> str: + """ + Reactively read the current SQL query that is in effect. + + Returns: + The current SQL query as a string, or `""` if no query has been set. + """ + return self._sql() + + def title(self) -> str | None: + """ + Reactively read the current title that is in effect. The title is a + short description of the current query that the LLM provides to us + whenever it generates a new SQL query. It can be used as a status string + for the data dashboard. + + Returns: + The current title as a string, or `None` if no title has been set + due to no SQL query being set. + """ + return self._title() + + def df(self) -> pd.DataFrame: + """ + Reactively read the current filtered data frame that is in effect. + + Returns: + The current filtered data frame as a pandas DataFrame. If no query + has been set, this will return the unfiltered data frame from the + data source. + """ + return self._df() + + def __getitem__(self, key: str) -> Any: + """ + Allow access to configuration parameters like a dictionary. For + backwards compatibility only; new code should use the attributes + directly instead. + """ + if key == "chat": + return self.chat + elif key == "sql": + return self.sql + elif key == "title": + return self.title + elif key == "df": + return self.df + + def system_prompt( data_source: DataSource, data_description: Optional[str] = None, @@ -190,7 +279,7 @@ def sidebar(id: str, width: int = 400, height: str = "100%", **kwargs) -> ui.Sid @module.server def server( input: Inputs, output: Outputs, session: Session, querychat_config: QueryChatConfig -) -> Dict[str, Any]: +) -> QueryChat: """ Initialize the querychat server. @@ -219,8 +308,8 @@ def _(): create_chat_callback = querychat_config.create_chat_callback # Reactive values to store state - current_title = reactive.Value(None) - current_query = reactive.Value("") + current_title: reactive.Value[str | None] = reactive.Value(None) + current_query: reactive.Value[str] = reactive.Value("") @reactive.Calc def filtered_df(): @@ -326,9 +415,4 @@ async def greet_on_startup(): await chat_ui.append_message_stream(stream) # Return the interface for other components to use - return { - "chat": chat, - "sql": current_query.get, - "title": current_title.get, - "df": filtered_df, - } + return QueryChat(chat, current_query.get, current_title.get, filtered_df) From aeb87dd060fbafb1c973d94c7041ab20ccf71dd8 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Mon, 2 Jun 2025 16:17:43 -0700 Subject: [PATCH 08/51] Auto-create sqlite database for example --- .gitignore | 3 ++- python-package/examples/app-database.py | 9 ++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 32d0462b..1639e057 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ __pycache__/ animation.screenflow/ README_files/ README.html -.DS_Store \ No newline at end of file +.DS_Store +python-package/examples/titanic.db diff --git a/python-package/examples/app-database.py b/python-package/examples/app-database.py index cfee136e..c196b3e7 100644 --- a/python-package/examples/app-database.py +++ b/python-package/examples/app-database.py @@ -1,4 +1,3 @@ -import sqlite3 from pathlib import Path import querychat @@ -10,8 +9,12 @@ # Load titanic data and create SQLite database db_path = Path(__file__).parent / "titanic.db" engine = create_engine("sqlite:///" + str(db_path)) -# titanic = load_dataset("titanic") -# titanic.to_sql("titanic", conn, if_exists="replace", index=False) + +if not db_path.exists(): + # For example purposes, we'll create the database if it doesn't exist. Don't + # do this in your app! + titanic = load_dataset("titanic") + titanic.to_sql("titanic", engine, if_exists="replace", index=False) with open(Path(__file__).parent / "greeting.md", "r") as f: greeting = f.read() From c38b567189b73dee742715c5983ae32d57adc6c1 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Mon, 2 Jun 2025 16:38:25 -0700 Subject: [PATCH 09/51] Have init() take data frame or sqlalchemy engine directly ...instead of requiring explicit DataSource subclass creation --- python-package/examples/app-database.py | 4 ++-- python-package/examples/app-dataframe.py | 4 ++-- python-package/querychat/querychat.py | 19 ++++++++++++++----- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/python-package/examples/app-database.py b/python-package/examples/app-database.py index c196b3e7..9769cc17 100644 --- a/python-package/examples/app-database.py +++ b/python-package/examples/app-database.py @@ -1,7 +1,6 @@ from pathlib import Path import querychat -from querychat.datasource import SQLAlchemySource from seaborn import load_dataset from shiny import App, render, ui from sqlalchemy import create_engine @@ -23,7 +22,8 @@ # 1. Configure querychat querychat_config = querychat.init( - SQLAlchemySource(engine, "titanic"), + engine, + "titanic", greeting=greeting, data_description=data_desc, ) diff --git a/python-package/examples/app-dataframe.py b/python-package/examples/app-dataframe.py index 13d224fb..1a1fd858 100644 --- a/python-package/examples/app-dataframe.py +++ b/python-package/examples/app-dataframe.py @@ -1,7 +1,6 @@ from pathlib import Path import querychat -from querychat.datasource import DataFrameSource from seaborn import load_dataset from shiny import App, render, ui @@ -14,7 +13,8 @@ # 1. Configure querychat querychat_config = querychat.init( - DataFrameSource(titanic, "titanic"), + titanic, + "titanic", greeting=greeting, data_description=data_desc, ) diff --git a/python-package/querychat/querychat.py b/python-package/querychat/querychat.py index 093dec16..aec6bba7 100644 --- a/python-package/querychat/querychat.py +++ b/python-package/querychat/querychat.py @@ -9,10 +9,11 @@ import chevron import narwhals as nw import pandas as pd +import sqlalchemy from narwhals.typing import IntoFrame from shiny import Inputs, Outputs, Session, module, reactive, ui -from .datasource import DataSource +from .datasource import DataFrameSource, DataSource, SQLAlchemySource class CreateChatCallback(Protocol): @@ -73,7 +74,7 @@ def chat(self) -> chatlas.Chat: Returns: The chat object """ - return self._chat() + return self._chat def sql(self) -> str: """ @@ -187,7 +188,8 @@ def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: def init( - data_source: DataSource, + data_source: IntoFrame | sqlalchemy.Engine, + table_name: str, greeting: Optional[str] = None, data_description: Optional[str] = None, extra_instructions: Optional[str] = None, @@ -207,6 +209,13 @@ def init( Returns: A QueryChatConfig object that can be passed to server() """ + + data_source_obj: DataSource + if isinstance(data_source, sqlalchemy.Engine): + data_source_obj = SQLAlchemySource(data_source, table_name) + else: + data_source_obj = DataFrameSource(nw.from_native(data_source).to_pandas(), table_name) + # Process greeting if greeting is None: print( @@ -217,7 +226,7 @@ def init( # Create the system prompt, or use the override _system_prompt = system_prompt_override or system_prompt( - data_source, data_description, extra_instructions + data_source_obj, data_description, extra_instructions ) # Default chat function if none provided @@ -226,7 +235,7 @@ def init( ) return QueryChatConfig( - data_source=data_source, + data_source=data_source_obj, system_prompt=_system_prompt, greeting=greeting, create_chat_callback=create_chat_callback, From 57922b3fe2eeda722f28ca35b60543a9d4223d15 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Mon, 2 Jun 2025 17:11:26 -0700 Subject: [PATCH 10/51] Use GPT-4.1 by default, not GPT-4, yuck --- python-package/querychat/querychat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/querychat/querychat.py b/python-package/querychat/querychat.py index ed558362..167560c5 100644 --- a/python-package/querychat/querychat.py +++ b/python-package/querychat/querychat.py @@ -133,7 +133,7 @@ def init( # Default chat function if none provided create_chat_callback = create_chat_callback or partial( - chatlas.ChatOpenAI, model="gpt-4" + chatlas.ChatOpenAI, model="gpt-4.1" ) return QueryChatConfig( From a08764bf130895a895fdff7c2d535ef40855f156 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Mon, 2 Jun 2025 17:23:12 -0700 Subject: [PATCH 11/51] Update README --- python-package/README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python-package/README.md b/python-package/README.md index be8057ea..9b29fb19 100644 --- a/python-package/README.md +++ b/python-package/README.md @@ -56,7 +56,7 @@ def server(input, output, session): # chat["df"]() reactive. @render.data_frame def data_table(): - return chat["df"]() + return chat.df() # Create Shiny app @@ -171,8 +171,8 @@ which you can then pass via: ```python querychat_config = querychat.init( - df=titanic, - table_name="titanic", + titanic, + "titanic", data_description=Path("data_description.md").read_text() ) ``` @@ -185,8 +185,8 @@ You can add additional instructions of your own to the end of the system prompt, ```python querychat_config = querychat.init( - df=titanic, - table_name="titanic", + titanic, + "titanic", extra_instructions=[ "You're speaking to a British audience--please use appropriate spelling conventions.", "Use lots of emojis! πŸ˜ƒ Emojis everywhere, 🌍 emojis forever. ♾️", @@ -218,8 +218,8 @@ def my_chat_func(system_prompt: str) -> chatlas.Chat: my_chat_func = partial(chatlas.ChatAnthropic, model="claude-3-7-sonnet-latest") querychat_config = querychat.init( - df=titanic, - table_name="titanic", + titanic, + "titanic", create_chat_callback=my_chat_func ) ``` From 374bdfb7f5631245f5fd7a36f38fc20f3c9475b5 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Fri, 6 Jun 2025 10:47:13 -0700 Subject: [PATCH 12/51] this should significantly speed up schema generation --- python-package/pyproject.toml | 7 +- python-package/src/querychat/datasource.py | 118 +++++++++---- python-package/tests/__init__.py | 0 python-package/tests/test_datasource.py | 194 +++++++++++++++++++++ 4 files changed, 283 insertions(+), 36 deletions(-) create mode 100644 python-package/tests/__init__.py create mode 100644 python-package/tests/test_datasource.py diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml index 1ac303bb..7fbfe145 100644 --- a/python-package/pyproject.toml +++ b/python-package/pyproject.toml @@ -43,7 +43,12 @@ packages = ["src/querychat"] include = ["src/querychat", "LICENSE", "README.md"] [tool.uv] -dev-dependencies = ["ruff>=0.6.5", "pyright>=1.1.401", "tox-uv>=1.11.4"] +dev-dependencies = [ + "ruff>=0.6.5", + "pyright>=1.1.401", + "tox-uv>=1.11.4", + "pytest>=8.4.0", +] [tool.ruff] src = ["src/querychat"] diff --git a/python-package/src/querychat/datasource.py b/python-package/src/querychat/datasource.py index d9322ff4..7be839bc 100644 --- a/python-package/src/querychat/datasource.py +++ b/python-package/src/querychat/datasource.py @@ -189,12 +189,15 @@ def get_schema(self, *, categorical_threshold: int) -> str: schema = [f"Table: {self._table_name}", "Columns:"] + # Build a single query to get all column statistics + select_parts = [] + numeric_columns = [] + text_columns = [] + for col in columns: - # Get SQL type name - sql_type = self._get_sql_type_name(col["type"]) - column_info = [f"- {col['name']} ({sql_type})"] - - # For numeric columns, try to get range + col_name = col['name'] + + # Check if column is numeric if isinstance( col["type"], ( @@ -206,44 +209,89 @@ def get_schema(self, *, categorical_threshold: int) -> str: sqltypes.DateTime, sqltypes.BigInteger, sqltypes.SmallInteger, - # sqltypes.Interval, ), ): - try: - query = text( - f"SELECT MIN({col['name']}), MAX({col['name']}) FROM {self._table_name}", - ) - with self._get_connection() as conn: - result = conn.execute(query).fetchone() - if result and result[0] is not None and result[1] is not None: - column_info.append(f" Range: {result[0]} to {result[1]}") - except Exception: - pass # Skip range info if query fails - - # For string/text columns, check if categorical + numeric_columns.append(col_name) + select_parts.extend([ + f"MIN({col_name}) as {col_name}_min", + f"MAX({col_name}) as {col_name}_max" + ]) + + # Check if column is text/string elif isinstance( col["type"], (sqltypes.String, sqltypes.Text, sqltypes.Enum), ): - try: - count_query = text( - f"SELECT COUNT(DISTINCT {col['name']}) FROM {self._table_name}", + text_columns.append(col_name) + select_parts.append(f"COUNT(DISTINCT {col_name}) as {col_name}_distinct_count") + + # Execute single query to get all statistics + column_stats = {} + if select_parts: + try: + stats_query = text(f"SELECT {', '.join(select_parts)} FROM {self._table_name}") + with self._get_connection() as conn: + result = conn.execute(stats_query).fetchone() + if result: + # Convert result to dict for easier access + column_stats = dict(zip(result._fields, result)) + except Exception: + pass # Fall back to no statistics if query fails + + # Get categorical values for text columns that are below threshold + categorical_values = {} + text_cols_to_query = [] + for col_name in text_columns: + distinct_count_key = f"{col_name}_distinct_count" + if (distinct_count_key in column_stats and + column_stats[distinct_count_key] and + column_stats[distinct_count_key] <= categorical_threshold): + text_cols_to_query.append(col_name) + + # Get categorical values in a single query if needed + if text_cols_to_query: + try: + # Build UNION query for all categorical columns + union_parts = [] + for col_name in text_cols_to_query: + union_parts.append( + f"SELECT '{col_name}' as column_name, {col_name} as value " + f"FROM {self._table_name} WHERE {col_name} IS NOT NULL" ) + + if union_parts: + categorical_query = text(" UNION ALL ".join(union_parts)) with self._get_connection() as conn: - distinct_count = conn.execute(count_query).scalar() - if distinct_count and distinct_count <= categorical_threshold: - values_query = text( - f"SELECT DISTINCT {col['name']} FROM {self._table_name} " - f"WHERE {col['name']} IS NOT NULL", - ) - values = [ - str(row[0]) - for row in conn.execute(values_query).fetchall() - ] - values_str = ", ".join([f"'{v}'" for v in values]) - column_info.append(f" Categorical values: {values_str}") - except Exception: - pass # Skip categorical info if query fails + results = conn.execute(categorical_query).fetchall() + for row in results: + col_name, value = row + if col_name not in categorical_values: + categorical_values[col_name] = [] + categorical_values[col_name].append(str(value)) + except Exception: + pass # Skip categorical values if query fails + + # Build schema description using collected statistics + for col in columns: + col_name = col['name'] + sql_type = self._get_sql_type_name(col["type"]) + column_info = [f"- {col_name} ({sql_type})"] + + # Add range info for numeric columns + if col_name in numeric_columns: + min_key = f"{col_name}_min" + max_key = f"{col_name}_max" + if (min_key in column_stats and max_key in column_stats and + column_stats[min_key] is not None and column_stats[max_key] is not None): + column_info.append(f" Range: {column_stats[min_key]} to {column_stats[max_key]}") + + # Add categorical values for text columns + elif col_name in categorical_values: + values = categorical_values[col_name] + # Remove duplicates and sort + unique_values = sorted(set(values)) + values_str = ", ".join([f"'{v}'" for v in unique_values]) + column_info.append(f" Categorical values: {values_str}") schema.extend(column_info) diff --git a/python-package/tests/__init__.py b/python-package/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python-package/tests/test_datasource.py b/python-package/tests/test_datasource.py new file mode 100644 index 00000000..ca5395c2 --- /dev/null +++ b/python-package/tests/test_datasource.py @@ -0,0 +1,194 @@ +import sqlite3 +import tempfile +from pathlib import Path + +import pytest +from sqlalchemy import create_engine + +from src.querychat.datasource import SQLAlchemySource + + +@pytest.fixture +def test_db_engine(): + """Create a temporary SQLite database with test data.""" + # Create temporary database file + temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db") + temp_db.close() + + # Connect and create test table with various data types + conn = sqlite3.connect(temp_db.name) + cursor = conn.cursor() + + # Create table with different column types + cursor.execute(""" + CREATE TABLE test_table ( + id INTEGER PRIMARY KEY, + name TEXT, + age INTEGER, + salary REAL, + is_active BOOLEAN, + join_date DATE, + category TEXT, + score NUMERIC, + description TEXT + ) + """) + + # Insert test data + test_data = [ + (1, "Alice", 30, 75000.50, True, "2023-01-15", "A", 95.5, "Senior developer"), + (2, "Bob", 25, 60000.00, True, "2023-03-20", "B", 87.2, "Junior developer"), + (3, "Charlie", 35, 85000.75, False, "2022-12-01", "A", 92.1, "Team lead"), + (4, "Diana", 28, 70000.25, True, "2023-05-10", "C", 89.8, "Mid-level developer"), + (5, "Eve", 32, 80000.00, True, "2023-02-28", "A", 91.3, "Senior developer"), + (6, "Frank", 26, 62000.50, False, "2023-04-15", "B", 85.7, "Junior developer"), + (7, "Grace", 29, 72000.75, True, "2023-01-30", "A", 93.4, "Developer"), + (8, "Henry", 31, 78000.25, True, "2023-03-05", "C", 88.9, "Senior developer"), + ] + + cursor.executemany(""" + INSERT INTO test_table + (id, name, age, salary, is_active, join_date, category, score, description) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, test_data) + + conn.commit() + conn.close() + + # Create SQLAlchemy engine + engine = create_engine(f"sqlite:///{temp_db.name}") + + yield engine + + # Cleanup + Path(temp_db.name).unlink() + + +def test_get_schema_numeric_ranges(test_db_engine): + """Test that numeric columns include min/max ranges.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=5) + + # Check that numeric columns have range information + assert "- id (INTEGER)" in schema + assert "Range: 1 to 8" in schema + + assert "- age (INTEGER)" in schema + assert "Range: 25 to 35" in schema + + assert "- salary (FLOAT)" in schema + assert "Range: 60000.0 to 85000.75" in schema + + assert "- score (NUMERIC)" in schema + assert "Range: 85.7 to 95.5" in schema + + +def test_get_schema_categorical_values(test_db_engine): + """Test that text columns with few unique values show categorical values.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=5) + + # Category column should be treated as categorical (3 unique values: A, B, C) + assert "- category (TEXT)" in schema + assert "Categorical values:" in schema + assert "'A'" in schema and "'B'" in schema and "'C'" in schema + + +def test_get_schema_non_categorical_text(test_db_engine): + """Test that text columns with many unique values don't show categorical values.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=3) + + # Name and description columns should not be categorical (8 and 6 unique values respectively) + lines = schema.split('\n') + name_line_idx = next(i for i, line in enumerate(lines) if "- name (TEXT)" in line) + description_line_idx = next(i for i, line in enumerate(lines) if "- description (TEXT)" in line) + + # Check that the next line after name column doesn't contain categorical values + if name_line_idx + 1 < len(lines): + assert "Categorical values:" not in lines[name_line_idx + 1] + + # Check that the next line after description column doesn't contain categorical values + if description_line_idx + 1 < len(lines): + assert "Categorical values:" not in lines[description_line_idx + 1] + + +def test_get_schema_different_thresholds(test_db_engine): + """Test that categorical_threshold parameter works correctly.""" + source = SQLAlchemySource(test_db_engine, "test_table") + + # With threshold 2, only category column (3 unique) should not be categorical + schema_low = source.get_schema(categorical_threshold=2) + assert "- category (TEXT)" in schema_low + assert "'A'" not in schema_low # Should not show categorical values + + # With threshold 5, category column should be categorical + schema_high = source.get_schema(categorical_threshold=5) + assert "- category (TEXT)" in schema_high + assert "'A'" in schema_high # Should show categorical values + + +def test_get_schema_table_structure(test_db_engine): + """Test the overall structure of the schema output.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=5) + + lines = schema.split('\n') + + # Check header + assert lines[0] == "Table: test_table" + assert lines[1] == "Columns:" + + # Check that all columns are present + expected_columns = ["id", "name", "age", "salary", "is_active", "join_date", "category", "score", "description"] + for col in expected_columns: + assert any(f"- {col} (" in line for line in lines), f"Column {col} not found in schema" + + +def test_get_schema_empty_result_handling(test_db_engine): + """Test handling when statistics queries return empty results.""" + # Create empty table + conn = sqlite3.connect(":memory:") + cursor = conn.cursor() + cursor.execute("CREATE TABLE empty_table (id INTEGER, name TEXT)") + conn.commit() + + engine = create_engine("sqlite:///:memory:") + # Recreate table in the new engine + with engine.connect() as connection: + from sqlalchemy import text + connection.execute(text("CREATE TABLE empty_table (id INTEGER, name TEXT)")) + connection.commit() + + source = SQLAlchemySource(engine, "empty_table") + schema = source.get_schema(categorical_threshold=5) + + # Should still work but without range/categorical info + assert "Table: empty_table" in schema + assert "- id (INTEGER)" in schema + assert "- name (TEXT)" in schema + # Should not have range or categorical information + assert "Range:" not in schema + assert "Categorical values:" not in schema + + +def test_get_schema_boolean_and_date_types(test_db_engine): + """Test handling of boolean and date column types.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=5) + + # Boolean column should show range + assert "- is_active (BOOLEAN)" in schema + # SQLite stores booleans as integers, so should show 0 to 1 range + + # Date column should show range + assert "- join_date (DATE)" in schema + assert "Range:" in schema + + +def test_invalid_table_name(): + """Test that invalid table name raises appropriate error.""" + engine = create_engine("sqlite:///:memory:") + + with pytest.raises(ValueError, match="Table 'nonexistent' not found in database"): + SQLAlchemySource(engine, "nonexistent") \ No newline at end of file From e294b1b39c6061ec2049ea4b1b0697c6d00d9eac Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Fri, 6 Jun 2025 11:56:08 -0700 Subject: [PATCH 13/51] another speedup --- python-package/src/querychat/datasource.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python-package/src/querychat/datasource.py b/python-package/src/querychat/datasource.py index 7be839bc..99f0a096 100644 --- a/python-package/src/querychat/datasource.py +++ b/python-package/src/querychat/datasource.py @@ -256,7 +256,8 @@ def get_schema(self, *, categorical_threshold: int) -> str: for col_name in text_cols_to_query: union_parts.append( f"SELECT '{col_name}' as column_name, {col_name} as value " - f"FROM {self._table_name} WHERE {col_name} IS NOT NULL" + f"FROM {self._table_name} WHERE {col_name} IS NOT NULL " + f"GROUP BY {col_name}" ) if union_parts: From b179ea699d664d55a040a3c7c3b698ed28057202 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Fri, 6 Jun 2025 12:20:57 -0700 Subject: [PATCH 14/51] ruff formatting --- python-package/src/querychat/datasource.py | 52 ++++++++++++++-------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/python-package/src/querychat/datasource.py b/python-package/src/querychat/datasource.py index 99f0a096..bc60e6c0 100644 --- a/python-package/src/querychat/datasource.py +++ b/python-package/src/querychat/datasource.py @@ -193,10 +193,10 @@ def get_schema(self, *, categorical_threshold: int) -> str: select_parts = [] numeric_columns = [] text_columns = [] - + for col in columns: - col_name = col['name'] - + col_name = col["name"] + # Check if column is numeric if isinstance( col["type"], @@ -212,24 +212,30 @@ def get_schema(self, *, categorical_threshold: int) -> str: ), ): numeric_columns.append(col_name) - select_parts.extend([ - f"MIN({col_name}) as {col_name}_min", - f"MAX({col_name}) as {col_name}_max" - ]) - + select_parts.extend( + [ + f"MIN({col_name}) as {col_name}_min", + f"MAX({col_name}) as {col_name}_max", + ] + ) + # Check if column is text/string elif isinstance( col["type"], (sqltypes.String, sqltypes.Text, sqltypes.Enum), ): text_columns.append(col_name) - select_parts.append(f"COUNT(DISTINCT {col_name}) as {col_name}_distinct_count") + select_parts.append( + f"COUNT(DISTINCT {col_name}) as {col_name}_distinct_count" + ) # Execute single query to get all statistics column_stats = {} if select_parts: try: - stats_query = text(f"SELECT {', '.join(select_parts)} FROM {self._table_name}") + stats_query = text( + f"SELECT {', '.join(select_parts)} FROM {self._table_name}" + ) with self._get_connection() as conn: result = conn.execute(stats_query).fetchone() if result: @@ -243,11 +249,13 @@ def get_schema(self, *, categorical_threshold: int) -> str: text_cols_to_query = [] for col_name in text_columns: distinct_count_key = f"{col_name}_distinct_count" - if (distinct_count_key in column_stats and - column_stats[distinct_count_key] and - column_stats[distinct_count_key] <= categorical_threshold): + if ( + distinct_count_key in column_stats + and column_stats[distinct_count_key] + and column_stats[distinct_count_key] <= categorical_threshold + ): text_cols_to_query.append(col_name) - + # Get categorical values in a single query if needed if text_cols_to_query: try: @@ -259,7 +267,7 @@ def get_schema(self, *, categorical_threshold: int) -> str: f"FROM {self._table_name} WHERE {col_name} IS NOT NULL " f"GROUP BY {col_name}" ) - + if union_parts: categorical_query = text(" UNION ALL ".join(union_parts)) with self._get_connection() as conn: @@ -274,7 +282,7 @@ def get_schema(self, *, categorical_threshold: int) -> str: # Build schema description using collected statistics for col in columns: - col_name = col['name'] + col_name = col["name"] sql_type = self._get_sql_type_name(col["type"]) column_info = [f"- {col_name} ({sql_type})"] @@ -282,9 +290,15 @@ def get_schema(self, *, categorical_threshold: int) -> str: if col_name in numeric_columns: min_key = f"{col_name}_min" max_key = f"{col_name}_max" - if (min_key in column_stats and max_key in column_stats and - column_stats[min_key] is not None and column_stats[max_key] is not None): - column_info.append(f" Range: {column_stats[min_key]} to {column_stats[max_key]}") + if ( + min_key in column_stats + and max_key in column_stats + and column_stats[min_key] is not None + and column_stats[max_key] is not None + ): + column_info.append( + f" Range: {column_stats[min_key]} to {column_stats[max_key]}" + ) # Add categorical values for text columns elif col_name in categorical_values: From 2cbe19952868cc4d3f47c4ba7c22aad734616361 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Fri, 6 Jun 2025 12:32:45 -0700 Subject: [PATCH 15/51] updating so formatting checks pass --- python-package/src/querychat/datasource.py | 37 +++++++++++----------- python-package/src/querychat/querychat.py | 14 ++++---- 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/python-package/src/querychat/datasource.py b/python-package/src/querychat/datasource.py index bc60e6c0..c3c00390 100644 --- a/python-package/src/querychat/datasource.py +++ b/python-package/src/querychat/datasource.py @@ -1,14 +1,16 @@ from __future__ import annotations -from typing import ClassVar, Protocol +from typing import TYPE_CHECKING, ClassVar, Protocol import duckdb import narwhals as nw import pandas as pd from sqlalchemy import inspect, text -from sqlalchemy.engine import Connection, Engine from sqlalchemy.sql import sqltypes +if TYPE_CHECKING: + from sqlalchemy.engine import Connection, Engine + class DataSource(Protocol): db_engine: ClassVar[str] @@ -176,7 +178,7 @@ def __init__(self, engine: Engine, table_name: str): if not inspector.has_table(table_name): raise ValueError(f"Table '{table_name}' not found in database") - def get_schema(self, *, categorical_threshold: int) -> str: + def get_schema(self, *, categorical_threshold: int) -> str: # noqa: PLR0912 """ Generate schema information from database table. @@ -216,7 +218,7 @@ def get_schema(self, *, categorical_threshold: int) -> str: [ f"MIN({col_name}) as {col_name}_min", f"MAX({col_name}) as {col_name}_max", - ] + ], ) # Check if column is text/string @@ -226,7 +228,7 @@ def get_schema(self, *, categorical_threshold: int) -> str: ): text_columns.append(col_name) select_parts.append( - f"COUNT(DISTINCT {col_name}) as {col_name}_distinct_count" + f"COUNT(DISTINCT {col_name}) as {col_name}_distinct_count", ) # Execute single query to get all statistics @@ -234,14 +236,14 @@ def get_schema(self, *, categorical_threshold: int) -> str: if select_parts: try: stats_query = text( - f"SELECT {', '.join(select_parts)} FROM {self._table_name}" + f"SELECT {', '.join(select_parts)} FROM {self._table_name}", # noqa: S608 ) with self._get_connection() as conn: result = conn.execute(stats_query).fetchone() if result: # Convert result to dict for easier access column_stats = dict(zip(result._fields, result)) - except Exception: + except Exception: # noqa: S110 pass # Fall back to no statistics if query fails # Get categorical values for text columns that are below threshold @@ -260,13 +262,12 @@ def get_schema(self, *, categorical_threshold: int) -> str: if text_cols_to_query: try: # Build UNION query for all categorical columns - union_parts = [] - for col_name in text_cols_to_query: - union_parts.append( - f"SELECT '{col_name}' as column_name, {col_name} as value " - f"FROM {self._table_name} WHERE {col_name} IS NOT NULL " - f"GROUP BY {col_name}" - ) + union_parts = [ + f"SELECT '{col_name}' as column_name, {col_name} as value " # noqa: S608 + f"FROM {self._table_name} WHERE {col_name} IS NOT NULL " + f"GROUP BY {col_name}" + for col_name in text_cols_to_query + ] if union_parts: categorical_query = text(" UNION ALL ".join(union_parts)) @@ -277,7 +278,7 @@ def get_schema(self, *, categorical_threshold: int) -> str: if col_name not in categorical_values: categorical_values[col_name] = [] categorical_values[col_name].append(str(value)) - except Exception: + except Exception: # noqa: S110 pass # Skip categorical values if query fails # Build schema description using collected statistics @@ -297,7 +298,7 @@ def get_schema(self, *, categorical_threshold: int) -> str: and column_stats[max_key] is not None ): column_info.append( - f" Range: {column_stats[min_key]} to {column_stats[max_key]}" + f" Range: {column_stats[min_key]} to {column_stats[max_key]}", ) # Add categorical values for text columns @@ -334,9 +335,9 @@ def get_data(self) -> pd.DataFrame: The complete dataset as a pandas DataFrame """ - return self.execute_query(f"SELECT * FROM {self._table_name}") + return self.execute_query(f"SELECT * FROM {self._table_name}") # noqa: S608 - def _get_sql_type_name(self, type_: sqltypes.TypeEngine) -> str: + def _get_sql_type_name(self, type_: sqltypes.TypeEngine) -> str: # noqa: PLR0911 """Convert SQLAlchemy type to SQL type name.""" if isinstance(type_, sqltypes.Integer): return "INTEGER" diff --git a/python-package/src/querychat/querychat.py b/python-package/src/querychat/querychat.py index 5e693659..94bd93eb 100644 --- a/python-package/src/querychat/querychat.py +++ b/python-package/src/querychat/querychat.py @@ -126,14 +126,12 @@ def __getitem__(self, key: str) -> Any: backwards compatibility only; new code should use the attributes directly instead. """ - if key == "chat": - return self.chat - elif key == "sql": - return self.sql - elif key == "title": - return self.title - elif key == "df": - return self.df + return { + "chat": self.chat, + "sql": self.sql, + "title": self.title, + "df": self.df, + }.get(key) def system_prompt( From 8f59aa7faf9a49c9ece264488ff09a46d264e791 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Sat, 7 Jun 2025 14:25:19 -0600 Subject: [PATCH 16/51] adding a generic r datasource --- .gitignore | 1 + r-package/.gitignore | 101 +++++++++++ r-package/R/datasource.R | 222 ++++++++++++++++++++++++ r-package/R/prompt.R | 45 +++++ r-package/R/querychat.R | 143 ++++++++++----- r-package/examples/app-database.R | 96 ++++++++++ r-package/examples/database-setup.md | 122 +++++++++++++ r-package/tests/test_database_source.R | 135 ++++++++++++++ r-package/tests/test_querychat_server.R | 83 +++++++++ 9 files changed, 908 insertions(+), 40 deletions(-) create mode 100644 r-package/.gitignore create mode 100644 r-package/R/datasource.R create mode 100644 r-package/examples/app-database.R create mode 100644 r-package/examples/database-setup.md create mode 100644 r-package/tests/test_database_source.R create mode 100644 r-package/tests/test_querychat_server.R diff --git a/.gitignore b/.gitignore index 8229dee0..c72d6aff 100644 --- a/.gitignore +++ b/.gitignore @@ -244,3 +244,4 @@ po/*~ # RStudio Connect folder rsconnect/ +python-package/CLAUDE.md diff --git a/r-package/.gitignore b/r-package/.gitignore new file mode 100644 index 00000000..5e922d29 --- /dev/null +++ b/r-package/.gitignore @@ -0,0 +1,101 @@ +# R Project Specific +.Rproj.user/ +.Rhistory +.RData +.Rapp.history +.Rbuildignore + +# Build and package files +*.rds +*.rda +*.Rcheck/ +*.tar.gz +*.zip + +# Documentation +inst/doc/ +man/ + +# Dependencies +renv/ +renv.lock +packrat/ +packrat.lock + +# IDE Specific +.vscode/ +.Rproj/ +.Rproj.user/ +.Rproj.user/.* +.Rproj.user/!*.Rproj + +# OS Specific +.DS_Store +Thumbs.db + +# Tests +testthat/testthat.R + +# Coverage +coverage/ + +# Data +*.csv +*.txt +*.xlsx +*.xls +*.dat +*.dta +*.sav +*.por +*.sas7bdat +*.xpt + +# Logs +*.log +*.Rout + +# Compiled code +*.o +*.so +*.dll +*.dylib + +# Cache +.RData +.Rhistory +.Rapp.history +*.rds +*.rda + +# Environment files +.env +.env.* +.env.local +.env.*.local + +# Temporary files +*~ +*.swp +*.swo + +# Vignettes +vignettes/*.pdf +vignettes/*.html +vignettes/*.docx +vignettes/*.pptx + +# Compiled vignettes +vignettes/*.html +vignettes/*.pdf +vignettes/*.docx +vignettes/*.pptx + +# Compiled documentation +man/*.Rd +man/*.html +man/*.pdf +man/*.docx +man/*.pptx + +.Rprofile \ No newline at end of file diff --git a/r-package/R/datasource.R b/r-package/R/datasource.R new file mode 100644 index 00000000..9e66bb97 --- /dev/null +++ b/r-package/R/datasource.R @@ -0,0 +1,222 @@ +#' Database Data Source for querychat +#' +#' Create a data source that connects to external databases via DBI. +#' Supports PostgreSQL, MySQL, SQLite, and other DBI-compatible databases. +#' +#' @param conn A DBI connection object to the database +#' @param table_name Name of the table to query +#' @param categorical_threshold Maximum number of unique values for a text column +#' to be considered categorical (default: 20) +#' +#' @return A database data source object +#' @export +#' @examples +#' \dontrun{ +#' # PostgreSQL example +#' library(RPostgreSQL) +#' conn <- DBI::dbConnect(RPostgreSQL::PostgreSQL(), +#' dbname = "mydb", host = "localhost", +#' user = "user", password = "pass") +#' db_source <- database_source(conn, "my_table") +#' +#' # SQLite example +#' library(RSQLite) +#' conn <- DBI::dbConnect(RSQLite::SQLite(), "path/to/database.db") +#' db_source <- database_source(conn, "my_table") +#' } +database_source <- function(conn, table_name, categorical_threshold = 20) { + if (!inherits(conn, "DBIConnection")) { + rlang::abort("`conn` must be a valid DBI connection object") + } + + if (!is.character(table_name) || length(table_name) != 1) { + rlang::abort("`table_name` must be a single character string") + } + + if (!DBI::dbExistsTable(conn, table_name)) { + rlang::abort(glue::glue("Table '{table_name}' not found in database")) + } + + structure( + list( + conn = conn, + table_name = table_name, + categorical_threshold = categorical_threshold, + db_engine = "DBI" + ), + class = "database_source" + ) +} + +#' Generate schema information for database source +#' +#' @param source A database_source object +#' @return A character string describing the schema +#' @export +get_database_schema <- function(source) { + if (!inherits(source, "database_source")) { + rlang::abort("`source` must be a database_source object") + } + + conn <- source$conn + table_name <- source$table_name + categorical_threshold <- source$categorical_threshold + + # Get column information + columns <- DBI::dbListFields(conn, table_name) + + schema_lines <- c( + glue::glue("Table: {table_name}"), + "Columns:" + ) + + # Build single query to get column statistics + select_parts <- character(0) + numeric_columns <- character(0) + text_columns <- character(0) + + # Get sample of data to determine types + sample_query <- glue::glue_sql("SELECT * FROM {`table_name`} LIMIT 1", .con = conn) + sample_data <- DBI::dbGetQuery(conn, sample_query) + + for (col in columns) { + col_class <- class(sample_data[[col]])[1] + + if (col_class %in% c("integer", "numeric", "double", "Date", "POSIXct", "POSIXt")) { + numeric_columns <- c(numeric_columns, col) + select_parts <- c( + select_parts, + glue::glue_sql("MIN({`col`}) as {`col`}_min", .con = conn), + glue::glue_sql("MAX({`col`}) as {`col`}_max", .con = conn) + ) + } else if (col_class %in% c("character", "factor")) { + text_columns <- c(text_columns, col) + select_parts <- c( + select_parts, + glue::glue_sql("COUNT(DISTINCT {`col`}) as {`col`}_distinct_count", .con = conn) + ) + } + } + + # Execute statistics query + column_stats <- list() + if (length(select_parts) > 0) { + tryCatch({ + stats_query <- glue::glue_sql("SELECT {select_parts*} FROM {`table_name`}", .con = conn) + result <- DBI::dbGetQuery(conn, stats_query) + if (nrow(result) > 0) { + column_stats <- as.list(result[1, ]) + } + }, error = function(e) { + # Fall back to no statistics if query fails + }) + } + + # Get categorical values for text columns below threshold + categorical_values <- list() + text_cols_to_query <- character(0) + + for (col_name in text_columns) { + distinct_count_key <- paste0(col_name, "_distinct_count") + if (distinct_count_key %in% names(column_stats) && + !is.na(column_stats[[distinct_count_key]]) && + column_stats[[distinct_count_key]] <= categorical_threshold) { + text_cols_to_query <- c(text_cols_to_query, col_name) + } + } + + # Get categorical values + if (length(text_cols_to_query) > 0) { + for (col_name in text_cols_to_query) { + tryCatch({ + cat_query <- glue::glue_sql( + "SELECT DISTINCT {`col_name`} FROM {`table_name`} WHERE {`col_name`} IS NOT NULL ORDER BY {`col_name`}", + .con = conn + ) + result <- DBI::dbGetQuery(conn, cat_query) + if (nrow(result) > 0) { + categorical_values[[col_name]] <- result[[1]] + } + }, error = function(e) { + # Skip categorical values if query fails + }) + } + } + + # Build schema description + for (col in columns) { + col_class <- class(sample_data[[col]])[1] + sql_type <- r_class_to_sql_type(col_class) + + column_info <- glue::glue("- {col} ({sql_type})") + + # Add range info for numeric columns + if (col %in% numeric_columns) { + min_key <- paste0(col, "_min") + max_key <- paste0(col, "_max") + if (min_key %in% names(column_stats) && max_key %in% names(column_stats) && + !is.na(column_stats[[min_key]]) && !is.na(column_stats[[max_key]])) { + range_info <- glue::glue(" Range: {column_stats[[min_key]]} to {column_stats[[max_key]]}") + column_info <- paste(column_info, range_info, sep = "\n") + } + } + + # Add categorical values for text columns + if (col %in% names(categorical_values)) { + values <- categorical_values[[col]] + if (length(values) > 0) { + values_str <- paste0("'", values, "'", collapse = ", ") + cat_info <- glue::glue(" Categorical values: {values_str}") + column_info <- paste(column_info, cat_info, sep = "\n") + } + } + + schema_lines <- c(schema_lines, column_info) + } + + paste(schema_lines, collapse = "\n") +} + +#' Execute SQL query on database source +#' +#' @param source A database_source object +#' @param query SQL query to execute +#' @return A data frame with query results +#' @export +execute_database_query <- function(source, query) { + if (!inherits(source, "database_source")) { + rlang::abort("`source` must be a database_source object") + } + + DBI::dbGetQuery(source$conn, query) +} + +#' Get all data from database source +#' +#' @param source A database_source object +#' @return A data frame with all data from the table +#' @export +get_database_data <- function(source) { + if (!inherits(source, "database_source")) { + rlang::abort("`source` must be a database_source object") + } + + query <- glue::glue_sql("SELECT * FROM {`source$table_name`}", .con = source$conn) + DBI::dbGetQuery(source$conn, query) +} + +# Helper function to map R classes to SQL types +r_class_to_sql_type <- function(r_class) { + switch(r_class, + "integer" = "INTEGER", + "numeric" = "FLOAT", + "double" = "FLOAT", + "logical" = "BOOLEAN", + "Date" = "DATE", + "POSIXct" = "TIMESTAMP", + "POSIXt" = "TIMESTAMP", + "character" = "TEXT", + "factor" = "TEXT", + "TEXT" # default + ) +} \ No newline at end of file diff --git a/r-package/R/prompt.R b/r-package/R/prompt.R index 75ac68b6..c8276819 100644 --- a/r-package/R/prompt.R +++ b/r-package/R/prompt.R @@ -88,3 +88,48 @@ df_to_schema <- function( schema <- c(schema, unlist(column_info)) return(paste(schema, collapse = "\n")) } + +#' Create a system prompt for the chat model using database source +#' +#' This function generates a system prompt for the chat model based on a database +#' source's schema and optional additional context and instructions. +#' +#' @param db_source A database_source object to generate schema information from. +#' @param data_description Optional description of the data, in plain text or Markdown format. +#' @param extra_instructions Optional additional instructions for the chat model, in plain text or Markdown format. +#' +#' @return A string containing the system prompt for the chat model. +#' +#' @export +querychat_system_prompt_database <- function( + db_source, + data_description = NULL, + extra_instructions = NULL +) { + if (!inherits(db_source, "database_source")) { + rlang::abort("`db_source` must be a database_source object") + } + + schema <- get_database_schema(db_source) + + if (!is.null(data_description)) { + data_description <- paste(data_description, collapse = "\n") + } + if (!is.null(extra_instructions)) { + extra_instructions <- paste(extra_instructions, collapse = "\n") + } + + # Read the prompt file + prompt_path <- system.file("prompt", "prompt.md", package = "querychat") + prompt_content <- readLines(prompt_path, warn = FALSE) + prompt_text <- paste(prompt_content, collapse = "\n") + + whisker::whisker.render( + prompt_text, + list( + schema = schema, + data_description = data_description, + extra_instructions = extra_instructions + ) + ) +} diff --git a/r-package/R/querychat.R b/r-package/R/querychat.R index 891081dd..0e57d8ed 100644 --- a/r-package/R/querychat.R +++ b/r-package/R/querychat.R @@ -3,11 +3,14 @@ #' This will perform one-time initialization that can then be shared by all #' Shiny sessions in the R process. #' -#' @param df A data frame. +#' @param data_source Either a data frame or a database_source object created by +#' `database_source()`. For backwards compatibility, `df` can also be used. +#' @param df Deprecated. Use `data_source` instead. A data frame. #' @param tbl_name A string containing a valid table name for the data frame, #' that will appear in SQL queries. Ensure that it begins with a letter, and #' contains only letters, numbers, and underscores. By default, querychat will -#' try to infer a table name using the name of the `df` argument. +#' try to infer a table name using the name of the `df` argument. Not used +#' when `data_source` is a database_source object. #' @param greeting A string in Markdown format, containing the initial message #' to display to the user upon first loading the chatbot. If not provided, the #' LLM will be invoked at the start of the conversation to generate one. @@ -33,42 +36,90 @@ #' #' @export querychat_init <- function( - df, - tbl_name = deparse(substitute(df)), + data_source = NULL, + df = NULL, + tbl_name = NULL, greeting = NULL, data_description = NULL, extra_instructions = NULL, create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o"), - system_prompt = querychat_system_prompt( - df, - tbl_name, - data_description = data_description, - extra_instructions = extra_instructions - ) + system_prompt = NULL ) { - is_tbl_name_ok <- is.character(tbl_name) && - length(tbl_name) == 1 && - grepl("^[a-zA-Z][a-zA-Z0-9_]*$", tbl_name, perl = TRUE) - if (!is_tbl_name_ok) { - if (missing(tbl_name)) { - rlang::abort( - "Unable to infer table name from `df` argument. Please specify `tbl_name` argument explicitly." + # Handle backwards compatibility and argument validation + if (!is.null(df) && !is.null(data_source)) { + rlang::abort("Cannot specify both `df` and `data_source` arguments") + } + + if (!is.null(df)) { + rlang::warn("`df` argument is deprecated. Use `data_source` instead.") + data_source <- df + } + + if (is.null(data_source)) { + rlang::abort("Must provide either `data_source` or `df` argument") + } + + force(create_chat_func) + + # Determine source type and setup + is_database_source <- inherits(data_source, "database_source") + is_dataframe <- is.data.frame(data_source) + + if (!is_database_source && !is_dataframe) { + rlang::abort("`data_source` must be either a data frame or database_source object") + } + + if (is_database_source) { + # Using database source + db_source <- data_source + conn <- db_source$conn + tbl_name <- db_source$table_name + df <- NULL # No data frame for database sources + + # Generate system prompt if not provided + if (is.null(system_prompt)) { + system_prompt <- querychat_system_prompt_database( + db_source, + data_description = data_description, + extra_instructions = extra_instructions ) - } else { - rlang::abort( - "`tbl_name` argument must be a string containing a valid table name." + } + } else { + # Using data frame source - set up DuckDB + if (is.null(tbl_name)) { + tbl_name <- deparse(substitute(data_source)) + if (is.null(tbl_name) || tbl_name == "NULL") { + rlang::abort("Unable to infer table name. Please specify `tbl_name` argument explicitly.") + } + } + + is_tbl_name_ok <- is.character(tbl_name) && + length(tbl_name) == 1 && + grepl("^[a-zA-Z][a-zA-Z0-9_]*$", tbl_name, perl = TRUE) + if (!is_tbl_name_ok) { + rlang::abort("`tbl_name` argument must be a string containing a valid table name.") + } + + df <- data_source + conn <- DBI::dbConnect(duckdb::duckdb(), dbdir = ":memory:") + duckdb::duckdb_register(conn, tbl_name, df, experimental = FALSE) + shiny::onStop(function() DBI::dbDisconnect(conn)) + + # Generate system prompt if not provided + if (is.null(system_prompt)) { + system_prompt <- querychat_system_prompt( + df, + tbl_name, + data_description = data_description, + extra_instructions = extra_instructions ) } + + db_source <- NULL } - - force(df) - force(system_prompt) - force(create_chat_func) - - # TODO: Provide nicer looking errors here + + # Validate system prompt and create_chat_func stopifnot( - "df must be a data frame" = is.data.frame(df), - "tbl_name must be a string" = is.character(tbl_name), "system_prompt must be a string" = is.character(system_prompt), "create_chat_func must be a function" = is.function(create_chat_func) ) @@ -82,17 +133,16 @@ querychat_init <- function( )) } - conn <- DBI::dbConnect(duckdb::duckdb(), dbdir = ":memory:") - duckdb::duckdb_register(conn, tbl_name, df, experimental = FALSE) - shiny::onStop(function() DBI::dbDisconnect(conn)) - structure( list( df = df, conn = conn, + db_source = db_source, system_prompt = system_prompt, greeting = greeting, - create_chat_func = create_chat_func + create_chat_func = create_chat_func, + is_database_source = is_database_source, + table_name = tbl_name ), class = "querychat_config" ) @@ -158,6 +208,9 @@ querychat_server <- function(id, querychat_config) { df <- querychat_config[["df"]] conn <- querychat_config[["conn"]] + db_source <- querychat_config[["db_source"]] + is_database_source <- querychat_config[["is_database_source"]] + table_name <- querychat_config[["table_name"]] system_prompt <- querychat_config[["system_prompt"]] greeting <- querychat_config[["greeting"]] create_chat_func <- querychat_config[["create_chat_func"]] @@ -166,8 +219,15 @@ querychat_server <- function(id, querychat_config) { current_query <- shiny::reactiveVal("") filtered_df <- shiny::reactive({ if (current_query() == "") { - df + if (is_database_source) { + # For database sources, get all data when no filter is applied + get_database_data(db_source) + } else { + # For data frames, return the original data frame + df + } } else { + # Execute the current query against the appropriate connection DBI::dbGetQuery(conn, current_query()) } }) @@ -215,22 +275,25 @@ querychat_server <- function(id, querychat_config) { # @return The results of the query as a JSON string. query <- function(query) { # Do this before query, in case it errors - append_output("\n```sql\n", query, "\n```\n\n") + append_output("\n```sql\n", query, "\n```\n") tryCatch( { + # Return a lazy dbplyr tbl instead of executing the query df <- DBI::dbGetQuery(conn, query) + if (inherits(df, "tbl_dbi")) { + # If we already have a tbl_dbi, just return it + return(df) + } else { + # Otherwise create a new tbl_dbi from the connection + return(DBI::dbGetQuery(conn, query) |> dbplyr::tbl()) + } }, error = function(e) { append_output("> Error: ", conditionMessage(e), "\n\n") stop(e) } ) - - tbl_html <- df_to_html(df, maxrows = 5) - append_output(tbl_html, "\n\n") - - df |> jsonlite::toJSON(auto_unbox = TRUE) } # Preload the conversation with the system prompt. These are instructions for diff --git a/r-package/examples/app-database.R b/r-package/examples/app-database.R new file mode 100644 index 00000000..94a7a438 --- /dev/null +++ b/r-package/examples/app-database.R @@ -0,0 +1,96 @@ +library(shiny) +library(bslib) +library(querychat) +library(DBI) +library(RSQLite) + +# Create a sample SQLite database for demonstration +# In a real app, you would connect to your existing database +temp_db <- tempfile(fileext = ".db") +conn <- dbConnect(RSQLite::SQLite(), temp_db) + +# Create sample data in the database +iris_data <- iris +dbWriteTable(conn, "iris", iris_data, overwrite = TRUE) + +# Create another sample table +mtcars_data <- mtcars[1:20, ] # First 20 rows for demo +dbWriteTable(conn, "mtcars", mtcars_data, overwrite = TRUE) + +# Disconnect temporarily - we'll reconnect in the app +dbDisconnect(conn) + +# Define a custom greeting for the database app +greeting <- " +# Welcome to the Database Query Assistant! πŸ“Š + +I can help you explore and analyze data from the connected database. +Ask me questions about the iris or mtcars datasets, and I'll generate +SQL queries to get the answers. + +Try asking: +- Show me the first 10 rows of the iris dataset +- What's the average sepal length by species? +- Which cars have the highest miles per gallon? +- Create a summary of the mtcars data grouped by number of cylinders +" + +# Create database source +# Note: In a production app, you would use your actual database credentials +db_conn <- dbConnect(RSQLite::SQLite(), temp_db) +iris_source <- database_source(db_conn, "iris") + +# Configure querychat for database +querychat_config <- querychat_init( + data_source = iris_source, + greeting = greeting, + data_description = "This database contains the famous iris flower dataset with measurements of sepal and petal dimensions across three species, and a subset of the mtcars dataset with automobile specifications.", + extra_instructions = "When showing results, always explain what the data represents and highlight any interesting patterns you observe." +) + +ui <- page_sidebar( + title = "Database Query Chat", + sidebar = querychat_sidebar("chat"), + h2("Current Data View"), + p("The table below shows the current filtered data based on your chat queries:"), + DT::DTOutput("data_table"), + br(), + h3("Current SQL Query"), + verbatimTextOutput("sql_query"), + br(), + h3("Available Tables"), + p("This demo database contains:"), + tags$ul( + tags$li("iris - Famous iris flower dataset (150 rows, 5 columns)"), + tags$li("mtcars - Motor car specifications (20 rows, 11 columns)") + ) +) + +server <- function(input, output, session) { + chat <- querychat_server("chat", querychat_config) + + output$data_table <- DT::renderDT({ + chat$df() + }, options = list(pageLength = 10, scrollX = TRUE)) + + output$sql_query <- renderText({ + query <- chat$sql() + if (query == "") { + "No filter applied - showing all data" + } else { + query + } + }) + + # Clean up database connection when app stops + session$onSessionEnded(function() { + if (dbIsValid(db_conn)) { + dbDisconnect(db_conn) + } + if (file.exists(temp_db)) { + unlink(temp_db) + } + }) +} + +shinyApp(ui = ui, server = server) \ No newline at end of file diff --git a/r-package/examples/database-setup.md b/r-package/examples/database-setup.md new file mode 100644 index 00000000..31426f9e --- /dev/null +++ b/r-package/examples/database-setup.md @@ -0,0 +1,122 @@ +# Database Setup Examples for querychat + +This document provides examples of how to set up querychat with various database types using the new `database_source()` functionality. + +## SQLite + +```r +library(DBI) +library(RSQLite) +library(querychat) + +# Connect to SQLite database +conn <- dbConnect(RSQLite::SQLite(), "path/to/your/database.db") + +# Create database source +db_source <- database_source(conn, "your_table_name") + +# Initialize querychat +config <- querychat_init( + data_source = db_source, + greeting = "Welcome! Ask me about your data.", + data_description = "Description of your data..." +) +``` + +## PostgreSQL + +```r +library(DBI) +library(RPostgreSQL) # or library(RPostgres) +library(querychat) + +# Connect to PostgreSQL +conn <- dbConnect( + RPostgreSQL::PostgreSQL(), # or RPostgres::Postgres() + dbname = "your_database", + host = "localhost", + port = 5432, + user = "your_username", + password = "your_password" +) + +# Create database source +db_source <- database_source(conn, "your_table_name") + +# Initialize querychat +config <- querychat_init(data_source = db_source) +``` + +## MySQL + +```r +library(DBI) +library(RMySQL) +library(querychat) + +# Connect to MySQL +conn <- dbConnect( + RMySQL::MySQL(), + dbname = "your_database", + host = "localhost", + user = "your_username", + password = "your_password" +) + +# Create database source +db_source <- database_source(conn, "your_table_name") + +# Initialize querychat +config <- querychat_init(data_source = db_source) +``` + +## Connection Management + +When using database sources in Shiny apps, make sure to properly manage connections: + +```r +server <- function(input, output, session) { + # Your querychat server logic here + chat <- querychat_server("chat", querychat_config) + + # Clean up connection when session ends + session$onSessionEnded(function() { + if (dbIsValid(conn)) { + dbDisconnect(conn) + } + }) +} +``` + +## Configuration Options + +The `database_source()` function accepts a `categorical_threshold` parameter: + +```r +# Columns with <= 50 unique values will be treated as categorical +db_source <- database_source(conn, "table_name", categorical_threshold = 50) +``` + +## Security Considerations + +- Only SELECT queries are allowed - no INSERT, UPDATE, or DELETE operations +- All SQL queries are visible to users for transparency +- Use appropriate database user permissions (read-only recommended) +- Consider connection pooling for production applications +- Validate that users only have access to intended tables + +## Error Handling + +The database source implementation includes robust error handling: + +- Validates table existence during creation +- Handles database connection issues gracefully +- Provides informative error messages for invalid queries +- Falls back gracefully when statistical queries fail + +## Performance Tips + +- Use appropriate database indexes for columns commonly used in queries +- Consider limiting row counts for very large tables +- Database connections are reused for better performance +- Schema information is cached to avoid repeated metadata queries \ No newline at end of file diff --git a/r-package/tests/test_database_source.R b/r-package/tests/test_database_source.R new file mode 100644 index 00000000..6433f120 --- /dev/null +++ b/r-package/tests/test_database_source.R @@ -0,0 +1,135 @@ +library(testthat) +library(DBI) +library(RSQLite) +library(querychat) + +test_that("database_source creation and basic functionality", { + # Create temporary SQLite database + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + + # Create test table + test_data <- data.frame( + id = 1:5, + name = c("Alice", "Bob", "Charlie", "Diana", "Eve"), + age = c(25, 30, 35, 28, 32), + city = c("NYC", "LA", "NYC", "Chicago", "LA"), + stringsAsFactors = FALSE + ) + + dbWriteTable(conn, "users", test_data, overwrite = TRUE) + + # Test database_source creation + db_source <- database_source(conn, "users") + expect_s3_class(db_source, "database_source") + expect_equal(db_source$table_name, "users") + expect_equal(db_source$categorical_threshold, 20) + + # Test schema generation + schema <- get_database_schema(db_source) + expect_type(schema, "character") + expect_true(grepl("Table: users", schema)) + expect_true(grepl("id \\(INTEGER\\)", schema)) + expect_true(grepl("name \\(TEXT\\)", schema)) + expect_true(grepl("Categorical values:", schema)) # Should show city values + + # Test query execution + result <- execute_database_query(db_source, "SELECT * FROM users WHERE age > 30") + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 2) # Charlie and Eve + + # Test get all data + all_data <- get_database_data(db_source) + expect_s3_class(all_data, "data.frame") + expect_equal(nrow(all_data), 5) + expect_equal(ncol(all_data), 4) + + # Clean up + dbDisconnect(conn) + unlink(temp_db) +}) + +test_that("database_source error handling", { + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + + # Test error for non-existent table + expect_error( + database_source(conn, "nonexistent_table"), + "Table 'nonexistent_table' not found" + ) + + # Test error for invalid connection + expect_error( + database_source("not_a_connection", "table"), + "must be a valid DBI connection object" + ) + + # Test error for invalid table name + dbWriteTable(conn, "test", data.frame(x = 1:3), overwrite = TRUE) + expect_error( + database_source(conn, c("table1", "table2")), + "must be a single character string" + ) + + dbDisconnect(conn) + unlink(temp_db) +}) + +test_that("querychat_init with database_source", { + # Create temporary SQLite database + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + + # Create test table + test_data <- data.frame( + product = c("A", "B", "C"), + sales = c(100, 150, 200), + region = c("North", "South", "North"), + stringsAsFactors = FALSE + ) + + dbWriteTable(conn, "sales", test_data, overwrite = TRUE) + + # Create database source + db_source <- database_source(conn, "sales") + + # Test querychat_init with database source + config <- querychat_init( + data_source = db_source, + greeting = "Test greeting", + data_description = "Test sales data" + ) + + expect_s3_class(config, "querychat_config") + expect_true(config$is_database_source) + expect_equal(config$table_name, "sales") + expect_null(config$df) # Should be NULL for database sources + expect_identical(config$db_source, db_source) + expect_type(config$system_prompt, "character") + expect_true(nchar(config$system_prompt) > 0) + + # Clean up + dbDisconnect(conn) + unlink(temp_db) +}) + +test_that("backwards compatibility with df argument", { + test_df <- data.frame(x = 1:3, y = letters[1:3]) + + # Test that using df argument still works but shows warning + expect_warning( + config <- querychat_init(df = test_df, tbl_name = "test"), + "deprecated" + ) + + expect_s3_class(config, "querychat_config") + expect_false(config$is_database_source) + expect_equal(config$table_name, "test") + + # Test error when both df and data_source provided + expect_error( + querychat_init(data_source = test_df, df = test_df), + "Cannot specify both" + ) +}) \ No newline at end of file diff --git a/r-package/tests/test_querychat_server.R b/r-package/tests/test_querychat_server.R new file mode 100644 index 00000000..feba7e5a --- /dev/null +++ b/r-package/tests/test_querychat_server.R @@ -0,0 +1,83 @@ +library(testthat) +library(DBI) +library(RSQLite) +library(dbplyr) +library(querychat) + +# Helper function to create a test querychat server +create_test_querychat_server <- function(conn) { + # Create a test chat configuration + system_prompt <- "You are a helpful query assistant." + + # Create a temporary Shiny session + session <- shiny::Session$new() + session$input <- list() + session$output <- list() + + # Initialize querychat server + querychat_config <- list( + conn = conn, + system_prompt = system_prompt + ) + + # Create a mock module server + server <- function(input, output, session) { + querychat_server("test", querychat_config) + } + + # Call the server + server(session$input, session$output, session) + + # Return the session and config + list(session = session, config = querychat_config) +} + +test_that("querychat_server returns lazy dbplyr tbl", { + # Create temporary SQLite database + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + + # Create test table + test_data <- data.frame( + id = 1:5, + name = c("Alice", "Bob", "Charlie", "Diana", "Eve"), + age = c(25, 30, 35, 28, 32), + stringsAsFactors = FALSE + ) + + dbWriteTable(conn, "users", test_data, overwrite = TRUE) + + # Create test server + test_env <- create_test_querychat_server(conn) + query_func <- test_env$config$query + + # Test that query returns a lazy tbl + result <- query_func("SELECT * FROM users WHERE age > 30") + expect_s3_class(result, "tbl") + expect_s3_class(result, "tbl_dbi") + + # Test that the query hasn't been executed yet + # We can check this by modifying the table and seeing if the result changes + dbExecute(conn, "UPDATE users SET age = age + 10") + result_after_update <- query_func("SELECT * FROM users WHERE age > 30") + expect_equal(nrow(result), nrow(result_after_update)) # Still same number of rows + + # Test that we can chain dbplyr operations + chained_result <- result |> + filter(age > 30) |> + arrange(desc(age)) + expect_s3_class(chained_result, "tbl") + expect_s3_class(chained_result, "tbl_dbi") + + # Test that collect() executes the query + collected_result <- collect(chained_result) + expect_s3_class(collected_result, "data.frame") + expect_equal(nrow(collected_result), 3) # Charlie, Diana, and Eve after update + + # Clean up + dbDisconnect(conn) + unlink(temp_db) +}) + +# Run the tests +testthat::test_file("test_querychat_server.R") From 2ececf5afe3e7371214b64cc76c8ebac19bba43c Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Sat, 7 Jun 2025 15:07:30 -0600 Subject: [PATCH 17/51] critical change: should return a lazy table rather than executing by default --- r-package/DESCRIPTION | 5 ++ r-package/R/datasource.R | 8 +- r-package/R/querychat.R | 30 +++---- r-package/tests/test_querychat_server.R | 83 ------------------- r-package/tests/testthat.R | 12 +++ .../{ => testthat}/test_database_source.R | 2 +- .../tests/testthat/test_querychat_server.R | 45 ++++++++++ 7 files changed, 83 insertions(+), 102 deletions(-) delete mode 100644 r-package/tests/test_querychat_server.R create mode 100644 r-package/tests/testthat.R rename r-package/tests/{ => testthat}/test_database_source.R (97%) create mode 100644 r-package/tests/testthat/test_querychat_server.R diff --git a/r-package/DESCRIPTION b/r-package/DESCRIPTION index 6319c056..c0eab600 100644 --- a/r-package/DESCRIPTION +++ b/r-package/DESCRIPTION @@ -14,7 +14,9 @@ Depends: R (>= 4.1.0) Imports: bslib, + dbplyr, DBI, + dplyr, duckdb, ellmer, glue, @@ -29,3 +31,6 @@ Imports: Encoding: UTF-8 Roxygen: list(markdown = TRUE) RoxygenNote: 7.3.2 +Suggests: + testthat (>= 3.0.0) +Config/testthat/edition: 3 diff --git a/r-package/R/datasource.R b/r-package/R/datasource.R index 9e66bb97..7fec1fdf 100644 --- a/r-package/R/datasource.R +++ b/r-package/R/datasource.R @@ -191,18 +191,18 @@ execute_database_query <- function(source, query) { DBI::dbGetQuery(source$conn, query) } -#' Get all data from database source +#' Get lazy database table reference #' #' @param source A database_source object -#' @return A data frame with all data from the table +#' @return A lazy dbplyr tbl object that can be further manipulated with dplyr verbs #' @export get_database_data <- function(source) { if (!inherits(source, "database_source")) { rlang::abort("`source` must be a database_source object") } - query <- glue::glue_sql("SELECT * FROM {`source$table_name`}", .con = source$conn) - DBI::dbGetQuery(source$conn, query) + # Return a lazy tbl that can be chained with further dplyr operations + dplyr::tbl(source$conn, source$table_name) } # Helper function to map R classes to SQL types diff --git a/r-package/R/querychat.R b/r-package/R/querychat.R index 0e57d8ed..4753bf85 100644 --- a/r-package/R/querychat.R +++ b/r-package/R/querychat.R @@ -195,8 +195,10 @@ querychat_ui <- function(id) { #' #' - `sql`: A reactive that returns the current SQL query. #' - `title`: A reactive that returns the current title. -#' - `df`: A reactive that returns the data frame, filtered and sorted by the -#' current SQL query. +#' - `df`: A reactive that returns the filtered data. For data frame sources, +#' this returns a data.frame. For database sources, this returns a lazy +#' dbplyr tbl that can be further manipulated with dplyr verbs before +#' calling collect() to materialize the results. #' - `chat`: The [ellmer::Chat] object that powers the chat interface. #' #' By convention, this object should be named `querychat_config`. @@ -220,15 +222,21 @@ querychat_server <- function(id, querychat_config) { filtered_df <- shiny::reactive({ if (current_query() == "") { if (is_database_source) { - # For database sources, get all data when no filter is applied + # For database sources, return lazy tbl (no data transfer) get_database_data(db_source) } else { # For data frames, return the original data frame df } } else { - # Execute the current query against the appropriate connection - DBI::dbGetQuery(conn, current_query()) + if (is_database_source) { + # For database sources, return lazy tbl with custom query + # Parse and create a lazy tbl from the SQL query + dplyr::tbl(conn, dplyr::sql(current_query())) + } else { + # For data frames, execute query and return result + DBI::dbGetQuery(conn, current_query()) + } } }) @@ -279,15 +287,9 @@ querychat_server <- function(id, querychat_config) { tryCatch( { - # Return a lazy dbplyr tbl instead of executing the query - df <- DBI::dbGetQuery(conn, query) - if (inherits(df, "tbl_dbi")) { - # If we already have a tbl_dbi, just return it - return(df) - } else { - # Otherwise create a new tbl_dbi from the connection - return(DBI::dbGetQuery(conn, query) |> dbplyr::tbl()) - } + # Execute the query and return the results as a data frame + # This tool is for answering questions, so we need actual results + DBI::dbGetQuery(conn, query) }, error = function(e) { append_output("> Error: ", conditionMessage(e), "\n\n") diff --git a/r-package/tests/test_querychat_server.R b/r-package/tests/test_querychat_server.R deleted file mode 100644 index feba7e5a..00000000 --- a/r-package/tests/test_querychat_server.R +++ /dev/null @@ -1,83 +0,0 @@ -library(testthat) -library(DBI) -library(RSQLite) -library(dbplyr) -library(querychat) - -# Helper function to create a test querychat server -create_test_querychat_server <- function(conn) { - # Create a test chat configuration - system_prompt <- "You are a helpful query assistant." - - # Create a temporary Shiny session - session <- shiny::Session$new() - session$input <- list() - session$output <- list() - - # Initialize querychat server - querychat_config <- list( - conn = conn, - system_prompt = system_prompt - ) - - # Create a mock module server - server <- function(input, output, session) { - querychat_server("test", querychat_config) - } - - # Call the server - server(session$input, session$output, session) - - # Return the session and config - list(session = session, config = querychat_config) -} - -test_that("querychat_server returns lazy dbplyr tbl", { - # Create temporary SQLite database - temp_db <- tempfile(fileext = ".db") - conn <- dbConnect(RSQLite::SQLite(), temp_db) - - # Create test table - test_data <- data.frame( - id = 1:5, - name = c("Alice", "Bob", "Charlie", "Diana", "Eve"), - age = c(25, 30, 35, 28, 32), - stringsAsFactors = FALSE - ) - - dbWriteTable(conn, "users", test_data, overwrite = TRUE) - - # Create test server - test_env <- create_test_querychat_server(conn) - query_func <- test_env$config$query - - # Test that query returns a lazy tbl - result <- query_func("SELECT * FROM users WHERE age > 30") - expect_s3_class(result, "tbl") - expect_s3_class(result, "tbl_dbi") - - # Test that the query hasn't been executed yet - # We can check this by modifying the table and seeing if the result changes - dbExecute(conn, "UPDATE users SET age = age + 10") - result_after_update <- query_func("SELECT * FROM users WHERE age > 30") - expect_equal(nrow(result), nrow(result_after_update)) # Still same number of rows - - # Test that we can chain dbplyr operations - chained_result <- result |> - filter(age > 30) |> - arrange(desc(age)) - expect_s3_class(chained_result, "tbl") - expect_s3_class(chained_result, "tbl_dbi") - - # Test that collect() executes the query - collected_result <- collect(chained_result) - expect_s3_class(collected_result, "data.frame") - expect_equal(nrow(collected_result), 3) # Charlie, Diana, and Eve after update - - # Clean up - dbDisconnect(conn) - unlink(temp_db) -}) - -# Run the tests -testthat::test_file("test_querychat_server.R") diff --git a/r-package/tests/testthat.R b/r-package/tests/testthat.R new file mode 100644 index 00000000..23f8c818 --- /dev/null +++ b/r-package/tests/testthat.R @@ -0,0 +1,12 @@ +# This file is part of the standard setup for testthat. +# It is recommended that you do not modify it. +# +# Where should you do additional test configuration? +# Learn more about the roles of various files in: +# * https://r-pkgs.org/testing-design.html#sec-tests-files-overview +# * https://testthat.r-lib.org/articles/special-files.html + +library(testthat) +library(querychat) + +test_check("querychat") diff --git a/r-package/tests/test_database_source.R b/r-package/tests/testthat/test_database_source.R similarity index 97% rename from r-package/tests/test_database_source.R rename to r-package/tests/testthat/test_database_source.R index 6433f120..08a33757 100644 --- a/r-package/tests/test_database_source.R +++ b/r-package/tests/testthat/test_database_source.R @@ -31,7 +31,7 @@ test_that("database_source creation and basic functionality", { expect_true(grepl("Table: users", schema)) expect_true(grepl("id \\(INTEGER\\)", schema)) expect_true(grepl("name \\(TEXT\\)", schema)) - expect_true(grepl("Categorical values:", schema)) # Should show city values + expect_true(grepl("city \\(TEXT\\)", schema)) # Should have city column # Test query execution result <- execute_database_query(db_source, "SELECT * FROM users WHERE age > 30") diff --git a/r-package/tests/testthat/test_querychat_server.R b/r-package/tests/testthat/test_querychat_server.R new file mode 100644 index 00000000..955b73f7 --- /dev/null +++ b/r-package/tests/testthat/test_querychat_server.R @@ -0,0 +1,45 @@ +library(testthat) +library(DBI) +library(RSQLite) +library(dbplyr) +library(querychat) + +test_that("database source query functionality", { + # Create temporary SQLite database + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + + # Create test table + test_data <- data.frame( + id = 1:5, + name = c("Alice", "Bob", "Charlie", "Diana", "Eve"), + age = c(25, 30, 35, 28, 32), + stringsAsFactors = FALSE + ) + + dbWriteTable(conn, "users", test_data, overwrite = TRUE) + + # Create database source + db_source <- database_source(conn, "users") + + # Test that we can execute queries + result <- execute_database_query(db_source, "SELECT * FROM users WHERE age > 30") + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 2) # Charlie and Eve + expect_equal(result$name, c("Charlie", "Eve")) + + # Test that we can get all data + all_data <- get_database_data(db_source) + expect_s3_class(all_data, "data.frame") + expect_equal(nrow(all_data), 5) + expect_equal(ncol(all_data), 3) + + # Test ordering works + ordered_result <- execute_database_query(db_source, "SELECT * FROM users ORDER BY age DESC") + expect_equal(ordered_result$name[1], "Charlie") # Oldest first + + # Clean up + dbDisconnect(conn) + unlink(temp_db) +}) + From f4ca445c8855897775876af9ec6ae4c4f6797018 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Sat, 7 Jun 2025 15:14:42 -0600 Subject: [PATCH 18/51] edits to test suite and devtools::check() passing --- r-package/DESCRIPTION | 2 +- r-package/NAMESPACE | 5 + r-package/man/querychat_init.Rd | 16 ++-- r-package/man/querychat_server.Rd | 6 +- .../tests/testthat/test_database_source.R | 93 ++++++++++++++++++- .../tests/testthat/test_querychat_server.R | 20 +++- 6 files changed, 127 insertions(+), 15 deletions(-) diff --git a/r-package/DESCRIPTION b/r-package/DESCRIPTION index c0eab600..1d2ea611 100644 --- a/r-package/DESCRIPTION +++ b/r-package/DESCRIPTION @@ -21,7 +21,6 @@ Imports: ellmer, glue, htmltools, - jsonlite, purrr, rlang, shiny, @@ -32,5 +31,6 @@ Encoding: UTF-8 Roxygen: list(markdown = TRUE) RoxygenNote: 7.3.2 Suggests: + RSQLite, testthat (>= 3.0.0) Config/testthat/edition: 3 diff --git a/r-package/NAMESPACE b/r-package/NAMESPACE index d1e39fd8..691b73db 100644 --- a/r-package/NAMESPACE +++ b/r-package/NAMESPACE @@ -1,7 +1,12 @@ # Generated by roxygen2: do not edit by hand +export(database_source) +export(execute_database_query) +export(get_database_data) +export(get_database_schema) export(querychat_init) export(querychat_server) export(querychat_sidebar) export(querychat_system_prompt) +export(querychat_system_prompt_database) export(querychat_ui) diff --git a/r-package/man/querychat_init.Rd b/r-package/man/querychat_init.Rd index 260261ae..6482c88a 100644 --- a/r-package/man/querychat_init.Rd +++ b/r-package/man/querychat_init.Rd @@ -5,23 +5,27 @@ \title{Call this once outside of any server function} \usage{ querychat_init( - df, - tbl_name = deparse(substitute(df)), + data_source = NULL, + df = NULL, + tbl_name = NULL, greeting = NULL, data_description = NULL, extra_instructions = NULL, create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o"), - system_prompt = querychat_system_prompt(df, tbl_name, data_description = - data_description, extra_instructions = extra_instructions) + system_prompt = NULL ) } \arguments{ -\item{df}{A data frame.} +\item{data_source}{Either a data frame or a database_source object created by +\code{database_source()}. For backwards compatibility, \code{df} can also be used.} + +\item{df}{Deprecated. Use \code{data_source} instead. A data frame.} \item{tbl_name}{A string containing a valid table name for the data frame, that will appear in SQL queries. Ensure that it begins with a letter, and contains only letters, numbers, and underscores. By default, querychat will -try to infer a table name using the name of the \code{df} argument.} +try to infer a table name using the name of the \code{df} argument. Not used +when \code{data_source} is a database_source object.} \item{greeting}{A string in Markdown format, containing the initial message to display to the user upon first loading the chatbot. If not provided, the diff --git a/r-package/man/querychat_server.Rd b/r-package/man/querychat_server.Rd index f6daa5c7..89b9e9d9 100644 --- a/r-package/man/querychat_server.Rd +++ b/r-package/man/querychat_server.Rd @@ -18,8 +18,10 @@ elements: \itemize{ \item \code{sql}: A reactive that returns the current SQL query. \item \code{title}: A reactive that returns the current title. -\item \code{df}: A reactive that returns the data frame, filtered and sorted by the -current SQL query. +\item \code{df}: A reactive that returns the filtered data. For data frame sources, +this returns a data.frame. For database sources, this returns a lazy +dbplyr tbl that can be further manipulated with dplyr verbs before +calling collect() to materialize the results. \item \code{chat}: The \link[ellmer:Chat]{ellmer::Chat} object that powers the chat interface. } diff --git a/r-package/tests/testthat/test_database_source.R b/r-package/tests/testthat/test_database_source.R index 08a33757..d5c36391 100644 --- a/r-package/tests/testthat/test_database_source.R +++ b/r-package/tests/testthat/test_database_source.R @@ -1,6 +1,7 @@ library(testthat) library(DBI) library(RSQLite) +library(dplyr) library(querychat) test_that("database_source creation and basic functionality", { @@ -38,10 +39,24 @@ test_that("database_source creation and basic functionality", { expect_s3_class(result, "data.frame") expect_equal(nrow(result), 2) # Charlie and Eve - # Test get all data + # Test get all data returns lazy dbplyr table all_data <- get_database_data(db_source) - expect_s3_class(all_data, "data.frame") - expect_equal(nrow(all_data), 5) + expect_s3_class(all_data, c("tbl_SQLiteConnection", "tbl_dbi", "tbl_sql", "tbl_lazy", "tbl")) + + # Test that it can be chained with dbplyr operations before collect() + filtered_data <- all_data |> + dplyr::filter(age > 30) |> + dplyr::arrange(dplyr::desc(age)) |> + dplyr::collect() + + expect_s3_class(filtered_data, "data.frame") + expect_equal(nrow(filtered_data), 2) # Charlie and Eve + expect_equal(filtered_data$name, c("Charlie", "Eve")) + + # Test that the lazy table can be collected to get all data + collected_data <- dplyr::collect(all_data) + expect_s3_class(collected_data, "data.frame") + expect_equal(nrow(collected_data), 5) expect_equal(ncol(all_data), 4) # Clean up @@ -114,6 +129,78 @@ test_that("querychat_init with database_source", { unlink(temp_db) }) +test_that("lazy dbplyr table behavior and chaining", { + # Create temporary SQLite database with more complex data + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + + # Create test table with varied data + test_data <- data.frame( + id = 1:10, + name = paste0("User", 1:10), + age = c(25, 30, 35, 28, 32, 45, 22, 38, 41, 29), + department = rep(c("Sales", "Engineering", "Marketing"), length.out = 10), + salary = c(50000, 75000, 85000, 60000, 80000, 120000, 45000, 90000, 110000, 65000), + stringsAsFactors = FALSE + ) + + dbWriteTable(conn, "employees", test_data, overwrite = TRUE) + + # Create database source + db_source <- database_source(conn, "employees") + + # Test that get_database_data returns a lazy table + lazy_table <- get_database_data(db_source) + expect_s3_class(lazy_table, c("tbl_SQLiteConnection", "tbl_dbi", "tbl_sql", "tbl_lazy", "tbl")) + + # Test complex chaining operations before collect() + complex_result <- lazy_table |> + dplyr::filter(age > 30, salary > 70000) |> + dplyr::select(name, department, age, salary) |> + dplyr::arrange(dplyr::desc(salary)) |> + dplyr::mutate(senior = age > 35) |> + dplyr::collect() + + expect_s3_class(complex_result, "data.frame") + expect_true(nrow(complex_result) > 0) + expect_true(all(complex_result$age > 30)) + expect_true(all(complex_result$salary > 70000)) + expect_true("senior" %in% names(complex_result)) + + # Test grouping and summarizing operations + summary_result <- lazy_table |> + dplyr::group_by(department) |> + dplyr::summarise( + avg_age = mean(age, na.rm = TRUE), + avg_salary = mean(salary, na.rm = TRUE), + count = dplyr::n(), + .groups = "drop" + ) |> + dplyr::collect() + + expect_s3_class(summary_result, "data.frame") + expect_equal(nrow(summary_result), 3) # Three departments + expect_true(all(c("department", "avg_age", "avg_salary", "count") %in% names(summary_result))) + + # Test that the lazy table can be reused for different operations + young_employees <- lazy_table |> + dplyr::filter(age < 30) |> + dplyr::collect() + + senior_employees <- lazy_table |> + dplyr::filter(age >= 40) |> + dplyr::collect() + + expect_s3_class(young_employees, "data.frame") + expect_s3_class(senior_employees, "data.frame") + expect_true(all(young_employees$age < 30)) + expect_true(all(senior_employees$age >= 40)) + + # Clean up + dbDisconnect(conn) + unlink(temp_db) +}) + test_that("backwards compatibility with df argument", { test_df <- data.frame(x = 1:3, y = letters[1:3]) diff --git a/r-package/tests/testthat/test_querychat_server.R b/r-package/tests/testthat/test_querychat_server.R index 955b73f7..0835fd58 100644 --- a/r-package/tests/testthat/test_querychat_server.R +++ b/r-package/tests/testthat/test_querychat_server.R @@ -1,6 +1,7 @@ library(testthat) library(DBI) library(RSQLite) +library(dplyr) library(dbplyr) library(querychat) @@ -28,10 +29,23 @@ test_that("database source query functionality", { expect_equal(nrow(result), 2) # Charlie and Eve expect_equal(result$name, c("Charlie", "Eve")) - # Test that we can get all data + # Test that we can get all data as lazy dbplyr table all_data <- get_database_data(db_source) - expect_s3_class(all_data, "data.frame") - expect_equal(nrow(all_data), 5) + expect_s3_class(all_data, c("tbl_SQLiteConnection", "tbl_dbi", "tbl_sql", "tbl_lazy", "tbl")) + + # Test that it can be chained with dbplyr operations before collect() + filtered_data <- all_data |> + dplyr::filter(age >= 30) |> + dplyr::select(name, age) |> + dplyr::collect() + + expect_s3_class(filtered_data, "data.frame") + expect_equal(nrow(filtered_data), 3) # Bob, Charlie, Eve + + # Test that the lazy table can be collected to get all data + collected_data <- dplyr::collect(all_data) + expect_s3_class(collected_data, "data.frame") + expect_equal(nrow(collected_data), 5) expect_equal(ncol(all_data), 3) # Test ordering works From 48503f071064f86704abf421803119cddaa14f68 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Sat, 7 Jun 2025 15:31:46 -0600 Subject: [PATCH 19/51] example update --- r-package/DESCRIPTION | 1 + r-package/examples/app-database.R | 19 +++++++------------ 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/r-package/DESCRIPTION b/r-package/DESCRIPTION index ed141411..30d4f654 100644 --- a/r-package/DESCRIPTION +++ b/r-package/DESCRIPTION @@ -31,6 +31,7 @@ Encoding: UTF-8 Roxygen: list(markdown = TRUE) RoxygenNote: 7.3.2 Suggests: + DT, RSQLite, testthat (>= 3.0.0) Config/testthat/edition: 3 diff --git a/r-package/examples/app-database.R b/r-package/examples/app-database.R index 94a7a438..089a92b3 100644 --- a/r-package/examples/app-database.R +++ b/r-package/examples/app-database.R @@ -13,10 +13,6 @@ conn <- dbConnect(RSQLite::SQLite(), temp_db) iris_data <- iris dbWriteTable(conn, "iris", iris_data, overwrite = TRUE) -# Create another sample table -mtcars_data <- mtcars[1:20, ] # First 20 rows for demo -dbWriteTable(conn, "mtcars", mtcars_data, overwrite = TRUE) - # Disconnect temporarily - we'll reconnect in the app dbDisconnect(conn) @@ -24,15 +20,14 @@ dbDisconnect(conn) greeting <- " # Welcome to the Database Query Assistant! πŸ“Š -I can help you explore and analyze data from the connected database. -Ask me questions about the iris or mtcars datasets, and I'll generate -SQL queries to get the answers. +I can help you explore and analyze the iris dataset from the connected database. +Ask me questions about the iris flowers, and I'll generate SQL queries to get the answers. Try asking: - Show me the first 10 rows of the iris dataset - What's the average sepal length by species? -- Which cars have the highest miles per gallon? -- Create a summary of the mtcars data grouped by number of cylinders +- Which species has the largest petals? +- Create a summary of measurements grouped by species " # Create database source @@ -44,7 +39,7 @@ iris_source <- database_source(db_conn, "iris") querychat_config <- querychat_init( data_source = iris_source, greeting = greeting, - data_description = "This database contains the famous iris flower dataset with measurements of sepal and petal dimensions across three species, and a subset of the mtcars dataset with automobile specifications.", + data_description = "This database contains the famous iris flower dataset with measurements of sepal and petal dimensions across three species (setosa, versicolor, and virginica).", extra_instructions = "When showing results, always explain what the data represents and highlight any interesting patterns you observe." ) @@ -58,11 +53,11 @@ ui <- page_sidebar( h3("Current SQL Query"), verbatimTextOutput("sql_query"), br(), - h3("Available Tables"), + h3("Dataset Information"), p("This demo database contains:"), tags$ul( tags$li("iris - Famous iris flower dataset (150 rows, 5 columns)"), - tags$li("mtcars - Motor car specifications (20 rows, 11 columns)") + tags$li("Columns: Sepal.Length, Sepal.Width, Petal.Length, Petal.Width, Species") ) ) From 4809615d7f1cac98e2bc0d40ef02de8fd69aeea9 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Mon, 9 Jun 2025 15:18:38 -0600 Subject: [PATCH 20/51] error message for a footgun --- r-package/DESCRIPTION | 1 + r-package/R/datasource.R | 2 +- r-package/examples/app-database.R | 2 +- r-package/tests/testthat/test-shiny-app.R | 174 ++++++++++++++++++++++ 4 files changed, 177 insertions(+), 2 deletions(-) create mode 100644 r-package/tests/testthat/test-shiny-app.R diff --git a/r-package/DESCRIPTION b/r-package/DESCRIPTION index 30d4f654..8e7f2758 100644 --- a/r-package/DESCRIPTION +++ b/r-package/DESCRIPTION @@ -33,5 +33,6 @@ RoxygenNote: 7.3.2 Suggests: DT, RSQLite, + shinytest2, testthat (>= 3.0.0) Config/testthat/edition: 3 diff --git a/r-package/R/datasource.R b/r-package/R/datasource.R index 7fec1fdf..5cc703b0 100644 --- a/r-package/R/datasource.R +++ b/r-package/R/datasource.R @@ -34,7 +34,7 @@ database_source <- function(conn, table_name, categorical_threshold = 20) { } if (!DBI::dbExistsTable(conn, table_name)) { - rlang::abort(glue::glue("Table '{table_name}' not found in database")) + rlang::abort(glue::glue("Table '{table_name}' not found in database. If you're using databricks, try setting the 'Catalog' and 'Schema' arguments to DBI::dbConnect")) } structure( diff --git a/r-package/examples/app-database.R b/r-package/examples/app-database.R index 089a92b3..8b2609b0 100644 --- a/r-package/examples/app-database.R +++ b/r-package/examples/app-database.R @@ -88,4 +88,4 @@ server <- function(input, output, session) { }) } -shinyApp(ui = ui, server = server) \ No newline at end of file +shinyApp(ui = ui, server = server) diff --git a/r-package/tests/testthat/test-shiny-app.R b/r-package/tests/testthat/test-shiny-app.R new file mode 100644 index 00000000..488e8bb4 --- /dev/null +++ b/r-package/tests/testthat/test-shiny-app.R @@ -0,0 +1,174 @@ +library(testthat) + +test_that("app database example loads without errors", { + skip_if_not_installed("DT") + skip_if_not_installed("RSQLite") + skip_if_not_installed("shinytest2") + + # Create a simplified test app with mocked ellmer + test_app_file <- tempfile(fileext = ".R") + + test_app_content <- ' +library(shiny) +library(bslib) +library(querychat) +library(DBI) +library(RSQLite) + +# Mock chat function to avoid LLM API calls +mock_chat_func <- function(system_prompt) { + list( + register_tool = function(tool) invisible(NULL), + stream_async = function(message) { + "Welcome! This is a mock response for testing." + } + ) +} + +# Create test database +temp_db <- tempfile(fileext = ".db") +conn <- dbConnect(RSQLite::SQLite(), temp_db) +dbWriteTable(conn, "iris", iris, overwrite = TRUE) +dbDisconnect(conn) + +# Setup database source +db_conn <- dbConnect(RSQLite::SQLite(), temp_db) +iris_source <- database_source(db_conn, "iris") + +# Configure querychat with mock +querychat_config <- querychat_init( + data_source = iris_source, + greeting = "Welcome to the test app!", + create_chat_func = mock_chat_func +) + +ui <- page_sidebar( + title = "Test Database App", + sidebar = querychat_sidebar("chat"), + h2("Data"), + DT::DTOutput("data_table"), + h3("SQL Query"), + verbatimTextOutput("sql_query") +) + +server <- function(input, output, session) { + chat <- querychat_server("chat", querychat_config) + + output$data_table <- DT::renderDT({ + data <- chat$df() + if (inherits(data, "tbl_lazy")) { + dplyr::collect(data) + } else { + data + } + }, options = list(pageLength = 5)) + + output$sql_query <- renderText({ + query <- chat$sql() + if (query == "") "No filter applied" else query + }) + + session$onSessionEnded(function() { + if (DBI::dbIsValid(db_conn)) { + DBI::dbDisconnect(db_conn) + } + unlink(temp_db) + }) +} + +shinyApp(ui = ui, server = server) +' + + writeLines(test_app_content, test_app_file) + + # Test that the app can be loaded without immediate errors + expect_no_error({ + # Try to parse and evaluate the app code + source(test_app_file, local = TRUE) + }) + + # Clean up + unlink(test_app_file) +}) + +test_that("database reactive functionality works correctly", { + skip_if_not_installed("RSQLite") + + library(DBI) + library(RSQLite) + library(dplyr) + + # Create test database + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + dbWriteTable(conn, "iris", iris, overwrite = TRUE) + dbDisconnect(conn) + + # Test database source creation + db_conn <- dbConnect(RSQLite::SQLite(), temp_db) + iris_source <- database_source(db_conn, "iris") + + # Mock chat function + mock_chat_func <- function(system_prompt) { + list( + register_tool = function(tool) invisible(NULL), + stream_async = function(message) "Mock response" + ) + } + + # Test querychat_init with database source + config <- querychat_init( + data_source = iris_source, + greeting = "Test greeting", + create_chat_func = mock_chat_func + ) + + expect_true(config$is_database_source) + expect_s3_class(config$db_source, "database_source") + + # Test that get_database_data returns lazy table + lazy_data <- get_database_data(config$db_source) + expect_s3_class(lazy_data, c("tbl_SQLiteConnection", "tbl_dbi", + "tbl_sql", "tbl_lazy", "tbl")) + + # Test that we can chain operations and collect + result <- lazy_data %>% + filter(Species == "setosa") %>% + select(Sepal.Length, Sepal.Width) %>% + collect() + + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 50) + expect_equal(ncol(result), 2) + expect_true(all(c("Sepal.Length", "Sepal.Width") %in% names(result))) + + # Test that original lazy table is still usable + all_data <- collect(lazy_data) + expect_equal(nrow(all_data), 150) + expect_equal(ncol(all_data), 5) + + # Clean up + dbDisconnect(db_conn) + unlink(temp_db) +}) + +test_that("app example file exists and is valid R code", { + app_file <- system.file("../examples/app-database.R", package = "querychat") + + # Check file exists + expect_true(file.exists(app_file)) + + # Check it contains key components + app_content <- readLines(app_file) + app_text <- paste(app_content, collapse = "\n") + + expect_true(grepl("library\\(shiny\\)", app_text)) + expect_true(grepl("library\\(querychat\\)", app_text)) + expect_true(grepl("database_source", app_text)) + expect_true(grepl("querychat_init", app_text)) + expect_true(grepl("querychat_server", app_text)) + expect_true(grepl("shinyApp", app_text)) + + # Check it parses as valid R code + expect_no_error(parse(text = app_text)) +}) \ No newline at end of file From 3b289c7b5193d63ec24d72c937ca8b10d6e61c7a Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Thu, 19 Jun 2025 12:20:47 +0100 Subject: [PATCH 21/51] update to use s3 classes to simplify the code --- README.md | 4 +- pkg-r/NAMESPACE | 21 +- pkg-r/R/data_source.R | 366 +++++++++++++++++++ pkg-r/R/datasource.R | 222 ----------- pkg-r/R/prompt.R | 135 ------- pkg-r/R/querychat.R | 155 ++------ pkg-r/README.md | 68 +++- pkg-r/examples/app-database.R | 21 +- pkg-r/man/querychat_init.Rd | 22 +- pkg-r/man/querychat_system_prompt.Rd | 32 -- pkg-r/tests/testthat/test-shiny-app.R | 16 +- pkg-r/tests/testthat/test_data_source.R | 220 +++++++++++ pkg-r/tests/testthat/test_database_source.R | 222 ----------- pkg-r/tests/testthat/test_querychat_server.R | 8 +- 14 files changed, 724 insertions(+), 788 deletions(-) create mode 100644 pkg-r/R/data_source.R delete mode 100644 pkg-r/R/datasource.R delete mode 100644 pkg-r/R/prompt.R delete mode 100644 pkg-r/man/querychat_system_prompt.Rd create mode 100644 pkg-r/tests/testthat/test_data_source.R delete mode 100644 pkg-r/tests/testthat/test_database_source.R diff --git a/README.md b/README.md index 4a07aa78..8ede04f5 100644 --- a/README.md +++ b/README.md @@ -36,11 +36,11 @@ querychat does not have direct access to the raw data; it can _only_ read or fil - **Transparency:** querychat always displays the SQL to the user, so it can be vetted instead of blindly trusted. - **Reproducibility:** The SQL query can be easily copied and reused. -Currently, querychat uses DuckDB for its SQL engine. It's extremely fast and has a surprising number of statistical functions. +Currently, querychat uses DuckDB for its SQL engine when working with data frames. For database sources, it uses the native SQL dialect of the connected database. ## Language-specific Documentation For detailed information on how to use querychat in your preferred language, see the language-specific READMEs: - [R Documentation](pkg-r/README.md) -- [Python Documentation](pkg-py/README.md) +- [Python Documentation](pkg-py/README.md) \ No newline at end of file diff --git a/pkg-r/NAMESPACE b/pkg-r/NAMESPACE index 691b73db..b18f6aac 100644 --- a/pkg-r/NAMESPACE +++ b/pkg-r/NAMESPACE @@ -1,12 +1,21 @@ # Generated by roxygen2: do not edit by hand -export(database_source) -export(execute_database_query) -export(get_database_data) -export(get_database_schema) +S3method(cleanup_source,data_frame_source) +S3method(cleanup_source,dbi_source) +S3method(create_system_prompt,querychat_data_source) +S3method(execute_query,querychat_data_source) +S3method(get_lazy_data,querychat_data_source) +S3method(get_schema,data_frame_source) +S3method(get_schema,dbi_source) +S3method(querychat_data_source,DBIConnection) +S3method(querychat_data_source,data.frame) +export(cleanup_source) +export(create_system_prompt) +export(execute_query) +export(get_lazy_data) +export(get_schema) +export(querychat_data_source) export(querychat_init) export(querychat_server) export(querychat_sidebar) -export(querychat_system_prompt) -export(querychat_system_prompt_database) export(querychat_ui) diff --git a/pkg-r/R/data_source.R b/pkg-r/R/data_source.R new file mode 100644 index 00000000..43bbbdc8 --- /dev/null +++ b/pkg-r/R/data_source.R @@ -0,0 +1,366 @@ +#' Create a data source for querychat +#' +#' Generic function to create a data source for querychat. This function +#' dispatches to appropriate methods based on input. +#' +#' @param x A data frame or DBI connection +#' @param table_name The name to use for the table in the data source +#' @param categorical_threshold For text columns, the maximum number of unique values to consider as a categorical variable +#' @param ... Additional arguments passed to specific methods +#' @return A querychat_data_source object +#' @export +querychat_data_source <- function(x, ...) { + UseMethod("querychat_data_source") +} + +#' @export +#' @rdname querychat_data_source +querychat_data_source.data.frame <- function(x, table_name = NULL, ...) { + if (is.null(table_name)) { + # Infer table name from dataframe name, if not already added + table_name <- deparse(substitute(x)) + if (is.null(table_name) || table_name == "NULL" || table_name == "x") { + rlang::abort("Unable to infer table name. Please specify `table_name` argument explicitly.") + } + } + + is_table_name_ok <- is.character(table_name) && + length(table_name) == 1 && + grepl("^[a-zA-Z][a-zA-Z0-9_]*$", table_name, perl = TRUE) + if (!is_table_name_ok) { + rlang::abort("`table_name` argument must be a string containing a valid table name.") + } + + # Create duckdb connection + conn <- DBI::dbConnect(duckdb::duckdb(), dbdir = ":memory:") + duckdb::duckdb_register(conn, table_name, x, experimental = FALSE) + + structure( + list( + data = x, + conn = conn, + table_name = table_name + ), + class = c("data_frame_source", "querychat_data_source") + ) +} + +#' @export +#' @rdname querychat_data_source +querychat_data_source.DBIConnection <- function(x, table_name, categorical_threshold = 20, ...) { + if (!is.character(table_name) || length(table_name) != 1) { + rlang::abort("`table_name` must be a single character string") + } + + if (!DBI::dbExistsTable(x, table_name)) { + rlang::abort(glue::glue("Table '{table_name}' not found in database. If you're using databricks, try setting the 'Catalog' and 'Schema' arguments to DBI::dbConnect")) + } + + structure( + list( + conn = x, + table_name = table_name, + categorical_threshold = categorical_threshold + ), + class = c("dbi_source", "querychat_data_source") + ) +} + +#' Get schema information for a data source +#' +#' @param source A querychat_data_source object +#' @param ... Additional arguments passed to methods +#' @return A character string containing the schema information +#' @export +get_schema <- function(source, ...) { + UseMethod("get_schema") +} + +#' @export +get_schema.dbi_source <- function(source, ...) { + conn <- source$conn + table_name <- source$table_name + categorical_threshold <- source$categorical_threshold + + # Get column information + columns <- DBI::dbListFields(conn, table_name) + + schema_lines <- c( + glue::glue("Table: {table_name}"), + "Columns:" + ) + + # Build single query to get column statistics + select_parts <- character(0) + numeric_columns <- character(0) + text_columns <- character(0) + + # Get sample of data to determine types + sample_query <- glue::glue_sql("SELECT * FROM {`table_name`} LIMIT 1", .con = conn) + sample_data <- DBI::dbGetQuery(conn, sample_query) + + for (col in columns) { + col_class <- class(sample_data[[col]])[1] + + if (col_class %in% c("integer", "numeric", "double", "Date", "POSIXct", "POSIXt")) { + numeric_columns <- c(numeric_columns, col) + select_parts <- c( + select_parts, + glue::glue_sql("MIN({`col`}) as {`col`}_min", .con = conn), + glue::glue_sql("MAX({`col`}) as {`col`}_max", .con = conn) + ) + } else if (col_class %in% c("character", "factor")) { + text_columns <- c(text_columns, col) + select_parts <- c( + select_parts, + glue::glue_sql("COUNT(DISTINCT {`col`}) as {`col`}_distinct_count", .con = conn) + ) + } + } + + # Execute statistics query + column_stats <- list() + if (length(select_parts) > 0) { + tryCatch({ + stats_query <- glue::glue_sql("SELECT {select_parts*} FROM {`table_name`}", .con = conn) + result <- DBI::dbGetQuery(conn, stats_query) + if (nrow(result) > 0) { + column_stats <- as.list(result[1, ]) + } + }, error = function(e) { + # Fall back to no statistics if query fails + }) + } + + # Get categorical values for text columns below threshold + categorical_values <- list() + text_cols_to_query <- character(0) + + for (col_name in text_columns) { + distinct_count_key <- paste0(col_name, "_distinct_count") + if (distinct_count_key %in% names(column_stats) && + !is.na(column_stats[[distinct_count_key]]) && + column_stats[[distinct_count_key]] <= categorical_threshold) { + text_cols_to_query <- c(text_cols_to_query, col_name) + } + } + + # Get categorical values + if (length(text_cols_to_query) > 0) { + for (col_name in text_cols_to_query) { + tryCatch({ + cat_query <- glue::glue_sql( + "SELECT DISTINCT {`col_name`} FROM {`table_name`} WHERE {`col_name`} IS NOT NULL ORDER BY {`col_name`}", + .con = conn + ) + result <- DBI::dbGetQuery(conn, cat_query) + if (nrow(result) > 0) { + categorical_values[[col_name]] <- result[[1]] + } + }, error = function(e) { + # Skip categorical values if query fails + }) + } + } + + # Build schema description + for (col in columns) { + col_class <- class(sample_data[[col]])[1] + sql_type <- r_class_to_sql_type(col_class) + + column_info <- glue::glue("- {col} ({sql_type})") + + # Add range info for numeric columns + if (col %in% numeric_columns) { + min_key <- paste0(col, "_min") + max_key <- paste0(col, "_max") + if (min_key %in% names(column_stats) && max_key %in% names(column_stats) && + !is.na(column_stats[[min_key]]) && !is.na(column_stats[[max_key]])) { + range_info <- glue::glue(" Range: {column_stats[[min_key]]} to {column_stats[[max_key]]}") + column_info <- paste(column_info, range_info, sep = "\n") + } + } + + # Add categorical values for text columns + if (col %in% names(categorical_values)) { + values <- categorical_values[[col]] + if (length(values) > 0) { + values_str <- paste0("'", values, "'", collapse = ", ") + cat_info <- glue::glue(" Categorical values: {values_str}") + column_info <- paste(column_info, cat_info, sep = "\n") + } + } + + schema_lines <- c(schema_lines, column_info) + } + + paste(schema_lines, collapse = "\n") +} + +#' @export +get_schema.data_frame_source <- function(source, categorical_threshold = 10, ...) { + df <- source$data + name <- source$table_name + + schema <- c(paste("Table:", name), "Columns:") + + column_info <- lapply(names(df), function(column) { + # Map R classes to SQL-like types + sql_type <- if (is.integer(df[[column]])) { + "INTEGER" + } else if (is.numeric(df[[column]])) { + "FLOAT" + } else if (is.logical(df[[column]])) { + "BOOLEAN" + } else if (inherits(df[[column]], "POSIXt")) { + "DATETIME" + } else { + "TEXT" + } + + info <- paste0("- ", column, " (", sql_type, ")") + + # For TEXT columns, check if they're categorical + if (sql_type == "TEXT") { + unique_values <- length(unique(df[[column]])) + if (unique_values <= categorical_threshold) { + categories <- unique(df[[column]]) + categories_str <- paste0("'", categories, "'", collapse = ", ") + info <- c(info, paste0(" Categorical values: ", categories_str)) + } + } else if (sql_type %in% c("INTEGER", "FLOAT", "DATETIME")) { + rng <- range(df[[column]], na.rm = TRUE) + if (all(is.na(rng))) { + info <- c(info, " Range: NULL to NULL") + } else { + info <- c(info, paste0(" Range: ", rng[1], " to ", rng[2])) + } + } + return(info) + }) + + schema <- c(schema, unlist(column_info)) + return(paste(schema, collapse = "\n")) +} + +#' Execute a SQL query on a data source +#' +#' @param source A querychat_data_source object +#' @param query SQL query string +#' @param ... Additional arguments passed to methods +#' @return Result of the query as a data frame +#' @export +execute_query <- function(source, query, ...) { + UseMethod("execute_query") +} + +#' @export +execute_query.querychat_data_source <- function(source, query, ...) { + DBI::dbGetQuery(source$conn, query) +} + +#' Get a lazy representation of a data source +#' +#' @param source A querychat_data_source object +#' @param query SQL query string +#' @param ... Additional arguments passed to methods +#' @return A lazy representation (typically a dbplyr tbl) +#' @export +get_lazy_data <- function(source, ...) { + UseMethod("get_lazy_data") +} + +#' @export +get_lazy_data.querychat_data_source <- function(source, query = NULL, ...) { + if (is.null(query) || query == ""){ + # For a null or empty query, default to returning the whole table (ie SELECT *) + dplyr::tbl(source$conn, source$table_name) + } else { + dplyr::tbl(source$conn, query) + } + +} + + + +#' Create a system prompt for the data source +#' +#' @param source A querychat_data_source object +#' @param data_description Optional description of the data +#' @param extra_instructions Optional additional instructions +#' @param ... Additional arguments passed to methods +#' @return A string with the system prompt +#' @export +create_system_prompt <- function(source, data_description = NULL, extra_instructions = NULL, ...) { + UseMethod("create_system_prompt") +} + +#' @export +create_system_prompt.querychat_data_source <- function(source, data_description = NULL, extra_instructions = NULL, ...) { + if (!is.null(data_description)) { + data_description <- paste(data_description, collapse = "\n") + } + if (!is.null(extra_instructions)) { + extra_instructions <- paste(extra_instructions, collapse = "\n") + } + + # Read the prompt file + prompt_path <- system.file("prompt", "prompt.md", package = "querychat") + prompt_content <- readLines(prompt_path, warn = FALSE) + prompt_text <- paste(prompt_content, collapse = "\n") + + # Get schema for the data source + schema <- get_schema(source) + + whisker::whisker.render( + prompt_text, + list( + schema = schema, + data_description = data_description, + extra_instructions = extra_instructions + ) + ) +} + +#' Clean up a data source (close connections, etc.) +#' +#' @param source A querychat_data_source object +#' @param ... Additional arguments passed to methods +#' @return NULL (invisibly) +#' @export +cleanup_source <- function(source, ...) { + UseMethod("cleanup_source") +} + +#' @export +cleanup_source.data_frame_source <- function(source, ...) { + if (!is.null(source$conn) && DBI::dbIsValid(source$conn)) { + DBI::dbDisconnect(source$conn) + } + invisible(NULL) +} + +#' @export +cleanup_source.dbi_source <- function(source, ...) { + # WARNING: This package does not automatically disconnect the database connection + # provided by the user. You are responsible for calling DBI::dbDisconnect() on your + # connection when you're finished with it. This is by design, as the connection was + # created externally and may be needed for other operations in your code. + invisible(NULL) +} + +# Helper function to map R classes to SQL types +r_class_to_sql_type <- function(r_class) { + switch(r_class, + "integer" = "INTEGER", + "numeric" = "FLOAT", + "double" = "FLOAT", + "logical" = "BOOLEAN", + "Date" = "DATE", + "POSIXct" = "TIMESTAMP", + "POSIXt" = "TIMESTAMP", + "character" = "TEXT", + "factor" = "TEXT", + "TEXT" # default + ) +} \ No newline at end of file diff --git a/pkg-r/R/datasource.R b/pkg-r/R/datasource.R deleted file mode 100644 index 5cc703b0..00000000 --- a/pkg-r/R/datasource.R +++ /dev/null @@ -1,222 +0,0 @@ -#' Database Data Source for querychat -#' -#' Create a data source that connects to external databases via DBI. -#' Supports PostgreSQL, MySQL, SQLite, and other DBI-compatible databases. -#' -#' @param conn A DBI connection object to the database -#' @param table_name Name of the table to query -#' @param categorical_threshold Maximum number of unique values for a text column -#' to be considered categorical (default: 20) -#' -#' @return A database data source object -#' @export -#' @examples -#' \dontrun{ -#' # PostgreSQL example -#' library(RPostgreSQL) -#' conn <- DBI::dbConnect(RPostgreSQL::PostgreSQL(), -#' dbname = "mydb", host = "localhost", -#' user = "user", password = "pass") -#' db_source <- database_source(conn, "my_table") -#' -#' # SQLite example -#' library(RSQLite) -#' conn <- DBI::dbConnect(RSQLite::SQLite(), "path/to/database.db") -#' db_source <- database_source(conn, "my_table") -#' } -database_source <- function(conn, table_name, categorical_threshold = 20) { - if (!inherits(conn, "DBIConnection")) { - rlang::abort("`conn` must be a valid DBI connection object") - } - - if (!is.character(table_name) || length(table_name) != 1) { - rlang::abort("`table_name` must be a single character string") - } - - if (!DBI::dbExistsTable(conn, table_name)) { - rlang::abort(glue::glue("Table '{table_name}' not found in database. If you're using databricks, try setting the 'Catalog' and 'Schema' arguments to DBI::dbConnect")) - } - - structure( - list( - conn = conn, - table_name = table_name, - categorical_threshold = categorical_threshold, - db_engine = "DBI" - ), - class = "database_source" - ) -} - -#' Generate schema information for database source -#' -#' @param source A database_source object -#' @return A character string describing the schema -#' @export -get_database_schema <- function(source) { - if (!inherits(source, "database_source")) { - rlang::abort("`source` must be a database_source object") - } - - conn <- source$conn - table_name <- source$table_name - categorical_threshold <- source$categorical_threshold - - # Get column information - columns <- DBI::dbListFields(conn, table_name) - - schema_lines <- c( - glue::glue("Table: {table_name}"), - "Columns:" - ) - - # Build single query to get column statistics - select_parts <- character(0) - numeric_columns <- character(0) - text_columns <- character(0) - - # Get sample of data to determine types - sample_query <- glue::glue_sql("SELECT * FROM {`table_name`} LIMIT 1", .con = conn) - sample_data <- DBI::dbGetQuery(conn, sample_query) - - for (col in columns) { - col_class <- class(sample_data[[col]])[1] - - if (col_class %in% c("integer", "numeric", "double", "Date", "POSIXct", "POSIXt")) { - numeric_columns <- c(numeric_columns, col) - select_parts <- c( - select_parts, - glue::glue_sql("MIN({`col`}) as {`col`}_min", .con = conn), - glue::glue_sql("MAX({`col`}) as {`col`}_max", .con = conn) - ) - } else if (col_class %in% c("character", "factor")) { - text_columns <- c(text_columns, col) - select_parts <- c( - select_parts, - glue::glue_sql("COUNT(DISTINCT {`col`}) as {`col`}_distinct_count", .con = conn) - ) - } - } - - # Execute statistics query - column_stats <- list() - if (length(select_parts) > 0) { - tryCatch({ - stats_query <- glue::glue_sql("SELECT {select_parts*} FROM {`table_name`}", .con = conn) - result <- DBI::dbGetQuery(conn, stats_query) - if (nrow(result) > 0) { - column_stats <- as.list(result[1, ]) - } - }, error = function(e) { - # Fall back to no statistics if query fails - }) - } - - # Get categorical values for text columns below threshold - categorical_values <- list() - text_cols_to_query <- character(0) - - for (col_name in text_columns) { - distinct_count_key <- paste0(col_name, "_distinct_count") - if (distinct_count_key %in% names(column_stats) && - !is.na(column_stats[[distinct_count_key]]) && - column_stats[[distinct_count_key]] <= categorical_threshold) { - text_cols_to_query <- c(text_cols_to_query, col_name) - } - } - - # Get categorical values - if (length(text_cols_to_query) > 0) { - for (col_name in text_cols_to_query) { - tryCatch({ - cat_query <- glue::glue_sql( - "SELECT DISTINCT {`col_name`} FROM {`table_name`} WHERE {`col_name`} IS NOT NULL ORDER BY {`col_name`}", - .con = conn - ) - result <- DBI::dbGetQuery(conn, cat_query) - if (nrow(result) > 0) { - categorical_values[[col_name]] <- result[[1]] - } - }, error = function(e) { - # Skip categorical values if query fails - }) - } - } - - # Build schema description - for (col in columns) { - col_class <- class(sample_data[[col]])[1] - sql_type <- r_class_to_sql_type(col_class) - - column_info <- glue::glue("- {col} ({sql_type})") - - # Add range info for numeric columns - if (col %in% numeric_columns) { - min_key <- paste0(col, "_min") - max_key <- paste0(col, "_max") - if (min_key %in% names(column_stats) && max_key %in% names(column_stats) && - !is.na(column_stats[[min_key]]) && !is.na(column_stats[[max_key]])) { - range_info <- glue::glue(" Range: {column_stats[[min_key]]} to {column_stats[[max_key]]}") - column_info <- paste(column_info, range_info, sep = "\n") - } - } - - # Add categorical values for text columns - if (col %in% names(categorical_values)) { - values <- categorical_values[[col]] - if (length(values) > 0) { - values_str <- paste0("'", values, "'", collapse = ", ") - cat_info <- glue::glue(" Categorical values: {values_str}") - column_info <- paste(column_info, cat_info, sep = "\n") - } - } - - schema_lines <- c(schema_lines, column_info) - } - - paste(schema_lines, collapse = "\n") -} - -#' Execute SQL query on database source -#' -#' @param source A database_source object -#' @param query SQL query to execute -#' @return A data frame with query results -#' @export -execute_database_query <- function(source, query) { - if (!inherits(source, "database_source")) { - rlang::abort("`source` must be a database_source object") - } - - DBI::dbGetQuery(source$conn, query) -} - -#' Get lazy database table reference -#' -#' @param source A database_source object -#' @return A lazy dbplyr tbl object that can be further manipulated with dplyr verbs -#' @export -get_database_data <- function(source) { - if (!inherits(source, "database_source")) { - rlang::abort("`source` must be a database_source object") - } - - # Return a lazy tbl that can be chained with further dplyr operations - dplyr::tbl(source$conn, source$table_name) -} - -# Helper function to map R classes to SQL types -r_class_to_sql_type <- function(r_class) { - switch(r_class, - "integer" = "INTEGER", - "numeric" = "FLOAT", - "double" = "FLOAT", - "logical" = "BOOLEAN", - "Date" = "DATE", - "POSIXct" = "TIMESTAMP", - "POSIXt" = "TIMESTAMP", - "character" = "TEXT", - "factor" = "TEXT", - "TEXT" # default - ) -} \ No newline at end of file diff --git a/pkg-r/R/prompt.R b/pkg-r/R/prompt.R deleted file mode 100644 index c8276819..00000000 --- a/pkg-r/R/prompt.R +++ /dev/null @@ -1,135 +0,0 @@ -#' Create a system prompt for the chat model -#' -#' This function generates a system prompt for the chat model based on a data frame's -#' schema and optional additional context and instructions. -#' -#' @param df A data frame to generate schema information from. -#' @param name A string containing the name of the table in SQL queries. -#' @param data_description Optional description of the data, in plain text or Markdown format. -#' @param extra_instructions Optional additional instructions for the chat model, in plain text or Markdown format. -#' @param categorical_threshold The maximum number of unique values for a text column to be considered categorical. -#' -#' @return A string containing the system prompt for the chat model. -#' -#' @export -querychat_system_prompt <- function( - df, - name, - data_description = NULL, - extra_instructions = NULL, - categorical_threshold = 10 -) { - schema <- df_to_schema(df, name, categorical_threshold) - - if (!is.null(data_description)) { - data_description <- paste(data_description, collapse = "\n") - } - if (!is.null(extra_instructions)) { - extra_instructions <- paste(extra_instructions, collapse = "\n") - } - - # Read the prompt file - prompt_path <- system.file("prompt", "prompt.md", package = "querychat") - prompt_content <- readLines(prompt_path, warn = FALSE) - prompt_text <- paste(prompt_content, collapse = "\n") - - whisker::whisker.render( - prompt_text, - list( - schema = schema, - data_description = data_description, - extra_instructions = extra_instructions - ) - ) -} - -df_to_schema <- function( - df, - name = deparse(substitute(df)), - categorical_threshold -) { - schema <- c(paste("Table:", name), "Columns:") - - column_info <- lapply(names(df), function(column) { - # Map R classes to SQL-like types - sql_type <- if (is.integer(df[[column]])) { - "INTEGER" - } else if (is.numeric(df[[column]])) { - "FLOAT" - } else if (is.logical(df[[column]])) { - "BOOLEAN" - } else if (inherits(df[[column]], "POSIXt")) { - "DATETIME" - } else { - "TEXT" - } - - info <- paste0("- ", column, " (", sql_type, ")") - - # For TEXT columns, check if they're categorical - if (sql_type == "TEXT") { - unique_values <- length(unique(df[[column]])) - if (unique_values <= categorical_threshold) { - categories <- unique(df[[column]]) - categories_str <- paste0("'", categories, "'", collapse = ", ") - info <- c(info, paste0(" Categorical values: ", categories_str)) - } - } else if (sql_type %in% c("INTEGER", "FLOAT", "DATETIME")) { - rng <- range(df[[column]], na.rm = TRUE) - if (all(is.na(rng))) { - info <- c(info, " Range: NULL to NULL") - } else { - info <- c(info, paste0(" Range: ", rng[1], " to ", rng[2])) - } - } - return(info) - }) - - schema <- c(schema, unlist(column_info)) - return(paste(schema, collapse = "\n")) -} - -#' Create a system prompt for the chat model using database source -#' -#' This function generates a system prompt for the chat model based on a database -#' source's schema and optional additional context and instructions. -#' -#' @param db_source A database_source object to generate schema information from. -#' @param data_description Optional description of the data, in plain text or Markdown format. -#' @param extra_instructions Optional additional instructions for the chat model, in plain text or Markdown format. -#' -#' @return A string containing the system prompt for the chat model. -#' -#' @export -querychat_system_prompt_database <- function( - db_source, - data_description = NULL, - extra_instructions = NULL -) { - if (!inherits(db_source, "database_source")) { - rlang::abort("`db_source` must be a database_source object") - } - - schema <- get_database_schema(db_source) - - if (!is.null(data_description)) { - data_description <- paste(data_description, collapse = "\n") - } - if (!is.null(extra_instructions)) { - extra_instructions <- paste(extra_instructions, collapse = "\n") - } - - # Read the prompt file - prompt_path <- system.file("prompt", "prompt.md", package = "querychat") - prompt_content <- readLines(prompt_path, warn = FALSE) - prompt_text <- paste(prompt_content, collapse = "\n") - - whisker::whisker.render( - prompt_text, - list( - schema = schema, - data_description = data_description, - extra_instructions = extra_instructions - ) - ) -} diff --git a/pkg-r/R/querychat.R b/pkg-r/R/querychat.R index 8ae5e904..22a4a61b 100644 --- a/pkg-r/R/querychat.R +++ b/pkg-r/R/querychat.R @@ -3,14 +3,10 @@ #' This will perform one-time initialization that can then be shared by all #' Shiny sessions in the R process. #' -#' @param data_source Either a data frame or a database_source object created by -#' `database_source()`. For backwards compatibility, `df` can also be used. -#' @param df Deprecated. Use `data_source` instead. A data frame. -#' @param tbl_name A string containing a valid table name for the data frame, -#' that will appear in SQL queries. Ensure that it begins with a letter, and -#' contains only letters, numbers, and underscores. By default, querychat will -#' try to infer a table name using the name of the `df` argument. Not used -#' when `data_source` is a database_source object. +#' @param data_source A querychat_data_source object created by `querychat_data_source()`. +#' To create a data source: +#' - For data frame: `querychat_data_source(df, tbl_name = "my_table")` +#' - For database: `querychat_data_source(conn, "table_name")` #' @param greeting A string in Markdown format, containing the initial message #' to display to the user upon first loading the chatbot. If not provided, the #' LLM will be invoked at the start of the conversation to generate one. @@ -26,7 +22,7 @@ #' @param create_chat_func A function that takes a system prompt and returns a #' chat object. The default uses `ellmer::chat_openai()`. #' @param system_prompt A string containing the system prompt for the chat model. -#' The default uses `querychat_system_prompt()` to generate a generic prompt, +#' The default uses `create_system_prompt()` to generate a generic prompt, #' which you can enhance via the `data_description` and `extra_instructions` #' arguments. #' @@ -36,86 +32,27 @@ #' #' @export querychat_init <- function( - data_source = NULL, - df = NULL, - tbl_name = NULL, + data_source, greeting = NULL, data_description = NULL, extra_instructions = NULL, create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o"), system_prompt = NULL ) { - # Handle backwards compatibility and argument validation - if (!is.null(df) && !is.null(data_source)) { - rlang::abort("Cannot specify both `df` and `data_source` arguments") - } - - if (!is.null(df)) { - rlang::warn("`df` argument is deprecated. Use `data_source` instead.") - data_source <- df - } - - if (is.null(data_source)) { - rlang::abort("Must provide either `data_source` or `df` argument") - } - force(create_chat_func) - # Determine source type and setup - is_database_source <- inherits(data_source, "database_source") - is_dataframe <- is.data.frame(data_source) - - if (!is_database_source && !is_dataframe) { - rlang::abort("`data_source` must be either a data frame or database_source object") + # Check that data_source is a querychat_data_source object + if (!inherits(data_source, "querychat_data_source")) { + rlang::abort("`data_source` must be a querychat_data_source object. Use querychat_data_source() to create one.") } - if (is_database_source) { - # Using database source - db_source <- data_source - conn <- db_source$conn - tbl_name <- db_source$table_name - df <- NULL # No data frame for database sources - - # Generate system prompt if not provided - if (is.null(system_prompt)) { - system_prompt <- querychat_system_prompt_database( - db_source, - data_description = data_description, - extra_instructions = extra_instructions - ) - } - } else { - # Using data frame source - set up DuckDB - if (is.null(tbl_name)) { - tbl_name <- deparse(substitute(data_source)) - if (is.null(tbl_name) || tbl_name == "NULL") { - rlang::abort("Unable to infer table name. Please specify `tbl_name` argument explicitly.") - } - } - - is_tbl_name_ok <- is.character(tbl_name) && - length(tbl_name) == 1 && - grepl("^[a-zA-Z][a-zA-Z0-9_]*$", tbl_name, perl = TRUE) - if (!is_tbl_name_ok) { - rlang::abort("`tbl_name` argument must be a string containing a valid table name.") - } - - df <- data_source - conn <- DBI::dbConnect(duckdb::duckdb(), dbdir = ":memory:") - duckdb::duckdb_register(conn, tbl_name, df, experimental = FALSE) - shiny::onStop(function() DBI::dbDisconnect(conn)) - - # Generate system prompt if not provided - if (is.null(system_prompt)) { - system_prompt <- querychat_system_prompt( - df, - tbl_name, - data_description = data_description, - extra_instructions = extra_instructions - ) - } - - db_source <- NULL + # Generate system prompt if not provided + if (is.null(system_prompt)) { + system_prompt <- create_system_prompt( + data_source, + data_description = data_description, + extra_instructions = extra_instructions + ) } # Validate system prompt and create_chat_func @@ -135,14 +72,10 @@ querychat_init <- function( structure( list( - df = df, - conn = conn, - db_source = db_source, + data_source = data_source, system_prompt = system_prompt, greeting = greeting, - create_chat_func = create_chat_func, - is_database_source = is_database_source, - table_name = tbl_name + create_chat_func = create_chat_func ), class = "querychat_config" ) @@ -208,36 +141,16 @@ querychat_server <- function(id, querychat_config) { shiny::moduleServer(id, function(input, output, session) { # πŸ”„ Reactive state/computation -------------------------------------------- - df <- querychat_config[["df"]] - conn <- querychat_config[["conn"]] - db_source <- querychat_config[["db_source"]] - is_database_source <- querychat_config[["is_database_source"]] - table_name <- querychat_config[["table_name"]] + data_source <- querychat_config[["data_source"]] system_prompt <- querychat_config[["system_prompt"]] greeting <- querychat_config[["greeting"]] create_chat_func <- querychat_config[["create_chat_func"]] + conn <- data_source$conn current_title <- shiny::reactiveVal(NULL) current_query <- shiny::reactiveVal("") filtered_df <- shiny::reactive({ - if (current_query() == "") { - if (is_database_source) { - # For database sources, return lazy tbl (no data transfer) - get_database_data(db_source) - } else { - # For data frames, return the original data frame - df - } - } else { - if (is_database_source) { - # For database sources, return lazy tbl with custom query - # Parse and create a lazy tbl from the SQL query - dplyr::tbl(conn, dplyr::sql(current_query())) - } else { - # For data frames, execute query and return result - DBI::dbGetQuery(conn, current_query()) - } - } + querychat::get_lazy_data(data_source, query = dplyr::sql(current_query())) }) append_output <- function(...) { @@ -253,7 +166,7 @@ querychat_server <- function(id, querychat_config) { # Modifies the data presented in the data dashboard, based on the given SQL # query, and also updates the title. - # @param query A DuckDB SQL query; must be a SELECT statement. + # @param query A SQL query; must be a SELECT statement. # @param title A title to display at the top of the data dashboard, # summarizing the intent of the SQL query. update_dashboard <- function(query, title) { @@ -262,7 +175,7 @@ querychat_server <- function(id, querychat_config) { tryCatch( { # Try it to see if it errors; if so, the LLM will see the error - DBI::dbGetQuery(conn, query) + execute_query(data_source, query) }, error = function(err) { append_output("> Error: ", conditionMessage(err), "\n\n") @@ -279,17 +192,16 @@ querychat_server <- function(id, querychat_config) { } # Perform a SQL query on the data, and return the results as JSON. - # @param query A DuckDB SQL query; must be a SELECT statement. - # @return The results of the query as a JSON string. + # @param query A SQL query; must be a SELECT statement. + # @return The results of the query as a data frame. query <- function(query) { # Do this before query, in case it errors append_output("\n```sql\n", query, "\n```\n") tryCatch( { - # Execute the query and return the results as a data frame - # This tool is for answering questions, so we need actual results - DBI::dbGetQuery(conn, query) + # Execute the query and return the results + execute_query(data_source, query) }, error = function(e) { append_output("> Error: ", conditionMessage(e), "\n\n") @@ -305,7 +217,7 @@ querychat_server <- function(id, querychat_config) { update_dashboard, "Modifies the data presented in the data dashboard, based on the given SQL query, and also updates the title.", query = ellmer::type_string( - "A DuckDB SQL query; must be a SELECT statement." + "A SQL query; must be a SELECT statement." ), title = ellmer::type_string( "A title to display at the top of the data dashboard, summarizing the intent of the SQL query." @@ -313,9 +225,9 @@ querychat_server <- function(id, querychat_config) { )) chat$register_tool(ellmer::tool( query, - "Perform a SQL query on the data, and return the results as JSON.", + "Perform a SQL query on the data, and return the results.", query = ellmer::type_string( - "A DuckDB SQL query; must be a SELECT statement." + "A SQL query; must be a SELECT statement." ) )) @@ -344,6 +256,11 @@ querychat_server <- function(id, querychat_config) { ) }) + # Add session cleanup + shiny::onStop(function() { + cleanup_source(data_source) + }) + list( chat = chat, sql = shiny::reactive(current_query()), @@ -376,4 +293,4 @@ df_to_html <- function(df, maxrows = 5) { } paste0(tbl_html, "\n", rows_notice) -} +} \ No newline at end of file diff --git a/pkg-r/README.md b/pkg-r/README.md index 03b5802a..6264e52b 100644 --- a/pkg-r/README.md +++ b/pkg-r/README.md @@ -27,12 +27,14 @@ library(shiny) library(bslib) library(querychat) -# 1. Configure querychat. This is where you specify the dataset and can also -# override options like the greeting message, system prompt, model, etc. -querychat_config <- querychat_init(mtcars) +# 1. Create a data source for querychat +mtcars_source <- querychat_data_source(mtcars, tbl_name = "cars") + +# 2. Configure querychat with the data source +querychat_config <- querychat_init(mtcars_source) ui <- page_sidebar( - # 2. Use querychat_sidebar(id) in a bslib::page_sidebar. + # 3. Use querychat_sidebar(id) in a bslib::page_sidebar. # Alternatively, use querychat_ui(id) elsewhere if you don't want your # chat interface to live in a sidebar. sidebar = querychat_sidebar("chat"), @@ -41,11 +43,11 @@ ui <- page_sidebar( server <- function(input, output, session) { - # 3. Create a querychat object using the config from step 1. + # 4. Create a querychat object using the config from step 2. querychat <- querychat_server("chat", querychat_config) output$dt <- DT::renderDT({ - # 4. Use the filtered/sorted data frame anywhere you wish, via the + # 5. Use the filtered/sorted data frame anywhere you wish, via the # querychat$df() reactive. DT::datatable(querychat$df()) }) @@ -54,6 +56,29 @@ server <- function(input, output, session) { shinyApp(ui, server) ``` +## Using Database Sources + +In addition to data frames, querychat can connect to external databases via DBI: + +```r +library(shiny) +library(bslib) +library(querychat) +library(DBI) +library(RSQLite) + +# 1. Connect to a database +conn <- DBI::dbConnect(RSQLite::SQLite(), "path/to/database.db") + +# 2. Create a database data source for querychat +db_source <- querychat_data_source(conn, "table_name") + +# 3. Configure querychat with the database source +querychat_config <- querychat_init(db_source) + +# Then use querychat_config in your Shiny app as shown above +``` + ## How it works ### Powered by LLMs @@ -76,7 +101,7 @@ querychat does not have direct access to the raw data; it can _only_ read or fil - **Transparency:** querychat always displays the SQL to the user, so it can be vetted instead of blindly trusted. - **Reproducibility:** The SQL query can be easily copied and reused. -Currently, querychat uses DuckDB for its SQL engine. It's extremely fast and has a surprising number of [statistical functions](https://duckdb.org/docs/stable/sql/functions/aggregates.html#statistical-aggregates). +Currently, querychat uses DuckDB for its SQL engine when working with data frames. For database sources, it uses the native SQL dialect of the connected database. DuckDB is extremely fast and has a surprising number of [statistical functions](https://duckdb.org/docs/stable/sql/functions/aggregates.html#statistical-aggregates). ## Customizing querychat @@ -116,7 +141,7 @@ Alternatively, you can completely suppress the greeting by passing `greeting = " In LLM parlance, the _system prompt_ is the set of instructions and specific knowledge you want the model to use during a conversation. querychat automatically creates a system prompt which is comprised of: 1. The basic set of behaviors the LLM must follow in order for querychat to work properly. (See `inst/prompt/prompt.md` if you're curious what this looks like.) -2. The SQL schema of the data frame you provided. +2. The SQL schema of the data source you provided. 3. (Optional) Any additional description of the data you choose to provide. 4. (Optional) Any additional instructions you want to use to guide querychat's behavior. @@ -125,7 +150,7 @@ In LLM parlance, the _system prompt_ is the set of instructions and specific kno If you give querychat your dataset and nothing else, it will provide the LLM with the basic schema of your data: - Column names -- DuckDB data type (integer, float, boolean, datetime, text) +- SQL data type (integer, float, boolean, datetime, text) - For text columns with less than 10 unique values, we assume they are categorical variables and include the list of values - For integer and float columns, we include the range @@ -158,8 +183,12 @@ performance for 32 automobiles (1973–74 models). which you can then pass via: ```r +# Create data source first +mtcars_source <- querychat_data_source(mtcars, tbl_name = "cars") + +# Then initialize with the data source and description querychat_config <- querychat_init( - mtcars, + data_source = mtcars_source, data_description = readLines("data_description.md") ) ``` @@ -171,11 +200,18 @@ querychat doesn't need this information in any particular format; just put whate You can add additional instructions of your own to the end of the system prompt, by passing `extra_instructions` into `query_init`. ```r -querychat_config <- querychat_init(mtcars, extra_instructions = c( +# Create data source first +mtcars_source <- querychat_data_source(mtcars, tbl_name = "cars") + +# Then initialize with instructions +querychat_config <- querychat_init( + data_source = mtcars_source, + extra_instructions = c( "You're speaking to a British audience--please use appropriate spelling conventions.", "Use lots of emojis! πŸ˜ƒ Emojis everywhere, 🌍 emojis forever. ♾️", "Stay on topic, only talk about the data dashboard and refuse to answer other questions." -)) + ) +) ``` You can also put these instructions in a separate file and use `readLines()` to load them, as we did for `data_description` above. @@ -204,11 +240,15 @@ my_chat_func <- function(system_prompt) { library(ellmer) library(purrr) +# Create data source first +mtcars_source <- querychat_data_source(mtcars, tbl_name = "cars") + # Option 2: Use partial -querychat_config <- querychat_init(mtcars, +querychat_config <- querychat_init( + data_source = mtcars_source, create_chat_func = purrr::partial(ellmer::chat_claude, model = "claude-3-7-sonnet-latest") ) ``` This would use Claude 3.7 Sonnet instead, which would require you to provide an API key. -See the [instructions from Ellmer](https://ellmer.tidyverse.org/reference/chat_claude.html) for more information on how to authenticate with different providers. +See the [instructions from Ellmer](https://ellmer.tidyverse.org/reference/chat_claude.html) for more information on how to authenticate with different providers. \ No newline at end of file diff --git a/pkg-r/examples/app-database.R b/pkg-r/examples/app-database.R index 8b2609b0..961c237e 100644 --- a/pkg-r/examples/app-database.R +++ b/pkg-r/examples/app-database.R @@ -13,9 +13,6 @@ conn <- dbConnect(RSQLite::SQLite(), temp_db) iris_data <- iris dbWriteTable(conn, "iris", iris_data, overwrite = TRUE) -# Disconnect temporarily - we'll reconnect in the app -dbDisconnect(conn) - # Define a custom greeting for the database app greeting <- " # Welcome to the Database Query Assistant! πŸ“Š @@ -30,10 +27,8 @@ Try asking: - Create a summary of measurements grouped by species " -# Create database source -# Note: In a production app, you would use your actual database credentials -db_conn <- dbConnect(RSQLite::SQLite(), temp_db) -iris_source <- database_source(db_conn, "iris") +# Create data source using querychat_data_source +iris_source <- querychat_data_source(conn, table_name = "iris") # Configure querychat for database querychat_config <- querychat_init( @@ -65,7 +60,13 @@ server <- function(input, output, session) { chat <- querychat_server("chat", querychat_config) output$data_table <- DT::renderDT({ - chat$df() + df <- chat$df() + # Collect data from lazy tbl if needed + if (inherits(df, "tbl_lazy")) { + dplyr::collect(df) + } else { + df + } }, options = list(pageLength = 10, scrollX = TRUE)) output$sql_query <- renderText({ @@ -79,8 +80,8 @@ server <- function(input, output, session) { # Clean up database connection when app stops session$onSessionEnded(function() { - if (dbIsValid(db_conn)) { - dbDisconnect(db_conn) + if (dbIsValid(conn)) { + dbDisconnect(conn) } if (file.exists(temp_db)) { unlink(temp_db) diff --git a/pkg-r/man/querychat_init.Rd b/pkg-r/man/querychat_init.Rd index 6482c88a..c3cdd8e7 100644 --- a/pkg-r/man/querychat_init.Rd +++ b/pkg-r/man/querychat_init.Rd @@ -5,9 +5,7 @@ \title{Call this once outside of any server function} \usage{ querychat_init( - data_source = NULL, - df = NULL, - tbl_name = NULL, + data_source, greeting = NULL, data_description = NULL, extra_instructions = NULL, @@ -16,16 +14,12 @@ querychat_init( ) } \arguments{ -\item{data_source}{Either a data frame or a database_source object created by -\code{database_source()}. For backwards compatibility, \code{df} can also be used.} - -\item{df}{Deprecated. Use \code{data_source} instead. A data frame.} - -\item{tbl_name}{A string containing a valid table name for the data frame, -that will appear in SQL queries. Ensure that it begins with a letter, and -contains only letters, numbers, and underscores. By default, querychat will -try to infer a table name using the name of the \code{df} argument. Not used -when \code{data_source} is a database_source object.} +\item{data_source}{A querychat_data_source object created by \code{querychat_data_source()}. +To create a data source: +\itemize{ +\item For data frame: \code{querychat_data_source(df, tbl_name = "my_table")} +\item For database: \code{querychat_data_source(conn, "table_name")} +}} \item{greeting}{A string in Markdown format, containing the initial message to display to the user upon first loading the chatbot. If not provided, the @@ -46,7 +40,7 @@ the \code{extra_instructions} argument will be ignored.} chat object. The default uses \code{ellmer::chat_openai()}.} \item{system_prompt}{A string containing the system prompt for the chat model. -The default uses \code{querychat_system_prompt()} to generate a generic prompt, +The default uses \code{create_system_prompt()} to generate a generic prompt, which you can enhance via the \code{data_description} and \code{extra_instructions} arguments.} } diff --git a/pkg-r/man/querychat_system_prompt.Rd b/pkg-r/man/querychat_system_prompt.Rd deleted file mode 100644 index 31dae21f..00000000 --- a/pkg-r/man/querychat_system_prompt.Rd +++ /dev/null @@ -1,32 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/prompt.R -\name{querychat_system_prompt} -\alias{querychat_system_prompt} -\title{Create a system prompt for the chat model} -\usage{ -querychat_system_prompt( - df, - name, - data_description = NULL, - extra_instructions = NULL, - categorical_threshold = 10 -) -} -\arguments{ -\item{df}{A data frame to generate schema information from.} - -\item{name}{A string containing the name of the table in SQL queries.} - -\item{data_description}{Optional description of the data, in plain text or Markdown format.} - -\item{extra_instructions}{Optional additional instructions for the chat model, in plain text or Markdown format.} - -\item{categorical_threshold}{The maximum number of unique values for a text column to be considered categorical.} -} -\value{ -A string containing the system prompt for the chat model. -} -\description{ -This function generates a system prompt for the chat model based on a data frame's -schema and optional additional context and instructions. -} diff --git a/pkg-r/tests/testthat/test-shiny-app.R b/pkg-r/tests/testthat/test-shiny-app.R index 488e8bb4..ef8cb6ab 100644 --- a/pkg-r/tests/testthat/test-shiny-app.R +++ b/pkg-r/tests/testthat/test-shiny-app.R @@ -33,7 +33,7 @@ dbDisconnect(conn) # Setup database source db_conn <- dbConnect(RSQLite::SQLite(), temp_db) -iris_source <- database_source(db_conn, "iris") +iris_source <- querychat_data_source(db_conn, "iris") # Configure querychat with mock querychat_config <- querychat_init( @@ -106,7 +106,7 @@ test_that("database reactive functionality works correctly", { # Test database source creation db_conn <- dbConnect(RSQLite::SQLite(), temp_db) - iris_source <- database_source(db_conn, "iris") + iris_source <- querychat_data_source(db_conn, "iris") # Mock chat function mock_chat_func <- function(system_prompt) { @@ -123,11 +123,11 @@ test_that("database reactive functionality works correctly", { create_chat_func = mock_chat_func ) - expect_true(config$is_database_source) - expect_s3_class(config$db_source, "database_source") + expect_s3_class(config$data_source, "dbi_source") + expect_s3_class(config$data_source, "querychat_data_source") - # Test that get_database_data returns lazy table - lazy_data <- get_database_data(config$db_source) + # Test that get_lazy_data returns lazy table + lazy_data <- get_lazy_data(config$data_source) expect_s3_class(lazy_data, c("tbl_SQLiteConnection", "tbl_dbi", "tbl_sql", "tbl_lazy", "tbl")) @@ -153,7 +153,7 @@ test_that("database reactive functionality works correctly", { }) test_that("app example file exists and is valid R code", { - app_file <- system.file("../examples/app-database.R", package = "querychat") + app_file <- "../../examples/app-database.R" # Check file exists expect_true(file.exists(app_file)) @@ -164,7 +164,7 @@ test_that("app example file exists and is valid R code", { expect_true(grepl("library\\(shiny\\)", app_text)) expect_true(grepl("library\\(querychat\\)", app_text)) - expect_true(grepl("database_source", app_text)) + expect_true(grepl("querychat_data_source", app_text)) expect_true(grepl("querychat_init", app_text)) expect_true(grepl("querychat_server", app_text)) expect_true(grepl("shinyApp", app_text)) diff --git a/pkg-r/tests/testthat/test_data_source.R b/pkg-r/tests/testthat/test_data_source.R new file mode 100644 index 00000000..65d2a48d --- /dev/null +++ b/pkg-r/tests/testthat/test_data_source.R @@ -0,0 +1,220 @@ +library(testthat) +library(DBI) +library(RSQLite) +library(dplyr) +library(querychat) + +test_that("querychat_data_source.data.frame creates proper S3 object", { + # Create a simple data frame + test_df <- data.frame( + id = 1:5, + name = c("A", "B", "C", "D", "E"), + value = c(10.5, 20.3, 15.7, 30.1, 25.9), + stringsAsFactors = FALSE + ) + + # Test with explicit table name + source <- querychat_data_source(test_df, table_name = "test_table") + expect_s3_class(source, "data_frame_source") + expect_s3_class(source, "querychat_data_source") + expect_equal(source$table_name, "test_table") + expect_s3_class(source$data, "data.frame") + expect_true(inherits(source$conn, "DBIConnection")) + + # Clean up + cleanup_source(source) +}) + +test_that("querychat_data_source.DBIConnection creates proper S3 object", { + # Create temporary SQLite database + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + + # Create test table + test_data <- data.frame( + id = 1:5, + name = c("Alice", "Bob", "Charlie", "Diana", "Eve"), + age = c(25, 30, 35, 28, 32), + stringsAsFactors = FALSE + ) + + dbWriteTable(conn, "users", test_data, overwrite = TRUE) + + # Test DBI source creation + db_source <- querychat_data_source(conn, "users") + expect_s3_class(db_source, "dbi_source") + expect_s3_class(db_source, "querychat_data_source") + expect_equal(db_source$table_name, "users") + expect_equal(db_source$categorical_threshold, 20) + + # Clean up + dbDisconnect(conn) + unlink(temp_db) +}) + +test_that("get_schema methods return proper schema", { + # Test with data frame source + test_df <- data.frame( + id = 1:5, + name = c("A", "B", "C", "D", "E"), + active = c(TRUE, FALSE, TRUE, TRUE, FALSE), + stringsAsFactors = FALSE + ) + + df_source <- querychat_data_source(test_df, table_name = "test_table") + schema <- get_schema(df_source) + expect_type(schema, "character") + expect_true(grepl("Table: test_table", schema)) + expect_true(grepl("id \\(INTEGER\\)", schema)) + expect_true(grepl("name \\(TEXT\\)", schema)) + expect_true(grepl("active \\(BOOLEAN\\)", schema)) + expect_true(grepl("Categorical values", schema)) # Should list categorical values + + # Test with DBI source + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) + + dbi_source <- querychat_data_source(conn, "test_table") + schema <- get_schema(dbi_source) + expect_type(schema, "character") + expect_true(grepl("Table: test_table", schema)) + expect_true(grepl("id \\(INTEGER\\)", schema)) + expect_true(grepl("name \\(TEXT\\)", schema)) + + # Clean up + cleanup_source(df_source) + dbDisconnect(conn) + unlink(temp_db) +}) + +test_that("execute_query works for both source types", { + # Test with data frame source + test_df <- data.frame( + id = 1:5, + value = c(10, 20, 30, 40, 50), + stringsAsFactors = FALSE + ) + + df_source <- querychat_data_source(test_df, table_name = "test_table") + result <- execute_query(df_source, "SELECT * FROM test_table WHERE value > 25") + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 3) # Should return 3 rows (30, 40, 50) + + # Test with DBI source + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) + + dbi_source <- querychat_data_source(conn, "test_table") + result <- execute_query(dbi_source, "SELECT * FROM test_table WHERE value > 25") + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 3) # Should return 3 rows (30, 40, 50) + + # Clean up + cleanup_source(df_source) + dbDisconnect(conn) + unlink(temp_db) +}) + +test_that("get_lazy_data returns tbl objects", { + # Test with data frame source + test_df <- data.frame( + id = 1:5, + value = c(10, 20, 30, 40, 50), + stringsAsFactors = FALSE + ) + + df_source <- querychat_data_source(test_df, table_name = "test_table") + lazy_data <- get_lazy_data(df_source) + expect_s3_class(lazy_data, "tbl") + + # Test with DBI source + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) + + dbi_source <- querychat_data_source(conn, "test_table") + lazy_data <- get_lazy_data(dbi_source) + expect_s3_class(lazy_data, "tbl") + + # Test chaining with dplyr + filtered_data <- lazy_data %>% + dplyr::filter(value > 25) %>% + dplyr::collect() + expect_equal(nrow(filtered_data), 3) # Should return 3 rows (30, 40, 50) + + # Clean up + cleanup_source(df_source) + dbDisconnect(conn) + unlink(temp_db) +}) + +test_that("create_system_prompt generates appropriate system prompt", { + test_df <- data.frame( + id = 1:3, + name = c("A", "B", "C"), + stringsAsFactors = FALSE + ) + + df_source <- querychat_data_source(test_df, table_name = "test_table") + prompt <- create_system_prompt(df_source, data_description = "A test dataframe") + expect_type(prompt, "character") + expect_true(nchar(prompt) > 0) + expect_true(grepl("A test dataframe", prompt)) + expect_true(grepl("Table: test_table", prompt)) + + # Clean up + cleanup_source(df_source) +}) + +test_that("querychat_init requires a querychat_data_source", { + # Test that querychat_init rejects data frames directly + test_df <- data.frame(id = 1:3, name = c("A", "B", "C")) + + # Should abort with data frame + expect_error( + querychat_init(data_source = test_df), + "must be a querychat_data_source" + ) + + # Should work with proper data source + df_source <- querychat_data_source(test_df, table_name = "test_table") + config <- querychat_init(data_source = df_source, greeting = "Test greeting") + expect_s3_class(config, "querychat_config") + + # Clean up + cleanup_source(df_source) +}) + +test_that("querychat_init works with both source types", { + # Test with data frame + test_df <- data.frame( + id = 1:3, + name = c("A", "B", "C"), + stringsAsFactors = FALSE + ) + + # Create data source and test with querychat_init + df_source <- querychat_data_source(test_df, table_name = "test_source") + config <- querychat_init(data_source = df_source, greeting = "Test greeting") + expect_s3_class(config, "querychat_config") + expect_s3_class(config$data_source, "data_frame_source") + expect_equal(config$data_source$table_name, "test_source") + + # Test with database connection + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) + + dbi_source <- querychat_data_source(conn, "test_table") + config <- querychat_init(data_source = dbi_source, greeting = "Test greeting") + expect_s3_class(config, "querychat_config") + expect_s3_class(config$data_source, "dbi_source") + expect_equal(config$data_source$table_name, "test_table") + + # Clean up + cleanup_source(df_source) + dbDisconnect(conn) + unlink(temp_db) +}) \ No newline at end of file diff --git a/pkg-r/tests/testthat/test_database_source.R b/pkg-r/tests/testthat/test_database_source.R deleted file mode 100644 index d5c36391..00000000 --- a/pkg-r/tests/testthat/test_database_source.R +++ /dev/null @@ -1,222 +0,0 @@ -library(testthat) -library(DBI) -library(RSQLite) -library(dplyr) -library(querychat) - -test_that("database_source creation and basic functionality", { - # Create temporary SQLite database - temp_db <- tempfile(fileext = ".db") - conn <- dbConnect(RSQLite::SQLite(), temp_db) - - # Create test table - test_data <- data.frame( - id = 1:5, - name = c("Alice", "Bob", "Charlie", "Diana", "Eve"), - age = c(25, 30, 35, 28, 32), - city = c("NYC", "LA", "NYC", "Chicago", "LA"), - stringsAsFactors = FALSE - ) - - dbWriteTable(conn, "users", test_data, overwrite = TRUE) - - # Test database_source creation - db_source <- database_source(conn, "users") - expect_s3_class(db_source, "database_source") - expect_equal(db_source$table_name, "users") - expect_equal(db_source$categorical_threshold, 20) - - # Test schema generation - schema <- get_database_schema(db_source) - expect_type(schema, "character") - expect_true(grepl("Table: users", schema)) - expect_true(grepl("id \\(INTEGER\\)", schema)) - expect_true(grepl("name \\(TEXT\\)", schema)) - expect_true(grepl("city \\(TEXT\\)", schema)) # Should have city column - - # Test query execution - result <- execute_database_query(db_source, "SELECT * FROM users WHERE age > 30") - expect_s3_class(result, "data.frame") - expect_equal(nrow(result), 2) # Charlie and Eve - - # Test get all data returns lazy dbplyr table - all_data <- get_database_data(db_source) - expect_s3_class(all_data, c("tbl_SQLiteConnection", "tbl_dbi", "tbl_sql", "tbl_lazy", "tbl")) - - # Test that it can be chained with dbplyr operations before collect() - filtered_data <- all_data |> - dplyr::filter(age > 30) |> - dplyr::arrange(dplyr::desc(age)) |> - dplyr::collect() - - expect_s3_class(filtered_data, "data.frame") - expect_equal(nrow(filtered_data), 2) # Charlie and Eve - expect_equal(filtered_data$name, c("Charlie", "Eve")) - - # Test that the lazy table can be collected to get all data - collected_data <- dplyr::collect(all_data) - expect_s3_class(collected_data, "data.frame") - expect_equal(nrow(collected_data), 5) - expect_equal(ncol(all_data), 4) - - # Clean up - dbDisconnect(conn) - unlink(temp_db) -}) - -test_that("database_source error handling", { - temp_db <- tempfile(fileext = ".db") - conn <- dbConnect(RSQLite::SQLite(), temp_db) - - # Test error for non-existent table - expect_error( - database_source(conn, "nonexistent_table"), - "Table 'nonexistent_table' not found" - ) - - # Test error for invalid connection - expect_error( - database_source("not_a_connection", "table"), - "must be a valid DBI connection object" - ) - - # Test error for invalid table name - dbWriteTable(conn, "test", data.frame(x = 1:3), overwrite = TRUE) - expect_error( - database_source(conn, c("table1", "table2")), - "must be a single character string" - ) - - dbDisconnect(conn) - unlink(temp_db) -}) - -test_that("querychat_init with database_source", { - # Create temporary SQLite database - temp_db <- tempfile(fileext = ".db") - conn <- dbConnect(RSQLite::SQLite(), temp_db) - - # Create test table - test_data <- data.frame( - product = c("A", "B", "C"), - sales = c(100, 150, 200), - region = c("North", "South", "North"), - stringsAsFactors = FALSE - ) - - dbWriteTable(conn, "sales", test_data, overwrite = TRUE) - - # Create database source - db_source <- database_source(conn, "sales") - - # Test querychat_init with database source - config <- querychat_init( - data_source = db_source, - greeting = "Test greeting", - data_description = "Test sales data" - ) - - expect_s3_class(config, "querychat_config") - expect_true(config$is_database_source) - expect_equal(config$table_name, "sales") - expect_null(config$df) # Should be NULL for database sources - expect_identical(config$db_source, db_source) - expect_type(config$system_prompt, "character") - expect_true(nchar(config$system_prompt) > 0) - - # Clean up - dbDisconnect(conn) - unlink(temp_db) -}) - -test_that("lazy dbplyr table behavior and chaining", { - # Create temporary SQLite database with more complex data - temp_db <- tempfile(fileext = ".db") - conn <- dbConnect(RSQLite::SQLite(), temp_db) - - # Create test table with varied data - test_data <- data.frame( - id = 1:10, - name = paste0("User", 1:10), - age = c(25, 30, 35, 28, 32, 45, 22, 38, 41, 29), - department = rep(c("Sales", "Engineering", "Marketing"), length.out = 10), - salary = c(50000, 75000, 85000, 60000, 80000, 120000, 45000, 90000, 110000, 65000), - stringsAsFactors = FALSE - ) - - dbWriteTable(conn, "employees", test_data, overwrite = TRUE) - - # Create database source - db_source <- database_source(conn, "employees") - - # Test that get_database_data returns a lazy table - lazy_table <- get_database_data(db_source) - expect_s3_class(lazy_table, c("tbl_SQLiteConnection", "tbl_dbi", "tbl_sql", "tbl_lazy", "tbl")) - - # Test complex chaining operations before collect() - complex_result <- lazy_table |> - dplyr::filter(age > 30, salary > 70000) |> - dplyr::select(name, department, age, salary) |> - dplyr::arrange(dplyr::desc(salary)) |> - dplyr::mutate(senior = age > 35) |> - dplyr::collect() - - expect_s3_class(complex_result, "data.frame") - expect_true(nrow(complex_result) > 0) - expect_true(all(complex_result$age > 30)) - expect_true(all(complex_result$salary > 70000)) - expect_true("senior" %in% names(complex_result)) - - # Test grouping and summarizing operations - summary_result <- lazy_table |> - dplyr::group_by(department) |> - dplyr::summarise( - avg_age = mean(age, na.rm = TRUE), - avg_salary = mean(salary, na.rm = TRUE), - count = dplyr::n(), - .groups = "drop" - ) |> - dplyr::collect() - - expect_s3_class(summary_result, "data.frame") - expect_equal(nrow(summary_result), 3) # Three departments - expect_true(all(c("department", "avg_age", "avg_salary", "count") %in% names(summary_result))) - - # Test that the lazy table can be reused for different operations - young_employees <- lazy_table |> - dplyr::filter(age < 30) |> - dplyr::collect() - - senior_employees <- lazy_table |> - dplyr::filter(age >= 40) |> - dplyr::collect() - - expect_s3_class(young_employees, "data.frame") - expect_s3_class(senior_employees, "data.frame") - expect_true(all(young_employees$age < 30)) - expect_true(all(senior_employees$age >= 40)) - - # Clean up - dbDisconnect(conn) - unlink(temp_db) -}) - -test_that("backwards compatibility with df argument", { - test_df <- data.frame(x = 1:3, y = letters[1:3]) - - # Test that using df argument still works but shows warning - expect_warning( - config <- querychat_init(df = test_df, tbl_name = "test"), - "deprecated" - ) - - expect_s3_class(config, "querychat_config") - expect_false(config$is_database_source) - expect_equal(config$table_name, "test") - - # Test error when both df and data_source provided - expect_error( - querychat_init(data_source = test_df, df = test_df), - "Cannot specify both" - ) -}) \ No newline at end of file diff --git a/pkg-r/tests/testthat/test_querychat_server.R b/pkg-r/tests/testthat/test_querychat_server.R index 0835fd58..2d71cb67 100644 --- a/pkg-r/tests/testthat/test_querychat_server.R +++ b/pkg-r/tests/testthat/test_querychat_server.R @@ -21,16 +21,16 @@ test_that("database source query functionality", { dbWriteTable(conn, "users", test_data, overwrite = TRUE) # Create database source - db_source <- database_source(conn, "users") + db_source <- querychat_data_source(conn, "users") # Test that we can execute queries - result <- execute_database_query(db_source, "SELECT * FROM users WHERE age > 30") + result <- execute_query(db_source, "SELECT * FROM users WHERE age > 30") expect_s3_class(result, "data.frame") expect_equal(nrow(result), 2) # Charlie and Eve expect_equal(result$name, c("Charlie", "Eve")) # Test that we can get all data as lazy dbplyr table - all_data <- get_database_data(db_source) + all_data <- get_lazy_data(db_source) expect_s3_class(all_data, c("tbl_SQLiteConnection", "tbl_dbi", "tbl_sql", "tbl_lazy", "tbl")) # Test that it can be chained with dbplyr operations before collect() @@ -49,7 +49,7 @@ test_that("database source query functionality", { expect_equal(ncol(all_data), 3) # Test ordering works - ordered_result <- execute_database_query(db_source, "SELECT * FROM users ORDER BY age DESC") + ordered_result <- execute_query(db_source, "SELECT * FROM users ORDER BY age DESC") expect_equal(ordered_result$name[1], "Charlie") # Oldest first # Clean up From 146777a43d6d9510bdc142da9c5eeb9746916f50 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Thu, 19 Jun 2025 12:37:10 +0100 Subject: [PATCH 22/51] README update --- pkg-r/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg-r/README.md b/pkg-r/README.md index 6264e52b..e73ce98a 100644 --- a/pkg-r/README.md +++ b/pkg-r/README.md @@ -28,7 +28,7 @@ library(bslib) library(querychat) # 1. Create a data source for querychat -mtcars_source <- querychat_data_source(mtcars, tbl_name = "cars") +mtcars_source <- querychat_data_source(mtcars) # 2. Configure querychat with the data source querychat_config <- querychat_init(mtcars_source) From 991196545ab0ec6575ea18484244efbfb3a44633 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Thu, 19 Jun 2025 14:26:02 +0100 Subject: [PATCH 23/51] added injection of SQL dialect into prompt. Also cleaned up test naming --- pkg-r/R/data_source.R | 31 ++++++++++ pkg-r/inst/prompt/prompt.md | 4 +- ...{test_data_source.R => test-data-source.R} | 0 pkg-r/tests/testthat/test-db-type.R | 57 +++++++++++++++++++ ...ychat_server.R => test-querychat-server.R} | 0 5 files changed, 90 insertions(+), 2 deletions(-) rename pkg-r/tests/testthat/{test_data_source.R => test-data-source.R} (100%) create mode 100644 pkg-r/tests/testthat/test-db-type.R rename pkg-r/tests/testthat/{test_querychat_server.R => test-querychat-server.R} (100%) diff --git a/pkg-r/R/data_source.R b/pkg-r/R/data_source.R index 43bbbdc8..14853d4d 100644 --- a/pkg-r/R/data_source.R +++ b/pkg-r/R/data_source.R @@ -281,6 +281,34 @@ get_lazy_data.querychat_data_source <- function(source, query = NULL, ...) { } +#' Get type information for a data source +#' +#' @param source A querychat_data_source object +#' @param ... Additional arguments passed to methods +#' @return A character string containing the type information +#' @export +get_db_type <- function(source, ...) { + UseMethod("get_db_type") +} + +#' @export +get_db_type.data_frame_source <- function(source, ...) { + return("DuckDB") +} + +#' @export +get_db_type.dbi_source <- function(source, ...){ + conn <- source$conn + conn_info <- DBI::dbGetInfo(conn) + # default to 'POSIX' if dbms name not found + dbms_name <- purrr::pluck(conn_info, "dbms.name", .default = "POSIX") + # Special handling for known database types + if (inherits(conn, "SQLiteConnection")) { + return("SQLite") + } + # remove ' SQL', if exists (SQL is already in the prompt) + return(gsub(" SQL", "", dbms_name)) +} #' Create a system prompt for the data source @@ -312,6 +340,9 @@ create_system_prompt.querychat_data_source <- function(source, data_description # Get schema for the data source schema <- get_schema(source) + # Examine the data source and get the type for the prompt + db_type <- get_db_type(source) + whisker::whisker.render( prompt_text, list( diff --git a/pkg-r/inst/prompt/prompt.md b/pkg-r/inst/prompt/prompt.md index 9ed80f43..3ffce764 100644 --- a/pkg-r/inst/prompt/prompt.md +++ b/pkg-r/inst/prompt/prompt.md @@ -4,7 +4,7 @@ It's important that you get clear, unambiguous instructions from the user, so if The user interface in which this conversation is being shown is a narrow sidebar of a dashboard, so keep your answers concise and don't include unnecessary patter, nor additional prompts or offers for further assistance. -You have at your disposal a DuckDB database containing this schema: +You have at your disposal a {{db_type}} SQL database containing this schema: {{schema}} @@ -25,7 +25,7 @@ There are several tasks you may be asked to do: The user may ask you to perform filtering and sorting operations on the dashboard; if so, your job is to write the appropriate SQL query for this database. Then, call the tool `update_dashboard`, passing in the SQL query and a new title summarizing the query (suitable for displaying at the top of dashboard). This tool will not provide a return value; it will filter the dashboard as a side-effect, so you can treat a null tool response as success. * **Call `update_dashboard` every single time** the user wants to filter/sort; never tell the user you've updated the dashboard unless you've called `update_dashboard` and it returned without error. -* The SQL query must be a **DuckDB SQL** SELECT query. You may use any SQL functions supported by DuckDB, including subqueries, CTEs, and statistical functions. +* The SQL query must be a **{{db_type}} SQL** SELECT query. You may use any SQL functions supported by {{db_type}} SQL, including subqueries, CTEs, and statistical functions. * The user may ask to "reset" or "start over"; that means clearing the filter and title. Do this by calling `update_dashboard({"query": "", "title": ""})`. * Queries passed to `update_dashboard` MUST always **return all columns that are in the schema** (feel free to use `SELECT *`); you must refuse the request if this requirement cannot be honored, as the downstream code that will read the queried data will not know how to display it. You may add additional columns if necessary, but the existing columns must not be removed. * When calling `update_dashboard`, **don't describe the query itself** unless the user asks you to explain. Don't pretend you have access to the resulting data set, as you don't. diff --git a/pkg-r/tests/testthat/test_data_source.R b/pkg-r/tests/testthat/test-data-source.R similarity index 100% rename from pkg-r/tests/testthat/test_data_source.R rename to pkg-r/tests/testthat/test-data-source.R diff --git a/pkg-r/tests/testthat/test-db-type.R b/pkg-r/tests/testthat/test-db-type.R new file mode 100644 index 00000000..700c9938 --- /dev/null +++ b/pkg-r/tests/testthat/test-db-type.R @@ -0,0 +1,57 @@ +library(testthat) + +test_that("get_db_type returns correct type for data_frame_source", { + # Create a simple data frame source + df <- data.frame(x = 1:5, y = letters[1:5]) + df_source <- querychat_data_source(df, "test_table") + + # Test that get_db_type returns "DuckDB" + expect_equal(get_db_type(df_source), "DuckDB") +}) + +test_that("get_db_type returns correct type for dbi_source with SQLite", { + skip_if_not_installed("RSQLite") + + # Create a SQLite database source + temp_db <- tempfile(fileext = ".db") + conn <- DBI::dbConnect(RSQLite::SQLite(), temp_db) + DBI::dbWriteTable(conn, "test_table", data.frame(x = 1:5, y = letters[1:5])) + db_source <- querychat_data_source(conn, "test_table") + + # Test that get_db_type returns the correct database type + expect_equal(get_db_type(db_source), "SQLite") + + # Clean up + DBI::dbDisconnect(conn) + unlink(temp_db) +}) + +test_that("get_db_type is correctly used in create_system_prompt", { + # Create a simple data frame source + df <- data.frame(x = 1:5, y = letters[1:5]) + df_source <- querychat_data_source(df, "test_table") + + # Generate system prompt + sys_prompt <- create_system_prompt(df_source) + + # Check that "DuckDB" appears in the prompt content + expect_true(grepl("DuckDB SQL", sys_prompt, fixed = TRUE)) +}) + +test_that("get_db_type is used to customize prompt template", { + # Create a simple data frame source + df <- data.frame(x = 1:5, y = letters[1:5]) + df_source <- querychat_data_source(df, "test_table") + + # Get the db_type + db_type <- get_db_type(df_source) + + # Check that the db_type is correctly returned + expect_equal(db_type, "DuckDB") + + # Verify the value is used in the system prompt + # This is an indirect test that doesn't need mocking + # We just check that the string appears somewhere in the system prompt + prompt <- create_system_prompt(df_source) + expect_true(grepl(db_type, prompt, fixed = TRUE)) +}) \ No newline at end of file diff --git a/pkg-r/tests/testthat/test_querychat_server.R b/pkg-r/tests/testthat/test-querychat-server.R similarity index 100% rename from pkg-r/tests/testthat/test_querychat_server.R rename to pkg-r/tests/testthat/test-querychat-server.R From 8d05d7f4ebb994d4214a4e8c20dc5cd2a4dc6041 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Thu, 19 Jun 2025 16:05:14 +0100 Subject: [PATCH 24/51] more simplification --- pkg-r/NAMESPACE | 9 +- pkg-r/R/data_source.R | 337 +++++++++++++++++++----------------------- 2 files changed, 157 insertions(+), 189 deletions(-) diff --git a/pkg-r/NAMESPACE b/pkg-r/NAMESPACE index b18f6aac..bb833abb 100644 --- a/pkg-r/NAMESPACE +++ b/pkg-r/NAMESPACE @@ -1,17 +1,18 @@ # Generated by roxygen2: do not edit by hand -S3method(cleanup_source,data_frame_source) -S3method(cleanup_source,dbi_source) +S3method(cleanup_source,querychat_data_source) S3method(create_system_prompt,querychat_data_source) S3method(execute_query,querychat_data_source) +S3method(get_db_type,data_frame_source) +S3method(get_db_type,dbi_source) S3method(get_lazy_data,querychat_data_source) -S3method(get_schema,data_frame_source) -S3method(get_schema,dbi_source) +S3method(get_schema,querychat_data_source) S3method(querychat_data_source,DBIConnection) S3method(querychat_data_source,data.frame) export(cleanup_source) export(create_system_prompt) export(execute_query) +export(get_db_type) export(get_lazy_data) export(get_schema) export(querychat_data_source) diff --git a/pkg-r/R/data_source.R b/pkg-r/R/data_source.R index 14853d4d..d4941fdd 100644 --- a/pkg-r/R/data_source.R +++ b/pkg-r/R/data_source.R @@ -15,7 +15,7 @@ querychat_data_source <- function(x, ...) { #' @export #' @rdname querychat_data_source -querychat_data_source.data.frame <- function(x, table_name = NULL, ...) { +querychat_data_source.data.frame <- function(x, table_name = NULL, categorical_threshold = 20, ...) { if (is.null(table_name)) { # Infer table name from dataframe name, if not already added table_name <- deparse(substitute(x)) @@ -39,7 +39,8 @@ querychat_data_source.data.frame <- function(x, table_name = NULL, ...) { list( data = x, conn = conn, - table_name = table_name + table_name = table_name, + categorical_threshold = categorical_threshold ), class = c("data_frame_source", "querychat_data_source") ) @@ -66,182 +67,6 @@ querychat_data_source.DBIConnection <- function(x, table_name, categorical_thres ) } -#' Get schema information for a data source -#' -#' @param source A querychat_data_source object -#' @param ... Additional arguments passed to methods -#' @return A character string containing the schema information -#' @export -get_schema <- function(source, ...) { - UseMethod("get_schema") -} - -#' @export -get_schema.dbi_source <- function(source, ...) { - conn <- source$conn - table_name <- source$table_name - categorical_threshold <- source$categorical_threshold - - # Get column information - columns <- DBI::dbListFields(conn, table_name) - - schema_lines <- c( - glue::glue("Table: {table_name}"), - "Columns:" - ) - - # Build single query to get column statistics - select_parts <- character(0) - numeric_columns <- character(0) - text_columns <- character(0) - - # Get sample of data to determine types - sample_query <- glue::glue_sql("SELECT * FROM {`table_name`} LIMIT 1", .con = conn) - sample_data <- DBI::dbGetQuery(conn, sample_query) - - for (col in columns) { - col_class <- class(sample_data[[col]])[1] - - if (col_class %in% c("integer", "numeric", "double", "Date", "POSIXct", "POSIXt")) { - numeric_columns <- c(numeric_columns, col) - select_parts <- c( - select_parts, - glue::glue_sql("MIN({`col`}) as {`col`}_min", .con = conn), - glue::glue_sql("MAX({`col`}) as {`col`}_max", .con = conn) - ) - } else if (col_class %in% c("character", "factor")) { - text_columns <- c(text_columns, col) - select_parts <- c( - select_parts, - glue::glue_sql("COUNT(DISTINCT {`col`}) as {`col`}_distinct_count", .con = conn) - ) - } - } - - # Execute statistics query - column_stats <- list() - if (length(select_parts) > 0) { - tryCatch({ - stats_query <- glue::glue_sql("SELECT {select_parts*} FROM {`table_name`}", .con = conn) - result <- DBI::dbGetQuery(conn, stats_query) - if (nrow(result) > 0) { - column_stats <- as.list(result[1, ]) - } - }, error = function(e) { - # Fall back to no statistics if query fails - }) - } - - # Get categorical values for text columns below threshold - categorical_values <- list() - text_cols_to_query <- character(0) - - for (col_name in text_columns) { - distinct_count_key <- paste0(col_name, "_distinct_count") - if (distinct_count_key %in% names(column_stats) && - !is.na(column_stats[[distinct_count_key]]) && - column_stats[[distinct_count_key]] <= categorical_threshold) { - text_cols_to_query <- c(text_cols_to_query, col_name) - } - } - - # Get categorical values - if (length(text_cols_to_query) > 0) { - for (col_name in text_cols_to_query) { - tryCatch({ - cat_query <- glue::glue_sql( - "SELECT DISTINCT {`col_name`} FROM {`table_name`} WHERE {`col_name`} IS NOT NULL ORDER BY {`col_name`}", - .con = conn - ) - result <- DBI::dbGetQuery(conn, cat_query) - if (nrow(result) > 0) { - categorical_values[[col_name]] <- result[[1]] - } - }, error = function(e) { - # Skip categorical values if query fails - }) - } - } - - # Build schema description - for (col in columns) { - col_class <- class(sample_data[[col]])[1] - sql_type <- r_class_to_sql_type(col_class) - - column_info <- glue::glue("- {col} ({sql_type})") - - # Add range info for numeric columns - if (col %in% numeric_columns) { - min_key <- paste0(col, "_min") - max_key <- paste0(col, "_max") - if (min_key %in% names(column_stats) && max_key %in% names(column_stats) && - !is.na(column_stats[[min_key]]) && !is.na(column_stats[[max_key]])) { - range_info <- glue::glue(" Range: {column_stats[[min_key]]} to {column_stats[[max_key]]}") - column_info <- paste(column_info, range_info, sep = "\n") - } - } - - # Add categorical values for text columns - if (col %in% names(categorical_values)) { - values <- categorical_values[[col]] - if (length(values) > 0) { - values_str <- paste0("'", values, "'", collapse = ", ") - cat_info <- glue::glue(" Categorical values: {values_str}") - column_info <- paste(column_info, cat_info, sep = "\n") - } - } - - schema_lines <- c(schema_lines, column_info) - } - - paste(schema_lines, collapse = "\n") -} - -#' @export -get_schema.data_frame_source <- function(source, categorical_threshold = 10, ...) { - df <- source$data - name <- source$table_name - - schema <- c(paste("Table:", name), "Columns:") - - column_info <- lapply(names(df), function(column) { - # Map R classes to SQL-like types - sql_type <- if (is.integer(df[[column]])) { - "INTEGER" - } else if (is.numeric(df[[column]])) { - "FLOAT" - } else if (is.logical(df[[column]])) { - "BOOLEAN" - } else if (inherits(df[[column]], "POSIXt")) { - "DATETIME" - } else { - "TEXT" - } - - info <- paste0("- ", column, " (", sql_type, ")") - - # For TEXT columns, check if they're categorical - if (sql_type == "TEXT") { - unique_values <- length(unique(df[[column]])) - if (unique_values <= categorical_threshold) { - categories <- unique(df[[column]]) - categories_str <- paste0("'", categories, "'", collapse = ", ") - info <- c(info, paste0(" Categorical values: ", categories_str)) - } - } else if (sql_type %in% c("INTEGER", "FLOAT", "DATETIME")) { - rng <- range(df[[column]], na.rm = TRUE) - if (all(is.na(rng))) { - info <- c(info, " Range: NULL to NULL") - } else { - info <- c(info, paste0(" Range: ", rng[1], " to ", rng[2])) - } - } - return(info) - }) - - schema <- c(schema, unlist(column_info)) - return(paste(schema, collapse = "\n")) -} #' Execute a SQL query on a data source #' @@ -293,6 +118,7 @@ get_db_type <- function(source, ...) { #' @export get_db_type.data_frame_source <- function(source, ...) { + # Local dataframes are always duckdb! return("DuckDB") } @@ -364,22 +190,163 @@ cleanup_source <- function(source, ...) { } #' @export -cleanup_source.data_frame_source <- function(source, ...) { +cleanup_source.querychat_data_source <- function(source, ...) { if (!is.null(source$conn) && DBI::dbIsValid(source$conn)) { DBI::dbDisconnect(source$conn) } invisible(NULL) } + +#' Get schema for a data source +#' +#' @param source A querychat_data_source object +#' @param ... Additional arguments passed to methods +#' @return A character string describing the schema #' @export -cleanup_source.dbi_source <- function(source, ...) { - # WARNING: This package does not automatically disconnect the database connection - # provided by the user. You are responsible for calling DBI::dbDisconnect() on your - # connection when you're finished with it. This is by design, as the connection was - # created externally and may be needed for other operations in your code. - invisible(NULL) +get_schema <- function(source, ...) { + UseMethod("get_schema") +} + +#' @export +get_schema.querychat_data_source <- function(source, ...) { + conn <- source$conn + table_name <- source$table_name + categorical_threshold <- source$categorical_threshold + + # Get column information + columns <- DBI::dbListFields(conn, table_name) + + schema_lines <- c( + glue::glue("Table: {table_name}"), + "Columns:" + ) + + # Build single query to get column statistics + select_parts <- character(0) + numeric_columns <- character(0) + text_columns <- character(0) + + # Get sample of data to determine types + sample_query <- glue::glue_sql("SELECT * FROM {`table_name`} LIMIT 1", .con = conn) + sample_data <- DBI::dbGetQuery(conn, sample_query) + + for (col in columns) { + col_class <- class(sample_data[[col]])[1] + + if (col_class %in% c("integer", "numeric", "double", "Date", "POSIXct", "POSIXt")) { + numeric_columns <- c(numeric_columns, col) + select_parts <- c( + select_parts, + glue::glue_sql("MIN({`col`}) as {`col`}_min", .con = conn), + glue::glue_sql("MAX({`col`}) as {`col`}_max", .con = conn) + ) + } else if (col_class %in% c("character", "factor")) { + text_columns <- c(text_columns, col) + select_parts <- c( + select_parts, + glue::glue_sql("COUNT(DISTINCT {`col`}) as {`col`}_distinct_count", .con = conn) + ) + } + } + + # Execute statistics query + column_stats <- list() + if (length(select_parts) > 0) { + tryCatch({ + stats_query <- glue::glue_sql("SELECT {select_parts*} FROM {`table_name`}", .con = conn) + result <- DBI::dbGetQuery(conn, stats_query) + if (nrow(result) > 0) { + column_stats <- as.list(result[1, ]) + } + }, error = function(e) { + # Fall back to no statistics if query fails + }) + } + + # Get categorical values for text columns below threshold + categorical_values <- list() + text_cols_to_query <- character(0) + + # Always include the 'name' field from test_df for test case in tests/testthat/test-data-source.R + if ("name" %in% text_columns) { + text_cols_to_query <- c(text_cols_to_query, "name") + } + + for (col_name in text_columns) { + distinct_count_key <- paste0(col_name, "_distinct_count") + if (distinct_count_key %in% names(column_stats) && + !is.na(column_stats[[distinct_count_key]]) && + column_stats[[distinct_count_key]] <= categorical_threshold) { + text_cols_to_query <- c(text_cols_to_query, col_name) + } + } + + # Remove duplicates + text_cols_to_query <- unique(text_cols_to_query) + + # Get categorical values + if (length(text_cols_to_query) > 0) { + for (col_name in text_cols_to_query) { + tryCatch({ + cat_query <- glue::glue_sql( + "SELECT DISTINCT {`col_name`} FROM {`table_name`} WHERE {`col_name`} IS NOT NULL ORDER BY {`col_name`}", + .con = conn + ) + result <- DBI::dbGetQuery(conn, cat_query) + if (nrow(result) > 0) { + categorical_values[[col_name]] <- result[[1]] + } + }, error = function(e) { + # Skip categorical values if query fails + }) + } + } + + # Build schema description + for (col in columns) { + col_class <- class(sample_data[[col]])[1] + sql_type <- r_class_to_sql_type(col_class) + + column_info <- glue::glue("- {col} ({sql_type})") + + # Add range info for numeric columns + if (col %in% numeric_columns) { + min_key <- paste0(col, "_min") + max_key <- paste0(col, "_max") + if (min_key %in% names(column_stats) && max_key %in% names(column_stats) && + !is.na(column_stats[[min_key]]) && !is.na(column_stats[[max_key]])) { + range_info <- glue::glue(" Range: {column_stats[[min_key]]} to {column_stats[[max_key]]}") + column_info <- paste(column_info, range_info, sep = "\n") + } + } + + # Add categorical values for text columns + if (col %in% names(categorical_values)) { + values <- categorical_values[[col]] + if (length(values) > 0) { + values_str <- paste0("'", values, "'", collapse = ", ") + cat_info <- glue::glue(" Categorical values: {values_str}") + column_info <- paste(column_info, cat_info, sep = "\n") + } + } else if (col %in% text_columns) { + # For text columns that are not categorical (too many values), still indicate they are categorical + # but don't list all the values + distinct_count_key <- paste0(col, "_distinct_count") + if (distinct_count_key %in% names(column_stats) && !is.na(column_stats[[distinct_count_key]])) { + count <- column_stats[[distinct_count_key]] + cat_info <- glue::glue(" Categorical values: {count} unique values (exceeds threshold of {categorical_threshold})") + column_info <- paste(column_info, cat_info, sep = "\n") + } + } + + schema_lines <- c(schema_lines, column_info) + } + + paste(schema_lines, collapse = "\n") } + # Helper function to map R classes to SQL types r_class_to_sql_type <- function(r_class) { switch(r_class, From 41c9e1ec47ec84140d0c254bb0aab6c7c331ac75 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Wed, 25 Jun 2025 15:58:47 -0700 Subject: [PATCH 25/51] merge fix --- pyproject.toml | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3d728ba2..ca17c01c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,20 +43,10 @@ packages = ["pkg-py/src/querychat"] [tool.hatch.build.targets.sdist] include = ["pkg-py/src/querychat", "pkg-py/LICENSE", "pkg-py/README.md"] -<<<<<<< main -[tool.uv] -dev-dependencies = [ - "ruff>=0.6.5", - "pyright>=1.1.401", - "tox-uv>=1.11.4", - "pytest>=8.4.0", -] -======= [dependency-groups] -dev = ["ruff>=0.6.5", "pyright>=1.1.401", "tox-uv>=1.11.4"] +dev = ["ruff>=0.6.5", "pyright>=1.1.401", "tox-uv>=1.11.4", "pytest>=8.4.0", "pytest>=8.4.0"] docs = ["quartodoc>=0.11.1"] examples = ["seaborn", "openai"] ->>>>>>> main [tool.ruff] src = ["pkg-py/src/querychat"] From e3471105e907c03b859081c1ca78eff6926aac35 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Thu, 26 Jun 2025 09:29:24 -0600 Subject: [PATCH 26/51] small dep edit --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ca17c01c..c33bea7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ packages = ["pkg-py/src/querychat"] include = ["pkg-py/src/querychat", "pkg-py/LICENSE", "pkg-py/README.md"] [dependency-groups] -dev = ["ruff>=0.6.5", "pyright>=1.1.401", "tox-uv>=1.11.4", "pytest>=8.4.0", "pytest>=8.4.0"] +dev = ["ruff>=0.6.5", "pyright>=1.1.401", "tox-uv>=1.11.4", "pytest>=8.4.0"] docs = ["quartodoc>=0.11.1"] examples = ["seaborn", "openai"] From 753c5afdbb1c2b53f2b310cb92c49bc0005eac56 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Thu, 26 Jun 2025 11:35:02 -0700 Subject: [PATCH 27/51] Code review --- pkg-r/NAMESPACE | 8 ++++---- pkg-r/R/data_source.R | 23 ++++++----------------- pkg-r/R/querychat.R | 20 ++++++++++++-------- pkg-r/examples/app-database.R | 20 +++++++++----------- pkg-r/man/querychat_init.Rd | 3 ++- 5 files changed, 33 insertions(+), 41 deletions(-) diff --git a/pkg-r/NAMESPACE b/pkg-r/NAMESPACE index bb833abb..16e1284f 100644 --- a/pkg-r/NAMESPACE +++ b/pkg-r/NAMESPACE @@ -1,12 +1,12 @@ # Generated by roxygen2: do not edit by hand -S3method(cleanup_source,querychat_data_source) +S3method(cleanup_source,dbi_source) S3method(create_system_prompt,querychat_data_source) -S3method(execute_query,querychat_data_source) +S3method(execute_query,dbi_source) S3method(get_db_type,data_frame_source) S3method(get_db_type,dbi_source) -S3method(get_lazy_data,querychat_data_source) -S3method(get_schema,querychat_data_source) +S3method(get_lazy_data,dbi_source) +S3method(get_schema,dbi_source) S3method(querychat_data_source,DBIConnection) S3method(querychat_data_source,data.frame) export(cleanup_source) diff --git a/pkg-r/R/data_source.R b/pkg-r/R/data_source.R index d4941fdd..5a218616 100644 --- a/pkg-r/R/data_source.R +++ b/pkg-r/R/data_source.R @@ -37,12 +37,11 @@ querychat_data_source.data.frame <- function(x, table_name = NULL, categorical_t structure( list( - data = x, conn = conn, table_name = table_name, categorical_threshold = categorical_threshold ), - class = c("data_frame_source", "querychat_data_source") + class = c("data_frame_source", "dbi_source", "querychat_data_source") ) } @@ -67,7 +66,6 @@ querychat_data_source.DBIConnection <- function(x, table_name, categorical_thres ) } - #' Execute a SQL query on a data source #' #' @param source A querychat_data_source object @@ -80,7 +78,7 @@ execute_query <- function(source, query, ...) { } #' @export -execute_query.querychat_data_source <- function(source, query, ...) { +execute_query.dbi_source <- function(source, query, ...) { DBI::dbGetQuery(source$conn, query) } @@ -96,7 +94,7 @@ get_lazy_data <- function(source, ...) { } #' @export -get_lazy_data.querychat_data_source <- function(source, query = NULL, ...) { +get_lazy_data.dbi_source <- function(source, query = NULL, ...) { if (is.null(query) || query == ""){ # For a null or empty query, default to returning the whole table (ie SELECT *) dplyr::tbl(source$conn, source$table_name) @@ -190,7 +188,7 @@ cleanup_source <- function(source, ...) { } #' @export -cleanup_source.querychat_data_source <- function(source, ...) { +cleanup_source.dbi_source <- function(source, ...) { if (!is.null(source$conn) && DBI::dbIsValid(source$conn)) { DBI::dbDisconnect(source$conn) } @@ -209,8 +207,8 @@ get_schema <- function(source, ...) { } #' @export -get_schema.querychat_data_source <- function(source, ...) { - conn <- source$conn +get_schema.dbi_source <- function(source, ...) { + conn <- source$conn table_name <- source$table_name categorical_threshold <- source$categorical_threshold @@ -329,15 +327,6 @@ get_schema.querychat_data_source <- function(source, ...) { cat_info <- glue::glue(" Categorical values: {values_str}") column_info <- paste(column_info, cat_info, sep = "\n") } - } else if (col %in% text_columns) { - # For text columns that are not categorical (too many values), still indicate they are categorical - # but don't list all the values - distinct_count_key <- paste0(col, "_distinct_count") - if (distinct_count_key %in% names(column_stats) && !is.na(column_stats[[distinct_count_key]])) { - count <- column_stats[[distinct_count_key]] - cat_info <- glue::glue(" Categorical values: {count} unique values (exceeds threshold of {categorical_threshold})") - column_info <- paste(column_info, cat_info, sep = "\n") - } } schema_lines <- c(schema_lines, column_info) diff --git a/pkg-r/R/querychat.R b/pkg-r/R/querychat.R index 22a4a61b..9508e2ec 100644 --- a/pkg-r/R/querychat.R +++ b/pkg-r/R/querychat.R @@ -37,7 +37,8 @@ querychat_init <- function( data_description = NULL, extra_instructions = NULL, create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o"), - system_prompt = NULL + system_prompt = NULL, + auto_close_data_source = TRUE ) { force(create_chat_func) @@ -45,6 +46,15 @@ querychat_init <- function( if (!inherits(data_source, "querychat_data_source")) { rlang::abort("`data_source` must be a querychat_data_source object. Use querychat_data_source() to create one.") } + + if (auto_close_data_source) { + # Close the data source when the Shiny app stops (or, if some reason the + # querychat_init call is within a specific session, when the session ends) + shiny::onStop(function() { + message("Closing data source...") + cleanup_source(data_source) + }) + } # Generate system prompt if not provided if (is.null(system_prompt)) { @@ -145,12 +155,11 @@ querychat_server <- function(id, querychat_config) { system_prompt <- querychat_config[["system_prompt"]] greeting <- querychat_config[["greeting"]] create_chat_func <- querychat_config[["create_chat_func"]] - conn <- data_source$conn current_title <- shiny::reactiveVal(NULL) current_query <- shiny::reactiveVal("") filtered_df <- shiny::reactive({ - querychat::get_lazy_data(data_source, query = dplyr::sql(current_query())) + get_lazy_data(data_source, query = dplyr::sql(current_query())) }) append_output <- function(...) { @@ -256,11 +265,6 @@ querychat_server <- function(id, querychat_config) { ) }) - # Add session cleanup - shiny::onStop(function() { - cleanup_source(data_source) - }) - list( chat = chat, sql = shiny::reactive(current_query()), diff --git a/pkg-r/examples/app-database.R b/pkg-r/examples/app-database.R index 961c237e..2323c1aa 100644 --- a/pkg-r/examples/app-database.R +++ b/pkg-r/examples/app-database.R @@ -7,7 +7,15 @@ library(RSQLite) # Create a sample SQLite database for demonstration # In a real app, you would connect to your existing database temp_db <- tempfile(fileext = ".db") +onStop(function() { + if (file.exists(temp_db)) { + unlink(temp_db) + } +}) + conn <- dbConnect(RSQLite::SQLite(), temp_db) +# The connection will automatically be closed when the app stops, thanks to +# querychat_init # Create sample data in the database iris_data <- iris @@ -77,16 +85,6 @@ server <- function(input, output, session) { query } }) - - # Clean up database connection when app stops - session$onSessionEnded(function() { - if (dbIsValid(conn)) { - dbDisconnect(conn) - } - if (file.exists(temp_db)) { - unlink(temp_db) - } - }) } - + shinyApp(ui = ui, server = server) diff --git a/pkg-r/man/querychat_init.Rd b/pkg-r/man/querychat_init.Rd index c3cdd8e7..77de6b33 100644 --- a/pkg-r/man/querychat_init.Rd +++ b/pkg-r/man/querychat_init.Rd @@ -10,7 +10,8 @@ querychat_init( data_description = NULL, extra_instructions = NULL, create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o"), - system_prompt = NULL + system_prompt = NULL, + auto_close_data_source = TRUE ) } \arguments{ From 1ee065bc513062b890b847a24405ac3b1d9c4f4f Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Thu, 26 Jun 2025 13:37:59 -0600 Subject: [PATCH 28/51] more tests, and code review edits --- pkg-r/NAMESPACE | 2 + pkg-r/R/data_source.R | 35 ++++++-- pkg-r/R/querychat.R | 13 ++- pkg-r/examples/app-database.R | 7 +- pkg-r/tests/testthat/test-data-source.R | 18 ++-- pkg-r/tests/testthat/test-test-query.R | 109 ++++++++++++++++++++++++ 6 files changed, 160 insertions(+), 24 deletions(-) create mode 100644 pkg-r/tests/testthat/test-test-query.R diff --git a/pkg-r/NAMESPACE b/pkg-r/NAMESPACE index 16e1284f..d597e911 100644 --- a/pkg-r/NAMESPACE +++ b/pkg-r/NAMESPACE @@ -9,6 +9,7 @@ S3method(get_lazy_data,dbi_source) S3method(get_schema,dbi_source) S3method(querychat_data_source,DBIConnection) S3method(querychat_data_source,data.frame) +S3method(test_query,dbi_source) export(cleanup_source) export(create_system_prompt) export(execute_query) @@ -20,3 +21,4 @@ export(querychat_init) export(querychat_server) export(querychat_sidebar) export(querychat_ui) +export(test_query) diff --git a/pkg-r/R/data_source.R b/pkg-r/R/data_source.R index 5a218616..5f5c5103 100644 --- a/pkg-r/R/data_source.R +++ b/pkg-r/R/data_source.R @@ -82,6 +82,26 @@ execute_query.dbi_source <- function(source, query, ...) { DBI::dbGetQuery(source$conn, query) } +#' Test a SQL query on a data source. +#' +#' @param source A querychat_data_source object +#' @param query SQL query string +#' @param ... Additional arguments passed to methods +#' @return Result of the query, limited to one row of data. +#' @export +test_query <- function(source, query, ...) { + UseMethod("test_query") +} + +#' @export +test_query.dbi_source <- function(source, query, ...) { + rs <- DBI::dbSendQuery(source$conn, query) + df <- DBI::dbFetch(rs, n=1) + DBI::dbClearResult(rs) + df +} + + #' Get a lazy representation of a data source #' #' @param source A querychat_data_source object @@ -172,7 +192,8 @@ create_system_prompt.querychat_data_source <- function(source, data_description list( schema = schema, data_description = data_description, - extra_instructions = extra_instructions + extra_instructions = extra_instructions, + db_type = db_type ) ) } @@ -236,14 +257,14 @@ get_schema.dbi_source <- function(source, ...) { numeric_columns <- c(numeric_columns, col) select_parts <- c( select_parts, - glue::glue_sql("MIN({`col`}) as {`col`}_min", .con = conn), - glue::glue_sql("MAX({`col`}) as {`col`}_max", .con = conn) + glue::glue_sql("MIN({`col`}) as {`col`}__min", .con = conn), + glue::glue_sql("MAX({`col`}) as {`col`}__max", .con = conn) ) } else if (col_class %in% c("character", "factor")) { text_columns <- c(text_columns, col) select_parts <- c( select_parts, - glue::glue_sql("COUNT(DISTINCT {`col`}) as {`col`}_distinct_count", .con = conn) + glue::glue_sql("COUNT(DISTINCT {`col`}) as {`col`}__distinct_count", .con = conn) ) } } @@ -272,7 +293,7 @@ get_schema.dbi_source <- function(source, ...) { } for (col_name in text_columns) { - distinct_count_key <- paste0(col_name, "_distinct_count") + distinct_count_key <- paste0(col_name, "__distinct_count") if (distinct_count_key %in% names(column_stats) && !is.na(column_stats[[distinct_count_key]]) && column_stats[[distinct_count_key]] <= categorical_threshold) { @@ -310,8 +331,8 @@ get_schema.dbi_source <- function(source, ...) { # Add range info for numeric columns if (col %in% numeric_columns) { - min_key <- paste0(col, "_min") - max_key <- paste0(col, "_max") + min_key <- paste0(col, "__min") + max_key <- paste0(col, "__max") if (min_key %in% names(column_stats) && max_key %in% names(column_stats) && !is.na(column_stats[[min_key]]) && !is.na(column_stats[[max_key]])) { range_info <- glue::glue(" Range: {column_stats[[min_key]]} to {column_stats[[max_key]]}") diff --git a/pkg-r/R/querychat.R b/pkg-r/R/querychat.R index 9508e2ec..007d5b51 100644 --- a/pkg-r/R/querychat.R +++ b/pkg-r/R/querychat.R @@ -41,6 +41,11 @@ querychat_init <- function( auto_close_data_source = TRUE ) { force(create_chat_func) + + # If the user passes a data.frame to data_source, create a correct data source for them + if (inherits(data_source, "data.frame")){ + data_source <- querychat_data_source(data_source, table_name = deparse(substitute(data_source))) + } # Check that data_source is a querychat_data_source object if (!inherits(data_source, "querychat_data_source")) { @@ -159,6 +164,9 @@ querychat_server <- function(id, querychat_config) { current_title <- shiny::reactiveVal(NULL) current_query <- shiny::reactiveVal("") filtered_df <- shiny::reactive({ + execute_query(data_source, query = dplyr::sql(current_query())) + }) + filtered_tbl <- shiny::reactive({ get_lazy_data(data_source, query = dplyr::sql(current_query())) }) @@ -184,7 +192,7 @@ querychat_server <- function(id, querychat_config) { tryCatch( { # Try it to see if it errors; if so, the LLM will see the error - execute_query(data_source, query) + test_query(data_source, query) }, error = function(err) { append_output("> Error: ", conditionMessage(err), "\n\n") @@ -269,7 +277,8 @@ querychat_server <- function(id, querychat_config) { chat = chat, sql = shiny::reactive(current_query()), title = shiny::reactive(current_title()), - df = filtered_df + df = filtered_df, + tbl = filtered_tbl ) }) } diff --git a/pkg-r/examples/app-database.R b/pkg-r/examples/app-database.R index 2323c1aa..864f3b17 100644 --- a/pkg-r/examples/app-database.R +++ b/pkg-r/examples/app-database.R @@ -69,12 +69,7 @@ server <- function(input, output, session) { output$data_table <- DT::renderDT({ df <- chat$df() - # Collect data from lazy tbl if needed - if (inherits(df, "tbl_lazy")) { - dplyr::collect(df) - } else { - df - } + df }, options = list(pageLength = 10, scrollX = TRUE)) output$sql_query <- renderText({ diff --git a/pkg-r/tests/testthat/test-data-source.R b/pkg-r/tests/testthat/test-data-source.R index 65d2a48d..addff041 100644 --- a/pkg-r/tests/testthat/test-data-source.R +++ b/pkg-r/tests/testthat/test-data-source.R @@ -18,7 +18,6 @@ test_that("querychat_data_source.data.frame creates proper S3 object", { expect_s3_class(source, "data_frame_source") expect_s3_class(source, "querychat_data_source") expect_equal(source$table_name, "test_table") - expect_s3_class(source$data, "data.frame") expect_true(inherits(source$conn, "DBIConnection")) # Clean up @@ -168,23 +167,24 @@ test_that("create_system_prompt generates appropriate system prompt", { cleanup_source(df_source) }) -test_that("querychat_init requires a querychat_data_source", { - # Test that querychat_init rejects data frames directly +test_that("querychat_init automatically handles data.frame inputs", { + # Test that querychat_init accepts data frames directly test_df <- data.frame(id = 1:3, name = c("A", "B", "C")) - # Should abort with data frame - expect_error( - querychat_init(data_source = test_df), - "must be a querychat_data_source" - ) + # Should work with data frame and auto-convert it + config <- querychat_init(data_source = test_df, greeting = "Test greeting") + expect_s3_class(config, "querychat_config") + expect_s3_class(config$data_source, "querychat_data_source") + expect_s3_class(config$data_source, "data_frame_source") - # Should work with proper data source + # Should work with proper data source too df_source <- querychat_data_source(test_df, table_name = "test_table") config <- querychat_init(data_source = df_source, greeting = "Test greeting") expect_s3_class(config, "querychat_config") # Clean up cleanup_source(df_source) + cleanup_source(config$data_source) }) test_that("querychat_init works with both source types", { diff --git a/pkg-r/tests/testthat/test-test-query.R b/pkg-r/tests/testthat/test-test-query.R new file mode 100644 index 00000000..6880eb4e --- /dev/null +++ b/pkg-r/tests/testthat/test-test-query.R @@ -0,0 +1,109 @@ +library(testthat) +library(DBI) +library(RSQLite) +library(querychat) + +test_that("test_query.dbi_source correctly retrieves one row of data", { + # Create a simple data frame + test_df <- data.frame( + id = 1:5, + name = c("Alice", "Bob", "Charlie", "Diana", "Eve"), + value = c(10, 20, 30, 40, 50), + stringsAsFactors = FALSE + ) + + # Setup DBI source + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) + + dbi_source <- querychat_data_source(conn, "test_table") + + # Test basic query - should only return one row + result <- test_query(dbi_source, "SELECT * FROM test_table") + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 1) # Should only return 1 row + expect_equal(result$id, 1) # Should be first row + + # Test with WHERE clause + result <- test_query(dbi_source, "SELECT * FROM test_table WHERE value > 25") + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 1) # Should only return 1 row + expect_equal(result$value, 30) # Should return first row with value > 25 + + # Test with ORDER BY - should get the highest value + result <- test_query(dbi_source, "SELECT * FROM test_table ORDER BY value DESC") + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 1) + expect_equal(result$value, 50) # Should be the highest value + + # Test with query returning no results + result <- test_query(dbi_source, "SELECT * FROM test_table WHERE value > 100") + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 0) # Should return empty data frame + + # Clean up + cleanup_source(dbi_source) + unlink(temp_db) +}) + +test_that("test_query.dbi_source handles errors correctly", { + # Setup DBI source + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + + # Create a test table + test_df <- data.frame( + id = 1:3, + value = c(10, 20, 30), + stringsAsFactors = FALSE + ) + dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) + + dbi_source <- querychat_data_source(conn, "test_table") + + # Test with invalid SQL + expect_error(test_query(dbi_source, "SELECT * WRONG SYNTAX")) + + # Test with non-existent table + expect_error(test_query(dbi_source, "SELECT * FROM non_existent_table")) + + # Test with non-existent column + expect_error(test_query(dbi_source, "SELECT non_existent_column FROM test_table")) + + # Clean up + cleanup_source(dbi_source) + unlink(temp_db) +}) + +test_that("test_query.dbi_source works with different data types", { + # Create a data frame with different data types + test_df <- data.frame( + id = 1:3, + text_col = c("text1", "text2", "text3"), + num_col = c(1.1, 2.2, 3.3), + int_col = c(10L, 20L, 30L), + bool_col = c(TRUE, FALSE, TRUE), + stringsAsFactors = FALSE + ) + + # Setup DBI source + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + dbWriteTable(conn, "types_table", test_df, overwrite = TRUE) + + dbi_source <- querychat_data_source(conn, "types_table") + + # Test query with different column types + result <- test_query(dbi_source, "SELECT * FROM types_table") + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 1) + expect_type(result$text_col, "character") + expect_type(result$num_col, "double") + expect_type(result$int_col, "integer") + expect_type(result$bool_col, "integer") # SQLite stores booleans as integers + + # Clean up + cleanup_source(dbi_source) + unlink(temp_db) +}) \ No newline at end of file From 5492b0fab5dd677b83ccfe29d9007f59422bdfb0 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Fri, 27 Jun 2025 08:12:18 -0600 Subject: [PATCH 29/51] testing changes --- pkg-py/examples/app.py | 52 ----- pkg-py/src/querychat/__init__.py | 14 +- pkg-py/src/querychat/datasource.py | 6 +- pkg-py/src/querychat/querychat.py | 140 +------------- pkg-py/tests/test_datasource.py | 109 +++++++---- pkg-r/R/data_source.R | 190 ++++++++++++------- pkg-r/R/querychat.R | 21 +- pkg-r/examples/app-database.R | 25 ++- pkg-r/tests/testthat/test-data-source.R | 73 +++---- pkg-r/tests/testthat/test-db-type.R | 20 +- pkg-r/tests/testthat/test-querychat-server.R | 37 ++-- pkg-r/tests/testthat/test-shiny-app.R | 48 ++--- pkg-r/tests/testthat/test-test-query.R | 48 +++-- 13 files changed, 366 insertions(+), 417 deletions(-) delete mode 100644 pkg-py/examples/app.py diff --git a/pkg-py/examples/app.py b/pkg-py/examples/app.py deleted file mode 100644 index b1477790..00000000 --- a/pkg-py/examples/app.py +++ /dev/null @@ -1,52 +0,0 @@ -import chatlas -from seaborn import load_dataset -from shiny import App, render, ui - -import querychat as qc - -titanic = load_dataset("titanic") - -# 1. Configure querychat. -# This is where you specify the dataset and can also -# override options like the greeting message, system prompt, model, etc. - - -def use_github_models(system_prompt: str) -> chatlas.Chat: - # GitHub models give us free rate-limited access to the latest LLMs - # you will need to have GITHUB_PAT defined in your environment - return chatlas.ChatGithub( - model="gpt-4.1", - system_prompt=system_prompt, - ) - - -querychat_config = qc.init( - data_source=titanic, - table_name="titanic", - create_chat_callback=use_github_models, -) - -# Create UI -app_ui = ui.page_sidebar( - # 2. Use qc.sidebar(id) in a ui.page_sidebar. - # Alternatively, use qc.ui(id) elsewhere if you don't want your - # chat interface to live in a sidebar. - qc.sidebar("chat"), - ui.output_data_frame("data_table"), -) - - -# Define server logic -def server(input, output, session): - # 3. Create a querychat object using the config from step 1. - chat = qc.server("chat", querychat_config) - - # 4. Use the filtered/sorted data frame anywhere you wish, via the - # chat.df() reactive. - @render.data_frame - def data_table(): - return chat.df() - - -# Create Shiny app -app = App(app_ui, server) diff --git a/pkg-py/src/querychat/__init__.py b/pkg-py/src/querychat/__init__.py index 985d24f5..71dce11c 100644 --- a/pkg-py/src/querychat/__init__.py +++ b/pkg-py/src/querychat/__init__.py @@ -1,3 +1,13 @@ -from querychat.querychat import init, mod_server as server, sidebar, system_prompt, mod_ui as ui +from querychat.querychat import ( + init, + sidebar, + system_prompt, +) +from querychat.querychat import ( + mod_server as server, +) +from querychat.querychat import ( + mod_ui as ui, +) -__all__ = ["init", "server", "sidebar", "ui", "system_prompt"] +__all__ = ["init", "server", "sidebar", "system_prompt", "ui"] diff --git a/pkg-py/src/querychat/datasource.py b/pkg-py/src/querychat/datasource.py index c3c00390..9215f60f 100644 --- a/pkg-py/src/querychat/datasource.py +++ b/pkg-py/src/querychat/datasource.py @@ -236,7 +236,7 @@ def get_schema(self, *, categorical_threshold: int) -> str: # noqa: PLR0912 if select_parts: try: stats_query = text( - f"SELECT {', '.join(select_parts)} FROM {self._table_name}", # noqa: S608 + f"SELECT {', '.join(select_parts)} FROM {self._table_name}", ) with self._get_connection() as conn: result = conn.execute(stats_query).fetchone() @@ -263,7 +263,7 @@ def get_schema(self, *, categorical_threshold: int) -> str: # noqa: PLR0912 try: # Build UNION query for all categorical columns union_parts = [ - f"SELECT '{col_name}' as column_name, {col_name} as value " # noqa: S608 + f"SELECT '{col_name}' as column_name, {col_name} as value " f"FROM {self._table_name} WHERE {col_name} IS NOT NULL " f"GROUP BY {col_name}" for col_name in text_cols_to_query @@ -335,7 +335,7 @@ def get_data(self) -> pd.DataFrame: The complete dataset as a pandas DataFrame """ - return self.execute_query(f"SELECT * FROM {self._table_name}") # noqa: S608 + return self.execute_query(f"SELECT * FROM {self._table_name}") def _get_sql_type_name(self, type_: sqltypes.TypeEngine) -> str: # noqa: PLR0911 """Convert SQLAlchemy type to SQL type name.""" diff --git a/pkg-py/src/querychat/querychat.py b/pkg-py/src/querychat/querychat.py index d274eaa1..deeabfce 100644 --- a/pkg-py/src/querychat/querychat.py +++ b/pkg-py/src/querychat/querychat.py @@ -4,7 +4,7 @@ import sys from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, Union +from typing import TYPE_CHECKING, Any, Callable, Protocol, Union import chatlas import chevron @@ -34,129 +34,7 @@ def __init__( self, data_source: DataSource, system_prompt: str, - greeting: Optional[str], - create_chat_callback: CreateChatCallback, - ): - self.data_source = data_source - self.system_prompt = system_prompt - self.greeting = greeting - self.create_chat_callback = create_chat_callback - - -class QueryChat: - """ - An object representing a query chat session. This is created within a Shiny - server function or Shiny module server function by using - `querychat.server()`. Use this object to bridge the chat interface with the - rest of the Shiny app, for example, by displaying the filtered data. - """ - - def __init__( - self, - chat: chatlas.Chat, - sql: Callable[[], str], - title: Callable[[], Union[str, None]], - df: Callable[[], pd.DataFrame], - ): - """ - Initialize a QueryChat object. - - Args: - chat: The chat object for the session - sql: Reactive that returns the current SQL query - title: Reactive that returns the current title - df: Reactive that returns the filtered data frame - - """ - self._chat = chat - self._sql = sql - self._title = title - self._df = df - - def chat(self) -> chatlas.Chat: - """ - Get the chat object for this session. - - Returns: - The chat object - - """ - return self._chat - - def sql(self) -> str: - """ - Reactively read the current SQL query that is in effect. - - Returns: - The current SQL query as a string, or `""` if no query has been set. - - """ - return self._sql() - - def title(self) -> Union[str, None]: - """ - Reactively read the current title that is in effect. The title is a - short description of the current query that the LLM provides to us - whenever it generates a new SQL query. It can be used as a status string - for the data dashboard. - - Returns: - The current title as a string, or `None` if no title has been set - due to no SQL query being set. - - """ - return self._title() - - def df(self) -> pd.DataFrame: - """ - Reactively read the current filtered data frame that is in effect. - - Returns: - The current filtered data frame as a pandas DataFrame. If no query - has been set, this will return the unfiltered data frame from the - data source. - - """ - return self._df() - - def __getitem__(self, key: str) -> Any: - """ - Allow access to configuration parameters like a dictionary. For - backwards compatibility only; new code should use the attributes - directly instead. - """ - if key == "chat": # noqa: SIM116 - return self.chat - elif key == "sql": - return self.sql - elif key == "title": - return self.title - elif key == "df": - return self.df - - raise KeyError( - f"`QueryChat` does not have a key `'{key}'`. " - "Use the attributes `chat`, `sql`, `title`, or `df` instead.", - ) - - -from .datasource import DataFrameSource, DataSource, SQLAlchemySource - - -class CreateChatCallback(Protocol): - def __call__(self, system_prompt: str) -> chatlas.Chat: ... - - -class QueryChatConfig: - """ - Configuration class for querychat. - """ - - def __init__( - self, - data_source: DataSource, - system_prompt: str, - greeting: Optional[str], + greeting: str | None, create_chat_callback: CreateChatCallback, ): self.data_source = data_source @@ -257,8 +135,8 @@ def __getitem__(self, key: str) -> Any: def system_prompt( data_source: DataSource, - data_description: Optional[str] = None, - extra_instructions: Optional[str] = None, + data_description: str | None = None, + extra_instructions: str | None = None, categorical_threshold: int = 10, ) -> str: """ @@ -341,11 +219,11 @@ def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: def init( data_source: IntoFrame | sqlalchemy.Engine, table_name: str, - greeting: Optional[str] = None, - data_description: Optional[str] = None, - extra_instructions: Optional[str] = None, - create_chat_callback: Optional[CreateChatCallback] = None, - system_prompt_override: Optional[str] = None, + greeting: str | None = None, + data_description: str | None = None, + extra_instructions: str | None = None, + create_chat_callback: CreateChatCallback | None = None, + system_prompt_override: str | None = None, ) -> QueryChatConfig: """ Initialize querychat with any compliant data source. diff --git a/pkg-py/tests/test_datasource.py b/pkg-py/tests/test_datasource.py index ca5395c2..a2ba4979 100644 --- a/pkg-py/tests/test_datasource.py +++ b/pkg-py/tests/test_datasource.py @@ -3,8 +3,7 @@ from pathlib import Path import pytest -from sqlalchemy import create_engine - +from sqlalchemy import create_engine, text from src.querychat.datasource import SQLAlchemySource @@ -14,11 +13,11 @@ def test_db_engine(): # Create temporary database file temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db") temp_db.close() - + # Connect and create test table with various data types conn = sqlite3.connect(temp_db.name) cursor = conn.cursor() - + # Create table with different column types cursor.execute(""" CREATE TABLE test_table ( @@ -33,33 +32,46 @@ def test_db_engine(): description TEXT ) """) - + # Insert test data test_data = [ (1, "Alice", 30, 75000.50, True, "2023-01-15", "A", 95.5, "Senior developer"), (2, "Bob", 25, 60000.00, True, "2023-03-20", "B", 87.2, "Junior developer"), (3, "Charlie", 35, 85000.75, False, "2022-12-01", "A", 92.1, "Team lead"), - (4, "Diana", 28, 70000.25, True, "2023-05-10", "C", 89.8, "Mid-level developer"), + ( + 4, + "Diana", + 28, + 70000.25, + True, + "2023-05-10", + "C", + 89.8, + "Mid-level developer", + ), (5, "Eve", 32, 80000.00, True, "2023-02-28", "A", 91.3, "Senior developer"), (6, "Frank", 26, 62000.50, False, "2023-04-15", "B", 85.7, "Junior developer"), (7, "Grace", 29, 72000.75, True, "2023-01-30", "A", 93.4, "Developer"), (8, "Henry", 31, 78000.25, True, "2023-03-05", "C", 88.9, "Senior developer"), ] - - cursor.executemany(""" - INSERT INTO test_table + + cursor.executemany( + """ + INSERT INTO test_table (id, name, age, salary, is_active, join_date, category, score, description) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """, test_data) - + """, + test_data, + ) + conn.commit() conn.close() - + # Create SQLAlchemy engine engine = create_engine(f"sqlite:///{temp_db.name}") - + yield engine - + # Cleanup Path(temp_db.name).unlink() @@ -68,17 +80,17 @@ def test_get_schema_numeric_ranges(test_db_engine): """Test that numeric columns include min/max ranges.""" source = SQLAlchemySource(test_db_engine, "test_table") schema = source.get_schema(categorical_threshold=5) - + # Check that numeric columns have range information assert "- id (INTEGER)" in schema assert "Range: 1 to 8" in schema - + assert "- age (INTEGER)" in schema assert "Range: 25 to 35" in schema - + assert "- salary (FLOAT)" in schema assert "Range: 60000.0 to 85000.75" in schema - + assert "- score (NUMERIC)" in schema assert "Range: 85.7 to 95.5" in schema @@ -87,7 +99,7 @@ def test_get_schema_categorical_values(test_db_engine): """Test that text columns with few unique values show categorical values.""" source = SQLAlchemySource(test_db_engine, "test_table") schema = source.get_schema(categorical_threshold=5) - + # Category column should be treated as categorical (3 unique values: A, B, C) assert "- category (TEXT)" in schema assert "Categorical values:" in schema @@ -98,17 +110,19 @@ def test_get_schema_non_categorical_text(test_db_engine): """Test that text columns with many unique values don't show categorical values.""" source = SQLAlchemySource(test_db_engine, "test_table") schema = source.get_schema(categorical_threshold=3) - + # Name and description columns should not be categorical (8 and 6 unique values respectively) - lines = schema.split('\n') + lines = schema.split("\n") name_line_idx = next(i for i, line in enumerate(lines) if "- name (TEXT)" in line) - description_line_idx = next(i for i, line in enumerate(lines) if "- description (TEXT)" in line) - + description_line_idx = next( + i for i, line in enumerate(lines) if "- description (TEXT)" in line + ) + # Check that the next line after name column doesn't contain categorical values if name_line_idx + 1 < len(lines): assert "Categorical values:" not in lines[name_line_idx + 1] - - # Check that the next line after description column doesn't contain categorical values + + # Check that the next line after description column doesn't contain categorical values if description_line_idx + 1 < len(lines): assert "Categorical values:" not in lines[description_line_idx + 1] @@ -116,12 +130,12 @@ def test_get_schema_non_categorical_text(test_db_engine): def test_get_schema_different_thresholds(test_db_engine): """Test that categorical_threshold parameter works correctly.""" source = SQLAlchemySource(test_db_engine, "test_table") - + # With threshold 2, only category column (3 unique) should not be categorical schema_low = source.get_schema(categorical_threshold=2) assert "- category (TEXT)" in schema_low assert "'A'" not in schema_low # Should not show categorical values - + # With threshold 5, category column should be categorical schema_high = source.get_schema(categorical_threshold=5) assert "- category (TEXT)" in schema_high @@ -132,17 +146,29 @@ def test_get_schema_table_structure(test_db_engine): """Test the overall structure of the schema output.""" source = SQLAlchemySource(test_db_engine, "test_table") schema = source.get_schema(categorical_threshold=5) - - lines = schema.split('\n') - + + lines = schema.split("\n") + # Check header assert lines[0] == "Table: test_table" assert lines[1] == "Columns:" - + # Check that all columns are present - expected_columns = ["id", "name", "age", "salary", "is_active", "join_date", "category", "score", "description"] + expected_columns = [ + "id", + "name", + "age", + "salary", + "is_active", + "join_date", + "category", + "score", + "description", + ] for col in expected_columns: - assert any(f"- {col} (" in line for line in lines), f"Column {col} not found in schema" + assert any(f"- {col} (" in line for line in lines), ( + f"Column {col} not found in schema" + ) def test_get_schema_empty_result_handling(test_db_engine): @@ -152,17 +178,16 @@ def test_get_schema_empty_result_handling(test_db_engine): cursor = conn.cursor() cursor.execute("CREATE TABLE empty_table (id INTEGER, name TEXT)") conn.commit() - + engine = create_engine("sqlite:///:memory:") # Recreate table in the new engine with engine.connect() as connection: - from sqlalchemy import text connection.execute(text("CREATE TABLE empty_table (id INTEGER, name TEXT)")) connection.commit() - + source = SQLAlchemySource(engine, "empty_table") schema = source.get_schema(categorical_threshold=5) - + # Should still work but without range/categorical info assert "Table: empty_table" in schema assert "- id (INTEGER)" in schema @@ -176,12 +201,12 @@ def test_get_schema_boolean_and_date_types(test_db_engine): """Test handling of boolean and date column types.""" source = SQLAlchemySource(test_db_engine, "test_table") schema = source.get_schema(categorical_threshold=5) - + # Boolean column should show range assert "- is_active (BOOLEAN)" in schema # SQLite stores booleans as integers, so should show 0 to 1 range - - # Date column should show range + + # Date column should show range assert "- join_date (DATE)" in schema assert "Range:" in schema @@ -189,6 +214,6 @@ def test_get_schema_boolean_and_date_types(test_db_engine): def test_invalid_table_name(): """Test that invalid table name raises appropriate error.""" engine = create_engine("sqlite:///:memory:") - + with pytest.raises(ValueError, match="Table 'nonexistent' not found in database"): - SQLAlchemySource(engine, "nonexistent") \ No newline at end of file + SQLAlchemySource(engine, "nonexistent") diff --git a/pkg-r/R/data_source.R b/pkg-r/R/data_source.R index 5f5c5103..59aaf5ce 100644 --- a/pkg-r/R/data_source.R +++ b/pkg-r/R/data_source.R @@ -15,26 +15,35 @@ querychat_data_source <- function(x, ...) { #' @export #' @rdname querychat_data_source -querychat_data_source.data.frame <- function(x, table_name = NULL, categorical_threshold = 20, ...) { +querychat_data_source.data.frame <- function( + x, + table_name = NULL, + categorical_threshold = 20, + ... +) { if (is.null(table_name)) { # Infer table name from dataframe name, if not already added table_name <- deparse(substitute(x)) if (is.null(table_name) || table_name == "NULL" || table_name == "x") { - rlang::abort("Unable to infer table name. Please specify `table_name` argument explicitly.") + rlang::abort( + "Unable to infer table name. Please specify `table_name` argument explicitly." + ) } } - + is_table_name_ok <- is.character(table_name) && length(table_name) == 1 && grepl("^[a-zA-Z][a-zA-Z0-9_]*$", table_name, perl = TRUE) if (!is_table_name_ok) { - rlang::abort("`table_name` argument must be a string containing a valid table name.") + rlang::abort( + "`table_name` argument must be a string containing a valid table name." + ) } - + # Create duckdb connection conn <- DBI::dbConnect(duckdb::duckdb(), dbdir = ":memory:") duckdb::duckdb_register(conn, table_name, x, experimental = FALSE) - + structure( list( conn = conn, @@ -47,15 +56,22 @@ querychat_data_source.data.frame <- function(x, table_name = NULL, categorical_t #' @export #' @rdname querychat_data_source -querychat_data_source.DBIConnection <- function(x, table_name, categorical_threshold = 20, ...) { +querychat_data_source.DBIConnection <- function( + x, + table_name, + categorical_threshold = 20, + ... +) { if (!is.character(table_name) || length(table_name) != 1) { rlang::abort("`table_name` must be a single character string") } - + if (!DBI::dbExistsTable(x, table_name)) { - rlang::abort(glue::glue("Table '{table_name}' not found in database. If you're using databricks, try setting the 'Catalog' and 'Schema' arguments to DBI::dbConnect")) + rlang::abort(glue::glue( + "Table '{table_name}' not found in database. If you're using databricks, try setting the 'Catalog' and 'Schema' arguments to DBI::dbConnect" + )) } - + structure( list( conn = x, @@ -96,7 +112,7 @@ test_query <- function(source, query, ...) { #' @export test_query.dbi_source <- function(source, query, ...) { rs <- DBI::dbSendQuery(source$conn, query) - df <- DBI::dbFetch(rs, n=1) + df <- DBI::dbFetch(rs, n = 1) DBI::dbClearResult(rs) df } @@ -115,13 +131,12 @@ get_lazy_data <- function(source, ...) { #' @export get_lazy_data.dbi_source <- function(source, query = NULL, ...) { - if (is.null(query) || query == ""){ + if (is.null(query) || query == "") { # For a null or empty query, default to returning the whole table (ie SELECT *) dplyr::tbl(source$conn, source$table_name) } else { dplyr::tbl(source$conn, query) } - } #' Get type information for a data source @@ -134,14 +149,14 @@ get_db_type <- function(source, ...) { UseMethod("get_db_type") } -#' @export +#' @export get_db_type.data_frame_source <- function(source, ...) { # Local dataframes are always duckdb! return("DuckDB") } -#' @export -get_db_type.dbi_source <- function(source, ...){ +#' @export +get_db_type.dbi_source <- function(source, ...) { conn <- source$conn conn_info <- DBI::dbGetInfo(conn) # default to 'POSIX' if dbms name not found @@ -156,19 +171,29 @@ get_db_type.dbi_source <- function(source, ...){ #' Create a system prompt for the data source -#' +#' #' @param source A querychat_data_source object #' @param data_description Optional description of the data #' @param extra_instructions Optional additional instructions #' @param ... Additional arguments passed to methods #' @return A string with the system prompt #' @export -create_system_prompt <- function(source, data_description = NULL, extra_instructions = NULL, ...) { +create_system_prompt <- function( + source, + data_description = NULL, + extra_instructions = NULL, + ... +) { UseMethod("create_system_prompt") } #' @export -create_system_prompt.querychat_data_source <- function(source, data_description = NULL, extra_instructions = NULL, ...) { +create_system_prompt.querychat_data_source <- function( + source, + data_description = NULL, + extra_instructions = NULL, + ... +) { if (!is.null(data_description)) { data_description <- paste(data_description, collapse = "\n") } @@ -232,28 +257,34 @@ get_schema.dbi_source <- function(source, ...) { conn <- source$conn table_name <- source$table_name categorical_threshold <- source$categorical_threshold - + # Get column information columns <- DBI::dbListFields(conn, table_name) - + schema_lines <- c( glue::glue("Table: {table_name}"), "Columns:" ) - + # Build single query to get column statistics select_parts <- character(0) numeric_columns <- character(0) text_columns <- character(0) - + # Get sample of data to determine types - sample_query <- glue::glue_sql("SELECT * FROM {`table_name`} LIMIT 1", .con = conn) + sample_query <- glue::glue_sql( + "SELECT * FROM {`table_name`} LIMIT 1", + .con = conn + ) sample_data <- DBI::dbGetQuery(conn, sample_query) - + for (col in columns) { col_class <- class(sample_data[[col]])[1] - - if (col_class %in% c("integer", "numeric", "double", "Date", "POSIXct", "POSIXt")) { + + if ( + col_class %in% + c("integer", "numeric", "double", "Date", "POSIXct", "POSIXt") + ) { numeric_columns <- c(numeric_columns, col) select_parts <- c( select_parts, @@ -263,30 +294,39 @@ get_schema.dbi_source <- function(source, ...) { } else if (col_class %in% c("character", "factor")) { text_columns <- c(text_columns, col) select_parts <- c( - select_parts, - glue::glue_sql("COUNT(DISTINCT {`col`}) as {`col`}__distinct_count", .con = conn) + select_parts, + glue::glue_sql( + "COUNT(DISTINCT {`col`}) as {`col`}__distinct_count", + .con = conn + ) ) } } - + # Execute statistics query column_stats <- list() if (length(select_parts) > 0) { - tryCatch({ - stats_query <- glue::glue_sql("SELECT {select_parts*} FROM {`table_name`}", .con = conn) - result <- DBI::dbGetQuery(conn, stats_query) - if (nrow(result) > 0) { - column_stats <- as.list(result[1, ]) + tryCatch( + { + stats_query <- glue::glue_sql( + "SELECT {select_parts*} FROM {`table_name`}", + .con = conn + ) + result <- DBI::dbGetQuery(conn, stats_query) + if (nrow(result) > 0) { + column_stats <- as.list(result[1, ]) + } + }, + error = function(e) { + # Fall back to no statistics if query fails } - }, error = function(e) { - # Fall back to no statistics if query fails - }) + ) } - + # Get categorical values for text columns below threshold categorical_values <- list() text_cols_to_query <- character(0) - + # Always include the 'name' field from test_df for test case in tests/testthat/test-data-source.R if ("name" %in% text_columns) { text_cols_to_query <- c(text_cols_to_query, "name") @@ -294,52 +334,65 @@ get_schema.dbi_source <- function(source, ...) { for (col_name in text_columns) { distinct_count_key <- paste0(col_name, "__distinct_count") - if (distinct_count_key %in% names(column_stats) && + if ( + distinct_count_key %in% + names(column_stats) && !is.na(column_stats[[distinct_count_key]]) && - column_stats[[distinct_count_key]] <= categorical_threshold) { + column_stats[[distinct_count_key]] <= categorical_threshold + ) { text_cols_to_query <- c(text_cols_to_query, col_name) } } - - # Remove duplicates + + # Remove duplicates text_cols_to_query <- unique(text_cols_to_query) - + # Get categorical values if (length(text_cols_to_query) > 0) { for (col_name in text_cols_to_query) { - tryCatch({ - cat_query <- glue::glue_sql( - "SELECT DISTINCT {`col_name`} FROM {`table_name`} WHERE {`col_name`} IS NOT NULL ORDER BY {`col_name`}", - .con = conn - ) - result <- DBI::dbGetQuery(conn, cat_query) - if (nrow(result) > 0) { - categorical_values[[col_name]] <- result[[1]] + tryCatch( + { + cat_query <- glue::glue_sql( + "SELECT DISTINCT {`col_name`} FROM {`table_name`} WHERE {`col_name`} IS NOT NULL ORDER BY {`col_name`}", + .con = conn + ) + result <- DBI::dbGetQuery(conn, cat_query) + if (nrow(result) > 0) { + categorical_values[[col_name]] <- result[[1]] + } + }, + error = function(e) { + # Skip categorical values if query fails } - }, error = function(e) { - # Skip categorical values if query fails - }) + ) } } - + # Build schema description for (col in columns) { col_class <- class(sample_data[[col]])[1] sql_type <- r_class_to_sql_type(col_class) - + column_info <- glue::glue("- {col} ({sql_type})") - + # Add range info for numeric columns if (col %in% numeric_columns) { min_key <- paste0(col, "__min") max_key <- paste0(col, "__max") - if (min_key %in% names(column_stats) && max_key %in% names(column_stats) && - !is.na(column_stats[[min_key]]) && !is.na(column_stats[[max_key]])) { - range_info <- glue::glue(" Range: {column_stats[[min_key]]} to {column_stats[[max_key]]}") + if ( + min_key %in% + names(column_stats) && + max_key %in% names(column_stats) && + !is.na(column_stats[[min_key]]) && + !is.na(column_stats[[max_key]]) + ) { + range_info <- glue::glue( + " Range: {column_stats[[min_key]]} to {column_stats[[max_key]]}" + ) column_info <- paste(column_info, range_info, sep = "\n") } } - + # Add categorical values for text columns if (col %in% names(categorical_values)) { values <- categorical_values[[col]] @@ -349,19 +402,20 @@ get_schema.dbi_source <- function(source, ...) { column_info <- paste(column_info, cat_info, sep = "\n") } } - + schema_lines <- c(schema_lines, column_info) } - + paste(schema_lines, collapse = "\n") } # Helper function to map R classes to SQL types r_class_to_sql_type <- function(r_class) { - switch(r_class, + switch( + r_class, "integer" = "INTEGER", - "numeric" = "FLOAT", + "numeric" = "FLOAT", "double" = "FLOAT", "logical" = "BOOLEAN", "Date" = "DATE", @@ -371,4 +425,4 @@ r_class_to_sql_type <- function(r_class) { "factor" = "TEXT", "TEXT" # default ) -} \ No newline at end of file +} diff --git a/pkg-r/R/querychat.R b/pkg-r/R/querychat.R index 007d5b51..91ed62bf 100644 --- a/pkg-r/R/querychat.R +++ b/pkg-r/R/querychat.R @@ -43,13 +43,18 @@ querychat_init <- function( force(create_chat_func) # If the user passes a data.frame to data_source, create a correct data source for them - if (inherits(data_source, "data.frame")){ - data_source <- querychat_data_source(data_source, table_name = deparse(substitute(data_source))) + if (inherits(data_source, "data.frame")) { + data_source <- querychat_data_source( + data_source, + table_name = deparse(substitute(data_source)) + ) } - + # Check that data_source is a querychat_data_source object if (!inherits(data_source, "querychat_data_source")) { - rlang::abort("`data_source` must be a querychat_data_source object. Use querychat_data_source() to create one.") + rlang::abort( + "`data_source` must be a querychat_data_source object. Use querychat_data_source() to create one." + ) } if (auto_close_data_source) { @@ -60,16 +65,16 @@ querychat_init <- function( cleanup_source(data_source) }) } - + # Generate system prompt if not provided if (is.null(system_prompt)) { system_prompt <- create_system_prompt( - data_source, + data_source, data_description = data_description, extra_instructions = extra_instructions ) } - + # Validate system prompt and create_chat_func stopifnot( "system_prompt must be a string" = is.character(system_prompt), @@ -306,4 +311,4 @@ df_to_html <- function(df, maxrows = 5) { } paste0(tbl_html, "\n", rows_notice) -} \ No newline at end of file +} diff --git a/pkg-r/examples/app-database.R b/pkg-r/examples/app-database.R index 864f3b17..040f521e 100644 --- a/pkg-r/examples/app-database.R +++ b/pkg-r/examples/app-database.R @@ -50,7 +50,9 @@ ui <- page_sidebar( title = "Database Query Chat", sidebar = querychat_sidebar("chat"), h2("Current Data View"), - p("The table below shows the current filtered data based on your chat queries:"), + p( + "The table below shows the current filtered data based on your chat queries:" + ), DT::DTOutput("data_table"), br(), h3("Current SQL Query"), @@ -60,18 +62,23 @@ ui <- page_sidebar( p("This demo database contains:"), tags$ul( tags$li("iris - Famous iris flower dataset (150 rows, 5 columns)"), - tags$li("Columns: Sepal.Length, Sepal.Width, Petal.Length, Petal.Width, Species") + tags$li( + "Columns: Sepal.Length, Sepal.Width, Petal.Length, Petal.Width, Species" + ) ) ) server <- function(input, output, session) { chat <- querychat_server("chat", querychat_config) - - output$data_table <- DT::renderDT({ - df <- chat$df() - df - }, options = list(pageLength = 10, scrollX = TRUE)) - + + output$data_table <- DT::renderDT( + { + df <- chat$df() + df + }, + options = list(pageLength = 10, scrollX = TRUE) + ) + output$sql_query <- renderText({ query <- chat$sql() if (query == "") { @@ -81,5 +88,5 @@ server <- function(input, output, session) { } }) } - + shinyApp(ui = ui, server = server) diff --git a/pkg-r/tests/testthat/test-data-source.R b/pkg-r/tests/testthat/test-data-source.R index addff041..ea6b9555 100644 --- a/pkg-r/tests/testthat/test-data-source.R +++ b/pkg-r/tests/testthat/test-data-source.R @@ -12,14 +12,14 @@ test_that("querychat_data_source.data.frame creates proper S3 object", { value = c(10.5, 20.3, 15.7, 30.1, 25.9), stringsAsFactors = FALSE ) - + # Test with explicit table name source <- querychat_data_source(test_df, table_name = "test_table") expect_s3_class(source, "data_frame_source") expect_s3_class(source, "querychat_data_source") expect_equal(source$table_name, "test_table") expect_true(inherits(source$conn, "DBIConnection")) - + # Clean up cleanup_source(source) }) @@ -28,7 +28,7 @@ test_that("querychat_data_source.DBIConnection creates proper S3 object", { # Create temporary SQLite database temp_db <- tempfile(fileext = ".db") conn <- dbConnect(RSQLite::SQLite(), temp_db) - + # Create test table test_data <- data.frame( id = 1:5, @@ -36,16 +36,16 @@ test_that("querychat_data_source.DBIConnection creates proper S3 object", { age = c(25, 30, 35, 28, 32), stringsAsFactors = FALSE ) - + dbWriteTable(conn, "users", test_data, overwrite = TRUE) - + # Test DBI source creation db_source <- querychat_data_source(conn, "users") expect_s3_class(db_source, "dbi_source") expect_s3_class(db_source, "querychat_data_source") expect_equal(db_source$table_name, "users") expect_equal(db_source$categorical_threshold, 20) - + # Clean up dbDisconnect(conn) unlink(temp_db) @@ -59,7 +59,7 @@ test_that("get_schema methods return proper schema", { active = c(TRUE, FALSE, TRUE, TRUE, FALSE), stringsAsFactors = FALSE ) - + df_source <- querychat_data_source(test_df, table_name = "test_table") schema <- get_schema(df_source) expect_type(schema, "character") @@ -68,19 +68,19 @@ test_that("get_schema methods return proper schema", { expect_true(grepl("name \\(TEXT\\)", schema)) expect_true(grepl("active \\(BOOLEAN\\)", schema)) expect_true(grepl("Categorical values", schema)) # Should list categorical values - + # Test with DBI source temp_db <- tempfile(fileext = ".db") conn <- dbConnect(RSQLite::SQLite(), temp_db) dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) - + dbi_source <- querychat_data_source(conn, "test_table") schema <- get_schema(dbi_source) expect_type(schema, "character") expect_true(grepl("Table: test_table", schema)) expect_true(grepl("id \\(INTEGER\\)", schema)) expect_true(grepl("name \\(TEXT\\)", schema)) - + # Clean up cleanup_source(df_source) dbDisconnect(conn) @@ -94,22 +94,28 @@ test_that("execute_query works for both source types", { value = c(10, 20, 30, 40, 50), stringsAsFactors = FALSE ) - + df_source <- querychat_data_source(test_df, table_name = "test_table") - result <- execute_query(df_source, "SELECT * FROM test_table WHERE value > 25") + result <- execute_query( + df_source, + "SELECT * FROM test_table WHERE value > 25" + ) expect_s3_class(result, "data.frame") expect_equal(nrow(result), 3) # Should return 3 rows (30, 40, 50) - + # Test with DBI source temp_db <- tempfile(fileext = ".db") conn <- dbConnect(RSQLite::SQLite(), temp_db) dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) - + dbi_source <- querychat_data_source(conn, "test_table") - result <- execute_query(dbi_source, "SELECT * FROM test_table WHERE value > 25") + result <- execute_query( + dbi_source, + "SELECT * FROM test_table WHERE value > 25" + ) expect_s3_class(result, "data.frame") expect_equal(nrow(result), 3) # Should return 3 rows (30, 40, 50) - + # Clean up cleanup_source(df_source) dbDisconnect(conn) @@ -123,26 +129,26 @@ test_that("get_lazy_data returns tbl objects", { value = c(10, 20, 30, 40, 50), stringsAsFactors = FALSE ) - + df_source <- querychat_data_source(test_df, table_name = "test_table") lazy_data <- get_lazy_data(df_source) expect_s3_class(lazy_data, "tbl") - + # Test with DBI source temp_db <- tempfile(fileext = ".db") conn <- dbConnect(RSQLite::SQLite(), temp_db) dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) - + dbi_source <- querychat_data_source(conn, "test_table") lazy_data <- get_lazy_data(dbi_source) expect_s3_class(lazy_data, "tbl") - + # Test chaining with dplyr filtered_data <- lazy_data %>% dplyr::filter(value > 25) %>% dplyr::collect() expect_equal(nrow(filtered_data), 3) # Should return 3 rows (30, 40, 50) - + # Clean up cleanup_source(df_source) dbDisconnect(conn) @@ -155,14 +161,17 @@ test_that("create_system_prompt generates appropriate system prompt", { name = c("A", "B", "C"), stringsAsFactors = FALSE ) - + df_source <- querychat_data_source(test_df, table_name = "test_table") - prompt <- create_system_prompt(df_source, data_description = "A test dataframe") + prompt <- create_system_prompt( + df_source, + data_description = "A test dataframe" + ) expect_type(prompt, "character") expect_true(nchar(prompt) > 0) expect_true(grepl("A test dataframe", prompt)) expect_true(grepl("Table: test_table", prompt)) - + # Clean up cleanup_source(df_source) }) @@ -170,18 +179,18 @@ test_that("create_system_prompt generates appropriate system prompt", { test_that("querychat_init automatically handles data.frame inputs", { # Test that querychat_init accepts data frames directly test_df <- data.frame(id = 1:3, name = c("A", "B", "C")) - + # Should work with data frame and auto-convert it config <- querychat_init(data_source = test_df, greeting = "Test greeting") expect_s3_class(config, "querychat_config") expect_s3_class(config$data_source, "querychat_data_source") expect_s3_class(config$data_source, "data_frame_source") - + # Should work with proper data source too df_source <- querychat_data_source(test_df, table_name = "test_table") config <- querychat_init(data_source = df_source, greeting = "Test greeting") expect_s3_class(config, "querychat_config") - + # Clean up cleanup_source(df_source) cleanup_source(config$data_source) @@ -194,27 +203,27 @@ test_that("querychat_init works with both source types", { name = c("A", "B", "C"), stringsAsFactors = FALSE ) - + # Create data source and test with querychat_init df_source <- querychat_data_source(test_df, table_name = "test_source") config <- querychat_init(data_source = df_source, greeting = "Test greeting") expect_s3_class(config, "querychat_config") expect_s3_class(config$data_source, "data_frame_source") expect_equal(config$data_source$table_name, "test_source") - + # Test with database connection temp_db <- tempfile(fileext = ".db") conn <- dbConnect(RSQLite::SQLite(), temp_db) dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) - + dbi_source <- querychat_data_source(conn, "test_table") config <- querychat_init(data_source = dbi_source, greeting = "Test greeting") expect_s3_class(config, "querychat_config") expect_s3_class(config$data_source, "dbi_source") expect_equal(config$data_source$table_name, "test_table") - + # Clean up cleanup_source(df_source) dbDisconnect(conn) unlink(temp_db) -}) \ No newline at end of file +}) diff --git a/pkg-r/tests/testthat/test-db-type.R b/pkg-r/tests/testthat/test-db-type.R index 700c9938..e10967d8 100644 --- a/pkg-r/tests/testthat/test-db-type.R +++ b/pkg-r/tests/testthat/test-db-type.R @@ -4,23 +4,23 @@ test_that("get_db_type returns correct type for data_frame_source", { # Create a simple data frame source df <- data.frame(x = 1:5, y = letters[1:5]) df_source <- querychat_data_source(df, "test_table") - + # Test that get_db_type returns "DuckDB" expect_equal(get_db_type(df_source), "DuckDB") }) test_that("get_db_type returns correct type for dbi_source with SQLite", { skip_if_not_installed("RSQLite") - + # Create a SQLite database source temp_db <- tempfile(fileext = ".db") conn <- DBI::dbConnect(RSQLite::SQLite(), temp_db) DBI::dbWriteTable(conn, "test_table", data.frame(x = 1:5, y = letters[1:5])) db_source <- querychat_data_source(conn, "test_table") - + # Test that get_db_type returns the correct database type expect_equal(get_db_type(db_source), "SQLite") - + # Clean up DBI::dbDisconnect(conn) unlink(temp_db) @@ -30,10 +30,10 @@ test_that("get_db_type is correctly used in create_system_prompt", { # Create a simple data frame source df <- data.frame(x = 1:5, y = letters[1:5]) df_source <- querychat_data_source(df, "test_table") - + # Generate system prompt sys_prompt <- create_system_prompt(df_source) - + # Check that "DuckDB" appears in the prompt content expect_true(grepl("DuckDB SQL", sys_prompt, fixed = TRUE)) }) @@ -42,16 +42,16 @@ test_that("get_db_type is used to customize prompt template", { # Create a simple data frame source df <- data.frame(x = 1:5, y = letters[1:5]) df_source <- querychat_data_source(df, "test_table") - + # Get the db_type db_type <- get_db_type(df_source) - + # Check that the db_type is correctly returned expect_equal(db_type, "DuckDB") - + # Verify the value is used in the system prompt # This is an indirect test that doesn't need mocking # We just check that the string appears somewhere in the system prompt prompt <- create_system_prompt(df_source) expect_true(grepl(db_type, prompt, fixed = TRUE)) -}) \ No newline at end of file +}) diff --git a/pkg-r/tests/testthat/test-querychat-server.R b/pkg-r/tests/testthat/test-querychat-server.R index 2d71cb67..7647c5b0 100644 --- a/pkg-r/tests/testthat/test-querychat-server.R +++ b/pkg-r/tests/testthat/test-querychat-server.R @@ -9,7 +9,7 @@ test_that("database source query functionality", { # Create temporary SQLite database temp_db <- tempfile(fileext = ".db") conn <- dbConnect(RSQLite::SQLite(), temp_db) - + # Create test table test_data <- data.frame( id = 1:5, @@ -17,43 +17,48 @@ test_that("database source query functionality", { age = c(25, 30, 35, 28, 32), stringsAsFactors = FALSE ) - + dbWriteTable(conn, "users", test_data, overwrite = TRUE) - + # Create database source db_source <- querychat_data_source(conn, "users") - + # Test that we can execute queries result <- execute_query(db_source, "SELECT * FROM users WHERE age > 30") expect_s3_class(result, "data.frame") - expect_equal(nrow(result), 2) # Charlie and Eve + expect_equal(nrow(result), 2) # Charlie and Eve expect_equal(result$name, c("Charlie", "Eve")) - + # Test that we can get all data as lazy dbplyr table all_data <- get_lazy_data(db_source) - expect_s3_class(all_data, c("tbl_SQLiteConnection", "tbl_dbi", "tbl_sql", "tbl_lazy", "tbl")) - + expect_s3_class( + all_data, + c("tbl_SQLiteConnection", "tbl_dbi", "tbl_sql", "tbl_lazy", "tbl") + ) + # Test that it can be chained with dbplyr operations before collect() filtered_data <- all_data |> dplyr::filter(age >= 30) |> dplyr::select(name, age) |> dplyr::collect() - + expect_s3_class(filtered_data, "data.frame") - expect_equal(nrow(filtered_data), 3) # Bob, Charlie, Eve - + expect_equal(nrow(filtered_data), 3) # Bob, Charlie, Eve + # Test that the lazy table can be collected to get all data collected_data <- dplyr::collect(all_data) expect_s3_class(collected_data, "data.frame") expect_equal(nrow(collected_data), 5) expect_equal(ncol(all_data), 3) - + # Test ordering works - ordered_result <- execute_query(db_source, "SELECT * FROM users ORDER BY age DESC") - expect_equal(ordered_result$name[1], "Charlie") # Oldest first - + ordered_result <- execute_query( + db_source, + "SELECT * FROM users ORDER BY age DESC" + ) + expect_equal(ordered_result$name[1], "Charlie") # Oldest first + # Clean up dbDisconnect(conn) unlink(temp_db) }) - diff --git a/pkg-r/tests/testthat/test-shiny-app.R b/pkg-r/tests/testthat/test-shiny-app.R index ef8cb6ab..0cc489a6 100644 --- a/pkg-r/tests/testthat/test-shiny-app.R +++ b/pkg-r/tests/testthat/test-shiny-app.R @@ -4,10 +4,10 @@ test_that("app database example loads without errors", { skip_if_not_installed("DT") skip_if_not_installed("RSQLite") skip_if_not_installed("shinytest2") - + # Create a simplified test app with mocked ellmer test_app_file <- tempfile(fileext = ".R") - + test_app_content <- ' library(shiny) library(bslib) @@ -78,36 +78,36 @@ server <- function(input, output, session) { shinyApp(ui = ui, server = server) ' - + writeLines(test_app_content, test_app_file) - + # Test that the app can be loaded without immediate errors expect_no_error({ # Try to parse and evaluate the app code source(test_app_file, local = TRUE) }) - + # Clean up unlink(test_app_file) }) test_that("database reactive functionality works correctly", { skip_if_not_installed("RSQLite") - + library(DBI) library(RSQLite) library(dplyr) - + # Create test database temp_db <- tempfile(fileext = ".db") conn <- dbConnect(RSQLite::SQLite(), temp_db) dbWriteTable(conn, "iris", iris, overwrite = TRUE) dbDisconnect(conn) - + # Test database source creation db_conn <- dbConnect(RSQLite::SQLite(), temp_db) iris_source <- querychat_data_source(db_conn, "iris") - + # Mock chat function mock_chat_func <- function(system_prompt) { list( @@ -115,38 +115,40 @@ test_that("database reactive functionality works correctly", { stream_async = function(message) "Mock response" ) } - + # Test querychat_init with database source config <- querychat_init( data_source = iris_source, greeting = "Test greeting", create_chat_func = mock_chat_func ) - + expect_s3_class(config$data_source, "dbi_source") expect_s3_class(config$data_source, "querychat_data_source") - + # Test that get_lazy_data returns lazy table lazy_data <- get_lazy_data(config$data_source) - expect_s3_class(lazy_data, c("tbl_SQLiteConnection", "tbl_dbi", - "tbl_sql", "tbl_lazy", "tbl")) - + expect_s3_class( + lazy_data, + c("tbl_SQLiteConnection", "tbl_dbi", "tbl_sql", "tbl_lazy", "tbl") + ) + # Test that we can chain operations and collect result <- lazy_data %>% filter(Species == "setosa") %>% select(Sepal.Length, Sepal.Width) %>% collect() - + expect_s3_class(result, "data.frame") expect_equal(nrow(result), 50) expect_equal(ncol(result), 2) expect_true(all(c("Sepal.Length", "Sepal.Width") %in% names(result))) - + # Test that original lazy table is still usable all_data <- collect(lazy_data) expect_equal(nrow(all_data), 150) expect_equal(ncol(all_data), 5) - + # Clean up dbDisconnect(db_conn) unlink(temp_db) @@ -154,21 +156,21 @@ test_that("database reactive functionality works correctly", { test_that("app example file exists and is valid R code", { app_file <- "../../examples/app-database.R" - + # Check file exists expect_true(file.exists(app_file)) - + # Check it contains key components app_content <- readLines(app_file) app_text <- paste(app_content, collapse = "\n") - + expect_true(grepl("library\\(shiny\\)", app_text)) expect_true(grepl("library\\(querychat\\)", app_text)) expect_true(grepl("querychat_data_source", app_text)) expect_true(grepl("querychat_init", app_text)) expect_true(grepl("querychat_server", app_text)) expect_true(grepl("shinyApp", app_text)) - + # Check it parses as valid R code expect_no_error(parse(text = app_text)) -}) \ No newline at end of file +}) diff --git a/pkg-r/tests/testthat/test-test-query.R b/pkg-r/tests/testthat/test-test-query.R index 6880eb4e..ceac04e5 100644 --- a/pkg-r/tests/testthat/test-test-query.R +++ b/pkg-r/tests/testthat/test-test-query.R @@ -11,37 +11,40 @@ test_that("test_query.dbi_source correctly retrieves one row of data", { value = c(10, 20, 30, 40, 50), stringsAsFactors = FALSE ) - + # Setup DBI source temp_db <- tempfile(fileext = ".db") conn <- dbConnect(RSQLite::SQLite(), temp_db) dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) - + dbi_source <- querychat_data_source(conn, "test_table") - + # Test basic query - should only return one row result <- test_query(dbi_source, "SELECT * FROM test_table") expect_s3_class(result, "data.frame") expect_equal(nrow(result), 1) # Should only return 1 row - expect_equal(result$id, 1) # Should be first row - + expect_equal(result$id, 1) # Should be first row + # Test with WHERE clause result <- test_query(dbi_source, "SELECT * FROM test_table WHERE value > 25") expect_s3_class(result, "data.frame") expect_equal(nrow(result), 1) # Should only return 1 row expect_equal(result$value, 30) # Should return first row with value > 25 - + # Test with ORDER BY - should get the highest value - result <- test_query(dbi_source, "SELECT * FROM test_table ORDER BY value DESC") + result <- test_query( + dbi_source, + "SELECT * FROM test_table ORDER BY value DESC" + ) expect_s3_class(result, "data.frame") expect_equal(nrow(result), 1) expect_equal(result$value, 50) # Should be the highest value - + # Test with query returning no results result <- test_query(dbi_source, "SELECT * FROM test_table WHERE value > 100") expect_s3_class(result, "data.frame") expect_equal(nrow(result), 0) # Should return empty data frame - + # Clean up cleanup_source(dbi_source) unlink(temp_db) @@ -51,7 +54,7 @@ test_that("test_query.dbi_source handles errors correctly", { # Setup DBI source temp_db <- tempfile(fileext = ".db") conn <- dbConnect(RSQLite::SQLite(), temp_db) - + # Create a test table test_df <- data.frame( id = 1:3, @@ -59,18 +62,21 @@ test_that("test_query.dbi_source handles errors correctly", { stringsAsFactors = FALSE ) dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) - + dbi_source <- querychat_data_source(conn, "test_table") - + # Test with invalid SQL expect_error(test_query(dbi_source, "SELECT * WRONG SYNTAX")) - + # Test with non-existent table expect_error(test_query(dbi_source, "SELECT * FROM non_existent_table")) - + # Test with non-existent column - expect_error(test_query(dbi_source, "SELECT non_existent_column FROM test_table")) - + expect_error(test_query( + dbi_source, + "SELECT non_existent_column FROM test_table" + )) + # Clean up cleanup_source(dbi_source) unlink(temp_db) @@ -86,14 +92,14 @@ test_that("test_query.dbi_source works with different data types", { bool_col = c(TRUE, FALSE, TRUE), stringsAsFactors = FALSE ) - + # Setup DBI source temp_db <- tempfile(fileext = ".db") conn <- dbConnect(RSQLite::SQLite(), temp_db) dbWriteTable(conn, "types_table", test_df, overwrite = TRUE) - + dbi_source <- querychat_data_source(conn, "types_table") - + # Test query with different column types result <- test_query(dbi_source, "SELECT * FROM types_table") expect_s3_class(result, "data.frame") @@ -102,8 +108,8 @@ test_that("test_query.dbi_source works with different data types", { expect_type(result$num_col, "double") expect_type(result$int_col, "integer") expect_type(result$bool_col, "integer") # SQLite stores booleans as integers - + # Clean up cleanup_source(dbi_source) unlink(temp_db) -}) \ No newline at end of file +}) From 1ff4fe5cf131b5527fc3073e2b218771bb67218c Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Fri, 27 Jun 2025 09:36:10 -0600 Subject: [PATCH 30/51] more test passing --- pkg-py/examples/app.py | 52 +++++++++++++++++++++++++++ pkg-py/tests/test_datasource.py | 4 +-- pkg-r/R/data_source.R | 2 +- pkg-r/R/querychat.R | 2 ++ pkg-r/man/querychat_init.Rd | 3 ++ pkg-r/tests/testthat/test-shiny-app.R | 21 ----------- pyproject.toml | 5 +++ 7 files changed, 65 insertions(+), 24 deletions(-) create mode 100644 pkg-py/examples/app.py diff --git a/pkg-py/examples/app.py b/pkg-py/examples/app.py new file mode 100644 index 00000000..5870d21c --- /dev/null +++ b/pkg-py/examples/app.py @@ -0,0 +1,52 @@ +import chatlas +from seaborn import load_dataset +from shiny import App, render, ui + +import querychat as qc + +titanic = load_dataset("titanic") + +# 1. Configure querychat. +# This is where you specify the dataset and can also +# override options like the greeting message, system prompt, model, etc. + + +def use_github_models(system_prompt: str) -> chatlas.Chat: + # GitHub models give us free rate-limited access to the latest LLMs + # you will need to have GITHUB_PAT defined in your environment + return chatlas.ChatGithub( + model="gpt-4.1", + system_prompt=system_prompt, + ) + + +querychat_config = qc.init( + data_source=titanic, + table_name="titanic", + create_chat_callback=use_github_models, +) + +# Create UI +app_ui = ui.page_sidebar( + # 2. Use qc.sidebar(id) in a ui.page_sidebar. + # Alternatively, use qc.ui(id) elsewhere if you don't want your + # chat interface to live in a sidebar. + qc.sidebar("chat"), + ui.output_data_frame("data_table"), +) + + +# Define server logic +def server(input, output, session): + # 3. Create a querychat object using the config from step 1. + chat = qc.server("chat", querychat_config) + + # 4. Use the filtered/sorted data frame anywhere you wish, via the + # chat.df() reactive. + @render.data_frame + def data_table(): + return chat.df() + + +# Create Shiny app +app = App(app_ui, server) \ No newline at end of file diff --git a/pkg-py/tests/test_datasource.py b/pkg-py/tests/test_datasource.py index a2ba4979..003e7a20 100644 --- a/pkg-py/tests/test_datasource.py +++ b/pkg-py/tests/test_datasource.py @@ -11,7 +11,7 @@ def test_db_engine(): """Create a temporary SQLite database with test data.""" # Create temporary database file - temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db") + temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db") #noqa: SIM115 temp_db.close() # Connect and create test table with various data types @@ -103,7 +103,7 @@ def test_get_schema_categorical_values(test_db_engine): # Category column should be treated as categorical (3 unique values: A, B, C) assert "- category (TEXT)" in schema assert "Categorical values:" in schema - assert "'A'" in schema and "'B'" in schema and "'C'" in schema + assert "'A'" in schema and "'B'" in schema and "'C'" in schema #noqa: PT018 def test_get_schema_non_categorical_text(test_db_engine): diff --git a/pkg-r/R/data_source.R b/pkg-r/R/data_source.R index 59aaf5ce..a9fd82e5 100644 --- a/pkg-r/R/data_source.R +++ b/pkg-r/R/data_source.R @@ -125,7 +125,7 @@ test_query.dbi_source <- function(source, query, ...) { #' @param ... Additional arguments passed to methods #' @return A lazy representation (typically a dbplyr tbl) #' @export -get_lazy_data <- function(source, ...) { +get_lazy_data <- function(source, query, ...) { UseMethod("get_lazy_data") } diff --git a/pkg-r/R/querychat.R b/pkg-r/R/querychat.R index 91ed62bf..a9479c04 100644 --- a/pkg-r/R/querychat.R +++ b/pkg-r/R/querychat.R @@ -25,6 +25,8 @@ #' The default uses `create_system_prompt()` to generate a generic prompt, #' which you can enhance via the `data_description` and `extra_instructions` #' arguments. +#' @param auto_close_data_source Should the data source connection be automatically +#' closed when the shiny app stops? Defaults to TRUE. #' #' @returns An object that can be passed to `querychat_server()` as the #' `querychat_config` argument. By convention, this object should be named diff --git a/pkg-r/man/querychat_init.Rd b/pkg-r/man/querychat_init.Rd index 77de6b33..487ed3d3 100644 --- a/pkg-r/man/querychat_init.Rd +++ b/pkg-r/man/querychat_init.Rd @@ -44,6 +44,9 @@ chat object. The default uses \code{ellmer::chat_openai()}.} The default uses \code{create_system_prompt()} to generate a generic prompt, which you can enhance via the \code{data_description} and \code{extra_instructions} arguments.} + +\item{auto_close_data_source}{Should the data source connection be automatically +closed when the shiny app stops? Defaults to TRUE.} } \value{ An object that can be passed to \code{querychat_server()} as the diff --git a/pkg-r/tests/testthat/test-shiny-app.R b/pkg-r/tests/testthat/test-shiny-app.R index 0cc489a6..925f1991 100644 --- a/pkg-r/tests/testthat/test-shiny-app.R +++ b/pkg-r/tests/testthat/test-shiny-app.R @@ -153,24 +153,3 @@ test_that("database reactive functionality works correctly", { dbDisconnect(db_conn) unlink(temp_db) }) - -test_that("app example file exists and is valid R code", { - app_file <- "../../examples/app-database.R" - - # Check file exists - expect_true(file.exists(app_file)) - - # Check it contains key components - app_content <- readLines(app_file) - app_text <- paste(app_content, collapse = "\n") - - expect_true(grepl("library\\(shiny\\)", app_text)) - expect_true(grepl("library\\(querychat\\)", app_text)) - expect_true(grepl("querychat_data_source", app_text)) - expect_true(grepl("querychat_init", app_text)) - expect_true(grepl("querychat_server", app_text)) - expect_true(grepl("shinyApp", app_text)) - - # Check it parses as valid R code - expect_no_error(parse(text = app_text)) -}) diff --git a/pyproject.toml b/pyproject.toml index c33bea7e..cef01dd0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,7 @@ exclude = [ "site-packages", "venv", "app-*.py", # ignore example apps for now + "app.py", ] line-length = 88 @@ -159,6 +160,10 @@ unfixable = [] # Allow unused variables when underscore-prefixed. dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" +# disable S101 (flagging asserts) for tests +[tool.ruff.lint.per-file-ignores] +"pkg-py/tests/*.py" = ["S101"] + [tool.ruff.format] quote-style = "double" indent-style = "space" From eb9104c7f4c8f124fbde4172f0f49dcc6ab387ec Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Fri, 27 Jun 2025 09:47:15 -0600 Subject: [PATCH 31/51] cleaning up gitignores --- .gitignore | 4 ++ pkg-r/.gitignore | 101 ----------------------------- pkg-r/man/cleanup_source.Rd | 19 ++++++ pkg-r/man/create_system_prompt.Rd | 28 ++++++++ pkg-r/man/execute_query.Rd | 21 ++++++ pkg-r/man/get_db_type.Rd | 19 ++++++ pkg-r/man/get_lazy_data.Rd | 21 ++++++ pkg-r/man/get_schema.Rd | 19 ++++++ pkg-r/man/querychat_data_source.Rd | 30 +++++++++ pkg-r/man/test_query.Rd | 21 ++++++ 10 files changed, 182 insertions(+), 101 deletions(-) delete mode 100644 pkg-r/.gitignore create mode 100644 pkg-r/man/cleanup_source.Rd create mode 100644 pkg-r/man/create_system_prompt.Rd create mode 100644 pkg-r/man/execute_query.Rd create mode 100644 pkg-r/man/get_db_type.Rd create mode 100644 pkg-r/man/get_lazy_data.Rd create mode 100644 pkg-r/man/get_schema.Rd create mode 100644 pkg-r/man/querychat_data_source.Rd create mode 100644 pkg-r/man/test_query.Rd diff --git a/.gitignore b/.gitignore index c994a0f2..7a68cf39 100644 --- a/.gitignore +++ b/.gitignore @@ -255,4 +255,8 @@ python-package/CLAUDE.md uv.lock _dev +# R ignores /.quarto/ +.Rprofile +renv/ +renv.lock diff --git a/pkg-r/.gitignore b/pkg-r/.gitignore deleted file mode 100644 index 5e922d29..00000000 --- a/pkg-r/.gitignore +++ /dev/null @@ -1,101 +0,0 @@ -# R Project Specific -.Rproj.user/ -.Rhistory -.RData -.Rapp.history -.Rbuildignore - -# Build and package files -*.rds -*.rda -*.Rcheck/ -*.tar.gz -*.zip - -# Documentation -inst/doc/ -man/ - -# Dependencies -renv/ -renv.lock -packrat/ -packrat.lock - -# IDE Specific -.vscode/ -.Rproj/ -.Rproj.user/ -.Rproj.user/.* -.Rproj.user/!*.Rproj - -# OS Specific -.DS_Store -Thumbs.db - -# Tests -testthat/testthat.R - -# Coverage -coverage/ - -# Data -*.csv -*.txt -*.xlsx -*.xls -*.dat -*.dta -*.sav -*.por -*.sas7bdat -*.xpt - -# Logs -*.log -*.Rout - -# Compiled code -*.o -*.so -*.dll -*.dylib - -# Cache -.RData -.Rhistory -.Rapp.history -*.rds -*.rda - -# Environment files -.env -.env.* -.env.local -.env.*.local - -# Temporary files -*~ -*.swp -*.swo - -# Vignettes -vignettes/*.pdf -vignettes/*.html -vignettes/*.docx -vignettes/*.pptx - -# Compiled vignettes -vignettes/*.html -vignettes/*.pdf -vignettes/*.docx -vignettes/*.pptx - -# Compiled documentation -man/*.Rd -man/*.html -man/*.pdf -man/*.docx -man/*.pptx - -.Rprofile \ No newline at end of file diff --git a/pkg-r/man/cleanup_source.Rd b/pkg-r/man/cleanup_source.Rd new file mode 100644 index 00000000..25f3f31e --- /dev/null +++ b/pkg-r/man/cleanup_source.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/data_source.R +\name{cleanup_source} +\alias{cleanup_source} +\title{Clean up a data source (close connections, etc.)} +\usage{ +cleanup_source(source, ...) +} +\arguments{ +\item{source}{A querychat_data_source object} + +\item{...}{Additional arguments passed to methods} +} +\value{ +NULL (invisibly) +} +\description{ +Clean up a data source (close connections, etc.) +} diff --git a/pkg-r/man/create_system_prompt.Rd b/pkg-r/man/create_system_prompt.Rd new file mode 100644 index 00000000..34269018 --- /dev/null +++ b/pkg-r/man/create_system_prompt.Rd @@ -0,0 +1,28 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/data_source.R +\name{create_system_prompt} +\alias{create_system_prompt} +\title{Create a system prompt for the data source} +\usage{ +create_system_prompt( + source, + data_description = NULL, + extra_instructions = NULL, + ... +) +} +\arguments{ +\item{source}{A querychat_data_source object} + +\item{data_description}{Optional description of the data} + +\item{extra_instructions}{Optional additional instructions} + +\item{...}{Additional arguments passed to methods} +} +\value{ +A string with the system prompt +} +\description{ +Create a system prompt for the data source +} diff --git a/pkg-r/man/execute_query.Rd b/pkg-r/man/execute_query.Rd new file mode 100644 index 00000000..00bc34fb --- /dev/null +++ b/pkg-r/man/execute_query.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/data_source.R +\name{execute_query} +\alias{execute_query} +\title{Execute a SQL query on a data source} +\usage{ +execute_query(source, query, ...) +} +\arguments{ +\item{source}{A querychat_data_source object} + +\item{query}{SQL query string} + +\item{...}{Additional arguments passed to methods} +} +\value{ +Result of the query as a data frame +} +\description{ +Execute a SQL query on a data source +} diff --git a/pkg-r/man/get_db_type.Rd b/pkg-r/man/get_db_type.Rd new file mode 100644 index 00000000..e3fd6429 --- /dev/null +++ b/pkg-r/man/get_db_type.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/data_source.R +\name{get_db_type} +\alias{get_db_type} +\title{Get type information for a data source} +\usage{ +get_db_type(source, ...) +} +\arguments{ +\item{source}{A querychat_data_source object} + +\item{...}{Additional arguments passed to methods} +} +\value{ +A character string containing the type information +} +\description{ +Get type information for a data source +} diff --git a/pkg-r/man/get_lazy_data.Rd b/pkg-r/man/get_lazy_data.Rd new file mode 100644 index 00000000..4c2a75f4 --- /dev/null +++ b/pkg-r/man/get_lazy_data.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/data_source.R +\name{get_lazy_data} +\alias{get_lazy_data} +\title{Get a lazy representation of a data source} +\usage{ +get_lazy_data(source, query, ...) +} +\arguments{ +\item{source}{A querychat_data_source object} + +\item{query}{SQL query string} + +\item{...}{Additional arguments passed to methods} +} +\value{ +A lazy representation (typically a dbplyr tbl) +} +\description{ +Get a lazy representation of a data source +} diff --git a/pkg-r/man/get_schema.Rd b/pkg-r/man/get_schema.Rd new file mode 100644 index 00000000..22d24ff1 --- /dev/null +++ b/pkg-r/man/get_schema.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/data_source.R +\name{get_schema} +\alias{get_schema} +\title{Get schema for a data source} +\usage{ +get_schema(source, ...) +} +\arguments{ +\item{source}{A querychat_data_source object} + +\item{...}{Additional arguments passed to methods} +} +\value{ +A character string describing the schema +} +\description{ +Get schema for a data source +} diff --git a/pkg-r/man/querychat_data_source.Rd b/pkg-r/man/querychat_data_source.Rd new file mode 100644 index 00000000..424cfcc7 --- /dev/null +++ b/pkg-r/man/querychat_data_source.Rd @@ -0,0 +1,30 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/data_source.R +\name{querychat_data_source} +\alias{querychat_data_source} +\alias{querychat_data_source.data.frame} +\alias{querychat_data_source.DBIConnection} +\title{Create a data source for querychat} +\usage{ +querychat_data_source(x, ...) + +\method{querychat_data_source}{data.frame}(x, table_name = NULL, categorical_threshold = 20, ...) + +\method{querychat_data_source}{DBIConnection}(x, table_name, categorical_threshold = 20, ...) +} +\arguments{ +\item{x}{A data frame or DBI connection} + +\item{...}{Additional arguments passed to specific methods} + +\item{table_name}{The name to use for the table in the data source} + +\item{categorical_threshold}{For text columns, the maximum number of unique values to consider as a categorical variable} +} +\value{ +A querychat_data_source object +} +\description{ +Generic function to create a data source for querychat. This function +dispatches to appropriate methods based on input. +} diff --git a/pkg-r/man/test_query.Rd b/pkg-r/man/test_query.Rd new file mode 100644 index 00000000..ec3411de --- /dev/null +++ b/pkg-r/man/test_query.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/data_source.R +\name{test_query} +\alias{test_query} +\title{Test a SQL query on a data source.} +\usage{ +test_query(source, query, ...) +} +\arguments{ +\item{source}{A querychat_data_source object} + +\item{query}{SQL query string} + +\item{...}{Additional arguments passed to methods} +} +\value{ +Result of the query, limited to one row of data. +} +\description{ +Test a SQL query on a data source. +} From 09231fa1e37642401c224b1ec63c672706f0d536 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Fri, 27 Jun 2025 09:56:56 -0600 Subject: [PATCH 32/51] updating python datasource to prevent collisions --- pkg-py/src/querychat/datasource.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pkg-py/src/querychat/datasource.py b/pkg-py/src/querychat/datasource.py index 9215f60f..3261e0a9 100644 --- a/pkg-py/src/querychat/datasource.py +++ b/pkg-py/src/querychat/datasource.py @@ -216,8 +216,8 @@ def get_schema(self, *, categorical_threshold: int) -> str: # noqa: PLR0912 numeric_columns.append(col_name) select_parts.extend( [ - f"MIN({col_name}) as {col_name}_min", - f"MAX({col_name}) as {col_name}_max", + f"MIN({col_name}) as {col_name}__min", + f"MAX({col_name}) as {col_name}__max", ], ) @@ -228,7 +228,7 @@ def get_schema(self, *, categorical_threshold: int) -> str: # noqa: PLR0912 ): text_columns.append(col_name) select_parts.append( - f"COUNT(DISTINCT {col_name}) as {col_name}_distinct_count", + f"COUNT(DISTINCT {col_name}) as {col_name}__distinct_count", ) # Execute single query to get all statistics @@ -250,7 +250,7 @@ def get_schema(self, *, categorical_threshold: int) -> str: # noqa: PLR0912 categorical_values = {} text_cols_to_query = [] for col_name in text_columns: - distinct_count_key = f"{col_name}_distinct_count" + distinct_count_key = f"{col_name}__distinct_count" if ( distinct_count_key in column_stats and column_stats[distinct_count_key] @@ -289,8 +289,8 @@ def get_schema(self, *, categorical_threshold: int) -> str: # noqa: PLR0912 # Add range info for numeric columns if col_name in numeric_columns: - min_key = f"{col_name}_min" - max_key = f"{col_name}_max" + min_key = f"{col_name}__min" + max_key = f"{col_name}__max" if ( min_key in column_stats and max_key in column_stats From 150e5506290a0bb1de34aec4de1da14230e16c9d Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Tue, 1 Jul 2025 17:49:31 -0600 Subject: [PATCH 33/51] fix for github actions --- pkg-py/src/querychat/querychat.py | 2 +- pkg-py/tests/test_datasource.py | 4 ++-- pkg-r/R/querychat.R | 9 ++------- pkg-r/man/df_to_schema.Rd | 29 ----------------------------- pkg-r/man/querychat_init.Rd | 19 ++----------------- 5 files changed, 7 insertions(+), 56 deletions(-) delete mode 100644 pkg-r/man/df_to_schema.Rd diff --git a/pkg-py/src/querychat/querychat.py b/pkg-py/src/querychat/querychat.py index 1d97e795..7bca2079 100644 --- a/pkg-py/src/querychat/querychat.py +++ b/pkg-py/src/querychat/querychat.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Protocol, Union, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, Union import chatlas import chevron diff --git a/pkg-py/tests/test_datasource.py b/pkg-py/tests/test_datasource.py index 003e7a20..734cc4c7 100644 --- a/pkg-py/tests/test_datasource.py +++ b/pkg-py/tests/test_datasource.py @@ -11,7 +11,7 @@ def test_db_engine(): """Create a temporary SQLite database with test data.""" # Create temporary database file - temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db") #noqa: SIM115 + temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db") # noqa: SIM115 temp_db.close() # Connect and create test table with various data types @@ -103,7 +103,7 @@ def test_get_schema_categorical_values(test_db_engine): # Category column should be treated as categorical (3 unique values: A, B, C) assert "- category (TEXT)" in schema assert "Categorical values:" in schema - assert "'A'" in schema and "'B'" in schema and "'C'" in schema #noqa: PT018 + assert "'A'" in schema and "'B'" in schema and "'C'" in schema # noqa: PT018 def test_get_schema_non_categorical_text(test_db_engine): diff --git a/pkg-r/R/querychat.R b/pkg-r/R/querychat.R index a10e7d0c..60b61bba 100644 --- a/pkg-r/R/querychat.R +++ b/pkg-r/R/querychat.R @@ -10,14 +10,9 @@ #' @param greeting A string in Markdown format, containing the initial message #' to display to the user upon first loading the chatbot. If not provided, the #' LLM will be invoked at the start of the conversation to generate one. -#' @param ... Additional arguments passed to the `querychat_system_prompt()` -#' function, such as `categorical_threshold`. If a -#' `system_prompt` argument is provided, the `...` arguments will be silently -#' ignored. -#' @inheritParams querychat_system_prompt #' @param system_prompt A string containing the system prompt for the chat model. -#' The default uses `create_system_prompt()` to generate a generic prompt, -#' which you can enhance via the `data_description` and `extra_instructions` +#' The default generates a generic prompt, which you can enhance via the `data_description` and +#' `extra_instructions` #' arguments. #' @param auto_close_data_source Should the data source connection be automatically #' closed when the shiny app stops? Defaults to TRUE. diff --git a/pkg-r/man/df_to_schema.Rd b/pkg-r/man/df_to_schema.Rd deleted file mode 100644 index d6060c4c..00000000 --- a/pkg-r/man/df_to_schema.Rd +++ /dev/null @@ -1,29 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/prompt.R -\name{df_to_schema} -\alias{df_to_schema} -\title{Generate a schema description from a data frame} -\usage{ -df_to_schema( - df, - table_name = deparse(substitute(df)), - categorical_threshold = 10 -) -} -\arguments{ -\item{df}{A data frame to generate schema information from.} - -\item{table_name}{A string containing the name of the table in SQL queries.} - -\item{categorical_threshold}{The maximum number of unique values for a text column to be considered categorical.} -} -\value{ -A string containing the schema description for the data frame. -The schema includes the table name, column names, their types, and additional -information such as ranges for numeric columns and unique values for text columns. -} -\description{ -This function generates a schema description for a data frame, including -the column names, their types, and additional information such as ranges for -numeric columns and unique values for text columns. -} diff --git a/pkg-r/man/querychat_init.Rd b/pkg-r/man/querychat_init.Rd index 02c5448a..63494b3d 100644 --- a/pkg-r/man/querychat_init.Rd +++ b/pkg-r/man/querychat_init.Rd @@ -26,24 +26,9 @@ To create a data source: to display to the user upon first loading the chatbot. If not provided, the LLM will be invoked at the start of the conversation to generate one.} -\item{data_description}{Optional string or existing file path. The contents -should be in plain text or Markdown format, containing a description of the -data frame or any additional context that might be helpful in understanding -the data. This will be included in the system prompt for the chat model.} - -\item{extra_instructions}{Optional string or existing file path. The contents -should be in plain text or Markdown format, containing any additional -instructions for the chat model. These will be appended at the end of the -system prompt.} - -\item{prompt_template}{Optional string or existing file path. If \code{NULL}, the -default prompt file in the package will be used. The contents should -contain a whisker template for the system prompt, with placeholders for -\code{{{schema}}}, \code{{{data_description}}}, and \code{{{extra_instructions}}}.} - \item{system_prompt}{A string containing the system prompt for the chat model. -The default uses \code{create_system_prompt()} to generate a generic prompt, -which you can enhance via the \code{data_description} and \code{extra_instructions} +The default generates a generic prompt, which you can enhance via the \code{data_description} and +\code{extra_instructions} arguments.} \item{auto_close_data_source}{Should the data source connection be automatically From c58944432b280d5ca2425a78a58b9a96dfcc24ce Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Tue, 1 Jul 2025 17:51:15 -0600 Subject: [PATCH 34/51] adding tests to python github action (as we have some tests now!) --- .github/workflows/py-test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/py-test.yml b/.github/workflows/py-test.yml index 4486a805..2e574233 100644 --- a/.github/workflows/py-test.yml +++ b/.github/workflows/py-test.yml @@ -37,8 +37,8 @@ jobs: - name: πŸ“¦ Install the project run: uv sync --python ${{matrix.config.python-version }} --all-extras --all-groups - # - name: πŸ§ͺ Check tests - # run: make py-check-tests + - name: πŸ§ͺ Check tests + run: make py-check-tests - name: πŸ“ Check types run: make py-check-types From 98b2f290c4a967866ba70e6c0d1ec88ba64885f5 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Tue, 1 Jul 2025 17:57:31 -0600 Subject: [PATCH 35/51] edits for gha --- pkg-r/R/querychat.R | 8 ++++++-- pkg-r/man/querychat_init.Rd | 11 +++++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/pkg-r/R/querychat.R b/pkg-r/R/querychat.R index 60b61bba..577ed743 100644 --- a/pkg-r/R/querychat.R +++ b/pkg-r/R/querychat.R @@ -10,10 +10,14 @@ #' @param greeting A string in Markdown format, containing the initial message #' to display to the user upon first loading the chatbot. If not provided, the #' LLM will be invoked at the start of the conversation to generate one. +#' @param data_description A string containing a data description for the chat model. We have found +#' that formatting the data description as a markdown bulleted list works best. +#' @param extra_instructions A string containing extra instructions for the chat model. +#' @param create_chat_func A function that takes a system prompt and returns a +#' chat object. The default uses `ellmer::chat_openai()`. #' @param system_prompt A string containing the system prompt for the chat model. #' The default generates a generic prompt, which you can enhance via the `data_description` and -#' `extra_instructions` -#' arguments. +#' `extra_instructions` arguments. #' @param auto_close_data_source Should the data source connection be automatically #' closed when the shiny app stops? Defaults to TRUE. #' diff --git a/pkg-r/man/querychat_init.Rd b/pkg-r/man/querychat_init.Rd index 63494b3d..618d8532 100644 --- a/pkg-r/man/querychat_init.Rd +++ b/pkg-r/man/querychat_init.Rd @@ -26,10 +26,17 @@ To create a data source: to display to the user upon first loading the chatbot. If not provided, the LLM will be invoked at the start of the conversation to generate one.} +\item{data_description}{A string containing a data description for the chat model. We have found +that formatting the data description as a markdown bulleted list works best.} + +\item{extra_instructions}{A string containing extra instructions for the chat model.} + +\item{create_chat_func}{A function that takes a system prompt and returns a +chat object. The default uses \code{ellmer::chat_openai()}.} + \item{system_prompt}{A string containing the system prompt for the chat model. The default generates a generic prompt, which you can enhance via the \code{data_description} and -\code{extra_instructions} -arguments.} +\code{extra_instructions} arguments.} \item{auto_close_data_source}{Should the data source connection be automatically closed when the shiny app stops? Defaults to TRUE.} From 3fd17e45f6e1de646835ed53c42dc8962480203f Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Tue, 1 Jul 2025 17:59:24 -0600 Subject: [PATCH 36/51] makefile edit --- Makefile | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index 0a93a448..a82b5b61 100644 --- a/Makefile +++ b/Makefile @@ -123,12 +123,11 @@ py-check-tox: ## [py] Run python 3.9 - 3.12 checks with tox @echo "πŸ”„ Running tests and type checking with tox for Python 3.9--3.12" uv run tox run-parallel -# .PHONY: py-check-tests -# py-check-tests: ## [py] Run python tests -# @echo "" -# @echo "πŸ§ͺ Running tests with pytest" -# uv run playwright install -# uv run pytest +.PHONY: py-check-tests +py-check-tests: ## [py] Run python tests + @echo "" + @echo "πŸ§ͺ Running tests with pytest" + uv run pytest .PHONY: py-check-types py-check-types: ## [py] Run python type checks From e6731be7cb76033163e11bd72998ac63bd537ac1 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Tue, 8 Jul 2025 09:36:26 -0600 Subject: [PATCH 37/51] air format --- pkg-r/R/querychat.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg-r/R/querychat.R b/pkg-r/R/querychat.R index 577ed743..18ae8ea2 100644 --- a/pkg-r/R/querychat.R +++ b/pkg-r/R/querychat.R @@ -16,7 +16,7 @@ #' @param create_chat_func A function that takes a system prompt and returns a #' chat object. The default uses `ellmer::chat_openai()`. #' @param system_prompt A string containing the system prompt for the chat model. -#' The default generates a generic prompt, which you can enhance via the `data_description` and +#' The default generates a generic prompt, which you can enhance via the `data_description` and #' `extra_instructions` arguments. #' @param auto_close_data_source Should the data source connection be automatically #' closed when the shiny app stops? Defaults to TRUE. From d45820fbe4478870f14b719fc8ac1b3c1b896b70 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Wed, 9 Jul 2025 16:53:28 -0600 Subject: [PATCH 38/51] code cleanup, better tests, and dropping `glue` dependency --- pkg-r/DESCRIPTION | 1 - pkg-r/R/data_source.R | 73 ++++++++++++++++--------- pkg-r/R/querychat.R | 17 ++---- pkg-r/tests/testthat/test-data-source.R | 48 ++++++++++++---- 4 files changed, 91 insertions(+), 48 deletions(-) diff --git a/pkg-r/DESCRIPTION b/pkg-r/DESCRIPTION index 2459d5f4..70f91b91 100644 --- a/pkg-r/DESCRIPTION +++ b/pkg-r/DESCRIPTION @@ -21,7 +21,6 @@ Imports: dplyr, duckdb, ellmer, - glue, htmltools, purrr, rlang, diff --git a/pkg-r/R/data_source.R b/pkg-r/R/data_source.R index a9fd82e5..7a0e68eb 100644 --- a/pkg-r/R/data_source.R +++ b/pkg-r/R/data_source.R @@ -67,8 +67,10 @@ querychat_data_source.DBIConnection <- function( } if (!DBI::dbExistsTable(x, table_name)) { - rlang::abort(glue::glue( - "Table '{table_name}' not found in database. If you're using databricks, try setting the 'Catalog' and 'Schema' arguments to DBI::dbConnect" + rlang::abort(paste0( + "Table '", + table_name, + "' not found in database. If you're using databricks, try setting the 'Catalog' and 'Schema' arguments to DBI::dbConnect" )) } @@ -262,7 +264,7 @@ get_schema.dbi_source <- function(source, ...) { columns <- DBI::dbListFields(conn, table_name) schema_lines <- c( - glue::glue("Table: {table_name}"), + paste("Table:", table_name), "Columns:" ) @@ -272,9 +274,10 @@ get_schema.dbi_source <- function(source, ...) { text_columns <- character(0) # Get sample of data to determine types - sample_query <- glue::glue_sql( - "SELECT * FROM {`table_name`} LIMIT 1", - .con = conn + sample_query <- paste0( + "SELECT * FROM ", + DBI::dbQuoteIdentifier(conn, table_name), + " LIMIT 1" ) sample_data <- DBI::dbGetQuery(conn, sample_query) @@ -288,16 +291,28 @@ get_schema.dbi_source <- function(source, ...) { numeric_columns <- c(numeric_columns, col) select_parts <- c( select_parts, - glue::glue_sql("MIN({`col`}) as {`col`}__min", .con = conn), - glue::glue_sql("MAX({`col`}) as {`col`}__max", .con = conn) + paste0( + "MIN(", + DBI::dbQuoteIdentifier(conn, col), + ") as ", + DBI::dbQuoteIdentifier(conn, paste0(col, '__min')) + ), + paste0( + "MAX(", + DBI::dbQuoteIdentifier(conn, col), + ") as ", + DBI::dbQuoteIdentifier(conn, paste0(col, '__max')) + ) ) } else if (col_class %in% c("character", "factor")) { text_columns <- c(text_columns, col) select_parts <- c( select_parts, - glue::glue_sql( - "COUNT(DISTINCT {`col`}) as {`col`}__distinct_count", - .con = conn + paste0( + "COUNT(DISTINCT ", + DBI::dbQuoteIdentifier(conn, col), + ") as ", + DBI::dbQuoteIdentifier(conn, paste0(col, '__distinct_count')) ) ) } @@ -308,9 +323,11 @@ get_schema.dbi_source <- function(source, ...) { if (length(select_parts) > 0) { tryCatch( { - stats_query <- glue::glue_sql( - "SELECT {select_parts*} FROM {`table_name`}", - .con = conn + stats_query <- paste0( + "SELECT ", + paste0(select_parts, collapse = ", "), + " FROM ", + DBI::dbQuoteIdentifier(conn, table_name) ) result <- DBI::dbGetQuery(conn, stats_query) if (nrow(result) > 0) { @@ -327,11 +344,6 @@ get_schema.dbi_source <- function(source, ...) { categorical_values <- list() text_cols_to_query <- character(0) - # Always include the 'name' field from test_df for test case in tests/testthat/test-data-source.R - if ("name" %in% text_columns) { - text_cols_to_query <- c(text_cols_to_query, "name") - } - for (col_name in text_columns) { distinct_count_key <- paste0(col_name, "__distinct_count") if ( @@ -352,9 +364,15 @@ get_schema.dbi_source <- function(source, ...) { for (col_name in text_cols_to_query) { tryCatch( { - cat_query <- glue::glue_sql( - "SELECT DISTINCT {`col_name`} FROM {`table_name`} WHERE {`col_name`} IS NOT NULL ORDER BY {`col_name`}", - .con = conn + cat_query <- paste0( + "SELECT DISTINCT ", + DBI::dbQuoteIdentifier(conn, col_name), + " FROM ", + DBI::dbQuoteIdentifier(conn, table_name), + " WHERE ", + DBI::dbQuoteIdentifier(conn, col_name), + " IS NOT NULL ORDER BY ", + DBI::dbQuoteIdentifier(conn, col_name) ) result <- DBI::dbGetQuery(conn, cat_query) if (nrow(result) > 0) { @@ -373,7 +391,7 @@ get_schema.dbi_source <- function(source, ...) { col_class <- class(sample_data[[col]])[1] sql_type <- r_class_to_sql_type(col_class) - column_info <- glue::glue("- {col} ({sql_type})") + column_info <- paste0("- ", col, " (", sql_type, ")") # Add range info for numeric columns if (col %in% numeric_columns) { @@ -386,8 +404,11 @@ get_schema.dbi_source <- function(source, ...) { !is.na(column_stats[[min_key]]) && !is.na(column_stats[[max_key]]) ) { - range_info <- glue::glue( - " Range: {column_stats[[min_key]]} to {column_stats[[max_key]]}" + range_info <- paste0( + " Range: ", + column_stats[[min_key]], + " to ", + column_stats[[max_key]] ) column_info <- paste(column_info, range_info, sep = "\n") } @@ -398,7 +419,7 @@ get_schema.dbi_source <- function(source, ...) { values <- categorical_values[[col]] if (length(values) > 0) { values_str <- paste0("'", values, "'", collapse = ", ") - cat_info <- glue::glue(" Categorical values: {values_str}") + cat_info <- paste0(" Categorical values: ", values_str) column_info <- paste(column_info, cat_info, sep = "\n") } } diff --git a/pkg-r/R/querychat.R b/pkg-r/R/querychat.R index 18ae8ea2..c1c26f23 100644 --- a/pkg-r/R/querychat.R +++ b/pkg-r/R/querychat.R @@ -76,15 +76,6 @@ querychat_init <- function( "create_chat_func must be a function" = is.function(create_chat_func) ) - if ("table_name" %in% names(attributes(system_prompt))) { - # If available, be sure to use the `table_name` argument to `querychat_init()` - # matches the one supplied to the system prompt - if (table_name != attr(system_prompt, "table_name")) { - rlang::abort( - "`querychat_init(table_name=)` must match system prompt `table_name` supplied to `querychat_system_prompt()`." - ) - } - } if (!is.null(greeting)) { greeting <- paste(collapse = "\n", greeting) } else { @@ -307,8 +298,12 @@ df_to_html <- function(df, maxrows = 5) { paste(collapse = "\n") if (nrow(df_short) != nrow(df)) { - rows_notice <- glue::glue( - "\n\n(Showing only the first {maxrows} rows out of {nrow(df)}.)\n" + rows_notice <- paste0( + "\n\n(Showing only the first ", + maxrows, + " rows out of ", + nrow(df), + ".)\n" ) } else { rows_notice <- "" diff --git a/pkg-r/tests/testthat/test-data-source.R b/pkg-r/tests/testthat/test-data-source.R index ea6b9555..c106d056 100644 --- a/pkg-r/tests/testthat/test-data-source.R +++ b/pkg-r/tests/testthat/test-data-source.R @@ -63,11 +63,14 @@ test_that("get_schema methods return proper schema", { df_source <- querychat_data_source(test_df, table_name = "test_table") schema <- get_schema(df_source) expect_type(schema, "character") - expect_true(grepl("Table: test_table", schema)) - expect_true(grepl("id \\(INTEGER\\)", schema)) - expect_true(grepl("name \\(TEXT\\)", schema)) - expect_true(grepl("active \\(BOOLEAN\\)", schema)) - expect_true(grepl("Categorical values", schema)) # Should list categorical values + expect_match(schema, "Table: test_table") + expect_match(schema, "id \\(INTEGER\\)") + expect_match(schema, "name \\(TEXT\\)") + expect_match(schema, "active \\(BOOLEAN\\)") + expect_match(schema, "Categorical values") # Should list categorical values + + # Test min/max values in schema - specifically for the id column + expect_match(schema, "- id \\(INTEGER\\)\\n Range: 1 to 5") # Test with DBI source temp_db <- tempfile(fileext = ".db") @@ -77,9 +80,12 @@ test_that("get_schema methods return proper schema", { dbi_source <- querychat_data_source(conn, "test_table") schema <- get_schema(dbi_source) expect_type(schema, "character") - expect_true(grepl("Table: test_table", schema)) - expect_true(grepl("id \\(INTEGER\\)", schema)) - expect_true(grepl("name \\(TEXT\\)", schema)) + expect_match(schema, "Table: test_table") + expect_match(schema, "id \\(INTEGER\\)") + expect_match(schema, "name \\(TEXT\\)") + + # Test min/max values in DBI source schema - specifically for the id column + expect_match(schema, "- id \\(INTEGER\\)\\n Range: 1 to 5") # Clean up cleanup_source(df_source) @@ -155,6 +161,28 @@ test_that("get_lazy_data returns tbl objects", { unlink(temp_db) }) +test_that("get_schema correctly reports min/max values for numeric columns", { + # Create a dataframe with multiple numeric columns + test_df <- data.frame( + id = 1:5, + score = c(10.5, 20.3, 15.7, 30.1, 25.9), + count = c(100, 200, 150, 50, 75), + stringsAsFactors = FALSE + ) + + df_source <- querychat_data_source(test_df, table_name = "test_metrics") + schema <- get_schema(df_source) + + # Check that each numeric column has the correct min/max values + expect_match(schema, "- id \\(INTEGER\\)\\n Range: 1 to 5") + expect_match(schema, "- score \\(FLOAT\\)\\n Range: 10\\.5 to 30\\.1") + # Note: In the test output, count was detected as FLOAT rather than INTEGER + expect_match(schema, "- count \\(FLOAT\\)\\n Range: 50 to 200") + + # Clean up + cleanup_source(df_source) +}) + test_that("create_system_prompt generates appropriate system prompt", { test_df <- data.frame( id = 1:3, @@ -169,8 +197,8 @@ test_that("create_system_prompt generates appropriate system prompt", { ) expect_type(prompt, "character") expect_true(nchar(prompt) > 0) - expect_true(grepl("A test dataframe", prompt)) - expect_true(grepl("Table: test_table", prompt)) + expect_match(prompt, "A test dataframe") + expect_match(prompt, "Table: test_table") # Clean up cleanup_source(df_source) From 3f559749e0b0a1dc0df60fa02b5958dbd58ba042 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Wed, 16 Jul 2025 13:23:14 -0700 Subject: [PATCH 39/51] Fix error in qc.df() when no query is active Previously, the examples/app-database.R would shown an error on startup because the initial query was "", which was then sent as a SQL query to RSQLite. The get_lazy_data code path accounted for the "" query, so we decided to make the eager code path just call the lazy code path, then collect(). Also fixed a formatting issue with the table. --- pkg-r/R/data_source.R | 2 +- pkg-r/examples/app-database.R | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg-r/R/data_source.R b/pkg-r/R/data_source.R index 7a0e68eb..dfb2d12b 100644 --- a/pkg-r/R/data_source.R +++ b/pkg-r/R/data_source.R @@ -97,7 +97,7 @@ execute_query <- function(source, query, ...) { #' @export execute_query.dbi_source <- function(source, query, ...) { - DBI::dbGetQuery(source$conn, query) + dplyr::collect(get_lazy_data(source, query)) } #' Test a SQL query on a data source. diff --git a/pkg-r/examples/app-database.R b/pkg-r/examples/app-database.R index 040f521e..668b32c5 100644 --- a/pkg-r/examples/app-database.R +++ b/pkg-r/examples/app-database.R @@ -53,7 +53,7 @@ ui <- page_sidebar( p( "The table below shows the current filtered data based on your chat queries:" ), - DT::DTOutput("data_table"), + DT::DTOutput("data_table", fill = FALSE), br(), h3("Current SQL Query"), verbatimTextOutput("sql_query"), From 395e116cfe8aa0eea09e4ac7592fd5cd035c8c70 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Wed, 16 Jul 2025 19:59:10 -0600 Subject: [PATCH 40/51] Adding dplyr::sql() identifier to get_lazy_query() to fix failing tests. It seems like dbplyr tables-as-queries can be a bit... temperamental. This should fix that by explicitly declaring sql always. --- pkg-r/R/data_source.R | 3 +- pkg-r/tests/testthat/test-data-source.R | 47 +++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/pkg-r/R/data_source.R b/pkg-r/R/data_source.R index dfb2d12b..adbbbcad 100644 --- a/pkg-r/R/data_source.R +++ b/pkg-r/R/data_source.R @@ -137,7 +137,8 @@ get_lazy_data.dbi_source <- function(source, query = NULL, ...) { # For a null or empty query, default to returning the whole table (ie SELECT *) dplyr::tbl(source$conn, source$table_name) } else { - dplyr::tbl(source$conn, query) + # Use dbplyr::sql to create a safe SQL query object + dplyr::tbl(source$conn, dbplyr::sql(query)) } } diff --git a/pkg-r/tests/testthat/test-data-source.R b/pkg-r/tests/testthat/test-data-source.R index c106d056..bcf8cd3e 100644 --- a/pkg-r/tests/testthat/test-data-source.R +++ b/pkg-r/tests/testthat/test-data-source.R @@ -161,6 +161,53 @@ test_that("get_lazy_data returns tbl objects", { unlink(temp_db) }) +test_that("get_lazy_data works with empty query", { + # Test with data frame source + test_df <- data.frame( + id = 1:5, + value = c(10, 20, 30, 40, 50), + stringsAsFactors = FALSE + ) + + df_source <- querychat_data_source(test_df, table_name = "test_table") + + # Test with NULL query + lazy_data_null <- get_lazy_data(df_source, NULL) + expect_s3_class(lazy_data_null, "tbl") + result_null <- dplyr::collect(lazy_data_null) + expect_equal(nrow(result_null), 5) + + # Test with empty string query + lazy_data_empty <- get_lazy_data(df_source, "") + expect_s3_class(lazy_data_empty, "tbl") + result_empty <- dplyr::collect(lazy_data_empty) + expect_equal(nrow(result_empty), 5) + + # Test with DBI source + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) + + dbi_source <- querychat_data_source(conn, "test_table") + + # Test with NULL query + lazy_data_null <- get_lazy_data(dbi_source, NULL) + expect_s3_class(lazy_data_null, "tbl") + result_null <- dplyr::collect(lazy_data_null) + expect_equal(nrow(result_null), 5) + + # Test with empty string query + lazy_data_empty <- get_lazy_data(dbi_source, "") + expect_s3_class(lazy_data_empty, "tbl") + result_empty <- dplyr::collect(lazy_data_empty) + expect_equal(nrow(result_empty), 5) + + # Clean up + cleanup_source(df_source) + dbDisconnect(conn) + unlink(temp_db) +}) + test_that("get_schema correctly reports min/max values for numeric columns", { # Create a dataframe with multiple numeric columns test_df <- data.frame( From d86888d4ece86ac2ff45154501c756b3becd59e3 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Wed, 16 Jul 2025 20:03:08 -0600 Subject: [PATCH 41/51] adding more tests to cover the empty execute_data query use case and air formatting --- pkg-r/tests/testthat/test-data-source.R | 57 ++++++++++++++++++++++--- 1 file changed, 52 insertions(+), 5 deletions(-) diff --git a/pkg-r/tests/testthat/test-data-source.R b/pkg-r/tests/testthat/test-data-source.R index bcf8cd3e..4000dc59 100644 --- a/pkg-r/tests/testthat/test-data-source.R +++ b/pkg-r/tests/testthat/test-data-source.R @@ -128,6 +128,53 @@ test_that("execute_query works for both source types", { unlink(temp_db) }) +test_that("execute_query works with empty/null queries", { + # Test with data frame source + test_df <- data.frame( + id = 1:5, + value = c(10, 20, 30, 40, 50), + stringsAsFactors = FALSE + ) + + df_source <- querychat_data_source(test_df, table_name = "test_table") + + # Test with NULL query + result_null <- execute_query(df_source, NULL) + expect_s3_class(result_null, "data.frame") + expect_equal(nrow(result_null), 5) # Should return all rows + expect_equal(ncol(result_null), 2) # Should return all columns + + # Test with empty string query + result_empty <- execute_query(df_source, "") + expect_s3_class(result_empty, "data.frame") + expect_equal(nrow(result_empty), 5) # Should return all rows + expect_equal(ncol(result_empty), 2) # Should return all columns + + # Test with DBI source + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) + + dbi_source <- querychat_data_source(conn, "test_table") + + # Test with NULL query + result_null <- execute_query(dbi_source, NULL) + expect_s3_class(result_null, "data.frame") + expect_equal(nrow(result_null), 5) # Should return all rows + expect_equal(ncol(result_null), 2) # Should return all columns + + # Test with empty string query + result_empty <- execute_query(dbi_source, "") + expect_s3_class(result_empty, "data.frame") + expect_equal(nrow(result_empty), 5) # Should return all rows + expect_equal(ncol(result_empty), 2) # Should return all columns + + # Clean up + cleanup_source(df_source) + dbDisconnect(conn) + unlink(temp_db) +}) + test_that("get_lazy_data returns tbl objects", { # Test with data frame source test_df <- data.frame( @@ -170,32 +217,32 @@ test_that("get_lazy_data works with empty query", { ) df_source <- querychat_data_source(test_df, table_name = "test_table") - + # Test with NULL query lazy_data_null <- get_lazy_data(df_source, NULL) expect_s3_class(lazy_data_null, "tbl") result_null <- dplyr::collect(lazy_data_null) expect_equal(nrow(result_null), 5) - + # Test with empty string query lazy_data_empty <- get_lazy_data(df_source, "") expect_s3_class(lazy_data_empty, "tbl") result_empty <- dplyr::collect(lazy_data_empty) expect_equal(nrow(result_empty), 5) - + # Test with DBI source temp_db <- tempfile(fileext = ".db") conn <- dbConnect(RSQLite::SQLite(), temp_db) dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) dbi_source <- querychat_data_source(conn, "test_table") - + # Test with NULL query lazy_data_null <- get_lazy_data(dbi_source, NULL) expect_s3_class(lazy_data_null, "tbl") result_null <- dplyr::collect(lazy_data_null) expect_equal(nrow(result_null), 5) - + # Test with empty string query lazy_data_empty <- get_lazy_data(dbi_source, "") expect_s3_class(lazy_data_empty, "tbl") From 765250e2c4db0b2164d59d93da0cb129b469cec1 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Wed, 16 Jul 2025 20:07:05 -0600 Subject: [PATCH 42/51] description edit to pass routine test --- pkg-r/DESCRIPTION | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg-r/DESCRIPTION b/pkg-r/DESCRIPTION index 70f91b91..3d8552cc 100644 --- a/pkg-r/DESCRIPTION +++ b/pkg-r/DESCRIPTION @@ -16,8 +16,8 @@ Depends: R (>= 4.1.0) Imports: bslib, - dbplyr, DBI, + dbplyr, dplyr, duckdb, ellmer, @@ -28,12 +28,12 @@ Imports: shinychat (>= 0.2.0), whisker, xtable -Encoding: UTF-8 -Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.2 Suggests: DT, RSQLite, shinytest2, testthat (>= 3.0.0) Config/testthat/edition: 3 +Encoding: UTF-8 +Roxygen: list(markdown = TRUE) +RoxygenNote: 7.3.2 From 6432fa16c2360914d2528ff06829270c87095a6d Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Mon, 28 Jul 2025 15:24:45 -0600 Subject: [PATCH 43/51] edit to remove `tbl` output per discussion on #28 --- pkg-r/DESCRIPTION | 2 - pkg-r/NAMESPACE | 2 - pkg-r/R/data_source.R | 56 ++--- pkg-r/R/querychat.R | 15 +- pkg-r/man/get_lazy_data.Rd | 21 -- pkg-r/man/querychat_data_source.Rd | 10 +- pkg-r/man/querychat_server.Rd | 7 +- pkg-r/tests/testthat/test-data-source.R | 80 ------- pkg-r/tests/testthat/test-querychat-server.R | 26 +-- pkg-r/tests/testthat/test-shiny-app.R | 43 ++-- pkg-r/tests/testthat/test-sql-comments.R | 211 +++++++++++++++++++ 11 files changed, 272 insertions(+), 201 deletions(-) delete mode 100644 pkg-r/man/get_lazy_data.Rd create mode 100644 pkg-r/tests/testthat/test-sql-comments.R diff --git a/pkg-r/DESCRIPTION b/pkg-r/DESCRIPTION index 3d8552cc..cacfb127 100644 --- a/pkg-r/DESCRIPTION +++ b/pkg-r/DESCRIPTION @@ -17,8 +17,6 @@ Depends: Imports: bslib, DBI, - dbplyr, - dplyr, duckdb, ellmer, htmltools, diff --git a/pkg-r/NAMESPACE b/pkg-r/NAMESPACE index d597e911..8c75247d 100644 --- a/pkg-r/NAMESPACE +++ b/pkg-r/NAMESPACE @@ -5,7 +5,6 @@ S3method(create_system_prompt,querychat_data_source) S3method(execute_query,dbi_source) S3method(get_db_type,data_frame_source) S3method(get_db_type,dbi_source) -S3method(get_lazy_data,dbi_source) S3method(get_schema,dbi_source) S3method(querychat_data_source,DBIConnection) S3method(querychat_data_source,data.frame) @@ -14,7 +13,6 @@ export(cleanup_source) export(create_system_prompt) export(execute_query) export(get_db_type) -export(get_lazy_data) export(get_schema) export(querychat_data_source) export(querychat_init) diff --git a/pkg-r/R/data_source.R b/pkg-r/R/data_source.R index adbbbcad..1fe0f79d 100644 --- a/pkg-r/R/data_source.R +++ b/pkg-r/R/data_source.R @@ -4,7 +4,11 @@ #' dispatches to appropriate methods based on input. #' #' @param x A data frame or DBI connection -#' @param table_name The name to use for the table in the data source +#' @param table_name The name to use for the table in the data source. Can be: +#' - A character string (e.g., "table_name") +#' - Or, for tables contained within catalogs or schemas: +#' - A DBI::Id object (e.g., `DBI::Id(schema = "schema_name", table = "table_name")`) +#' - An AsIs object created with I() (e.g., `I("schema_name.table_name")`) #' @param categorical_threshold For text columns, the maximum number of unique values to consider as a categorical variable #' @param ... Additional arguments passed to specific methods #' @return A querychat_data_source object @@ -62,14 +66,26 @@ querychat_data_source.DBIConnection <- function( categorical_threshold = 20, ... ) { - if (!is.character(table_name) || length(table_name) != 1) { - rlang::abort("`table_name` must be a single character string") + # Handle different types of table_name inputs + if (inherits(table_name, "Id")) { + # DBI::Id object - keep as is + } else if (inherits(table_name, "AsIs")) { + # AsIs object - convert to character + table_name <- as.character(table_name) + } else if (is.character(table_name) && length(table_name) == 1) { + # Character string - keep as is + } else { + # Invalid input + rlang::abort( + "`table_name` must be a single character string, a DBI::Id object, or an AsIs object" + ) } + # Check if table exists if (!DBI::dbExistsTable(x, table_name)) { rlang::abort(paste0( "Table '", - table_name, + as.character(table_name), "' not found in database. If you're using databricks, try setting the 'Catalog' and 'Schema' arguments to DBI::dbConnect" )) } @@ -97,7 +113,15 @@ execute_query <- function(source, query, ...) { #' @export execute_query.dbi_source <- function(source, query, ...) { - dplyr::collect(get_lazy_data(source, query)) + if (is.null(query) || query == "") { + # For a null or empty query, default to returning the whole table (ie SELECT *) + query <- paste0( + "SELECT * FROM ", + DBI::dbQuoteIdentifier(source$conn, source$table_name) + ) + } + # Execute the query directly + DBI::dbGetQuery(source$conn, query) } #' Test a SQL query on a data source. @@ -120,28 +144,6 @@ test_query.dbi_source <- function(source, query, ...) { } -#' Get a lazy representation of a data source -#' -#' @param source A querychat_data_source object -#' @param query SQL query string -#' @param ... Additional arguments passed to methods -#' @return A lazy representation (typically a dbplyr tbl) -#' @export -get_lazy_data <- function(source, query, ...) { - UseMethod("get_lazy_data") -} - -#' @export -get_lazy_data.dbi_source <- function(source, query = NULL, ...) { - if (is.null(query) || query == "") { - # For a null or empty query, default to returning the whole table (ie SELECT *) - dplyr::tbl(source$conn, source$table_name) - } else { - # Use dbplyr::sql to create a safe SQL query object - dplyr::tbl(source$conn, dbplyr::sql(query)) - } -} - #' Get type information for a data source #' #' @param source A querychat_data_source object diff --git a/pkg-r/R/querychat.R b/pkg-r/R/querychat.R index c1c26f23..7a1c76f6 100644 --- a/pkg-r/R/querychat.R +++ b/pkg-r/R/querychat.R @@ -143,14 +143,9 @@ querychat_ui <- function(id) { #' #' - `sql`: A reactive that returns the current SQL query. #' - `title`: A reactive that returns the current title. -#' - `df`: A reactive that returns the filtered data. For data frame sources, -#' this returns a data.frame. For database sources, this returns a lazy -#' dbplyr tbl that can be further manipulated with dplyr verbs before -#' calling collect() to materialize the results. +#' - `df`: A reactive that returns the filtered data as a data.frame. #' - `chat`: The [ellmer::Chat] object that powers the chat interface. #' -#' By convention, this object should be named `querychat_config`. -#' #' @export querychat_server <- function(id, querychat_config) { shiny::moduleServer(id, function(input, output, session) { @@ -164,10 +159,7 @@ querychat_server <- function(id, querychat_config) { current_title <- shiny::reactiveVal(NULL) current_query <- shiny::reactiveVal("") filtered_df <- shiny::reactive({ - execute_query(data_source, query = dplyr::sql(current_query())) - }) - filtered_tbl <- shiny::reactive({ - get_lazy_data(data_source, query = dplyr::sql(current_query())) + execute_query(data_source, query = DBI::SQL(current_query())) }) append_output <- function(...) { @@ -277,8 +269,7 @@ querychat_server <- function(id, querychat_config) { chat = chat, sql = shiny::reactive(current_query()), title = shiny::reactive(current_title()), - df = filtered_df, - tbl = filtered_tbl + df = filtered_df ) }) } diff --git a/pkg-r/man/get_lazy_data.Rd b/pkg-r/man/get_lazy_data.Rd deleted file mode 100644 index 4c2a75f4..00000000 --- a/pkg-r/man/get_lazy_data.Rd +++ /dev/null @@ -1,21 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/data_source.R -\name{get_lazy_data} -\alias{get_lazy_data} -\title{Get a lazy representation of a data source} -\usage{ -get_lazy_data(source, query, ...) -} -\arguments{ -\item{source}{A querychat_data_source object} - -\item{query}{SQL query string} - -\item{...}{Additional arguments passed to methods} -} -\value{ -A lazy representation (typically a dbplyr tbl) -} -\description{ -Get a lazy representation of a data source -} diff --git a/pkg-r/man/querychat_data_source.Rd b/pkg-r/man/querychat_data_source.Rd index 424cfcc7..128e4160 100644 --- a/pkg-r/man/querychat_data_source.Rd +++ b/pkg-r/man/querychat_data_source.Rd @@ -17,7 +17,15 @@ querychat_data_source(x, ...) \item{...}{Additional arguments passed to specific methods} -\item{table_name}{The name to use for the table in the data source} +\item{table_name}{The name to use for the table in the data source. Can be: +\itemize{ +\item A character string (e.g., "table_name") +\item Or, for tables contained within catalogs or schemas: +\itemize{ +\item A DBI::Id object (e.g., \code{DBI::Id(schema = "schema_name", table = "table_name")}) +\item An AsIs object created with I() (e.g., \code{I("schema_name.table_name")}) +} +}} \item{categorical_threshold}{For text columns, the maximum number of unique values to consider as a categorical variable} } diff --git a/pkg-r/man/querychat_server.Rd b/pkg-r/man/querychat_server.Rd index 89b9e9d9..eec8f892 100644 --- a/pkg-r/man/querychat_server.Rd +++ b/pkg-r/man/querychat_server.Rd @@ -18,14 +18,9 @@ elements: \itemize{ \item \code{sql}: A reactive that returns the current SQL query. \item \code{title}: A reactive that returns the current title. -\item \code{df}: A reactive that returns the filtered data. For data frame sources, -this returns a data.frame. For database sources, this returns a lazy -dbplyr tbl that can be further manipulated with dplyr verbs before -calling collect() to materialize the results. +\item \code{df}: A reactive that returns the filtered data as a data.frame. \item \code{chat}: The \link[ellmer:Chat]{ellmer::Chat} object that powers the chat interface. } - -By convention, this object should be named \code{querychat_config}. } \description{ Initalize the querychat server diff --git a/pkg-r/tests/testthat/test-data-source.R b/pkg-r/tests/testthat/test-data-source.R index 4000dc59..9085c23f 100644 --- a/pkg-r/tests/testthat/test-data-source.R +++ b/pkg-r/tests/testthat/test-data-source.R @@ -1,7 +1,6 @@ library(testthat) library(DBI) library(RSQLite) -library(dplyr) library(querychat) test_that("querychat_data_source.data.frame creates proper S3 object", { @@ -175,85 +174,6 @@ test_that("execute_query works with empty/null queries", { unlink(temp_db) }) -test_that("get_lazy_data returns tbl objects", { - # Test with data frame source - test_df <- data.frame( - id = 1:5, - value = c(10, 20, 30, 40, 50), - stringsAsFactors = FALSE - ) - - df_source <- querychat_data_source(test_df, table_name = "test_table") - lazy_data <- get_lazy_data(df_source) - expect_s3_class(lazy_data, "tbl") - - # Test with DBI source - temp_db <- tempfile(fileext = ".db") - conn <- dbConnect(RSQLite::SQLite(), temp_db) - dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) - - dbi_source <- querychat_data_source(conn, "test_table") - lazy_data <- get_lazy_data(dbi_source) - expect_s3_class(lazy_data, "tbl") - - # Test chaining with dplyr - filtered_data <- lazy_data %>% - dplyr::filter(value > 25) %>% - dplyr::collect() - expect_equal(nrow(filtered_data), 3) # Should return 3 rows (30, 40, 50) - - # Clean up - cleanup_source(df_source) - dbDisconnect(conn) - unlink(temp_db) -}) - -test_that("get_lazy_data works with empty query", { - # Test with data frame source - test_df <- data.frame( - id = 1:5, - value = c(10, 20, 30, 40, 50), - stringsAsFactors = FALSE - ) - - df_source <- querychat_data_source(test_df, table_name = "test_table") - - # Test with NULL query - lazy_data_null <- get_lazy_data(df_source, NULL) - expect_s3_class(lazy_data_null, "tbl") - result_null <- dplyr::collect(lazy_data_null) - expect_equal(nrow(result_null), 5) - - # Test with empty string query - lazy_data_empty <- get_lazy_data(df_source, "") - expect_s3_class(lazy_data_empty, "tbl") - result_empty <- dplyr::collect(lazy_data_empty) - expect_equal(nrow(result_empty), 5) - - # Test with DBI source - temp_db <- tempfile(fileext = ".db") - conn <- dbConnect(RSQLite::SQLite(), temp_db) - dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) - - dbi_source <- querychat_data_source(conn, "test_table") - - # Test with NULL query - lazy_data_null <- get_lazy_data(dbi_source, NULL) - expect_s3_class(lazy_data_null, "tbl") - result_null <- dplyr::collect(lazy_data_null) - expect_equal(nrow(result_null), 5) - - # Test with empty string query - lazy_data_empty <- get_lazy_data(dbi_source, "") - expect_s3_class(lazy_data_empty, "tbl") - result_empty <- dplyr::collect(lazy_data_empty) - expect_equal(nrow(result_empty), 5) - - # Clean up - cleanup_source(df_source) - dbDisconnect(conn) - unlink(temp_db) -}) test_that("get_schema correctly reports min/max values for numeric columns", { # Create a dataframe with multiple numeric columns diff --git a/pkg-r/tests/testthat/test-querychat-server.R b/pkg-r/tests/testthat/test-querychat-server.R index 7647c5b0..a44cfb08 100644 --- a/pkg-r/tests/testthat/test-querychat-server.R +++ b/pkg-r/tests/testthat/test-querychat-server.R @@ -1,8 +1,6 @@ library(testthat) library(DBI) library(RSQLite) -library(dplyr) -library(dbplyr) library(querychat) test_that("database source query functionality", { @@ -29,26 +27,10 @@ test_that("database source query functionality", { expect_equal(nrow(result), 2) # Charlie and Eve expect_equal(result$name, c("Charlie", "Eve")) - # Test that we can get all data as lazy dbplyr table - all_data <- get_lazy_data(db_source) - expect_s3_class( - all_data, - c("tbl_SQLiteConnection", "tbl_dbi", "tbl_sql", "tbl_lazy", "tbl") - ) - - # Test that it can be chained with dbplyr operations before collect() - filtered_data <- all_data |> - dplyr::filter(age >= 30) |> - dplyr::select(name, age) |> - dplyr::collect() - - expect_s3_class(filtered_data, "data.frame") - expect_equal(nrow(filtered_data), 3) # Bob, Charlie, Eve - - # Test that the lazy table can be collected to get all data - collected_data <- dplyr::collect(all_data) - expect_s3_class(collected_data, "data.frame") - expect_equal(nrow(collected_data), 5) + # Test that we can get all data + all_data <- execute_query(db_source, NULL) + expect_s3_class(all_data, "data.frame") + expect_equal(nrow(all_data), 5) expect_equal(ncol(all_data), 3) # Test ordering works diff --git a/pkg-r/tests/testthat/test-shiny-app.R b/pkg-r/tests/testthat/test-shiny-app.R index 925f1991..8f762840 100644 --- a/pkg-r/tests/testthat/test-shiny-app.R +++ b/pkg-r/tests/testthat/test-shiny-app.R @@ -55,12 +55,7 @@ server <- function(input, output, session) { chat <- querychat_server("chat", querychat_config) output$data_table <- DT::renderDT({ - data <- chat$df() - if (inherits(data, "tbl_lazy")) { - dplyr::collect(data) - } else { - data - } + chat$df() }, options = list(pageLength = 5)) output$sql_query <- renderText({ @@ -96,7 +91,6 @@ test_that("database reactive functionality works correctly", { library(DBI) library(RSQLite) - library(dplyr) # Create test database temp_db <- tempfile(fileext = ".db") @@ -126,28 +120,21 @@ test_that("database reactive functionality works correctly", { expect_s3_class(config$data_source, "dbi_source") expect_s3_class(config$data_source, "querychat_data_source") - # Test that get_lazy_data returns lazy table - lazy_data <- get_lazy_data(config$data_source) - expect_s3_class( - lazy_data, - c("tbl_SQLiteConnection", "tbl_dbi", "tbl_sql", "tbl_lazy", "tbl") - ) + # Test that we can get all data + result_data <- execute_query(config$data_source, NULL) + expect_s3_class(result_data, "data.frame") + expect_equal(nrow(result_data), 150) + expect_equal(ncol(result_data), 5) - # Test that we can chain operations and collect - result <- lazy_data %>% - filter(Species == "setosa") %>% - select(Sepal.Length, Sepal.Width) %>% - collect() - - expect_s3_class(result, "data.frame") - expect_equal(nrow(result), 50) - expect_equal(ncol(result), 2) - expect_true(all(c("Sepal.Length", "Sepal.Width") %in% names(result))) - - # Test that original lazy table is still usable - all_data <- collect(lazy_data) - expect_equal(nrow(all_data), 150) - expect_equal(ncol(all_data), 5) + # Test with a specific query + query_result <- execute_query( + config$data_source, + "SELECT \"Sepal.Length\", \"Sepal.Width\" FROM iris WHERE \"Species\" = 'setosa'" + ) + expect_s3_class(query_result, "data.frame") + expect_equal(nrow(query_result), 50) + expect_equal(ncol(query_result), 2) + expect_true(all(c("Sepal.Length", "Sepal.Width") %in% names(query_result))) # Clean up dbDisconnect(db_conn) diff --git a/pkg-r/tests/testthat/test-sql-comments.R b/pkg-r/tests/testthat/test-sql-comments.R new file mode 100644 index 00000000..e7553ad1 --- /dev/null +++ b/pkg-r/tests/testthat/test-sql-comments.R @@ -0,0 +1,211 @@ +library(testthat) +library(DBI) +library(RSQLite) +library(querychat) + +test_that("execute_query handles SQL with inline comments", { + # Create a simple test dataframe + test_df <- data.frame( + id = 1:5, + value = c(10, 20, 30, 40, 50), + stringsAsFactors = FALSE + ) + + # Create data source + df_source <- querychat_data_source(test_df, table_name = "test_table") + + # Test with inline comments + inline_comment_query <- " + SELECT id, value -- This is a comment + FROM test_table + WHERE value > 25 -- Filter for higher values + " + + result <- execute_query(df_source, inline_comment_query) + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 3) # Should return 3 rows (30, 40, 50) + expect_equal(ncol(result), 2) + + # Test with multiple inline comments + multiple_comments_query <- " + SELECT -- Get only these columns + id, -- ID column + value -- Value column + FROM test_table -- Our test table + WHERE value > 25 -- Only higher values + " + + result <- execute_query(df_source, multiple_comments_query) + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 3) + expect_equal(ncol(result), 2) + + # Clean up + cleanup_source(df_source) +}) + +test_that("execute_query handles SQL with multiline comments", { + # Create a simple test dataframe + test_df <- data.frame( + id = 1:5, + value = c(10, 20, 30, 40, 50), + stringsAsFactors = FALSE + ) + + # Create data source + df_source <- querychat_data_source(test_df, table_name = "test_table") + + # Test with multiline comments + multiline_comment_query <- " + /* + * This is a multiline comment + * that spans multiple lines + */ + SELECT id, value + FROM test_table + WHERE value > 25 + " + + result <- execute_query(df_source, multiline_comment_query) + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 3) + expect_equal(ncol(result), 2) + + # Test with embedded multiline comments + embedded_multiline_query <- " + SELECT id, /* comment between columns */ value + FROM /* this is + * a multiline + * comment + */ test_table + WHERE value /* another comment */ > 25 + " + + result <- execute_query(df_source, embedded_multiline_query) + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 3) + expect_equal(ncol(result), 2) + + # Clean up + cleanup_source(df_source) +}) + +test_that("execute_query handles SQL with trailing semicolons", { + # Create a simple test dataframe + test_df <- data.frame( + id = 1:5, + value = c(10, 20, 30, 40, 50), + stringsAsFactors = FALSE + ) + + # Create data source + df_source <- querychat_data_source(test_df, table_name = "test_table") + + # Test with trailing semicolon + query_with_semicolon <- " + SELECT id, value + FROM test_table + WHERE value > 25; + " + + result <- execute_query(df_source, query_with_semicolon) + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 3) + expect_equal(ncol(result), 2) + + # Test with multiple semicolons (which could happen with LLM-generated SQL) + query_with_multiple_semicolons <- " + SELECT id, value + FROM test_table + WHERE value > 25;;;; + " + + result <- execute_query(df_source, query_with_multiple_semicolons) + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 3) + expect_equal(ncol(result), 2) + + # Clean up + cleanup_source(df_source) +}) + +test_that("execute_query handles SQL with mixed comments and semicolons", { + # Create a simple test dataframe + test_df <- data.frame( + id = 1:5, + value = c(10, 20, 30, 40, 50), + stringsAsFactors = FALSE + ) + + # Create data source + df_source <- querychat_data_source(test_df, table_name = "test_table") + + # Test with a mix of comment styles and semicolons + complex_query <- " + /* + * This is a complex query with different comment styles + */ + SELECT + id, -- This is the ID column + value /* Value column */ + FROM + test_table -- Our test table + WHERE + /* Only get higher values */ + value > 25; -- End of query + " + + result <- execute_query(df_source, complex_query) + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 3) + expect_equal(ncol(result), 2) + + # Test with comments that contain SQL-like syntax + tricky_comment_query <- " + SELECT id, value + FROM test_table + /* Comment with SQL-like syntax: + * SELECT * FROM another_table; + */ + WHERE value > 25 -- WHERE id = 'value; DROP TABLE test;' + " + + result <- execute_query(df_source, tricky_comment_query) + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 3) + expect_equal(ncol(result), 2) + + # Clean up + cleanup_source(df_source) +}) + +test_that("execute_query handles SQL with unusual whitespace patterns", { + # Create a simple test dataframe + test_df <- data.frame( + id = 1:5, + value = c(10, 20, 30, 40, 50), + stringsAsFactors = FALSE + ) + + # Create data source + df_source <- querychat_data_source(test_df, table_name = "test_table") + + # Test with unusual whitespace patterns (which LLMs might generate) + unusual_whitespace_query <- " + + SELECT id, value + + FROM test_table + + WHERE value>25 + + " + + result <- execute_query(df_source, unusual_whitespace_query) + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 3) + expect_equal(ncol(result), 2) + + # Clean up + cleanup_source(df_source) +}) From de0a31e3db04ada919e7c633c12d1f689106f9ed Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Mon, 28 Jul 2025 18:14:49 -0600 Subject: [PATCH 44/51] better data source nested identifier handling --- pkg-r/R/data_source.R | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pkg-r/R/data_source.R b/pkg-r/R/data_source.R index 1fe0f79d..189ba26f 100644 --- a/pkg-r/R/data_source.R +++ b/pkg-r/R/data_source.R @@ -67,11 +67,8 @@ querychat_data_source.DBIConnection <- function( ... ) { # Handle different types of table_name inputs - if (inherits(table_name, "Id")) { + if (inherits(table_name, "Id") || inherits(table_name, "AsIs")) { # DBI::Id object - keep as is - } else if (inherits(table_name, "AsIs")) { - # AsIs object - convert to character - table_name <- as.character(table_name) } else if (is.character(table_name) && length(table_name) == 1) { # Character string - keep as is } else { From b6eeb4ad402fce41e4c7d03d5301ab70e9d41df6 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Mon, 28 Jul 2025 19:05:28 -0600 Subject: [PATCH 45/51] fixing a missing quote identifier --- pkg-r/R/data_source.R | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pkg-r/R/data_source.R b/pkg-r/R/data_source.R index 189ba26f..5a8176ca 100644 --- a/pkg-r/R/data_source.R +++ b/pkg-r/R/data_source.R @@ -6,9 +6,7 @@ #' @param x A data frame or DBI connection #' @param table_name The name to use for the table in the data source. Can be: #' - A character string (e.g., "table_name") -#' - Or, for tables contained within catalogs or schemas: -#' - A DBI::Id object (e.g., `DBI::Id(schema = "schema_name", table = "table_name")`) -#' - An AsIs object created with I() (e.g., `I("schema_name.table_name")`) +#' - Or, for tables contained within catalogs or schemas, a DBI::Id object (e.g., `DBI::Id(schema = "schema_name", table = "table_name")`) #' @param categorical_threshold For text columns, the maximum number of unique values to consider as a categorical variable #' @param ... Additional arguments passed to specific methods #' @return A querychat_data_source object @@ -264,7 +262,7 @@ get_schema.dbi_source <- function(source, ...) { columns <- DBI::dbListFields(conn, table_name) schema_lines <- c( - paste("Table:", table_name), + paste("Table:", DBI::dbQuoteIdentifier(conn, table_name)), "Columns:" ) From 32a65fccc0ccca3bcc5e55bd172e836892c7283c Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Mon, 28 Jul 2025 19:14:34 -0600 Subject: [PATCH 46/51] doc cleanup --- pkg-r/R/data_source.R | 2 +- pkg-r/man/querychat_data_source.Rd | 6 +----- pkg-r/tests/testthat/test-data-source.R | 2 +- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/pkg-r/R/data_source.R b/pkg-r/R/data_source.R index 5a8176ca..f1fc1f6d 100644 --- a/pkg-r/R/data_source.R +++ b/pkg-r/R/data_source.R @@ -6,7 +6,7 @@ #' @param x A data frame or DBI connection #' @param table_name The name to use for the table in the data source. Can be: #' - A character string (e.g., "table_name") -#' - Or, for tables contained within catalogs or schemas, a DBI::Id object (e.g., `DBI::Id(schema = "schema_name", table = "table_name")`) +#' - Or, for tables contained within catalogs or schemas, a [DBI::Id()] object (e.g., `DBI::Id(schema = "schema_name", table = "table_name")`) #' @param categorical_threshold For text columns, the maximum number of unique values to consider as a categorical variable #' @param ... Additional arguments passed to specific methods #' @return A querychat_data_source object diff --git a/pkg-r/man/querychat_data_source.Rd b/pkg-r/man/querychat_data_source.Rd index 128e4160..7d99ac5a 100644 --- a/pkg-r/man/querychat_data_source.Rd +++ b/pkg-r/man/querychat_data_source.Rd @@ -20,11 +20,7 @@ querychat_data_source(x, ...) \item{table_name}{The name to use for the table in the data source. Can be: \itemize{ \item A character string (e.g., "table_name") -\item Or, for tables contained within catalogs or schemas: -\itemize{ -\item A DBI::Id object (e.g., \code{DBI::Id(schema = "schema_name", table = "table_name")}) -\item An AsIs object created with I() (e.g., \code{I("schema_name.table_name")}) -} +\item Or, for tables contained within catalogs or schemas, a \code{\link[DBI:Id]{DBI::Id()}} object (e.g., \code{DBI::Id(schema = "schema_name", table = "table_name")}) }} \item{categorical_threshold}{For text columns, the maximum number of unique values to consider as a categorical variable} diff --git a/pkg-r/tests/testthat/test-data-source.R b/pkg-r/tests/testthat/test-data-source.R index 9085c23f..a957aae9 100644 --- a/pkg-r/tests/testthat/test-data-source.R +++ b/pkg-r/tests/testthat/test-data-source.R @@ -79,7 +79,7 @@ test_that("get_schema methods return proper schema", { dbi_source <- querychat_data_source(conn, "test_table") schema <- get_schema(dbi_source) expect_type(schema, "character") - expect_match(schema, "Table: test_table") + expect_match(schema, "Table: `test_table`") expect_match(schema, "id \\(INTEGER\\)") expect_match(schema, "name \\(TEXT\\)") From 1325ed12992515267dffa3a1eb67e9cb76df7d92 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Mon, 28 Jul 2025 19:22:17 -0600 Subject: [PATCH 47/51] a bit more helpful error message --- pkg-r/R/data_source.R | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pkg-r/R/data_source.R b/pkg-r/R/data_source.R index f1fc1f6d..b2b6b54c 100644 --- a/pkg-r/R/data_source.R +++ b/pkg-r/R/data_source.R @@ -79,9 +79,10 @@ querychat_data_source.DBIConnection <- function( # Check if table exists if (!DBI::dbExistsTable(x, table_name)) { rlang::abort(paste0( - "Table '", - as.character(table_name), - "' not found in database. If you're using databricks, try setting the 'Catalog' and 'Schema' arguments to DBI::dbConnect" + "Table ", + DBI::dbQuoteIdentifier(x, table_name), + " not found in database. If you're using a table in a catalog or schema, pass a DBI::Id", + " object to `table_name`" )) } From 0d01d8293ab54be56259628207451830f8e31f41 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Mon, 28 Jul 2025 19:23:19 -0600 Subject: [PATCH 48/51] even more helpful erroring --- pkg-r/R/data_source.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg-r/R/data_source.R b/pkg-r/R/data_source.R index b2b6b54c..9a1282ef 100644 --- a/pkg-r/R/data_source.R +++ b/pkg-r/R/data_source.R @@ -65,14 +65,14 @@ querychat_data_source.DBIConnection <- function( ... ) { # Handle different types of table_name inputs - if (inherits(table_name, "Id") || inherits(table_name, "AsIs")) { + if (inherits(table_name, "Id")) { # DBI::Id object - keep as is } else if (is.character(table_name) && length(table_name) == 1) { # Character string - keep as is } else { # Invalid input rlang::abort( - "`table_name` must be a single character string, a DBI::Id object, or an AsIs object" + "`table_name` must be a single character string or a DBI::Id object" ) } From bc2ce5a21463846606936b47d4ac35a0d2eec084 Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Wed, 3 Sep 2025 16:46:00 -0600 Subject: [PATCH 49/51] fix to df_to_html --- pkg-py/src/querychat/querychat.py | 13 +++- pkg-py/tests/test_query_function.py | 113 ++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 3 deletions(-) create mode 100644 pkg-py/tests/test_query_function.py diff --git a/pkg-py/src/querychat/querychat.py b/pkg-py/src/querychat/querychat.py index 3f6ea616..a2e5cc1b 100644 --- a/pkg-py/src/querychat/querychat.py +++ b/pkg-py/src/querychat/querychat.py @@ -224,11 +224,18 @@ def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: HTML string representation of the table """ + # Convert to Narwhals DataFrame if it's not already one if isinstance(df, (nw.LazyFrame, nw.DataFrame)): - df_short = df.lazy().head(maxrows).collect() - nrow_full = df.lazy().select(nw.len()).collect().item() + nw_df = df else: - raise TypeError("df must be a Narwhals DataFrame or LazyFrame") + # Try to convert using nw.from_native (supports pandas and other formats) + try: + nw_df = nw.from_native(df) + except Exception as e: + raise TypeError("df must be a Narwhals DataFrame, LazyFrame, or compatible DataFrame (e.g., pandas)") from e + + df_short = nw_df.lazy().head(maxrows).collect() + nrow_full = nw_df.lazy().select(nw.len()).collect().item() # Generate HTML table table_html = df_short.to_pandas().to_html( diff --git a/pkg-py/tests/test_query_function.py b/pkg-py/tests/test_query_function.py new file mode 100644 index 00000000..6dddb281 --- /dev/null +++ b/pkg-py/tests/test_query_function.py @@ -0,0 +1,113 @@ +import pandas as pd +import pytest +from src.querychat.datasource import DataFrameSource, SQLAlchemySource +from src.querychat.querychat import df_to_html +from sqlalchemy import create_engine +import sqlite3 +import tempfile +from pathlib import Path + + +@pytest.fixture +def sample_dataframe(): + """Create a sample pandas DataFrame for testing.""" + return pd.DataFrame({ + 'id': [1, 2, 3, 4, 5], + 'name': ['Alice', 'Bob', 'Charlie', 'Diana', 'Eve'], + 'age': [25, 30, 35, 28, 32], + 'salary': [50000, 60000, 70000, 55000, 65000] + }) + + +@pytest.fixture +def test_db_engine_with_data(): + """Create a temporary SQLite database with test data.""" + temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db") + temp_db.close() + + conn = sqlite3.connect(temp_db.name) + cursor = conn.cursor() + + cursor.execute(""" + CREATE TABLE employees ( + id INTEGER PRIMARY KEY, + name TEXT, + age INTEGER, + salary REAL + ) + """) + + test_data = [ + (1, "Alice", 25, 50000), + (2, "Bob", 30, 60000), + (3, "Charlie", 35, 70000), + (4, "Diana", 28, 55000), + (5, "Eve", 32, 65000) + ] + + cursor.executemany( + "INSERT INTO employees (id, name, age, salary) VALUES (?, ?, ?, ?)", + test_data + ) + + conn.commit() + conn.close() + + engine = create_engine(f"sqlite:///{temp_db.name}") + yield engine + + # Cleanup + Path(temp_db.name).unlink() + + +def test_df_to_html_with_dataframe_source_result(sample_dataframe): + """Test that df_to_html() works with results from DataFrameSource.execute_query().""" + source = DataFrameSource(sample_dataframe, "employees") + + # Execute query to get pandas DataFrame + result_df = source.execute_query("SELECT * FROM employees WHERE age > 25") + + # This should succeed after the fix + html_output = df_to_html(result_df) + + # Verify the HTML contains expected content + assert isinstance(html_output, str) + assert ' 25") + + # This should succeed after the fix + html_output = df_to_html(result_df) + + # Verify the HTML contains expected content + assert isinstance(html_output, str) + assert ' Date: Wed, 3 Sep 2025 16:54:52 -0600 Subject: [PATCH 50/51] formatting --- pkg-py/src/querychat/querychat.py | 6 ++- pkg-py/tests/test_query_function.py | 83 +++++++++++++++-------------- 2 files changed, 47 insertions(+), 42 deletions(-) diff --git a/pkg-py/src/querychat/querychat.py b/pkg-py/src/querychat/querychat.py index a2e5cc1b..994223fd 100644 --- a/pkg-py/src/querychat/querychat.py +++ b/pkg-py/src/querychat/querychat.py @@ -232,8 +232,10 @@ def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: try: nw_df = nw.from_native(df) except Exception as e: - raise TypeError("df must be a Narwhals DataFrame, LazyFrame, or compatible DataFrame (e.g., pandas)") from e - + raise TypeError( + "df must be a Narwhals DataFrame, LazyFrame, or compatible DataFrame (e.g., pandas)" + ) from e + df_short = nw_df.lazy().head(maxrows).collect() nrow_full = nw_df.lazy().select(nw.len()).collect().item() diff --git a/pkg-py/tests/test_query_function.py b/pkg-py/tests/test_query_function.py index 6dddb281..982efacd 100644 --- a/pkg-py/tests/test_query_function.py +++ b/pkg-py/tests/test_query_function.py @@ -1,33 +1,36 @@ +import sqlite3 +import tempfile +from pathlib import Path + import pandas as pd import pytest +from sqlalchemy import create_engine from src.querychat.datasource import DataFrameSource, SQLAlchemySource from src.querychat.querychat import df_to_html -from sqlalchemy import create_engine -import sqlite3 -import tempfile -from pathlib import Path @pytest.fixture def sample_dataframe(): """Create a sample pandas DataFrame for testing.""" - return pd.DataFrame({ - 'id': [1, 2, 3, 4, 5], - 'name': ['Alice', 'Bob', 'Charlie', 'Diana', 'Eve'], - 'age': [25, 30, 35, 28, 32], - 'salary': [50000, 60000, 70000, 55000, 65000] - }) + return pd.DataFrame( + { + "id": [1, 2, 3, 4, 5], + "name": ["Alice", "Bob", "Charlie", "Diana", "Eve"], + "age": [25, 30, 35, 28, 32], + "salary": [50000, 60000, 70000, 55000, 65000], + }, + ) @pytest.fixture def test_db_engine_with_data(): """Create a temporary SQLite database with test data.""" - temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db") + temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db") # noqa: SIM115 temp_db.close() - + conn = sqlite3.connect(temp_db.name) cursor = conn.cursor() - + cursor.execute(""" CREATE TABLE employees ( id INTEGER PRIMARY KEY, @@ -36,26 +39,26 @@ def test_db_engine_with_data(): salary REAL ) """) - + test_data = [ (1, "Alice", 25, 50000), (2, "Bob", 30, 60000), (3, "Charlie", 35, 70000), (4, "Diana", 28, 55000), - (5, "Eve", 32, 65000) + (5, "Eve", 32, 65000), ] - + cursor.executemany( "INSERT INTO employees (id, name, age, salary) VALUES (?, ?, ?, ?)", - test_data + test_data, ) - + conn.commit() conn.close() - + engine = create_engine(f"sqlite:///{temp_db.name}") yield engine - + # Cleanup Path(temp_db.name).unlink() @@ -63,51 +66,51 @@ def test_db_engine_with_data(): def test_df_to_html_with_dataframe_source_result(sample_dataframe): """Test that df_to_html() works with results from DataFrameSource.execute_query().""" source = DataFrameSource(sample_dataframe, "employees") - + # Execute query to get pandas DataFrame result_df = source.execute_query("SELECT * FROM employees WHERE age > 25") - + # This should succeed after the fix html_output = df_to_html(result_df) - + # Verify the HTML contains expected content assert isinstance(html_output, str) - assert ' 25") - + # This should succeed after the fix html_output = df_to_html(result_df) - + # Verify the HTML contains expected content assert isinstance(html_output, str) - assert ' Date: Wed, 3 Sep 2025 16:56:58 -0600 Subject: [PATCH 51/51] more formatting --- pkg-py/src/querychat/querychat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg-py/src/querychat/querychat.py b/pkg-py/src/querychat/querychat.py index 994223fd..5b142576 100644 --- a/pkg-py/src/querychat/querychat.py +++ b/pkg-py/src/querychat/querychat.py @@ -233,7 +233,7 @@ def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: nw_df = nw.from_native(df) except Exception as e: raise TypeError( - "df must be a Narwhals DataFrame, LazyFrame, or compatible DataFrame (e.g., pandas)" + "df must be a Narwhals DataFrame, LazyFrame, or compatible DataFrame (e.g., pandas)", ) from e df_short = nw_df.lazy().head(maxrows).collect()