Skip to content

Commit ce0d0d3

Browse files
authored
fix(pkg-py): fix/disallow setting of initialization values post-initialization (#108)
* fix(pkg-py): allow full client object to read, but not set (after initialization) * fix(pkg-py): make .data_source a read-only property * fix(pkg-py): fix setting of system prompt after server-initialization * Fork chat on init and simplify * Make client an attribute * Cleanup * Client needs to be session-specific * Update tests
1 parent d3479cb commit ce0d0d3

File tree

4 files changed

+33
-121
lines changed

4 files changed

+33
-121
lines changed

pkg-py/src/querychat/_querychat.py

Lines changed: 33 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
greeting: Optional[str | Path] = None,
4141
client: Optional[str | chatlas.Chat] = None,
4242
data_description: Optional[str | Path] = None,
43+
categorical_threshold: int = 10,
4344
extra_instructions: Optional[str | Path] = None,
4445
prompt_template: Optional[str | Path] = None,
4546
):
@@ -79,6 +80,9 @@ def __init__(
7980
Description of the data in plain text or Markdown. If a pathlib.Path
8081
object is passed, querychat will read the contents of the path into a
8182
string with `.read_text()`.
83+
categorical_threshold
84+
Threshold for determining if a column is categorical based on number of
85+
unique values.
8286
extra_instructions
8387
Additional instructions for the chat model. If a pathlib.Path object is
8488
passed, querychat will read the contents of the path into a string with
@@ -104,7 +108,7 @@ def __init__(
104108
```
105109
106110
"""
107-
self.data_source = normalize_data_source(data_source, table_name)
111+
self._data_source = normalize_data_source(data_source, table_name)
108112

109113
# Validate table name (must begin with letter, contain only letters, numbers, underscores)
110114
if not re.match(r"^[a-zA-Z][a-zA-Z0-9_]*$", table_name):
@@ -114,8 +118,6 @@ def __init__(
114118

115119
self.id = id or table_name
116120

117-
self.client = normalize_client(client)
118-
119121
if greeting is None:
120122
print(
121123
"Warning: No greeting provided; the LLM will be invoked at conversation start to generate one. "
@@ -126,13 +128,20 @@ def __init__(
126128

127129
self.greeting = greeting.read_text() if isinstance(greeting, Path) else greeting
128130

129-
self.system_prompt = get_system_prompt(
130-
self.data_source,
131+
prompt = get_system_prompt(
132+
self._data_source,
131133
data_description=data_description,
132134
extra_instructions=extra_instructions,
135+
categorical_threshold=categorical_threshold,
133136
prompt_template=prompt_template,
134137
)
135138

139+
# Fork and empty chat now so the per-session forks are fast
140+
client = normalize_client(client)
141+
self._client = copy.deepcopy(client)
142+
self._client.set_turns([])
143+
self._client.system_prompt = prompt
144+
136145
# Populated when ._server() gets called (in an active session)
137146
self._server_values: ModServerResult | None = None
138147

@@ -160,7 +169,7 @@ def app(
160169
161170
"""
162171
enable_bookmarking = bookmark_store != "disable"
163-
table_name = self.data_source.table_name
172+
table_name = self._data_source.table_name
164173

165174
def app_ui(request):
166175
return ui.page_sidebar(
@@ -303,10 +312,9 @@ def _server(self, *, enable_bookmarking: bool = False) -> None:
303312
# Call the server module
304313
self._server_values = mod_server(
305314
self.id,
306-
data_source=self.data_source,
307-
system_prompt=self.system_prompt,
315+
data_source=self._data_source,
308316
greeting=self.greeting,
309-
client=self.client,
317+
client=self._client,
310318
enable_bookmarking=enable_bookmarking,
311319
)
312320

@@ -416,7 +424,7 @@ def title(self, value: Optional[str] = None) -> str | None | bool:
416424
else:
417425
return vals.title.set(value)
418426

419-
def generate_greeting(self, *, echo: Literal["none", "text"] = "none"):
427+
def generate_greeting(self, *, echo: Literal["none", "output"] = "none"):
420428
"""
421429
Generate a welcome greeting for the chat.
422430
@@ -428,7 +436,7 @@ def generate_greeting(self, *, echo: Literal["none", "text"] = "none"):
428436
Parameters
429437
----------
430438
echo
431-
If `echo = "text"`, prints the greeting to standard output. If
439+
If `echo = "output"`, prints the greeting to standard output. If
432440
`echo = "none"` (default), does not print anything.
433441
434442
Returns
@@ -437,97 +445,39 @@ def generate_greeting(self, *, echo: Literal["none", "text"] = "none"):
437445
The greeting string (in Markdown format).
438446
439447
"""
440-
client = copy.deepcopy(self.client)
441-
client.system_prompt = self.system_prompt
448+
client = copy.deepcopy(self._client)
442449
client.set_turns([])
443450
prompt = "Please give me a friendly greeting. Include a few sample prompts in a two-level bulleted list."
444451
return str(client.chat(prompt, echo=echo))
445452

446-
def set_system_prompt(
447-
self,
448-
data_source: DataSource,
449-
*,
450-
data_description: Optional[str | Path] = None,
451-
extra_instructions: Optional[str | Path] = None,
452-
categorical_threshold: int = 10,
453-
prompt_template: Optional[str | Path] = None,
454-
) -> None:
455-
"""
456-
Customize the system prompt.
457-
458-
Control the logic behind how the system prompt is generated based on the
459-
data source's schema and optional additional context and instructions.
460-
461-
Note
462-
----
463-
This method is for parametrized system prompt generation only. To set a
464-
fully custom system prompt string, set the `system_prompt` attribute
465-
directly.
466-
467-
Parameters
468-
----------
469-
data_source
470-
A data source to generate schema information from
471-
data_description
472-
Optional description of the data, in plain text or Markdown format
473-
extra_instructions
474-
Optional additional instructions for the chat model, in plain text or
475-
Markdown format
476-
categorical_threshold
477-
Threshold for determining if a column is categorical based on number of
478-
unique values
479-
prompt_template
480-
Optional `Path` to or string of a custom prompt template. If not provided, the default
481-
querychat template will be used.
482-
483-
"""
484-
self.system_prompt = get_system_prompt(
485-
data_source,
486-
data_description=data_description,
487-
extra_instructions=extra_instructions,
488-
categorical_threshold=categorical_threshold,
489-
prompt_template=prompt_template,
490-
)
491-
492-
def set_data_source(
493-
self, data_source: IntoFrame | sqlalchemy.Engine | DataSource, table_name: str
494-
) -> None:
453+
@property
454+
def client(self):
495455
"""
496-
Set a new data source for the QueryChat object.
497-
498-
Parameters
499-
----------
500-
data_source
501-
The new data source to use.
502-
table_name
503-
If a data_source is a data frame, a name to use to refer to the table
456+
Get the (session-specific) chat client.
504457
505458
Returns
506459
-------
507460
:
508-
None
461+
The current chat client.
509462
510463
"""
511-
self.data_source = normalize_data_source(data_source, table_name)
464+
vals = self._server_values
465+
if vals is None:
466+
raise RuntimeError("Must call .server() before accessing .client")
467+
return vals.client
512468

513-
def set_client(self, client: str | chatlas.Chat) -> None:
469+
@property
470+
def data_source(self):
514471
"""
515-
Set a new chat client for the QueryChat object.
516-
517-
Parameters
518-
----------
519-
client
520-
A `chatlas.Chat` object or a string to be passed to
521-
`chatlas.ChatAuto()` describing the model to use (e.g.
522-
`"openai/gpt-4.1"`).
472+
Get the current data source.
523473
524474
Returns
525475
-------
526476
:
527-
None
477+
The current data source.
528478
529479
"""
530-
self.client = normalize_client(client)
480+
return self._data_source
531481

532482

533483
class QueryChat(QueryChatBase):

pkg-py/src/querychat/_querychat_module.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def mod_server(
5858
session: Session,
5959
*,
6060
data_source: DataSource,
61-
system_prompt: str,
6261
greeting: str | None,
6362
client: chatlas.Chat,
6463
enable_bookmarking: bool,
@@ -70,8 +69,6 @@ def mod_server(
7069

7170
# Set up the chat object for this session
7271
chat = copy.deepcopy(client)
73-
chat.set_turns([])
74-
chat.system_prompt = system_prompt
7572

7673
# Create the tool functions
7774
update_dashboard_tool = tool_update_dashboard(data_source, sql, title)

pkg-py/tests/test_init_with_pandas.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,6 @@ def test_init_with_pandas_dataframe():
3939

4040
# Verify the result is properly configured
4141
assert qc is not None
42-
assert hasattr(qc, "data_source")
43-
assert hasattr(qc, "system_prompt")
44-
assert hasattr(qc, "greeting")
45-
assert hasattr(qc, "client")
4642

4743

4844
def test_init_with_narwhals_dataframe():
@@ -66,8 +62,6 @@ def test_init_with_narwhals_dataframe():
6662

6763
# Verify the result is correctly configured
6864
assert qc is not None
69-
assert hasattr(qc, "data_source")
70-
assert hasattr(qc, "system_prompt")
7165

7266

7367
def test_init_with_narwhals_lazyframe_direct_query():

pkg-py/tests/test_querychat.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,6 @@ def test_querychat_init(sample_df):
3939

4040
# Verify basic attributes are set
4141
assert qc is not None
42-
assert hasattr(qc, "data_source")
43-
assert hasattr(qc, "system_prompt")
44-
assert hasattr(qc, "greeting")
45-
assert hasattr(qc, "client")
4642
assert qc.id == "test_table"
4743

4844
# Even without server initialization, we should be able to query the data source
@@ -66,31 +62,6 @@ def test_querychat_custom_id(sample_df):
6662
assert qc.id == "custom_id"
6763

6864

69-
def test_querychat_set_methods(sample_df):
70-
"""Test that setter methods work."""
71-
qc = QueryChat(
72-
data_source=sample_df,
73-
table_name="test_table",
74-
greeting="Hello!",
75-
)
76-
77-
# Test set_system_prompt
78-
qc.set_system_prompt(
79-
qc.data_source,
80-
data_description="A test dataset",
81-
)
82-
assert "test dataset" in qc.system_prompt.lower()
83-
84-
# Test set_data_source
85-
new_df = pd.DataFrame({"x": [1, 2, 3]})
86-
qc.set_data_source(new_df, "new_table")
87-
assert qc.data_source is not None
88-
89-
# Test set_client
90-
qc.set_client("openai/gpt-4o-mini")
91-
assert qc.client is not None
92-
93-
9465
def test_querychat_core_reactive_access_before_server_raises(sample_df):
9566
"""Test that accessing reactive properties before .server() raises error."""
9667
qc = QueryChat(

0 commit comments

Comments
 (0)