@@ -40,6 +40,7 @@ def __init__(
4040 greeting : Optional [str | Path ] = None ,
4141 client : Optional [str | chatlas .Chat ] = None ,
4242 data_description : Optional [str | Path ] = None ,
43+ categorical_threshold : int = 10 ,
4344 extra_instructions : Optional [str | Path ] = None ,
4445 prompt_template : Optional [str | Path ] = None ,
4546 ):
@@ -79,6 +80,9 @@ def __init__(
7980 Description of the data in plain text or Markdown. If a pathlib.Path
8081 object is passed, querychat will read the contents of the path into a
8182 string with `.read_text()`.
83+ categorical_threshold
84+ Threshold for determining if a column is categorical based on number of
85+ unique values.
8286 extra_instructions
8387 Additional instructions for the chat model. If a pathlib.Path object is
8488 passed, querychat will read the contents of the path into a string with
@@ -104,7 +108,7 @@ def __init__(
104108 ```
105109
106110 """
107- self .data_source = normalize_data_source (data_source , table_name )
111+ self ._data_source = normalize_data_source (data_source , table_name )
108112
109113 # Validate table name (must begin with letter, contain only letters, numbers, underscores)
110114 if not re .match (r"^[a-zA-Z][a-zA-Z0-9_]*$" , table_name ):
@@ -114,8 +118,6 @@ def __init__(
114118
115119 self .id = id or table_name
116120
117- self .client = normalize_client (client )
118-
119121 if greeting is None :
120122 print (
121123 "Warning: No greeting provided; the LLM will be invoked at conversation start to generate one. "
@@ -126,13 +128,20 @@ def __init__(
126128
127129 self .greeting = greeting .read_text () if isinstance (greeting , Path ) else greeting
128130
129- self . system_prompt = get_system_prompt (
130- self .data_source ,
131+ prompt = get_system_prompt (
132+ self ._data_source ,
131133 data_description = data_description ,
132134 extra_instructions = extra_instructions ,
135+ categorical_threshold = categorical_threshold ,
133136 prompt_template = prompt_template ,
134137 )
135138
139+ # Fork and empty chat now so the per-session forks are fast
140+ client = normalize_client (client )
141+ self ._client = copy .deepcopy (client )
142+ self ._client .set_turns ([])
143+ self ._client .system_prompt = prompt
144+
136145 # Populated when ._server() gets called (in an active session)
137146 self ._server_values : ModServerResult | None = None
138147
@@ -160,7 +169,7 @@ def app(
160169
161170 """
162171 enable_bookmarking = bookmark_store != "disable"
163- table_name = self .data_source .table_name
172+ table_name = self ._data_source .table_name
164173
165174 def app_ui (request ):
166175 return ui .page_sidebar (
@@ -303,10 +312,9 @@ def _server(self, *, enable_bookmarking: bool = False) -> None:
303312 # Call the server module
304313 self ._server_values = mod_server (
305314 self .id ,
306- data_source = self .data_source ,
307- system_prompt = self .system_prompt ,
315+ data_source = self ._data_source ,
308316 greeting = self .greeting ,
309- client = self .client ,
317+ client = self ._client ,
310318 enable_bookmarking = enable_bookmarking ,
311319 )
312320
@@ -416,7 +424,7 @@ def title(self, value: Optional[str] = None) -> str | None | bool:
416424 else :
417425 return vals .title .set (value )
418426
419- def generate_greeting (self , * , echo : Literal ["none" , "text " ] = "none" ):
427+ def generate_greeting (self , * , echo : Literal ["none" , "output " ] = "none" ):
420428 """
421429 Generate a welcome greeting for the chat.
422430
@@ -428,7 +436,7 @@ def generate_greeting(self, *, echo: Literal["none", "text"] = "none"):
428436 Parameters
429437 ----------
430438 echo
431- If `echo = "text "`, prints the greeting to standard output. If
439+ If `echo = "output "`, prints the greeting to standard output. If
432440 `echo = "none"` (default), does not print anything.
433441
434442 Returns
@@ -437,97 +445,39 @@ def generate_greeting(self, *, echo: Literal["none", "text"] = "none"):
437445 The greeting string (in Markdown format).
438446
439447 """
440- client = copy .deepcopy (self .client )
441- client .system_prompt = self .system_prompt
448+ client = copy .deepcopy (self ._client )
442449 client .set_turns ([])
443450 prompt = "Please give me a friendly greeting. Include a few sample prompts in a two-level bulleted list."
444451 return str (client .chat (prompt , echo = echo ))
445452
446- def set_system_prompt (
447- self ,
448- data_source : DataSource ,
449- * ,
450- data_description : Optional [str | Path ] = None ,
451- extra_instructions : Optional [str | Path ] = None ,
452- categorical_threshold : int = 10 ,
453- prompt_template : Optional [str | Path ] = None ,
454- ) -> None :
455- """
456- Customize the system prompt.
457-
458- Control the logic behind how the system prompt is generated based on the
459- data source's schema and optional additional context and instructions.
460-
461- Note
462- ----
463- This method is for parametrized system prompt generation only. To set a
464- fully custom system prompt string, set the `system_prompt` attribute
465- directly.
466-
467- Parameters
468- ----------
469- data_source
470- A data source to generate schema information from
471- data_description
472- Optional description of the data, in plain text or Markdown format
473- extra_instructions
474- Optional additional instructions for the chat model, in plain text or
475- Markdown format
476- categorical_threshold
477- Threshold for determining if a column is categorical based on number of
478- unique values
479- prompt_template
480- Optional `Path` to or string of a custom prompt template. If not provided, the default
481- querychat template will be used.
482-
483- """
484- self .system_prompt = get_system_prompt (
485- data_source ,
486- data_description = data_description ,
487- extra_instructions = extra_instructions ,
488- categorical_threshold = categorical_threshold ,
489- prompt_template = prompt_template ,
490- )
491-
492- def set_data_source (
493- self , data_source : IntoFrame | sqlalchemy .Engine | DataSource , table_name : str
494- ) -> None :
453+ @property
454+ def client (self ):
495455 """
496- Set a new data source for the QueryChat object.
497-
498- Parameters
499- ----------
500- data_source
501- The new data source to use.
502- table_name
503- If a data_source is a data frame, a name to use to refer to the table
456+ Get the (session-specific) chat client.
504457
505458 Returns
506459 -------
507460 :
508- None
461+ The current chat client.
509462
510463 """
511- self .data_source = normalize_data_source (data_source , table_name )
464+ vals = self ._server_values
465+ if vals is None :
466+ raise RuntimeError ("Must call .server() before accessing .client" )
467+ return vals .client
512468
513- def set_client (self , client : str | chatlas .Chat ) -> None :
469+ @property
470+ def data_source (self ):
514471 """
515- Set a new chat client for the QueryChat object.
516-
517- Parameters
518- ----------
519- client
520- A `chatlas.Chat` object or a string to be passed to
521- `chatlas.ChatAuto()` describing the model to use (e.g.
522- `"openai/gpt-4.1"`).
472+ Get the current data source.
523473
524474 Returns
525475 -------
526476 :
527- None
477+ The current data source.
528478
529479 """
530- self .client = normalize_client ( client )
480+ return self ._data_source
531481
532482
533483class QueryChat (QueryChatBase ):
0 commit comments