Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 33 additions & 83 deletions pkg-py/src/querychat/_querychat.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
greeting: Optional[str | Path] = None,
client: Optional[str | chatlas.Chat] = None,
data_description: Optional[str | Path] = None,
categorical_threshold: int = 10,
extra_instructions: Optional[str | Path] = None,
prompt_template: Optional[str | Path] = None,
):
Expand Down Expand Up @@ -79,6 +80,9 @@ def __init__(
Description of the data in plain text or Markdown. If a pathlib.Path
object is passed, querychat will read the contents of the path into a
string with `.read_text()`.
categorical_threshold
Threshold for determining if a column is categorical based on number of
unique values.
extra_instructions
Additional instructions for the chat model. If a pathlib.Path object is
passed, querychat will read the contents of the path into a string with
Expand All @@ -104,7 +108,7 @@ def __init__(
```

"""
self.data_source = normalize_data_source(data_source, table_name)
self._data_source = normalize_data_source(data_source, table_name)

# Validate table name (must begin with letter, contain only letters, numbers, underscores)
if not re.match(r"^[a-zA-Z][a-zA-Z0-9_]*$", table_name):
Expand All @@ -114,8 +118,6 @@ def __init__(

self.id = id or table_name

self.client = normalize_client(client)

if greeting is None:
print(
"Warning: No greeting provided; the LLM will be invoked at conversation start to generate one. "
Expand All @@ -126,13 +128,20 @@ def __init__(

self.greeting = greeting.read_text() if isinstance(greeting, Path) else greeting

self.system_prompt = get_system_prompt(
self.data_source,
prompt = get_system_prompt(
self._data_source,
data_description=data_description,
extra_instructions=extra_instructions,
categorical_threshold=categorical_threshold,
prompt_template=prompt_template,
)

# Fork and empty chat now so the per-session forks are fast
client = normalize_client(client)
self._client = copy.deepcopy(client)
self._client.set_turns([])
self._client.system_prompt = prompt

# Populated when ._server() gets called (in an active session)
self._server_values: ModServerResult | None = None

Expand Down Expand Up @@ -160,7 +169,7 @@ def app(

"""
enable_bookmarking = bookmark_store != "disable"
table_name = self.data_source.table_name
table_name = self._data_source.table_name

def app_ui(request):
return ui.page_sidebar(
Expand Down Expand Up @@ -303,10 +312,9 @@ def _server(self, *, enable_bookmarking: bool = False) -> None:
# Call the server module
self._server_values = mod_server(
self.id,
data_source=self.data_source,
system_prompt=self.system_prompt,
data_source=self._data_source,
greeting=self.greeting,
client=self.client,
client=self._client,
enable_bookmarking=enable_bookmarking,
)

Expand Down Expand Up @@ -416,7 +424,7 @@ def title(self, value: Optional[str] = None) -> str | None | bool:
else:
return vals.title.set(value)

def generate_greeting(self, *, echo: Literal["none", "text"] = "none"):
def generate_greeting(self, *, echo: Literal["none", "output"] = "none"):
"""
Generate a welcome greeting for the chat.

Expand All @@ -428,7 +436,7 @@ def generate_greeting(self, *, echo: Literal["none", "text"] = "none"):
Parameters
----------
echo
If `echo = "text"`, prints the greeting to standard output. If
If `echo = "output"`, prints the greeting to standard output. If
`echo = "none"` (default), does not print anything.

Returns
Expand All @@ -437,97 +445,39 @@ def generate_greeting(self, *, echo: Literal["none", "text"] = "none"):
The greeting string (in Markdown format).

"""
client = copy.deepcopy(self.client)
client.system_prompt = self.system_prompt
client = copy.deepcopy(self._client)
client.set_turns([])
prompt = "Please give me a friendly greeting. Include a few sample prompts in a two-level bulleted list."
return str(client.chat(prompt, echo=echo))

def set_system_prompt(
self,
data_source: DataSource,
*,
data_description: Optional[str | Path] = None,
extra_instructions: Optional[str | Path] = None,
categorical_threshold: int = 10,
prompt_template: Optional[str | Path] = None,
) -> None:
"""
Customize the system prompt.

Control the logic behind how the system prompt is generated based on the
data source's schema and optional additional context and instructions.

Note
----
This method is for parametrized system prompt generation only. To set a
fully custom system prompt string, set the `system_prompt` attribute
directly.

Parameters
----------
data_source
A data source to generate schema information from
data_description
Optional description of the data, in plain text or Markdown format
extra_instructions
Optional additional instructions for the chat model, in plain text or
Markdown format
categorical_threshold
Threshold for determining if a column is categorical based on number of
unique values
prompt_template
Optional `Path` to or string of a custom prompt template. If not provided, the default
querychat template will be used.

"""
self.system_prompt = get_system_prompt(
data_source,
data_description=data_description,
extra_instructions=extra_instructions,
categorical_threshold=categorical_threshold,
prompt_template=prompt_template,
)

def set_data_source(
self, data_source: IntoFrame | sqlalchemy.Engine | DataSource, table_name: str
) -> None:
@property
def client(self):
"""
Set a new data source for the QueryChat object.

Parameters
----------
data_source
The new data source to use.
table_name
If a data_source is a data frame, a name to use to refer to the table
Get the (session-specific) chat client.

Returns
-------
:
None
The current chat client.

"""
self.data_source = normalize_data_source(data_source, table_name)
vals = self._server_values
if vals is None:
raise RuntimeError("Must call .server() before accessing .client")
return vals.client

def set_client(self, client: str | chatlas.Chat) -> None:
@property
def data_source(self):
"""
Set a new chat client for the QueryChat object.

Parameters
----------
client
A `chatlas.Chat` object or a string to be passed to
`chatlas.ChatAuto()` describing the model to use (e.g.
`"openai/gpt-4.1"`).
Get the current data source.

Returns
-------
:
None
The current data source.

"""
self.client = normalize_client(client)
return self._data_source


class QueryChat(QueryChatBase):
Expand Down
3 changes: 0 additions & 3 deletions pkg-py/src/querychat/_querychat_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def mod_server(
session: Session,
*,
data_source: DataSource,
system_prompt: str,
greeting: str | None,
client: chatlas.Chat,
enable_bookmarking: bool,
Expand All @@ -70,8 +69,6 @@ def mod_server(

# Set up the chat object for this session
chat = copy.deepcopy(client)
chat.set_turns([])
chat.system_prompt = system_prompt

# Create the tool functions
update_dashboard_tool = tool_update_dashboard(data_source, sql, title)
Expand Down
6 changes: 0 additions & 6 deletions pkg-py/tests/test_init_with_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ def test_init_with_pandas_dataframe():

# Verify the result is properly configured
assert qc is not None
assert hasattr(qc, "data_source")
assert hasattr(qc, "system_prompt")
assert hasattr(qc, "greeting")
assert hasattr(qc, "client")


def test_init_with_narwhals_dataframe():
Expand All @@ -66,8 +62,6 @@ def test_init_with_narwhals_dataframe():

# Verify the result is correctly configured
assert qc is not None
assert hasattr(qc, "data_source")
assert hasattr(qc, "system_prompt")


def test_init_with_narwhals_lazyframe_direct_query():
Expand Down
29 changes: 0 additions & 29 deletions pkg-py/tests/test_querychat.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ def test_querychat_init(sample_df):

# Verify basic attributes are set
assert qc is not None
assert hasattr(qc, "data_source")
assert hasattr(qc, "system_prompt")
assert hasattr(qc, "greeting")
assert hasattr(qc, "client")
assert qc.id == "test_table"

# Even without server initialization, we should be able to query the data source
Expand All @@ -66,31 +62,6 @@ def test_querychat_custom_id(sample_df):
assert qc.id == "custom_id"


def test_querychat_set_methods(sample_df):
"""Test that setter methods work."""
qc = QueryChat(
data_source=sample_df,
table_name="test_table",
greeting="Hello!",
)

# Test set_system_prompt
qc.set_system_prompt(
qc.data_source,
data_description="A test dataset",
)
assert "test dataset" in qc.system_prompt.lower()

# Test set_data_source
new_df = pd.DataFrame({"x": [1, 2, 3]})
qc.set_data_source(new_df, "new_table")
assert qc.data_source is not None

# Test set_client
qc.set_client("openai/gpt-4o-mini")
assert qc.client is not None


def test_querychat_core_reactive_access_before_server_raises(sample_df):
"""Test that accessing reactive properties before .server() raises error."""
qc = QueryChat(
Expand Down